Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
dtype=util.tf_dtype('int'),
trainable=False,
initial_value=0
)
self.graph.add_to_collection(name='global-timestep', value=self.global_timestep)
self.graph.add_to_collection(name=tf.GraphKeys.GLOBAL_STEP, value=self.global_timestep)
else:
assert len(collection) == 1
self.global_timestep = collection[0]
# Global episode
collection = self.graph.get_collection(name='global-episode')
if len(collection) == 0:
self.global_episode = tf.Variable(
name='global-episode',
dtype=util.tf_dtype('int'),
trainable=False,
initial_value=0
)
self.graph.add_to_collection(name='global-episode', value=self.global_episode)
else:
assert len(collection) == 1
self.global_episode = collection[0]
# Create placeholders, tf functions, internals, etc
self.initialize(custom_getter=custom_getter)
# self.fn_actions_and_internals(
# states=states,
# internals=internals,
# update=update,
# deterministic=deterministic
parallel = self.parallel_input
zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long'))
buffer_index = tf.gather(params=self.buffer_index, indices=parallel)
# Assertions
assertions = list()
# terminal: type and shape
tf.debugging.assert_type(tensor=terminal, tf_type=util.tf_dtype(dtype='long'))
assertions.append(tf.debugging.assert_rank(x=terminal, rank=1))
# reward: type and shape
tf.debugging.assert_type(tensor=reward, tf_type=util.tf_dtype(dtype='float'))
assertions.append(tf.debugging.assert_rank(x=reward, rank=1))
# parallel: type, shape and value
tf.debugging.assert_type(tensor=parallel, tf_type=util.tf_dtype(dtype='long'))
tf.debugging.assert_scalar(tensor=parallel[0])
assertions.append(tf.debugging.assert_non_negative(x=parallel))
assertions.append(tf.debugging.assert_less(
x=parallel[0],
y=tf.constant(value=self.parallel_interactions, dtype=util.tf_dtype(dtype='long'))
))
# shape of terminal equals shape of reward
assertions.append(tf.debugging.assert_equal(
x=tf.shape(input=terminal), y=tf.shape(input=reward)
))
# size of terminal equals buffer index
assertions.append(tf.debugging.assert_equal(
x=tf.shape(input=terminal, out_type=tf.int64)[0],
y=tf.dtypes.cast(x=buffer_index, dtype=tf.int64)
))
# at most one terminal
# e.g. F T F F F F
# Store the steps until end of the episode(s) determined by the input terminal signals (True starts new count).
lengths = tf.scan(
fn=len_, elems=terminal,
initializer=tf.zeros_like(tensor=terminal[0], dtype=util.tf_dtype(dtype='int'))
)
# e.g. 1 1 2 3 4 5
off_horizon = tf.greater(x=lengths, y=tf.fill(dims=tf.shape(input=lengths), value=tf.constant(value=horizon, dtype=util.tf_dtype(dtype='int'))))
# e.g. F F F F T T
# Calculate the horizon-subtraction value for each step.
if horizon > 0:
horizon_subtractions = tf.map_fn(lambda x: (discount ** horizon) * x, reward, dtype=util.tf_dtype(dtype='float'))
# Shift right by size of horizon (fill rest with 0.0).
horizon_subtractions = tf.concat([np.zeros(shape=(horizon,), dtype=util.tf_dtype(dtype='float')), horizon_subtractions], axis=0)
horizon_subtractions = tf.slice(horizon_subtractions, begin=(0,), size=tf.shape(reward))
# e.g. 0.0, 0.0, 0.0, -1.0*g^3, 1.0*g^3, 0.5*g^3
# all 0.0 if infinite horizon (special case: horizon=0)
else:
horizon_subtractions = tf.zeros(shape=tf.shape(reward), dtype=util.tf_dtype(dtype='float'))
# Now do the scan, each time summing up the previous step (discounted by gamma) and
# subtracting the respective `horizon_subtraction`.
if isinstance(final_reward, float):
final_reward = tf.constant(value=final_reward, dtype=util.tf_dtype(dtype='float'))
reward = tf.scan(
fn=cumulate,
elems=(reward, terminal, off_horizon, horizon_subtractions),
initializer=final_reward if horizon != 1 else tf.constant(value=0.0, dtype=util.tf_dtype(dtype='float'))
)
# Re-reverse again to match input sequences.
def tf_apply(self, x):
if util.tf_dtype('int') not in (tf.int32, tf.int64):
x = tf.dtypes.cast(x=x, dtype=tf.int32)
elif util.dtype(x=x) == 'bool':
x = tf.dtypes.cast(x=x, dtype=util.tf_dtype('int'))
x = tf.nn.embedding_lookup(params=self.weights, ids=x, max_norm=self.max_norm)
return super().tf_apply(x=x)
def tf_enqueue(self, **values):
# Constants
zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long'))
capacity = tf.constant(value=self.capacity, dtype=util.tf_dtype(dtype='long'))
# Get number of values
for value in values.values():
if not isinstance(value, dict):
break
elif len(value) > 0:
value = next(iter(value.values()))
break
if util.tf_dtype(dtype='long') in (tf.int32, tf.int64):
num_values = tf.shape(input=value, out_type=util.tf_dtype(dtype='long'))[0]
else:
num_values = tf.dtypes.cast(
x=tf.shape(input=value)[0], dtype=util.tf_dtype(dtype='long')
)
# Check whether instances fit into buffer
direction = tf.sign(x=(unperturbed_loss - perturbed_loss))
deltas = [
delta + direction * perturbation
for delta, perturbation in zip(deltas, perturbations)
]
return deltas, perturbations
num_samples = self.num_samples.value()
deltas, perturbations = self.while_loop(
cond=util.tf_always_true, body=body, loop_vars=(deltas, previous_perturbations),
back_prop=False, maximum_iterations=num_samples
)
with tf.control_dependencies(control_inputs=deltas):
num_samples = tf.dtypes.cast(x=num_samples, dtype=util.tf_dtype(dtype='float'))
deltas = [delta / num_samples for delta in deltas]
perturbation_deltas = [delta - pert for delta, pert in zip(deltas, perturbations)]
applied = self.apply_step(variables=variables, deltas=perturbation_deltas)
with tf.control_dependencies(control_inputs=(applied,)):
# Trivial operation to enforce control dependency
return util.fmap(function=util.identity_operation, xs=deltas)
loss_per_instance = tf.reduce_mean(input_tensor=tf.concat(values=deltas, axis=1), axis=1)
# Optional Huber loss
huber_loss = self.huber_loss.value()
def no_huber_loss():
return tf.square(x=loss_per_instance)
def apply_huber_loss():
return tf.where(
condition=(tf.abs(x=loss_per_instance) <= huber_loss),
x=(0.5 * tf.square(x=loss_per_instance)),
y=(huber_loss * (tf.abs(x=loss_per_instance) - 0.5 * huber_loss))
)
zero = tf.constant(value=0.0, dtype=util.tf_dtype(dtype='float'))
skip_huber_loss = tf.math.equal(x=huber_loss, y=zero)
return self.cond(pred=skip_huber_loss, true_fn=no_huber_loss, false_fn=apply_huber_loss)
def body(deltas, previous_perturbations):
with tf.control_dependencies(control_inputs=deltas):
perturbations = [
learning_rate * tf.random.normal(
shape=util.shape(x=variable), dtype=util.tf_dtype(dtype='float')
) for variable in variables
]
perturbation_deltas = [
pert - prev_pert
for pert, prev_pert in zip(perturbations, previous_perturbations)
]
applied = self.apply_step(variables=variables, deltas=perturbation_deltas)
with tf.control_dependencies(control_inputs=(applied,)):
perturbed_loss = fn_loss(**arguments)
direction = tf.sign(x=(unperturbed_loss - perturbed_loss))
deltas = [
delta + direction * perturbation
for delta, perturbation in zip(deltas, perturbations)
]
name='count',
dtype=util.tf_dtype('float'),
initializer=0.0,
trainable=False
)
mean_estimate = tf.get_variable(
name='mean-estimate',
shape=self.shape,
dtype=util.tf_dtype('float'),
initializer=tf.zeros_initializer(),
trainable=False
)
variance_sum_estimate = tf.get_variable(
name='variance-sum-estimate',
shape=self.shape,
dtype=util.tf_dtype('float'),
initializer=tf.zeros_initializer(),
trainable=False
)
self.reset_op = tf.variables_initializer([count, mean_estimate, variance_sum_estimate], name='reset-op')
assignment = tf.assign_add(ref=count, value=1.0)
with tf.control_dependencies(control_inputs=(assignment,)):
# Mean update
mean = tf.reduce_sum(input_tensor=(tensor - mean_estimate), axis=0) # reduce_mean?
assignment = tf.assign_add(ref=mean_estimate, value=(mean / count))
with tf.control_dependencies(control_inputs=(assignment,)):
def first_run():
# No meaningful mean and variance yet.
if self.update_unit == 'timesteps':
# Timestep-based batch
batch = self.memory.retrieve_timesteps(n=batch_size)
elif self.update_unit == 'episodes':
# Episode-based batch
batch = self.memory.retrieve_episodes(n=batch_size)
elif self.update_unit == 'sequences':
# Timestep-sequence-based batch
batch = self.memory.retrieve_sequences(
n=batch_size, sequence_length=sequence_length
)
# Do not calculate gradients for memory-internal operations.
batch = util.fmap(function=tf.stop_gradient, xs=batch)
Module.update_tensors(
update=tf.constant(value=True, dtype=util.tf_dtype(dtype='bool'))
)
optimized = self.optimization(**batch)
return optimized