Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_sn_dense(self):
layer_test(
SNDense, kwargs={'units': 3}, input_shape=(3, 2),
custom_objects={'SNDense': SNDense})
def test_sn_dense(self):
layer_test(
SNDense, kwargs={'units': 3}, input_shape=(3, 2),
custom_objects={'SNDense': SNDense})
def set_weights_fn(policy, weights):
actor_weights, critic_weights, critic_target_weights = weights
update_target_variables(
policy.actor.weights, actor_weights, tau=1.)
update_target_variables(
policy.critic.weights, critic_weights, tau=1.)
update_target_variables(
policy.critic_target.weights, critic_target_weights, tau=1.)
def set_weights_fn(policy, weights):
actor_weights, critic_weights, critic_target_weights = weights
update_target_variables(
policy.actor.weights, actor_weights, tau=1.)
update_target_variables(
policy.critic.weights, critic_weights, tau=1.)
update_target_variables(
policy.critic_target.weights, critic_target_weights, tau=1.)
def setUpClass(cls):
super().setUpClass()
cls.agent = DDPG(
state_shape=cls.continuous_env.observation_space.shape,
action_dim=cls.continuous_env.action_space.low.size,
batch_size=cls.batch_size,
sigma=0.5, # Make noise bigger to easier to test
gpu=-1)
def policy_fn(env, name, memory_capacity=int(1e6), gpu=-1, *args, **kwargs):
return DDPG(
state_shape=env.observation_space.shape,
action_dim=env.action_space.high.size,
n_warmup=500,
gpu=-1)
def test_huber_loss(self):
# [0, 0] and [1, 1] -> [0.5, 0.5]
y_target = np.array([0., 0.])
y_pred = np.array([1., 1.])
expected = np.array([0.5, 0.5])
loss = huber_loss(y_target, y_pred)
print(loss)
# self.assertEqual(expected, loss.numpy())
y_target = np.array([0., 0.])
y_pred = np.array([10., 10.])
expected = np.array([10., 10.])
loss = huber_loss(y_target, y_pred)
print(loss)
# self.assertEqual(expected, loss.numpy())
def setUpClass(cls):
cls.env = gym.make("CartPole-v0")
policy = DQN(
state_shape=cls.env.observation_space.shape,
action_dim=cls.env.action_space.n,
memory_capacity=2**4)
cls.replay_buffer = get_replay_buffer(
policy, cls.env)
cls.output_dir = os.path.join(
os.path.dirname(__file__),
"tests")
if not os.path.isdir(cls.output_dir):
os.makedirs(cls.output_dir)
def setUpClass(cls):
super().setUpClass()
cls.agent = DQN(
state_shape=cls.discrete_env.observation_space.shape,
action_dim=cls.discrete_env.action_space.n,
batch_size=cls.batch_size,
enable_categorical_dqn=True,
epsilon=0.,
gpu=-1)
def setUpClass(cls):
super().setUpClass()
cls.agent = DQN(
state_shape=cls.discrete_env.observation_space.shape,
action_dim=cls.discrete_env.action_space.n,
batch_size=cls.batch_size,
enable_double_dqn=True,
enable_dueling_dqn=True,
epsilon=0.,
gpu=-1)