How to use the trfl.vtrace_ops.vtrace_from_importance_weights function in trfl

To help you get started, we’ve selected a few trfl examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github fomorians-oss / pyoneer / pyoneer / rl / agents / v_trace_advantage_actor_critic_agent_impl.py View on Github external
total_num = math_ops.reduce_sum(sequence_length)

        policy = self.policy(states, training=True)
        behavioral_policy = self.behavioral_policy(states)
        baseline_values = array_ops.squeeze(
            self.value(states, training=True), 
            axis=-1) * weights
        bootstrap_values = indexing_ops.batched_index(
            baseline_values, math_ops.cast(sequence_length - 1, dtypes.int32))
        baseline_values = parray_ops.swap_time_major(baseline_values)

        pcontinues = parray_ops.swap_time_major(decay * weights)
        log_prob = policy.log_prob(actions)
        log_rhos = parray_ops.swap_time_major(log_prob) - parray_ops.swap_time_major(
            behavioral_policy.log_prob(actions))
        vtrace_returns = vtrace_ops.vtrace_from_importance_weights(
            log_rhos,
            pcontinues,
            parray_ops.swap_time_major(rewards),
            baseline_values,
            bootstrap_values)

        advantages = parray_ops.swap_time_major(vtrace_returns.pg_advantages)
        if normalize_advantages:
            advantages = normalization_ops.normalize_by_moments(advantages, weights)
        advantages = gen_array_ops.stop_gradient(advantages)

        policy_gradient_loss = advantages * -log_prob
        self.policy_gradient_loss = losses_impl.compute_weighted_loss(
            policy_gradient_loss,
            weights=weights)
github fomorians-oss / pyoneer / pyoneer / rl / agents / v_trace_proximal_policy_optimization_agent_impl.py View on Github external
def compute_loss(self, rollouts, decay=.999, lambda_=1., entropy_scale=.2, baseline_scale=1., ratio_epsilon=.2):
        policy = self.policy(rollouts.states, training=True)
        behavioral_policy = self.behavioral_policy(rollouts.states)
        baseline_values = parray_ops.swap_time_major(
            array_ops.squeeze(self.value(rollouts.states, training=True), axis=-1))

        pcontinues = parray_ops.swap_time_major(decay * rollouts.weights)
        bootstrap_values = baseline_values[-1, :]

        log_prob = policy.log_prob(rollouts.actions)
        behavioral_log_prob = behavioral_policy.log_prob(rollouts.actions)
        log_rhos = parray_ops.swap_time_major(log_prob - gen_array_ops.stop_gradient(behavioral_log_prob))
        vtrace_returns = vtrace_ops.vtrace_from_importance_weights(
            log_rhos,
            pcontinues,
            parray_ops.swap_time_major(rollouts.rewards),
            baseline_values,
            bootstrap_values)

        advantages = parray_ops.swap_time_major(vtrace_returns.pg_advantages)
        advantages = normalization_ops.weighted_moments_normalize(advantages, rollouts.weights)
        advantages = gen_array_ops.stop_gradient(advantages)

        ratio = parray_ops.swap_time_major(gen_math_ops.exp(log_rhos))
        clipped_ratio = clip_ops.clip_by_value(ratio, 1. - ratio_epsilon, 1. + ratio_epsilon)

        self.policy_gradient_loss = -losses_impl.compute_weighted_loss(
            gen_math_ops.minimum(advantages * ratio, advantages * clipped_ratio), 
            weights=rollouts.weights)