How to use the trfl.base_ops.assert_rank_and_shape_compatibility function in trfl

To help you get started, we’ve selected a few trfl examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github fomorians-oss / pyoneer / pyoneer / rl / agents / advantage_actor_critic_agent_impl.py View on Github external
baseline_scale: scalar or Tensor of shape `[B, T]` containing the baseline loss scale.
            **kwargs: positional arguments (unused)

        Returns:
            the total loss Tensor of shape [].
        """
        del kwargs
        base_ops.assert_rank_and_shape_compatibility([weights], 2)
        sequence_lengths = math_ops.reduce_sum(weights, axis=1)
        total_num = math_ops.reduce_sum(sequence_lengths)

        multi_advantages = []
        self.value_loss = []
        multi_baseline_values = self.value(states, training=True) * array_ops.expand_dims(weights, axis=-1)

        base_ops.assert_rank_and_shape_compatibility(
            [rewards, multi_baseline_values], 3)
        multi_baseline_values = array_ops.unstack(multi_baseline_values, axis=-1)
        num_values = len(multi_baseline_values)

        base_shape = rewards.shape
        decay = self._least_fit(decay, base_shape)
        lambda_ = self._least_fit(lambda_, base_shape)
        baseline_scale = self._least_fit(baseline_scale, base_shape)

        for i in range(num_values):
            pcontinues = decay[..., i] * weights
            lambdas = lambda_[..., i] * weights
            bootstrap_values = indexing_ops.batched_index(
                multi_baseline_values[i], math_ops.cast(sequence_lengths - 1, dtypes.int32))
            baseline_loss, td_lambda = value_ops.td_lambda(
                parray_ops.swap_time_major(multi_baseline_values[i]), 
github fomorians-oss / pyoneer / pyoneer / rl / agents / advantage_actor_critic_agent_impl.py View on Github external
actions: Tensor of `[B, T, ...]` containing actions.
            rewards: Tensor of `[B, T, V]` containing rewards.
            weights: Tensor of shape `[B, T]` containing weights (1. or 0.).
            decay: scalar, 1-D Tensor of shape [V], or Tensor of shape 
                `[B, T]` or `[B, T, V]` containing decays/discounts.
            lambda_: scalar, 1-D Tensor of shape [V], or Tensor of shape 
                `[B, T]` or `[B, T, V]` containing generalized lambda parameter.
            entropy_scale: scalar or Tensor of shape `[B, T]` containing the entropy loss scale.
            baseline_scale: scalar or Tensor of shape `[B, T]` containing the baseline loss scale.
            **kwargs: positional arguments (unused)

        Returns:
            the total loss Tensor of shape [].
        """
        del kwargs
        base_ops.assert_rank_and_shape_compatibility([weights], 2)
        sequence_lengths = math_ops.reduce_sum(weights, axis=1)
        total_num = math_ops.reduce_sum(sequence_lengths)

        multi_advantages = []
        self.value_loss = []
        multi_baseline_values = self.value(states, training=True) * array_ops.expand_dims(weights, axis=-1)

        base_ops.assert_rank_and_shape_compatibility(
            [rewards, multi_baseline_values], 3)
        multi_baseline_values = array_ops.unstack(multi_baseline_values, axis=-1)
        num_values = len(multi_baseline_values)

        base_shape = rewards.shape
        decay = self._least_fit(decay, base_shape)
        lambda_ = self._least_fit(lambda_, base_shape)
        baseline_scale = self._least_fit(baseline_scale, base_shape)
github fomorians-oss / pyoneer / pyoneer / rl / agents / advantage_actor_critic_agent_impl.py View on Github external
Returns:
            the total loss Tensor of shape [].

        Raises:
            ValueError: If tensors are empty or fail the rank and mutual
                compatibility asserts.
        """
        del kwargs
        base_ops.assert_rank_and_shape_compatibility([weights], 2)
        sequence_lengths = math_ops.reduce_sum(weights, axis=1)
        total_num = math_ops.reduce_sum(sequence_lengths)

        baseline_values = array_ops.squeeze(
            self.value(states, training=True), 
            axis=-1) * weights
        base_ops.assert_rank_and_shape_compatibility([rewards, baseline_values], 2)

        pcontinues = decay * weights
        lambda_ = lambda_ * weights
        bootstrap_values = indexing_ops.batched_index(
            baseline_values, math_ops.cast(sequence_lengths - 1, dtypes.int32))

        baseline_loss, td_lambda = value_ops.td_lambda(
            parray_ops.swap_time_major(baseline_values), 
            parray_ops.swap_time_major(rewards), 
            parray_ops.swap_time_major(pcontinues), 
            bootstrap_values, 
            parray_ops.swap_time_major(lambda_))

        advantages = parray_ops.swap_time_major(td_lambda.temporal_differences)
        if normalize_advantages:
            advantages = normalization_ops.normalize_by_moments(advantages, weights)
github deepmind / trfl / trfl / retrace_ops.py View on Github external
def check_rank(tensors, ranks):
      for i, (tensor, rank) in enumerate(zip(tensors, ranks)):
        if tensor.get_shape():
          base_ops.assert_rank_and_shape_compatibility([tensor], rank)
        else:
          tf.logging.error(
              'Tensor "%s", which was offered as Retrace parameter %d, has '
              'no rank at construction time, so Retrace can\'t verify that '
              'it has the necessary rank of %d', tensor.name, i + 1, rank)