How to use tf2rl - 10 common examples

To help you get started, we’ve selected a few tf2rl examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github keiohta / tf2rl / tests / networks / test_spectral_norm_dense.py View on Github external
def test_sn_dense(self):
        layer_test(
            SNDense, kwargs={'units': 3}, input_shape=(3, 2),
            custom_objects={'SNDense': SNDense})
github keiohta / tf2rl / tests / networks / test_spectral_norm_dense.py View on Github external
def test_sn_dense(self):
        layer_test(
            SNDense, kwargs={'units': 3}, input_shape=(3, 2),
            custom_objects={'SNDense': SNDense})
github keiohta / tf2rl / tests / algos / test_apex.py View on Github external
def set_weights_fn(policy, weights):
            actor_weights, critic_weights, critic_target_weights = weights
            update_target_variables(
                policy.actor.weights, actor_weights, tau=1.)
            update_target_variables(
                policy.critic.weights, critic_weights, tau=1.)
            update_target_variables(
                policy.critic_target.weights, critic_target_weights, tau=1.)
github keiohta / tf2rl / tests / algos / test_apex.py View on Github external
def set_weights_fn(policy, weights):
            actor_weights, critic_weights, critic_target_weights = weights
            update_target_variables(
                policy.actor.weights, actor_weights, tau=1.)
            update_target_variables(
                policy.critic.weights, critic_weights, tau=1.)
            update_target_variables(
                policy.critic_target.weights, critic_target_weights, tau=1.)
github keiohta / tf2rl / tests / algos / test_ddpg.py View on Github external
def setUpClass(cls):
        super().setUpClass()
        cls.agent = DDPG(
            state_shape=cls.continuous_env.observation_space.shape,
            action_dim=cls.continuous_env.action_space.low.size,
            batch_size=cls.batch_size,
            sigma=0.5,  # Make noise bigger to easier to test
            gpu=-1)
github keiohta / tf2rl / tests / algos / test_apex.py View on Github external
def policy_fn(env, name, memory_capacity=int(1e6), gpu=-1, *args, **kwargs):
            return DDPG(
                state_shape=env.observation_space.shape,
                action_dim=env.action_space.high.size,
                n_warmup=500,
                gpu=-1)
github keiohta / tf2rl / tests / misc / test_huber_loss.py View on Github external
def test_huber_loss(self):
        # [0, 0] and [1, 1] -> [0.5, 0.5]
        y_target = np.array([0., 0.])
        y_pred = np.array([1., 1.])
        expected = np.array([0.5, 0.5])
        loss = huber_loss(y_target, y_pred)
        print(loss)
        # self.assertEqual(expected, loss.numpy())

        y_target = np.array([0., 0.])
        y_pred = np.array([10., 10.])
        expected = np.array([10., 10.])
        loss = huber_loss(y_target, y_pred)
        print(loss)
        # self.assertEqual(expected, loss.numpy())
github keiohta / tf2rl / tests / experiments / test_utils.py View on Github external
def setUpClass(cls):
        cls.env = gym.make("CartPole-v0")
        policy = DQN(
            state_shape=cls.env.observation_space.shape,
            action_dim=cls.env.action_space.n,
            memory_capacity=2**4)
        cls.replay_buffer = get_replay_buffer(
            policy, cls.env)
        cls.output_dir = os.path.join(
            os.path.dirname(__file__),
            "tests")
        if not os.path.isdir(cls.output_dir):
            os.makedirs(cls.output_dir)
github keiohta / tf2rl / tests / algos / test_dqn.py View on Github external
def setUpClass(cls):
        super().setUpClass()
        cls.agent = DQN(
            state_shape=cls.discrete_env.observation_space.shape,
            action_dim=cls.discrete_env.action_space.n,
            batch_size=cls.batch_size,
            enable_categorical_dqn=True,
            epsilon=0.,
            gpu=-1)
github keiohta / tf2rl / tests / algos / test_dqn.py View on Github external
def setUpClass(cls):
        super().setUpClass()
        cls.agent = DQN(
            state_shape=cls.discrete_env.observation_space.shape,
            action_dim=cls.discrete_env.action_space.n,
            batch_size=cls.batch_size,
            enable_double_dqn=True,
            enable_dueling_dqn=True,
            epsilon=0.,
            gpu=-1)