Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _postprocess_if_needed(self, batch):
if not self.ioctx.config.get("postprocess_inputs"):
return batch
if isinstance(batch, SampleBatch):
out = []
for sub_batch in batch.split_by_episode():
out.append(self.ioctx.worker.policy_map[DEFAULT_POLICY_ID]
.postprocess_trajectory(sub_batch))
return SampleBatch.concat_samples(out)
else:
# TODO(ekl) this is trickier since the alignments between agent
# trajectories in the episode are not available any more.
raise NotImplementedError(
"Postprocessing of multi-agent data not implemented yet.")
if self.workers.remote_workers():
samples = collect_samples(
self.workers.remote_workers(), self.sample_batch_size,
self.num_envs_per_worker, self.train_batch_size)
if samples.count > self.train_batch_size * 2:
logger.info(
"Collected more training samples than expected "
"(actual={}, train_batch_size={}). ".format(
samples.count, self.train_batch_size) +
"This may be because you have many workers or "
"long episodes in 'complete_episodes' batch mode.")
else:
samples = []
while sum(s.count for s in samples) < self.train_batch_size:
samples.append(self.workers.local_worker().sample())
samples = SampleBatch.concat_samples(samples)
# Handle everything as if multiagent
if isinstance(samples, SampleBatch):
samples = MultiAgentBatch({
DEFAULT_POLICY_ID: samples
}, samples.count)
for policy_id, policy in self.policies.items():
if policy_id not in samples.policy_batches:
continue
batch = samples.policy_batches[policy_id]
for field in self.standardize_fields:
value = batch[field]
standardized = (value - value.mean()) / max(1e-4, value.std())
batch[field] = standardized
def _initialize_loss(self):
def fake_array(tensor):
shape = tensor.shape.as_list()
shape = [s if s is not None else 1 for s in shape]
return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)
dummy_batch = {
SampleBatch.CUR_OBS: fake_array(self._obs_input),
SampleBatch.NEXT_OBS: fake_array(self._obs_input),
SampleBatch.DONES: np.array([False], dtype=np.bool),
SampleBatch.ACTIONS: fake_array(
ModelCatalog.get_action_placeholder(self.action_space)),
SampleBatch.REWARDS: np.array([0], dtype=np.float32),
}
if self._obs_include_prev_action_reward:
dummy_batch.update({
SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input),
SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input),
})
state_init = self.get_initial_state()
state_batches = []
for i, h in enumerate(state_init):
dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0)
dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0)
state_batches.append(np.expand_dims(h, 0))
if state_init:
dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)
for k, v in self.extra_compute_action_fetches().items():
dummy_batch[k] = fake_array(v)
def _from_json(batch):
if isinstance(batch, bytes): # smart_open S3 doesn't respect "r"
batch = batch.decode("utf-8")
data = json.loads(batch)
if "type" in data:
data_type = data.pop("type")
else:
raise ValueError("JSON record missing 'type' field")
if data_type == "SampleBatch":
for k, v in data.items():
data[k] = unpack_if_needed(v)
return SampleBatch(data)
elif data_type == "MultiAgentBatch":
policy_batches = {}
for policy_id, policy_batch in data["policy_batches"].items():
inner = {}
for k, v in policy_batch.items():
inner[k] = unpack_if_needed(v)
policy_batches[policy_id] = SampleBatch(inner)
return MultiAgentBatch(policy_batches, data["count"])
else:
raise ValueError(
"Type field must be one of ['SampleBatch', 'MultiAgentBatch']",
data_type)
"Not implemented for train_batch_mode=truncate_episodes"
assert other_agent_batches is not None
[(_, opponent_batch)] = list(other_agent_batches.values())
# also record the opponent obs and actions in the trajectory
sample_batch[OPPONENT_OBS] = opponent_batch[SampleBatch.CUR_OBS]
sample_batch[OPPONENT_ACTION] = opponent_batch[SampleBatch.ACTIONS]
# overwrite default VF prediction with the central VF
sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf(
sample_batch[SampleBatch.CUR_OBS], sample_batch[OPPONENT_OBS],
sample_batch[OPPONENT_ACTION])
else:
# policy hasn't initialized yet, use zeros
sample_batch[OPPONENT_OBS] = np.zeros_like(
sample_batch[SampleBatch.CUR_OBS])
sample_batch[OPPONENT_ACTION] = np.zeros_like(
sample_batch[SampleBatch.ACTIONS])
sample_batch[SampleBatch.VF_PREDS] = np.zeros_like(
sample_batch[SampleBatch.ACTIONS], dtype=np.float32)
train_batch = compute_advantages(
sample_batch,
0.0,
policy.config["gamma"],
policy.config["lambda"],
use_gae=policy.config["use_gae"])
return train_batch
def add_advantages(policy,
sample_batch,
other_agent_batches=None,
episode=None):
completed = sample_batch[SampleBatch.DONES][-1]
if completed:
last_r = 0.0
else:
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1])
return compute_advantages(sample_batch, last_r, policy.config["gamma"],
policy.config["lambda"])
def actor_critic_loss(policy, model, dist_class, train_batch):
logits, _ = model.from_batch(train_batch)
action_dist = dist_class(logits, model)
policy.loss = A3CLoss(action_dist, train_batch[SampleBatch.ACTIONS],
train_batch[Postprocessing.ADVANTAGES],
train_batch[Postprocessing.VALUE_TARGETS],
model.value_function(),
policy.config["vf_loss_coeff"],
policy.config["entropy_coeff"])
moa_loss = setup_moa_loss(logits, model, policy, train_batch)
policy.loss.total_loss += moa_loss.total_loss
# store this for future statistics
policy.moa_loss = moa_loss.total_loss
return policy.loss.total_loss
"is_training": policy._get_is_training_placeholder(),
}, [], None)
target_model_out_tp1, _ = policy.target_model({
"obs": train_batch[SampleBatch.NEXT_OBS],
"is_training": policy._get_is_training_placeholder(),
}, [], None)
# TODO(hartikainen): figure actions and log pis
policy_t, log_pis_t = model.get_policy_output(model_out_t)
policy_tp1, log_pis_tp1 = model.get_policy_output(model_out_tp1)
log_alpha = model.log_alpha
alpha = model.alpha
# q network evaluation
q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])
if policy.config["twin_q"]:
twin_q_t = model.get_twin_q_values(model_out_t,
train_batch[SampleBatch.ACTIONS])
# Q-values for current policy (no noise) in given current state
q_t_det_policy = model.get_q_values(model_out_t, policy_t)
if policy.config["twin_q"]:
twin_q_t_det_policy = model.get_q_values(model_out_t, policy_t)
q_t_det_policy = tf.reduce_min(
(q_t_det_policy, twin_q_t_det_policy), axis=0)
# target q network evaluation
q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1)
if policy.config["twin_q"]:
twin_q_tp1 = policy.target_model.get_twin_q_values(
target_model_out_tp1, policy_tp1)
if get_batch_divisibility_req:
dummy_batch = {
k: tile_to(v, get_batch_divisibility_req(self))
for k, v in dummy_batch.items()
}
# Execute a forward pass to get self.action_dist etc initialized,
# and also obtain the extra action fetches
_, _, fetches = self.compute_actions(
dummy_batch[SampleBatch.CUR_OBS], state_batches,
dummy_batch.get(SampleBatch.PREV_ACTIONS),
dummy_batch.get(SampleBatch.PREV_REWARDS))
dummy_batch.update(fetches)
postprocessed_batch = self.postprocess_trajectory(
SampleBatch(dummy_batch))
# model forward pass for the loss (needed after postprocess to
# overwrite any tensor state from that call)
self.model.from_batch(dummy_batch)
postprocessed_batch = {
k: tf.convert_to_tensor(v)
for k, v in postprocessed_batch.items()
}
loss_fn(self, self.model, self.dist_class, postprocessed_batch)
if stats_fn:
stats_fn(self, postprocessed_batch)