Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __init__(self, nvec, flat):
"""
Probability distributions from multicategorical input
:param nvec: ([int]) the sizes of the different categorical inputs
:param flat: ([float]) the categorical logits input
"""
self.flat = flat
self.categoricals = list(map(CategoricalProbabilityDistribution, tf.split(flat, nvec, axis=-1)))
super(MultiCategoricalProbabilityDistribution, self).__init__()
def proba_distribution_from_flat(self, flat):
return MultiCategoricalProbabilityDistribution(self.n_vec, flat)
def probability_distribution_class(self):
return MultiCategoricalProbabilityDistribution