Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
@DeveloperAPI
def get_initial_state(self):
"""Returns initial RNN state for the current policy."""
return []
@DeveloperAPI
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
"""Implements algorithm-specific trajectory postprocessing.
This will be called on each trajectory fragment computed during policy
evaluation. Each fragment is guaranteed to be only from one episode.
Arguments:
sample_batch (SampleBatch): batch of experiences for the policy,
which will contain at most one episode trajectory.
other_agent_batches (dict): In a multi-agent env, this contains a
mapping of agent ids to (policy, agent_batch) tuples
containing the policy and experiences of the other agents.
episode (MultiAgentEpisode): this provides access to all of the
@DeveloperAPI
def foreach_policy(self, func):
"""Apply the given function to each (policy, policy_id) tuple."""
return [func(policy, pid) for pid, policy in self.policy_map.items()]
@DeveloperAPI
def __init__(self,
env_creator,
policy,
policy_mapping_fn=None,
policies_to_train=None,
tf_session_creator=None,
batch_steps=100,
batch_mode="truncate_episodes",
episode_horizon=None,
preprocessor_pref="deepmind",
sample_async=False,
compress_observations=False,
num_envs=1,
observation_filter="NoFilter",
clip_rewards=None,
clip_actions=True,
@DeveloperAPI
def get_action_placeholder(action_space):
"""Returns an action placeholder consistent with the action space
Args:
action_space (Space): Action space of the target gym env.
Returns:
action_placeholder (Tensor): A placeholder for the actions
"""
dtype, shape = ModelCatalog.get_action_shape(action_space)
return tf.placeholder(dtype, shape=shape, name="action")
@DeveloperAPI
def kl(self, other):
"""The KL-divergence between two action distributions."""
raise NotImplementedError
@DeveloperAPI
def get_state(self):
"""Saves all local state.
Returns:
state (obj): Serialized local state.
"""
return self.get_weights()
@DeveloperAPI
def foreach_worker_with_index(self, func):
"""Apply the given function to each worker instance.
The index will be passed as the second arg to the given function.
"""
return self.workers.foreach_worker_with_index(func)
@DeveloperAPI
def sample_with_idxes(self, idxes, beta):
assert beta > 0
self._num_sampled += len(idxes)
weights = []
p_min = self._it_min.min() / self._it_sum.sum()
max_weight = (p_min * len(self._storage))**(-beta)
for idx in idxes:
p_sample = self._it_sum[idx] / self._it_sum.sum()
weight = (p_sample * len(self._storage))**(-beta)
weights.append(weight / max_weight)
weights = np.array(weights)
encoded_sample = self._encode_sample(idxes)
return tuple(list(encoded_sample) + [weights, idxes])
@DeveloperAPI
def logp(self, x):
"""The log-likelihood of the action distribution."""
raise NotImplementedError