Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def setUpClass(cls):
cls.batch_size = 32
cls.memory_capacity = 32
cls.on_policy_agent = OnPolicyAgent(
name="OnPolicyAgent",
batch_size=cls.batch_size)
cls.off_policy_agent = OffPolicyAgent(
name="OffPolicyAgent",
memory_capacity=cls.memory_capacity)
cls.discrete_env = gym.make("CartPole-v0")
cls.continuous_env = gym.make("Pendulum-v0")
np.zeros(shape=(1,)+state_shape, dtype=np.float32))
dummy_action = tf.constant(
np.zeros(shape=[1, action_dim], dtype=np.float32))
with tf.device("/cpu:0"):
self([dummy_state, dummy_action])
def call(self, inputs):
states, actions = inputs
features = tf.concat([states, actions], axis=1)
features = tf.nn.relu(self.l1(features))
features = tf.nn.relu(self.l2(features))
features = self.l3(features)
return features
class DDPG(OffPolicyAgent):
def __init__(
self,
state_shape,
action_dim,
name="DDPG",
max_action=1.,
lr_actor=0.001,
lr_critic=0.001,
actor_units=[400, 300],
critic_units=[400, 300],
sigma=0.1,
tau=0.005,
n_warmup=int(1e4),
memory_capacity=int(1e6),
**kwargs):
super().__init__(name=name, memory_capacity=memory_capacity, n_warmup=n_warmup, **kwargs)
def get_argument(parser=None):
parser = OffPolicyAgent.get_argument(parser)
parser.add_argument('--enable-double-dqn', action='store_true')
parser.add_argument('--enable-dueling-dqn', action='store_true')
parser.add_argument('--enable-categorical-dqn', action='store_true')
parser.add_argument('--enable-noisy-dqn', action='store_true')
return parser
def get_replay_buffer(policy, env, use_prioritized_rb=False,
use_nstep_rb=False, n_step=1, size=None):
if policy is None or env is None:
return None
obs_shape = get_space_size(env.observation_space)
kwargs = get_default_rb_dict(policy.memory_capacity, env)
if size is not None:
kwargs["size"] = size
# on-policy policy
if not issubclass(type(policy), OffPolicyAgent):
kwargs["size"] = policy.horizon
kwargs["env_dict"].pop("next_obs")
kwargs["env_dict"].pop("rew")
# TODO: Remove done. Currently cannot remove because of cpprb implementation
# kwargs["env_dict"].pop("done")
kwargs["env_dict"]["logp"] = {}
kwargs["env_dict"]["ret"] = {}
kwargs["env_dict"]["adv"] = {}
if is_discrete(env.action_space):
kwargs["env_dict"]["act"]["dtype"] = np.int32
return ReplayBuffer(**kwargs)
# N-step prioritized
if use_prioritized_rb and use_nstep_rb:
kwargs["Nstep"] = {"size": n_step,
"gamma": policy.discount,
np.zeros(shape=(1,)+state_shape, dtype=np.float32))
dummy_action = tf.constant(
np.zeros(shape=[1, action_dim], dtype=np.float32))
self([dummy_state, dummy_action])
def call(self, inputs):
[states, actions] = inputs
features = tf.concat([states, actions], axis=1)
features = self.l1(features)
features = self.l2(features)
values = self.l3(features)
return tf.squeeze(values, axis=1)
class SAC(OffPolicyAgent):
def __init__(
self,
state_shape,
action_dim,
name="SAC",
max_action=1.,
lr=3e-4,
actor_units=[256, 256],
critic_units=[256, 256],
tau=0.005,
alpha=.2,
auto_alpha=False,
n_warmup=int(1e4),
memory_capacity=int(1e6),
**kwargs):
super().__init__(
features, (-1, self._action_dim, self._n_atoms)) # [batch_size, action_dim, n_atoms]
# [batch_size, action_dim, n_atoms]
q_dist = tf.keras.activations.softmax(features, axis=2)
return tf.clip_by_value(q_dist, 1e-8, 1.0-1e-8)
else:
if self._enable_dueling_dqn:
advantages = self.l3(features)
v_values = self.l4(features)
q_values = v_values + \
(advantages - tf.reduce_mean(advantages, axis=1, keepdims=True))
else:
q_values = self.l3(features)
return q_values
class DQN(OffPolicyAgent):
def __init__(
self,
state_shape,
action_dim,
discrete_input,
q_func=None,
name="DQN",
lr=0.001,
units=[32, 32],
epsilon=0.1,
epsilon_min=None,
epsilon_decay_step=int(1e6),
n_warmup=int(1e4),
target_replace_interval=int(5e3),
memory_capacity=int(1e6),
optimizer=None,
def get_argument(parser=None):
parser = OffPolicyAgent.get_argument(parser)
parser.add_argument('--alpha', type=float, default=0.2)
parser.add_argument('--auto-alpha', action="store_true")
return parser