How to use the tf2rl.envs.utils.is_discrete function in tf2rl

To help you get started, we’ve selected a few tf2rl examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github keiohta / tf2rl / tests / envs / test_utils.py View on Github external
def test_is_discrete(self):
        discrete_space = gym.make('CartPole-v0').action_space
        continuous_space = gym.make('Pendulum-v0').action_space
        self.assertTrue(is_discrete(discrete_space))
        self.assertFalse(is_discrete(continuous_space))
github keiohta / tf2rl / tests / envs / test_utils.py View on Github external
def test_is_discrete(self):
        discrete_space = gym.make('CartPole-v0').action_space
        continuous_space = gym.make('Pendulum-v0').action_space
        self.assertTrue(is_discrete(discrete_space))
        self.assertFalse(is_discrete(continuous_space))
github keiohta / tf2rl / tf2rl / experiments / on_policy_trainer.py View on Github external
def __call__(self):
        total_steps = 0
        n_episode = 0

        # TODO: clean codes
        # Prepare buffer
        self.replay_buffer = get_replay_buffer(
            self._policy, self._env)
        kwargs_local_buf = get_default_rb_dict(
            size=self._episode_max_steps, env=self._env)
        kwargs_local_buf["env_dict"]["logp"] = {}
        kwargs_local_buf["env_dict"]["val"] = {}
        if is_discrete(self._env.action_space):
            kwargs_local_buf["env_dict"]["act"]["dtype"] = np.int32
        self.local_buffer = ReplayBuffer(**kwargs_local_buf)

        tf.summary.experimental.set_step(total_steps)
        while total_steps < self._max_steps:
            # Collect samples
            n_episode, total_rewards = self._collect_sample(n_episode, total_steps)
            total_steps += self._policy.horizon
            tf.summary.experimental.set_step(total_steps)

            if len(total_rewards) > 0:
                avg_training_return = sum(total_rewards) / len(total_rewards)
                tf.summary.scalar(
                    name="Common/training_return", data=avg_training_return)

            # Train actor critic
github keiohta / tf2rl / examples / run_ppo_pendulum.py View on Github external
parser.add_argument('--env-name', type=str,
                        default="Pendulum-v0")
    parser.set_defaults(test_interval=10240)
    parser.set_defaults(max_steps=int(1e7))
    parser.set_defaults(horizon=512)
    parser.set_defaults(batch_size=32)
    parser.set_defaults(gpu=-1)
    args = parser.parse_args()

    env = gym.make(args.env_name)
    test_env = gym.make(args.env_name)
    policy = PPO(
        state_shape=env.observation_space.shape,
        action_dim=get_act_dim(env.action_space),
        is_discrete=is_discrete(env.action_space),
        max_action=None if is_discrete(
            env.action_space) else env.action_space.high[0],
        batch_size=args.batch_size,
        actor_units=[32, 32],
        critic_units=[32, 32],
        discount=0.9,
        horizon=args.horizon,
        normalize_adv=args.normalize_adv,
        enable_gae=args.enable_gae,
        gpu=args.gpu)
    trainer = OnPolicyTrainer(policy, env, args, test_env=test_env)
    trainer()
github keiohta / tf2rl / examples / run_vpg.py View on Github external
parser.add_argument('--env-name', type=str,
                        default="Pendulum-v0")
    parser.add_argument('--normalize-adv', action='store_true')
    parser.add_argument('--enable-gae', action='store_true')
    parser.set_defaults(test_interval=5000)
    parser.set_defaults(max_steps=int(1e7))
    parser.set_defaults(horizon=1000)
    parser.set_defaults(gpu=-1)
    args = parser.parse_args()

    env = gym.make(args.env_name)
    test_env = gym.make(args.env_name)
    policy = VPG(
        state_shape=env.observation_space.shape,
        action_dim=get_act_dim(env.action_space),
        is_discrete=is_discrete(env.action_space),
        max_action=None if is_discrete(
            env.action_space) else env.action_space.high[0],
        batch_size=args.batch_size,
        actor_units=[32, 32],
        critic_units=[32, 32],
        discount=0.9,
        horizon=args.horizon,
        fix_std=True,
        normalize_adv=args.normalize_adv,
        enable_gae=args.enable_gae,
        gpu=args.gpu)
    trainer = OnPolicyTrainer(policy, env, args, test_env=test_env)
    trainer()
github keiohta / tf2rl / examples / run_ppo.py View on Github external
default="Pendulum-v0")
    parser.set_defaults(test_interval=20480)
    parser.set_defaults(max_steps=int(1e7))
    parser.set_defaults(horizon=2048)
    parser.set_defaults(batch_size=64)
    parser.set_defaults(gpu=-1)
    args = parser.parse_args()

    env = gym.make(args.env_name)
    test_env = gym.make(args.env_name)

    policy = PPO(
        state_shape=env.observation_space.shape,
        action_dim=get_act_dim(env.action_space),
        is_discrete=is_discrete(env.action_space),
        max_action=None if is_discrete(
            env.action_space) else env.action_space.high[0],
        batch_size=args.batch_size,
        actor_units=[64, 64],
        critic_units=[64, 64],
        n_epoch=10,
        n_epoch_critic=10,
        lr_actor=3e-4,
        lr_critic=3e-4,
        discount=0.99,
        lam=0.95,
        horizon=args.horizon,
        normalize_adv=args.normalize_adv,
        enable_gae=args.enable_gae,
        gpu=args.gpu)
    trainer = OnPolicyTrainer(policy, env, args, test_env=test_env)
    trainer()
github keiohta / tf2rl / examples / run_vpg.py View on Github external
default="Pendulum-v0")
    parser.add_argument('--normalize-adv', action='store_true')
    parser.add_argument('--enable-gae', action='store_true')
    parser.set_defaults(test_interval=5000)
    parser.set_defaults(max_steps=int(1e7))
    parser.set_defaults(horizon=1000)
    parser.set_defaults(gpu=-1)
    args = parser.parse_args()

    env = gym.make(args.env_name)
    test_env = gym.make(args.env_name)
    policy = VPG(
        state_shape=env.observation_space.shape,
        action_dim=get_act_dim(env.action_space),
        is_discrete=is_discrete(env.action_space),
        max_action=None if is_discrete(
            env.action_space) else env.action_space.high[0],
        batch_size=args.batch_size,
        actor_units=[32, 32],
        critic_units=[32, 32],
        discount=0.9,
        horizon=args.horizon,
        fix_std=True,
        normalize_adv=args.normalize_adv,
        enable_gae=args.enable_gae,
        gpu=args.gpu)
    trainer = OnPolicyTrainer(policy, env, args, test_env=test_env)
    trainer()
github keiohta / tf2rl / tf2rl / misc / get_replay_buffer.py View on Github external
kwargs = get_default_rb_dict(policy.memory_capacity, env)

    if size is not None:
        kwargs["size"] = size

    # on-policy policy
    if not issubclass(type(policy), OffPolicyAgent):
        kwargs["size"] = policy.horizon
        kwargs["env_dict"].pop("next_obs")
        kwargs["env_dict"].pop("rew")
        # TODO: Remove done. Currently cannot remove because of cpprb implementation
        # kwargs["env_dict"].pop("done")
        kwargs["env_dict"]["logp"] = {}
        kwargs["env_dict"]["ret"] = {}
        kwargs["env_dict"]["adv"] = {}
        if is_discrete(env.action_space):
            kwargs["env_dict"]["act"]["dtype"] = np.int32
        return ReplayBuffer(**kwargs)

    # N-step prioritized
    if use_prioritized_rb and use_nstep_rb:
        kwargs["Nstep"] = {"size": n_step,
                           "gamma": policy.discount,
                           "rew": "rew",
                           "next": "next_obs"}
        return PrioritizedReplayBuffer(**kwargs)

    if len(obs_shape) == 3:
        kwargs["env_dict"]["obs"]["dtype"] = np.ubyte
        kwargs["env_dict"]["next_obs"]["dtype"] = np.ubyte

    # prioritized
github keiohta / tf2rl / examples / run_ppo_pendulum.py View on Github external
parser = PPO.get_argument(parser)
    parser.add_argument('--env-name', type=str,
                        default="Pendulum-v0")
    parser.set_defaults(test_interval=10240)
    parser.set_defaults(max_steps=int(1e7))
    parser.set_defaults(horizon=512)
    parser.set_defaults(batch_size=32)
    parser.set_defaults(gpu=-1)
    args = parser.parse_args()

    env = gym.make(args.env_name)
    test_env = gym.make(args.env_name)
    policy = PPO(
        state_shape=env.observation_space.shape,
        action_dim=get_act_dim(env.action_space),
        is_discrete=is_discrete(env.action_space),
        max_action=None if is_discrete(
            env.action_space) else env.action_space.high[0],
        batch_size=args.batch_size,
        actor_units=[32, 32],
        critic_units=[32, 32],
        discount=0.9,
        horizon=args.horizon,
        normalize_adv=args.normalize_adv,
        enable_gae=args.enable_gae,
        gpu=args.gpu)
    trainer = OnPolicyTrainer(policy, env, args, test_env=test_env)
    trainer()
github keiohta / tf2rl / examples / run_ppo.py View on Github external
parser.add_argument('--env-name', type=str,
                        default="Pendulum-v0")
    parser.set_defaults(test_interval=20480)
    parser.set_defaults(max_steps=int(1e7))
    parser.set_defaults(horizon=2048)
    parser.set_defaults(batch_size=64)
    parser.set_defaults(gpu=-1)
    args = parser.parse_args()

    env = gym.make(args.env_name)
    test_env = gym.make(args.env_name)

    policy = PPO(
        state_shape=env.observation_space.shape,
        action_dim=get_act_dim(env.action_space),
        is_discrete=is_discrete(env.action_space),
        max_action=None if is_discrete(
            env.action_space) else env.action_space.high[0],
        batch_size=args.batch_size,
        actor_units=[64, 64],
        critic_units=[64, 64],
        n_epoch=10,
        n_epoch_critic=10,
        lr_actor=3e-4,
        lr_critic=3e-4,
        discount=0.99,
        lam=0.95,
        horizon=args.horizon,
        normalize_adv=args.normalize_adv,
        enable_gae=args.enable_gae,
        gpu=args.gpu)
    trainer = OnPolicyTrainer(policy, env, args, test_env=test_env)