How to use the trfl.td_lambda function in trfl

To help you get started, we’ve selected a few trfl examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github vladfi1 / DeepSmash / dsmash / rllib / imitation_trainer.py View on Github external
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                      tf.get_variable_scope().name)

    # actual loss computation
    imitation_loss = -tf.reduce_mean(actions_logp)
    
    tm_values = self.model.values
    baseline_values = tm_values[:-1]
    
    if config.get("soft_horizon"):
      discounts = config["gamma"]
    else:
      discounts = tf.to_float(~dones[:-1]) * config["gamma"]
    
    td_lambda = trfl.td_lambda(
        state_values=baseline_values,
        rewards=rewards[:-1],
        pcontinues=discounts,
        bootstrap_value=tm_values[-1],
        lambda_=config.get("lambda", 1.))

    # td_lambda.loss has shape [B] after a reduce_sum
    vf_loss = tf.reduce_mean(td_lambda.loss) / T
    
    self.total_loss = imitation_loss + self.config["vf_loss_coeff"] * vf_loss

    # Initialize TFPolicyGraph
    loss_in = [
      (SampleBatch.ACTIONS, actions),
      (SampleBatch.DONES, dones),
      # (BEHAVIOUR_LOGITS, behaviour_logits),