How to use the tf2rl.algos.sac.SAC function in tf2rl

To help you get started, we’ve selected a few tf2rl examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github keiohta / tf2rl / tests / algos / test_sac.py View on Github external
def setUpClass(cls):
        super().setUpClass()
        cls.agent = SAC(
            state_shape=cls.continuous_env.observation_space.shape,
            action_dim=cls.continuous_env.action_space.low.size,
            batch_size=cls.batch_size,
            gpu=-1)
github keiohta / tf2rl / examples / run_apex_shmem_ddpg.py View on Github external
dim_state, dim_action = env.observation_space.shape[0], env.action_space.shape[0]
    max_action = env.action_space.high

    with tf.device("/gpu:0"):
        if name == "DDPG":
            if tf2rl:
                policy = DDPG_tf2rl(state_shape=(dim_state,), action_dim=dim_action,
                                    max_action=max_action[0], max_grad=1.)
                saved_policy = DDPG_tf2rl(state_shape=(dim_state,), action_dim=dim_action,
                                          max_action=max_action[0], max_grad=1.)
            else:
                policy = DDPG.DDPG(dim_state, dim_action, max_action, training=True)
                saved_policy = DDPG.DDPG(dim_state, dim_action, max_action, training=False)
        elif name == "SAC":
            if tf2rl:
                policy = SAC_tf2rl(state_shape=(dim_state,), action_dim=dim_action, max_action=max_action[0])
                saved_policy = SAC_tf2rl(state_shape=(dim_state,), action_dim=dim_action, max_action=max_action[0])
            else:
                policy = SAC.SAC(dim_state, dim_action, max_action, training=True)
                saved_policy = SAC.SAC(dim_state, dim_action, max_action, training=False)
        else:
            raise ValueError("invalid policy")

    return policy, saved_policy
github keiohta / tf2rl / examples / run_apex_shmem_ddpg.py View on Github external
max_action = env.action_space.high

    with tf.device("/gpu:0"):
        if name == "DDPG":
            if tf2rl:
                policy = DDPG_tf2rl(state_shape=(dim_state,), action_dim=dim_action,
                                    max_action=max_action[0], max_grad=1.)
                saved_policy = DDPG_tf2rl(state_shape=(dim_state,), action_dim=dim_action,
                                          max_action=max_action[0], max_grad=1.)
            else:
                policy = DDPG.DDPG(dim_state, dim_action, max_action, training=True)
                saved_policy = DDPG.DDPG(dim_state, dim_action, max_action, training=False)
        elif name == "SAC":
            if tf2rl:
                policy = SAC_tf2rl(state_shape=(dim_state,), action_dim=dim_action, max_action=max_action[0])
                saved_policy = SAC_tf2rl(state_shape=(dim_state,), action_dim=dim_action, max_action=max_action[0])
            else:
                policy = SAC.SAC(dim_state, dim_action, max_action, training=True)
                saved_policy = SAC.SAC(dim_state, dim_action, max_action, training=False)
        else:
            raise ValueError("invalid policy")

    return policy, saved_policy
github keiohta / tf2rl / tf2rl / algos / sac_discrete.py View on Github external
def get_argument(parser=None):
        parser = SAC.get_argument(parser)
        parser.add_argument('--target-update-interval', type=int, default=None)
        return parser
github keiohta / tf2rl / examples / run_sac.py View on Github external
from tf2rl.algos.sac import SAC
from tf2rl.experiments.trainer import Trainer


if __name__ == '__main__':
    parser = Trainer.get_argument()
    parser = SAC.get_argument(parser)
    parser.add_argument('--env-name', type=str, default="RoboschoolAnt-v1")
    parser.set_defaults(batch_size=100)
    parser.set_defaults(n_warmup=10000)
    args = parser.parse_args()

    env = gym.make(args.env_name)
    test_env = gym.make(args.env_name)
    policy = SAC(
        state_shape=env.observation_space.shape,
        action_dim=env.action_space.high.size,
        gpu=args.gpu,
        memory_capacity=args.memory_capacity,
        max_action=env.action_space.high[0],
        batch_size=args.batch_size,
        n_warmup=args.n_warmup,
        auto_alpha=args.auto_alpha)
    trainer = Trainer(policy, env, args, test_env=test_env)
    trainer()
github keiohta / tf2rl / tf2rl / algos / sac_discrete.py View on Github external
self.l2 = Dense(critic_units[1], name="L2", activation='relu')
        self.l3 = Dense(action_dim, name="L2", activation='linear')

        dummy_state = tf.constant(
            np.zeros(shape=(1,) + state_shape, dtype=np.float32))
        self(dummy_state)

    def call(self, states):
        features = self.l1(states)
        features = self.l2(features)
        values = self.l3(features)

        return values


class SACDiscrete(SAC):
    def __init__(
            self,
            state_shape,
            action_dim,
            *args,
            actor_fn=None,
            critic_fn=None,
            target_update_interval=None,
            **kwargs):
        kwargs["name"] = "SAC_discrete"
        self.actor_fn = actor_fn if actor_fn is not None else CategoricalActor
        self.critic_fn = critic_fn if critic_fn is not None else CriticQ
        self.target_hard_update = target_update_interval is not None
        self.target_update_interval = target_update_interval
        self.n_training = tf.Variable(0, dtype=tf.int32)
        super().__init__(state_shape, action_dim, *args, **kwargs)
github keiohta / tf2rl / examples / run_sac.py View on Github external
import roboschool
import gym

from tf2rl.algos.sac import SAC
from tf2rl.experiments.trainer import Trainer


if __name__ == '__main__':
    parser = Trainer.get_argument()
    parser = SAC.get_argument(parser)
    parser.add_argument('--env-name', type=str, default="RoboschoolAnt-v1")
    parser.set_defaults(batch_size=100)
    parser.set_defaults(n_warmup=10000)
    args = parser.parse_args()

    env = gym.make(args.env_name)
    test_env = gym.make(args.env_name)
    policy = SAC(
        state_shape=env.observation_space.shape,
        action_dim=env.action_space.high.size,
        gpu=args.gpu,
        memory_capacity=args.memory_capacity,
        max_action=env.action_space.high[0],
        batch_size=args.batch_size,
        n_warmup=args.n_warmup,
        auto_alpha=args.auto_alpha)