How to use the trfl.indexing_ops.batched_index function in trfl

To help you get started, we’ve selected a few trfl examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github fomorians-oss / pyoneer / pyoneer / rl / agents / advantage_actor_critic_agent_impl.py View on Github external
multi_baseline_values = self.value(states, training=True) * array_ops.expand_dims(weights, axis=-1)

        base_ops.assert_rank_and_shape_compatibility(
            [rewards, multi_baseline_values], 3)
        multi_baseline_values = array_ops.unstack(multi_baseline_values, axis=-1)
        num_values = len(multi_baseline_values)

        base_shape = rewards.shape
        decay = self._least_fit(decay, base_shape)
        lambda_ = self._least_fit(lambda_, base_shape)
        baseline_scale = self._least_fit(baseline_scale, base_shape)

        for i in range(num_values):
            pcontinues = decay[..., i] * weights
            lambdas = lambda_[..., i] * weights
            bootstrap_values = indexing_ops.batched_index(
                multi_baseline_values[i], math_ops.cast(sequence_lengths - 1, dtypes.int32))
            baseline_loss, td_lambda = value_ops.td_lambda(
                parray_ops.swap_time_major(multi_baseline_values[i]), 
                parray_ops.swap_time_major(rewards[..., i]), 
                parray_ops.swap_time_major(pcontinues), 
                bootstrap_values, 
                parray_ops.swap_time_major(lambdas))
            value_loss = pmath_ops.safe_divide(
                baseline_scale[i] * math_ops.reduce_sum(baseline_loss), total_num)
            self.value_loss.append(
                gen_array_ops.check_numerics(value_loss, 'value_loss'))
            advantages = parray_ops.swap_time_major(td_lambda.temporal_differences)
            multi_advantages.append(advantages)

        advantages = math_ops.add_n(multi_advantages) # A = A[0] + A[1] + ...
        if normalize_advantages:
github deepmind / trfl / trfl / action_value_ops.py View on Github external
* `loss`: a tensor containing the batch of losses, shape `[B]`.
    * `extra`: a namedtuple with fields:
        * `target`: batch of target values for `q_tm1[a_tm1]`, shape `[B]`.
        * `td_error`: batch of temporal difference errors, shape `[B]`.
  """
  # Rank and compatibility checks.
  base_ops.wrap_rank_shape_assert(
      [[q_tm1], [a_tm1, r_t, pcont_t, v_t]], [2, 1], name)

  # QV op.
  with tf.name_scope(name, values=[q_tm1, a_tm1, r_t, pcont_t, v_t]):

    # Build target and select head to update.
    with tf.name_scope("target"):
      target = tf.stop_gradient(r_t + pcont_t * v_t)
    qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)

    # Temporal difference error and loss.
    # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
    td_error = target - qa_tm1
    loss = 0.5 * tf.square(td_error)
    return base_ops.LossOutput(loss, QExtra(target, td_error))
github deepmind / trfl / trfl / action_value_ops.py View on Github external
* `extra`: a namedtuple with fields:
        * `target`: batch of target values for `q_tm1[a_tm1]`, shape `[B]`.
        * `td_error`: batch of temporal difference errors, shape `[B]`.
  """
  # Rank and compatibility checks.
  base_ops.wrap_rank_shape_assert(
      [[q_tm1, q_t], [a_tm1, r_t, pcont_t]], [2, 1], name)
  base_ops.assert_arg_bounded(action_gap_scale, 0, 1, name, "action_gap_scale")

  # persistent Q-learning op.
  with tf.name_scope(name, values=[q_tm1, a_tm1, r_t, pcont_t, q_t]):

    # Build target and select head to update.
    with tf.name_scope("target"):
      max_q_t = tf.reduce_max(q_t, axis=1)
      qa_t = indexing_ops.batched_index(q_t, a_tm1)
      corrected_q_t = (1 - action_gap_scale) * max_q_t + action_gap_scale * qa_t
      target = tf.stop_gradient(r_t + pcont_t * corrected_q_t)
    qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)

    # Temporal difference error and loss.
    # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
    td_error = target - qa_tm1
    loss = 0.5 * tf.square(td_error)
    return base_ops.LossOutput(loss, QExtra(target, td_error))
github deepmind / trfl / trfl / action_value_ops.py View on Github external
with tf.name_scope(name, values=[q_tm1, a_tm1, r_t, pcont_t, q_t, probs_a_t]):

    # Debug ops.
    deps = []
    if debug:
      cumulative_prob = tf.reduce_sum(probs_a_t, axis=1)
      almost_prob = tf.less(tf.abs(tf.subtract(cumulative_prob, 1.0)), 1e-6)
      deps.append(tf.Assert(
          tf.reduce_all(almost_prob),
          ["probs_a_t tensor does not sum to 1", probs_a_t]))

    # With dependency on possible debug ops.
    with tf.control_dependencies(deps):

      # Select head to update and build target.
      qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)
      target = tf.stop_gradient(
          r_t + pcont_t * tf.reduce_sum(tf.multiply(q_t, probs_a_t), axis=1))

      # Temporal difference error and loss.
      # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
      td_error = target - qa_tm1
      loss = 0.5 * tf.square(td_error)
      return base_ops.LossOutput(loss, QExtra(target, td_error))
github deepmind / trfl / trfl / action_value_ops.py View on Github external
A namedtuple with fields:

    * `loss`: a tensor containing the batch of losses, shape `[T, B]`.
    * `extra`: a namedtuple with fields:
        * `target`: batch of target values for `q_tm1[a_tm1]`, shape `[T, B]`.
        * `td_error`: batch of temporal difference errors, shape `[T, B]`.
  """
  # Rank and compatibility checks.
  base_ops.wrap_rank_shape_assert(
      [[q_tm1, q_t], [a_tm1, r_t, pcont_t, a_t]], [3, 2], name)

  # SARSALambda op.
  with tf.name_scope(name, values=[q_tm1, a_tm1, r_t, pcont_t, q_t, a_t]):

    # Select head to update and build target.
    qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)
    qa_t = indexing_ops.batched_index(q_t, a_t)
    target = sequence_ops.multistep_forward_view(
        r_t, pcont_t, qa_t, lambda_, back_prop=False)
    target = tf.stop_gradient(target)

    # Temporal difference error and loss.
    # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
    td_error = target - qa_tm1
    loss = 0.5 * tf.square(td_error)
    return base_ops.LossOutput(loss, QExtra(target, td_error))
github deepmind / trfl / trfl / action_value_ops.py View on Github external
* `extra`: a namedtuple with fields:
        * `target`: batch of target values for `q_tm1[a_tm1]`, shape `[B]`
        * `td_error`: batch of temporal difference errors, shape `[B]`
        * `best_action`: batch of greedy actions wrt `q_t_selector`, shape `[B]`
  """
  # Rank and compatibility checks.
  base_ops.wrap_rank_shape_assert(
      [[q_tm1, q_t_value, q_t_selector], [a_tm1, r_t, pcont_t]], [2, 1], name)

  # double Q-learning op.
  with tf.name_scope(
      name, values=[q_tm1, a_tm1, r_t, pcont_t, q_t_value, q_t_selector]):

    # Build target and select head to update.
    best_action = tf.argmax(q_t_selector, 1, output_type=tf.int32)
    double_q_bootstrapped = indexing_ops.batched_index(q_t_value, best_action)
    target = tf.stop_gradient(r_t + pcont_t * double_q_bootstrapped)
    qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)

    # Temporal difference error and loss.
    # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
    td_error = target - qa_tm1
    loss = 0.5 * tf.square(td_error)
    return base_ops.LossOutput(
        loss, DoubleQExtra(target, td_error, best_action))
github deepmind / trfl / trfl / action_value_ops.py View on Github external
lambda_, tf.Tensor
  ) and lambda_.get_shape().ndims is not None and lambda_.get_shape().ndims > 0:
    base_ops.wrap_rank_shape_assert([[a_tm1, r_t, pcont_t, lambda_]], [2], name)
  else:
    base_ops.wrap_rank_shape_assert([[a_tm1, r_t, pcont_t]], [2], name)

  # QLambda op.
  with tf.name_scope(name, values=[q_tm1, a_tm1, r_t, pcont_t, q_t]):

    # Build target and select head to update.
    with tf.name_scope("target"):
      state_values = tf.reduce_max(q_t, axis=2)
      target = sequence_ops.multistep_forward_view(
          r_t, pcont_t, state_values, lambda_, back_prop=False)
      target = tf.stop_gradient(target)
    qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)

    # Temporal difference error and loss.
    # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
    td_error = target - qa_tm1
    loss = 0.5 * tf.square(td_error)
    return base_ops.LossOutput(loss, QExtra(target, td_error))