Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_get_replay_buffer(self):
# Replay Buffer
rb = get_replay_buffer(
self.on_policy_agent, self.discrete_env)
self.assertTrue(isinstance(rb, ReplayBuffer))
rb = get_replay_buffer(
self.off_policy_agent, self.discrete_env)
self.assertTrue(isinstance(rb, ReplayBuffer))
# Prioritized Replay Buffer
rb = get_replay_buffer(
self.off_policy_agent, self.discrete_env,
use_prioritized_rb=True)
self.assertTrue(isinstance(rb, PrioritizedReplayBuffer))
def setUpClass(cls):
cls.env = gym.make("CartPole-v0")
policy = DQN(
state_shape=cls.env.observation_space.shape,
action_dim=cls.env.action_space.n,
memory_capacity=2**4)
cls.replay_buffer = get_replay_buffer(
policy, cls.env)
cls.output_dir = os.path.join(
os.path.dirname(__file__),
"tests")
if not os.path.isdir(cls.output_dir):
os.makedirs(cls.output_dir)
def test_get_replay_buffer(self):
# Replay Buffer
rb = get_replay_buffer(
self.on_policy_agent, self.discrete_env)
self.assertTrue(isinstance(rb, ReplayBuffer))
rb = get_replay_buffer(
self.off_policy_agent, self.discrete_env)
self.assertTrue(isinstance(rb, ReplayBuffer))
# Prioritized Replay Buffer
rb = get_replay_buffer(
self.off_policy_agent, self.discrete_env,
use_prioritized_rb=True)
self.assertTrue(isinstance(rb, PrioritizedReplayBuffer))
def test_get_replay_buffer(self):
# Replay Buffer
rb = get_replay_buffer(
self.on_policy_agent, self.discrete_env)
self.assertTrue(isinstance(rb, ReplayBuffer))
rb = get_replay_buffer(
self.off_policy_agent, self.discrete_env)
self.assertTrue(isinstance(rb, ReplayBuffer))
# Prioritized Replay Buffer
rb = get_replay_buffer(
self.off_policy_agent, self.discrete_env,
use_prioritized_rb=True)
self.assertTrue(isinstance(rb, PrioritizedReplayBuffer))
def evaluate_policy(self, total_steps):
avg_test_return = 0.
if self._save_test_path:
replay_buffer = get_replay_buffer(
self._policy, self._test_env, size=self._episode_max_steps)
for i in range(self._test_episodes):
episode_return = 0.
frames = []
obs = self._test_env.reset()
done = False
for _ in range(self._episode_max_steps):
act, _ = self._policy.get_action(obs, test=True)
act = act if not hasattr(self._env.action_space, "high") else \
np.clip(act, self._env.action_space.low, self._env.action_space.high)
next_obs, reward, done, _ = self._test_env.step(act)
if self._save_test_path:
replay_buffer.add(
obs=obs, act=act, next_obs=next_obs,
rew=reward, done=done)
def __call__(self):
total_steps = 0
n_episode = 0
# TODO: clean codes
# Prepare buffer
self.replay_buffer = get_replay_buffer(
self._policy, self._env)
kwargs_local_buf = get_default_rb_dict(
size=self._episode_max_steps, env=self._env)
kwargs_local_buf["env_dict"]["logp"] = {}
kwargs_local_buf["env_dict"]["val"] = {}
if is_discrete(self._env.action_space):
kwargs_local_buf["env_dict"]["act"]["dtype"] = np.int32
self.local_buffer = ReplayBuffer(**kwargs_local_buf)
tf.summary.experimental.set_step(total_steps)
while total_steps < self._max_steps:
# Collect samples
n_episode, total_rewards = self._collect_sample(n_episode, total_steps)
total_steps += self._policy.horizon
tf.summary.experimental.set_step(total_steps)
def __call__(self):
total_steps = 0
tf.summary.experimental.set_step(total_steps)
episode_steps = 0
episode_return = 0
episode_start_time = time.perf_counter()
n_episode = 0
replay_buffer = get_replay_buffer(
policy=self._policy, env=self._env,
use_prioritized_rb=self._use_prioritized_rb,
use_nstep_rb=self._use_nstep_rb, n_step=self._n_step)
obs = self._env.reset()
while total_steps < self._max_steps:
if total_steps < self._policy.n_warmup:
action = self._env.action_space.sample()
else:
action = self._policy.get_action(obs)
next_obs, reward, done, _ = self._env.step(action)
if self._show_progress:
self._env.render()
episode_steps += 1
def evaluate_policy(self, total_steps):
avg_test_return = 0.
if self._save_test_path:
replay_buffer = get_replay_buffer(
self._policy, self._test_env, save_logp=True, size=self._episode_max_steps)
for i in range(self._test_episodes):
episode_return = 0.
frames = []
obs = self._test_env.reset()
for _ in range(self._episode_max_steps):
action = self._policy.get_action(obs, test=True)
next_obs, reward, done, _ = self._test_env.step(action)
if self._save_test_path:
data = {"obs": obs, "act": action, "next_obs": next_obs,
"rew": reward, "done": done}
if hasattr(self._policy, "get_logp"):
data["logp"] = self._policy.get_logp(obs)
else:
data["logp"] = None
replay_buffer.add(**data)
def __call__(self):
total_steps = 0
tf.summary.experimental.set_step(total_steps)
episode_steps = 0
episode_return = 0
episode_start_time = time.perf_counter()
n_episode = 0
replay_buffer = get_replay_buffer(
self._policy, self._env, self._use_prioritized_rb,
self._use_nstep_rb, self._n_step)
obs = self._env.reset()
while total_steps < self._max_steps:
while total_steps < self._max_steps:
if total_steps < self._policy.n_warmup:
action = self._env.action_space.sample()
else:
action = self._policy.get_action(obs)
next_obs, reward, done, _ = self._env.step(action)
if self._show_progress:
self._env.render()
episode_steps += 1