How to use the tf2rl.algos.ddpg.DDPG function in tf2rl

To help you get started, we’ve selected a few tf2rl examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github keiohta / tf2rl / tests / algos / test_ddpg.py View on Github external
def setUpClass(cls):
        super().setUpClass()
        cls.agent = DDPG(
            state_shape=cls.continuous_env.observation_space.shape,
            action_dim=cls.continuous_env.action_space.low.size,
            batch_size=cls.batch_size,
            sigma=0.5,  # Make noise bigger to easier to test
            gpu=-1)
github keiohta / tf2rl / tests / algos / test_apex.py View on Github external
def policy_fn(env, name, memory_capacity=int(1e6), gpu=-1, *args, **kwargs):
            return DDPG(
                state_shape=env.observation_space.shape,
                action_dim=env.action_space.high.size,
                n_warmup=500,
                gpu=-1)
github keiohta / tf2rl / examples / run_vail_ddpg.py View on Github external
if __name__ == '__main__':
    parser = IRLTrainer.get_argument()
    parser = VAIL.get_argument(parser)
    parser.add_argument('--env-name', type=str, default="RoboschoolReacher-v1")
    args = parser.parse_args()

    if args.expert_path_dir is None:
        print("Plaese generate demonstrations first")
        print("python examples/run_sac.py --env-name=RoboschoolReacher-v1 --save-test-path --test-interval=50000")
        exit()

    units = [400, 300]

    env = gym.make(args.env_name)
    test_env = gym.make(args.env_name)
    policy = DDPG(
        state_shape=env.observation_space.shape,
        action_dim=env.action_space.high.size,
        max_action=env.action_space.high[0],
        gpu=args.gpu,
        actor_units=units,
        critic_units=units,
        n_warmup=10000,
        batch_size=100)
    irl = VAIL(
        state_shape=env.observation_space.shape,
        action_dim=env.action_space.high.size,
        units=units,
        enable_sn=args.enable_sn,
        batch_size=32,
        gpu=args.gpu)
    expert_trajs = restore_latest_n_traj(
github keiohta / tf2rl / examples / run_apex_shmem_ddpg.py View on Github external
def make_policy(env, name, tf2rl=False):
    dim_state, dim_action = env.observation_space.shape[0], env.action_space.shape[0]
    max_action = env.action_space.high

    with tf.device("/gpu:0"):
        if name == "DDPG":
            if tf2rl:
                policy = DDPG_tf2rl(state_shape=(dim_state,), action_dim=dim_action,
                                    max_action=max_action[0], max_grad=1.)
                saved_policy = DDPG_tf2rl(state_shape=(dim_state,), action_dim=dim_action,
                                          max_action=max_action[0], max_grad=1.)
            else:
                policy = DDPG.DDPG(dim_state, dim_action, max_action, training=True)
                saved_policy = DDPG.DDPG(dim_state, dim_action, max_action, training=False)
        elif name == "SAC":
            if tf2rl:
                policy = SAC_tf2rl(state_shape=(dim_state,), action_dim=dim_action, max_action=max_action[0])
                saved_policy = SAC_tf2rl(state_shape=(dim_state,), action_dim=dim_action, max_action=max_action[0])
            else:
                policy = SAC.SAC(dim_state, dim_action, max_action, training=True)
                saved_policy = SAC.SAC(dim_state, dim_action, max_action, training=False)
        else:
            raise ValueError("invalid policy")
github keiohta / tf2rl / examples / run_ddpg.py View on Github external
from tf2rl.algos.ddpg import DDPG
from tf2rl.experiments.trainer import Trainer


if __name__ == '__main__':
    parser = Trainer.get_argument()
    parser = DDPG.get_argument(parser)
    parser.add_argument('--env-name', type=str, default="RoboschoolAnt-v1")
    parser.set_defaults(batch_size=100)
    parser.set_defaults(n_warmup=10000)
    args = parser.parse_args()

    env = gym.make(args.env_name)
    test_env = gym.make(args.env_name)
    policy = DDPG(
        state_shape=env.observation_space.shape,
        action_dim=env.action_space.high.size,
        gpu=args.gpu,
        memory_capacity=args.memory_capacity,
        max_action=env.action_space.high[0],
        batch_size=args.batch_size,
        n_warmup=args.n_warmup)
    trainer = Trainer(policy, env, args, test_env=test_env)
    trainer()
github keiohta / tf2rl / examples / run_apex_shmem_ddpg.py View on Github external
def make_policy(env, name, tf2rl=False):
    dim_state, dim_action = env.observation_space.shape[0], env.action_space.shape[0]
    max_action = env.action_space.high

    with tf.device("/gpu:0"):
        if name == "DDPG":
            if tf2rl:
                policy = DDPG_tf2rl(state_shape=(dim_state,), action_dim=dim_action,
                                    max_action=max_action[0], max_grad=1.)
                saved_policy = DDPG_tf2rl(state_shape=(dim_state,), action_dim=dim_action,
                                          max_action=max_action[0], max_grad=1.)
            else:
                policy = DDPG.DDPG(dim_state, dim_action, max_action, training=True)
                saved_policy = DDPG.DDPG(dim_state, dim_action, max_action, training=False)
        elif name == "SAC":
            if tf2rl:
                policy = SAC_tf2rl(state_shape=(dim_state,), action_dim=dim_action, max_action=max_action[0])
                saved_policy = SAC_tf2rl(state_shape=(dim_state,), action_dim=dim_action, max_action=max_action[0])
            else:
                policy = SAC.SAC(dim_state, dim_action, max_action, training=True)
                saved_policy = SAC.SAC(dim_state, dim_action, max_action, training=False)
        else:
            raise ValueError("invalid policy")

    return policy, saved_policy