How to use the tf2rl.misc.get_replay_buffer.get_replay_buffer function in tf2rl

To help you get started, we’ve selected a few tf2rl examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github keiohta / tf2rl / tests / misc / test_get_replay_buffer.py View on Github external
def test_get_replay_buffer(self):
        # Replay Buffer
        rb = get_replay_buffer(
            self.on_policy_agent, self.discrete_env)
        self.assertTrue(isinstance(rb, ReplayBuffer))

        rb = get_replay_buffer(
            self.off_policy_agent, self.discrete_env)
        self.assertTrue(isinstance(rb, ReplayBuffer))

        # Prioritized Replay Buffer
        rb = get_replay_buffer(
            self.off_policy_agent, self.discrete_env,
            use_prioritized_rb=True)
        self.assertTrue(isinstance(rb, PrioritizedReplayBuffer))
github keiohta / tf2rl / tests / experiments / test_utils.py View on Github external
def setUpClass(cls):
        cls.env = gym.make("CartPole-v0")
        policy = DQN(
            state_shape=cls.env.observation_space.shape,
            action_dim=cls.env.action_space.n,
            memory_capacity=2**4)
        cls.replay_buffer = get_replay_buffer(
            policy, cls.env)
        cls.output_dir = os.path.join(
            os.path.dirname(__file__),
            "tests")
        if not os.path.isdir(cls.output_dir):
            os.makedirs(cls.output_dir)
github keiohta / tf2rl / tests / misc / test_get_replay_buffer.py View on Github external
def test_get_replay_buffer(self):
        # Replay Buffer
        rb = get_replay_buffer(
            self.on_policy_agent, self.discrete_env)
        self.assertTrue(isinstance(rb, ReplayBuffer))

        rb = get_replay_buffer(
            self.off_policy_agent, self.discrete_env)
        self.assertTrue(isinstance(rb, ReplayBuffer))

        # Prioritized Replay Buffer
        rb = get_replay_buffer(
            self.off_policy_agent, self.discrete_env,
            use_prioritized_rb=True)
        self.assertTrue(isinstance(rb, PrioritizedReplayBuffer))
github keiohta / tf2rl / tests / misc / test_get_replay_buffer.py View on Github external
def test_get_replay_buffer(self):
        # Replay Buffer
        rb = get_replay_buffer(
            self.on_policy_agent, self.discrete_env)
        self.assertTrue(isinstance(rb, ReplayBuffer))

        rb = get_replay_buffer(
            self.off_policy_agent, self.discrete_env)
        self.assertTrue(isinstance(rb, ReplayBuffer))

        # Prioritized Replay Buffer
        rb = get_replay_buffer(
            self.off_policy_agent, self.discrete_env,
            use_prioritized_rb=True)
        self.assertTrue(isinstance(rb, PrioritizedReplayBuffer))
github keiohta / tf2rl / tf2rl / experiments / on_policy_trainer.py View on Github external
def evaluate_policy(self, total_steps):
        avg_test_return = 0.
        if self._save_test_path:
            replay_buffer = get_replay_buffer(
                self._policy, self._test_env, size=self._episode_max_steps)
        for i in range(self._test_episodes):
            episode_return = 0.
            frames = []
            obs = self._test_env.reset()
            done = False
            for _ in range(self._episode_max_steps):
                act, _ = self._policy.get_action(obs, test=True)
                act = act if not hasattr(self._env.action_space, "high") else \
                    np.clip(act, self._env.action_space.low, self._env.action_space.high)
                next_obs, reward, done, _ = self._test_env.step(act)
                if self._save_test_path:
                    replay_buffer.add(
                        obs=obs, act=act, next_obs=next_obs,
                        rew=reward, done=done)
github keiohta / tf2rl / tf2rl / experiments / on_policy_trainer.py View on Github external
def __call__(self):
        total_steps = 0
        n_episode = 0

        # TODO: clean codes
        # Prepare buffer
        self.replay_buffer = get_replay_buffer(
            self._policy, self._env)
        kwargs_local_buf = get_default_rb_dict(
            size=self._episode_max_steps, env=self._env)
        kwargs_local_buf["env_dict"]["logp"] = {}
        kwargs_local_buf["env_dict"]["val"] = {}
        if is_discrete(self._env.action_space):
            kwargs_local_buf["env_dict"]["act"]["dtype"] = np.int32
        self.local_buffer = ReplayBuffer(**kwargs_local_buf)

        tf.summary.experimental.set_step(total_steps)
        while total_steps < self._max_steps:
            # Collect samples
            n_episode, total_rewards = self._collect_sample(n_episode, total_steps)
            total_steps += self._policy.horizon
            tf.summary.experimental.set_step(total_steps)
github keiohta / tf2rl / tf2rl / experiments / trainer.py View on Github external
def __call__(self):
        total_steps = 0
        tf.summary.experimental.set_step(total_steps)
        episode_steps = 0
        episode_return = 0
        episode_start_time = time.perf_counter()
        n_episode = 0

        replay_buffer = get_replay_buffer(
            policy=self._policy, env=self._env,
            use_prioritized_rb=self._use_prioritized_rb,
            use_nstep_rb=self._use_nstep_rb, n_step=self._n_step)

        obs = self._env.reset()

        while total_steps < self._max_steps:
            if total_steps < self._policy.n_warmup:
                action = self._env.action_space.sample()
            else:
                action = self._policy.get_action(obs)

            next_obs, reward, done, _ = self._env.step(action)
            if self._show_progress:
                self._env.render()
            episode_steps += 1
github keiohta / tf2rl / tf2rl / experiments / trainer.py View on Github external
def evaluate_policy(self, total_steps):
        avg_test_return = 0.
        if self._save_test_path:
            replay_buffer = get_replay_buffer(
                self._policy, self._test_env, save_logp=True, size=self._episode_max_steps)
        for i in range(self._test_episodes):
            episode_return = 0.
            frames = []
            obs = self._test_env.reset()
            for _ in range(self._episode_max_steps):
                action = self._policy.get_action(obs, test=True)
                next_obs, reward, done, _ = self._test_env.step(action)
                if self._save_test_path:
                    data = {"obs": obs, "act": action, "next_obs": next_obs,
                              "rew": reward, "done": done}
                    if hasattr(self._policy, "get_logp"):
                        data["logp"] = self._policy.get_logp(obs)
                    else:
                        data["logp"] = None
                    replay_buffer.add(**data)
github keiohta / tf2rl / tf2rl / experiments / irl_trainer.py View on Github external
def __call__(self):
        total_steps = 0
        tf.summary.experimental.set_step(total_steps)
        episode_steps = 0
        episode_return = 0
        episode_start_time = time.perf_counter()
        n_episode = 0

        replay_buffer = get_replay_buffer(
            self._policy, self._env, self._use_prioritized_rb,
            self._use_nstep_rb, self._n_step)

        obs = self._env.reset()

        while total_steps < self._max_steps:
            while total_steps < self._max_steps:
                if total_steps < self._policy.n_warmup:
                    action = self._env.action_space.sample()
                else:
                    action = self._policy.get_action(obs)

                next_obs, reward, done, _ = self._env.step(action)
                if self._show_progress:
                    self._env.render()
                episode_steps += 1