Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __init__(self, env, model, n_steps):
"""
A runner to learn the policy of an environment for a model
:param env: (Gym environment) The environment to learn from
:param model: (Model) The model to learn
:param n_steps: (int) The number of steps to run for each environment
"""
super(_Runner, self).__init__(env=env, model=model, n_steps=n_steps)
self.env = env
self.model = model
self.n_env = n_env = env.num_envs
if isinstance(env.action_space, Discrete):
self.n_act = env.action_space.n
else:
self.n_act = env.action_space.shape[-1]
self.n_batch = n_env * n_steps
if len(env.observation_space.shape) > 1:
self.raw_pixels = True
obs_height, obs_width, obs_num_channels = env.observation_space.shape
self.batch_ob_shape = (n_env * (n_steps + 1), obs_height, obs_width, obs_num_channels)
self.obs_dtype = np.uint8
self.obs = np.zeros((n_env, obs_height, obs_width, obs_num_channels), dtype=self.obs_dtype)
self.num_channels = obs_num_channels
def _make_runner(self) -> AbstractEnvRunner:
return _Runner(env=self.env, model=self, n_steps=self.n_steps)