How to use the tianshou.core.policy.stochastic.OnehotCategorical function in tianshou

To help you get started, we’ve selected a few tianshou examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github thu-ml / tianshou / tianshou / core / policy / stochastic.py View on Github external
"""
        if self.weight_update_ops is not None:
            sess = tf.get_default_session()
            sess.run(self.weight_update_ops)

    def sync_weights(self):
        """
        sync the weights of network_old. Direct copy the weights of network.
        :return:
        """
        if self.sync_weights_ops is not None:
            sess = tf.get_default_session()
            sess.run(self.sync_weights_ops)


OnehotDiscrete = OnehotCategorical


class Normal(StochasticPolicy):
    """
        The :class:`Normal' class is the Normal policy

        :param mean:
        :param std:
        :param group_ndims
        :param observation_placeholder
    """
    def __init__(self,
                 policy_callable,
                 observation_placeholder,
                 weight_update=1,
                 group_ndims=1,
github thu-ml / tianshou / tianshou / core / policy / stochastic.py View on Github external
self.sync_weights_ops = [tf.assign(variable_old, variable)
                                     for (variable_old, variable) in zip(network_old_weights, network_weights)]

            if weight_update == 0:
                self.weight_update_ops = self.sync_weights_ops
            elif 0 < weight_update < 1:  # as in DDPG
                pass
            else:
                self.interaction_count = 0  # as in DQN
                import math
                self.weight_update = math.ceil(weight_update)

        tf.assert_rank(self._logits, rank=2) # TODO: flexible policy output rank, e.g. RNN
        self._n_categories = self._logits.get_shape()[-1].value

        super(OnehotCategorical, self).__init__(
            act_dtype=tf.int32,
            param_dtype=self._logits.dtype,
            is_continuous=False,
            observation_placeholder=observation_placeholder,
            group_ndims=group_ndims,
            **kwargs)
github thu-ml / tianshou / examples / actor_critic_fail_cartpole.py View on Github external
net = tf.layers.dense(net, 32, activation=tf.nn.tanh)

        action_logits = tf.layers.dense(net, action_dim, activation=None)

        return action_logits, None

    def my_critic():
        net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
        net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
        value = tf.layers.dense(net, 1, activation=None)

        return None, value

    ### 2. build policy, critic, loss, optimizer
    print('actor and critic will share the first two layers in this case, and the third layer will cause error')
    actor = policy.OnehotCategorical(my_actor, observation_placeholder=observation_ph, weight_update=1)
    critic = value_function.StateValue(my_critic, observation_placeholder=observation_ph)



    actor_loss = losses.vanilla_policy_gradient(actor)
    critic_loss = losses.value_mse(critic)
    total_loss = actor_loss + critic_loss

    optimizer = tf.train.AdamOptimizer(1e-4)
    train_op = optimizer.minimize(total_loss, var_list=actor.trainable_variables)

    ### 3. define data collection
    training_data = Batch(env, actor, advantage_estimation.full_return)

    ### 4. start training
    config = tf.ConfigProto()
github thu-ml / tianshou / examples / ppo_example.py View on Github external
action_dim = env.action_space.n

    clip_param = 0.2
    num_batches = 2

    # 1. build network with pure tf
    observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input

    with tf.variable_scope('pi'):
        action_logits = policy_net(observation, action_dim, 'pi')
        train_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES
    with tf.variable_scope('pi_old'):
        action_logits_old = policy_net(observation, action_dim, 'pi_old')

    # 2. build losses, optimizers
    pi = policy.OnehotCategorical(action_logits, observation_placeholder=observation) # YongRen: policy.Gaussian (could reference the policy in TRPO paper, my code is adapted from zhusuan.distributions) policy.DQN etc.
    # for continuous action space, you may need to change an environment to run
    pi_old = policy.OnehotCategorical(action_logits_old, observation_placeholder=observation)

    action = tf.placeholder(dtype=tf.int32, shape=[None]) # batch of integer actions
    advantage = tf.placeholder(dtype=tf.float32, shape=[None]) # advantage values used in the Gradients

    ppo_loss_clip = losses.ppo_clip(action, advantage, clip_param, pi, pi_old) # TongzhengRen: losses.vpg ... management of placeholders and feed_dict

    total_loss = ppo_loss_clip
    optimizer = tf.train.AdamOptimizer(1e-3)
    train_op = optimizer.minimize(total_loss, var_list=train_var_list)

    # 3. define data collection
    training_data = Batch(env, pi, advantage_estimation.full_return) # YouQiaoben: finish and polish Batch, advantage_estimation.gae_lambda as in PPO paper
                                                             # ShihongSong: Replay(), see dqn_example.py
    # maybe a dict to manage the elements to be collected
github thu-ml / tianshou / examples / ppo_cartpole_gym.py View on Github external
def my_policy():
        net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
        net = tf.layers.dense(net, 32, activation=tf.nn.tanh)

        action_logits = tf.layers.dense(net, action_dim, activation=None)

        return action_logits, None  # None value head

    # TODO: current implementation of passing function or overriding function has to return a value head
    # to allow network sharing between policy and value networks. This makes 'policy' and 'value_function'
    # imbalanced semantically (though they are naturally imbalanced since 'policy' is required to interact
    # with the environment and 'value_function' is not). I have an idea to solve this imbalance, which is
    # not based on passing function or overriding function.

    ### 2. build policy, loss, optimizer
    pi = policy.OnehotCategorical(my_policy, observation_placeholder=observation_ph, weight_update=0)

    ppo_loss_clip = losses.ppo_clip(pi, clip_param)

    total_loss = ppo_loss_clip
    optimizer = tf.train.AdamOptimizer(1e-4)
    train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)

    ### 3. define data collection
    training_data = Batch(env, pi, advantage_estimation.full_return)

    ### 4. start training
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
github thu-ml / tianshou / examples / actor_critic_separate_cartpole.py View on Github external
observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim)

    def my_network():
        net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
        net = tf.layers.dense(net, 32, activation=tf.nn.tanh)

        action_logits = tf.layers.dense(net, action_dim, activation=None)

        net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
        net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
        value = tf.layers.dense(net, 1, activation=None)

        return action_logits, value

    ### 2. build policy, critic, loss, optimizer
    actor = policy.OnehotCategorical(my_network, observation_placeholder=observation_ph, weight_update=1)
    critic = value_function.StateValue(my_network, observation_placeholder=observation_ph)

    actor_loss = losses.REINFORCE(actor)
    critic_loss = losses.value_mse(critic)

    actor_optimizer = tf.train.AdamOptimizer(1e-4)
    actor_train_op = actor_optimizer.minimize(actor_loss, var_list=actor.trainable_variables)

    critic_optimizer = tf.train.RMSPropOptimizer(1e-4)
    critic_train_op = critic_optimizer.minimize(critic_loss, var_list=critic.trainable_variables)

    ### 3. define data collection
    data_collector = Batch(env, actor,
                           [advantage_estimation.gae_lambda(1, critic), advantage_estimation.nstep_return(1, critic)],
                           [actor, critic])