Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def sample(self):
return tf.stack([p.sample() for p in self.categoricals], axis=-1)
@classmethod
def fromflat(cls, flat):
"""
Create an instance of this from new logits values
:param flat: ([float]) the multi categorical logits input
:return: (ProbabilityDistribution) the instance from the given multi categorical input
"""
raise NotImplementedError
class DiagGaussianProbabilityDistribution(ProbabilityDistribution):
def __init__(self, flat):
"""
Probability distributions from multivariate Gaussian input
:param flat: ([float]) the multivariate Gaussian input data
"""
self.flat = flat
mean, logstd = tf.split(axis=len(flat.shape) - 1, num_or_size_splits=2, value=flat)
self.mean = mean
self.logstd = logstd
self.std = tf.exp(logstd)
super(DiagGaussianProbabilityDistribution, self).__init__()
def flatparam(self):
return self.flat
# a categorical distribution (see http://amid.fish/humble-gumbel)
uniform = tf.random_uniform(tf.shape(self.logits), dtype=self.logits.dtype)
return tf.argmax(self.logits - tf.log(-tf.log(uniform)), axis=-1)
@classmethod
def fromflat(cls, flat):
"""
Create an instance of this from new logits values
:param flat: ([float]) the categorical logits input
:return: (ProbabilityDistribution) the instance from the given categorical input
"""
return cls(flat)
class MultiCategoricalProbabilityDistribution(ProbabilityDistribution):
def __init__(self, nvec, flat):
"""
Probability distributions from multicategorical input
:param nvec: ([int]) the sizes of the different categorical inputs
:param flat: ([float]) the categorical logits input
"""
self.flat = flat
self.categoricals = list(map(CategoricalProbabilityDistribution, tf.split(flat, nvec, axis=-1)))
super(MultiCategoricalProbabilityDistribution, self).__init__()
def flatparam(self):
return self.flat
def mode(self):
return tf.stack([p.mode() for p in self.categoricals], axis=-1)
def proba_distribution_from_latent(self, pi_latent_vector, vf_latent_vector, init_scale=1.0, init_bias=0.0):
pdparam = linear(pi_latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
q_values = linear(vf_latent_vector, 'q', self.size, init_scale=init_scale, init_bias=init_bias)
return self.proba_distribution_from_flat(pdparam), pdparam, q_values
def param_shape(self):
return [self.size]
def sample_shape(self):
return [self.size]
def sample_dtype(self):
return tf.int32
class CategoricalProbabilityDistribution(ProbabilityDistribution):
def __init__(self, logits):
"""
Probability distributions from categorical input
:param logits: ([float]) the categorical logits input
"""
self.logits = logits
super(CategoricalProbabilityDistribution, self).__init__()
def flatparam(self):
return self.logits
def mode(self):
return tf.argmax(self.logits, axis=-1)
def neglogp(self, x):
# Otherwise, it changes the distribution and breaks PPO2 for instance
return self.mean + self.std * tf.random_normal(tf.shape(self.mean),
dtype=self.mean.dtype)
@classmethod
def fromflat(cls, flat):
"""
Create an instance of this from new multivariate Gaussian input
:param flat: ([float]) the multivariate Gaussian input data
:return: (ProbabilityDistribution) the instance from the given multivariate Gaussian input data
"""
return cls(flat)
class BernoulliProbabilityDistribution(ProbabilityDistribution):
def __init__(self, logits):
"""
Probability distributions from Bernoulli input
:param logits: ([float]) the Bernoulli input data
"""
self.logits = logits
self.probabilities = tf.sigmoid(logits)
super(BernoulliProbabilityDistribution, self).__init__()
def flatparam(self):
return self.logits
def mode(self):
return tf.round(self.probabilities)
def __init__(self):
super(ProbabilityDistribution, self).__init__()