Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# to avoid unnecessary work in this case:
return sequence_ops.scan_discounted_sum(
rewards,
pcontinues,
initial_value=bootstrap_value,
reverse=True,
back_prop=False,
name="multistep_returns")
else:
v_tp1 = tf.concat(
axis=0, values=[values[1:, :],
tf.expand_dims(bootstrap_value, 0)])
# `back_prop=False` prevents gradients flowing into values and
# bootstrap_value, which is what you want when using the bootstrapped
# lambda-returns in an update as targets for values.
return sequence_ops.multistep_forward_view(
rewards,
pcontinues,
v_tp1,
lambda_,
back_prop=False,
name="generalized_lambda_returns")
* `loss`: a tensor containing the batch of losses, shape `[T, B]`.
* `extra`: a namedtuple with fields:
* `target`: batch of target values for `q_tm1[a_tm1]`, shape `[T, B]`.
* `td_error`: batch of temporal difference errors, shape `[T, B]`.
"""
# Rank and compatibility checks.
base_ops.wrap_rank_shape_assert(
[[q_tm1, q_t], [a_tm1, r_t, pcont_t, a_t]], [3, 2], name)
# SARSALambda op.
with tf.name_scope(name, values=[q_tm1, a_tm1, r_t, pcont_t, q_t, a_t]):
# Select head to update and build target.
qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)
qa_t = indexing_ops.batched_index(q_t, a_t)
target = sequence_ops.multistep_forward_view(
r_t, pcont_t, qa_t, lambda_, back_prop=False)
target = tf.stop_gradient(target)
# Temporal difference error and loss.
# Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
td_error = target - qa_tm1
loss = 0.5 * tf.square(td_error)
return base_ops.LossOutput(loss, QExtra(target, td_error))
# Rank and compatibility checks.
base_ops.wrap_rank_shape_assert([[q_tm1, q_t]], [3], name)
if isinstance(
lambda_, tf.Tensor
) and lambda_.get_shape().ndims is not None and lambda_.get_shape().ndims > 0:
base_ops.wrap_rank_shape_assert([[a_tm1, r_t, pcont_t, lambda_]], [2], name)
else:
base_ops.wrap_rank_shape_assert([[a_tm1, r_t, pcont_t]], [2], name)
# QLambda op.
with tf.name_scope(name, values=[q_tm1, a_tm1, r_t, pcont_t, q_t]):
# Build target and select head to update.
with tf.name_scope("target"):
state_values = tf.reduce_max(q_t, axis=2)
target = sequence_ops.multistep_forward_view(
r_t, pcont_t, state_values, lambda_, back_prop=False)
target = tf.stop_gradient(target)
qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)
# Temporal difference error and loss.
# Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
td_error = target - qa_tm1
loss = 0.5 * tf.square(td_error)
return base_ops.LossOutput(loss, QExtra(target, td_error))