How to use the numpyro.distributions.distribution.Distribution function in numpyro

To help you get started, we’ve selected a few numpyro examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / numpyro / distributions / continuous.py View on Github external
    @property
    def variance(self):
        # var is inf for alpha <= 2
        a = lax.div((self.scale ** 2) * self.alpha, (self.alpha - 1) ** 2 * (self.alpha - 2))
        return jnp.where(self.alpha <= 2, jnp.inf, a)

    # override the default behaviour to save computations
    @property
    def support(self):
        return constraints.greater_than(self.scale)

    def tree_flatten(self):
        return super(TransformedDistribution, self).tree_flatten()


class StudentT(Distribution):
    arg_constraints = {'df': constraints.positive, 'loc': constraints.real, 'scale': constraints.positive}
    support = constraints.real
    reparametrized_params = ['loc', 'scale']

    def __init__(self, df, loc=0., scale=1., validate_args=None):
        batch_shape = lax.broadcast_shapes(jnp.shape(df), jnp.shape(loc), jnp.shape(scale))
        self.df = jnp.broadcast_to(df, batch_shape)
        self.loc, self.scale = promote_shapes(loc, scale, shape=batch_shape)
        self._chi2 = Chi2(self.df)
        super(StudentT, self).__init__(batch_shape, validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        key_normal, key_chi2 = random.split(key)
        std_normal = random.normal(key_normal, shape=sample_shape + self.batch_shape)
        z = self._chi2.sample(key_chi2, sample_shape)
        y = std_normal * jnp.sqrt(self.df / z)
github pyro-ppl / numpyro / numpyro / distributions / discrete.py View on Github external
    @property
    def support(self):
        return constraints.multinomial(self.total_count)


def Multinomial(total_count=1, probs=None, logits=None, validate_args=None):
    if probs is not None:
        return MultinomialProbs(probs, total_count, validate_args=validate_args)
    elif logits is not None:
        return MultinomialLogits(logits, total_count, validate_args=validate_args)
    else:
        raise ValueError('One of `probs` or `logits` must be specified.')


class Poisson(Distribution):
    arg_constraints = {'rate': constraints.positive}
    support = constraints.nonnegative_integer
    is_discrete = True

    def __init__(self, rate, validate_args=None):
        self.rate = rate
        super(Poisson, self).__init__(jnp.shape(rate), validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        return random.poisson(key, self.rate, shape=sample_shape + self.batch_shape)

    @validate_sample
    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        return (jnp.log(self.rate) * value) - gammaln(value + 1) - self.rate
github pyro-ppl / numpyro / numpyro / distributions / discrete.py View on Github external
    @property
    def mean(self):
        return self.probs

    @property
    def variance(self):
        return self.probs * (1 - self.probs)

    def enumerate_support(self, expand=True):
        values = jnp.arange(2).reshape((-1,) + (1,) * len(self.batch_shape))
        if expand:
            values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
        return values


class BernoulliLogits(Distribution):
    arg_constraints = {'logits': constraints.real}
    support = constraints.boolean
    has_enumerate_support = True
    is_discrete = True

    def __init__(self, logits=None, validate_args=None):
        self.logits = logits
        super(BernoulliLogits, self).__init__(batch_shape=jnp.shape(self.logits), validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        return random.bernoulli(key, self.probs, shape=sample_shape + self.batch_shape)

    @validate_sample
    def log_prob(self, value):
        return -binary_cross_entropy_with_logits(self.logits, value)
github pyro-ppl / numpyro / numpyro / distributions / discrete.py View on Github external
    @validate_sample
    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        return (jnp.log(self.rate) * value) - gammaln(value + 1) - self.rate

    @property
    def mean(self):
        return self.rate

    @property
    def variance(self):
        return self.rate


class ZeroInflatedPoisson(Distribution):
    """
    A Zero Inflated Poisson distribution.

    :param numpy.ndarray gate: probability of extra zeros.
    :param numpy.ndarray rate: rate of Poisson distribution.
    """
    arg_constraints = {'gate': constraints.unit_interval, 'rate': constraints.positive}
    support = constraints.nonnegative_integer
    is_discrete = True

    def __init__(self, gate, rate=1., validate_args=None):
        batch_shape = lax.broadcast_shapes(jnp.shape(gate), jnp.shape(rate))
        self.gate, self.rate = promote_shapes(gate, rate)
        super(ZeroInflatedPoisson, self).__init__(batch_shape, validate_args=validate_args)

    def sample(self, key, sample_shape=()):
github pyro-ppl / numpyro / numpyro / distributions / continuous.py View on Github external
diff,
                                       self._capacitance_tril)
        log_det = _batch_lowrank_logdet(self.cov_factor,
                                        self.cov_diag,
                                        self._capacitance_tril)
        return -0.5 * (self.loc.shape[-1] * jnp.log(2 * jnp.pi) + log_det + M)

    def entropy(self):
        log_det = _batch_lowrank_logdet(self.cov_factor,
                                        self.cov_diag,
                                        self._capacitance_tril)
        H = 0.5 * (self.loc.shape[-1] * (1.0 + jnp.log(2 * jnp.pi)) + log_det)
        return jnp.broadcast_to(H, self.batch_shape)


class Normal(Distribution):
    arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
    support = constraints.real
    reparametrized_params = ['loc', 'scale']

    def __init__(self, loc=0., scale=1., validate_args=None):
        self.loc, self.scale = promote_shapes(loc, scale)
        batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
        super(Normal, self).__init__(batch_shape=batch_shape, validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        eps = random.normal(key, shape=sample_shape + self.batch_shape + self.event_shape)
        return self.loc + eps * self.scale

    @validate_sample
    def log_prob(self, value):
        normalize_term = jnp.log(jnp.sqrt(2 * jnp.pi) * self.scale)
github pyro-ppl / numpyro / numpyro / distributions / distribution.py View on Github external
def set_default_validate_args(value):
        if value not in [True, False]:
            raise ValueError
        Distribution._validate_args = value
github pyro-ppl / numpyro / numpyro / distributions / conjugate.py View on Github external
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from jax import lax, random
import jax.numpy as np
from jax.scipy.special import betaln, gammaln

from numpyro.distributions import constraints
from numpyro.distributions.continuous import Beta, Gamma
from numpyro.distributions.discrete import Binomial, Poisson
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import promote_shapes, validate_sample


class BetaBinomial(Distribution):
    r"""
    Compound distribution comprising of a beta-binomial pair. The probability of
    success (``probs`` for the :class:`~numpyro.distributions.Binomial` distribution)
    is unknown and randomly drawn from a :class:`~numpyro.distributions.Beta` distribution
    prior to a certain number of Bernoulli trials given by ``total_count``.

    :param numpy.ndarray concentration1: 1st concentration parameter (alpha) for the
        Beta distribution.
    :param numpy.ndarray concentration0: 2nd concentration parameter (beta) for the
        Beta distribution.
    :param numpy.ndarray total_count: number of Bernoulli trials.
    """
    arg_constraints = {'concentration1': constraints.positive, 'concentration0': constraints.positive,
                       'total_count': constraints.nonnegative_integer}

    def __init__(self, concentration1, concentration0, total_count=1, validate_args=None):
github pyro-ppl / numpyro / numpyro / distributions / distribution.py View on Github external
"""
        Masks a distribution by a boolean or boolean-valued array that is
        broadcastable to the distributions
        :attr:`Distribution.batch_shape` .

        :param mask: A boolean or boolean valued array.
        :type mask: bool or jnp.ndarray
        :return: A masked copy of this distribution.
        :rtype: :class:`MaskedDistribution`
        """
        if mask is True:
            return self
        return MaskedDistribution(self, mask)


class ExpandedDistribution(Distribution):
    arg_constraints = {}

    def __init__(self, base_dist, batch_shape=()):
        if isinstance(base_dist, ExpandedDistribution):
            batch_shape = self._broadcast_shape(base_dist.batch_shape, batch_shape)
            base_dist = base_dist.base_dist
        self.base_dist = base_dist
        super().__init__(base_dist.batch_shape, base_dist.event_shape)
        # adjust batch shape
        self.expand(batch_shape)

    def expand(self, batch_shape):
        # Do basic validation. e.g. we should not "unexpand" distributions even if that is possible.
        new_shape, _, _ = self._broadcast_shape(self.batch_shape, batch_shape)
        # Record interstitial and expanded dims/sizes w.r.t. the base distribution
        new_shape, expanded_sizes, interstitial_sizes = self._broadcast_shape(self.base_dist.batch_shape,
github pyro-ppl / numpyro / numpyro / distributions / discrete.py View on Github external
return constraints.integer_interval(0, self.total_count)

    def enumerate_support(self, expand=True):
        total_count = jnp.amax(self.total_count)
        if not_jax_tracer(total_count):
            # NB: the error can't be raised if inhomogeneous issue happens when tracing
            if jnp.amin(self.total_count) != total_count:
                raise NotImplementedError("Inhomogeneous total count not supported"
                                          " by `enumerate_support`.")
        values = jnp.arange(total_count + 1).reshape((-1,) + (1,) * len(self.batch_shape))
        if expand:
            values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
        return values


class BinomialLogits(Distribution):
    arg_constraints = {'logits': constraints.real,
                       'total_count': constraints.nonnegative_integer}
    has_enumerate_support = True
    is_discrete = True

    def __init__(self, logits, total_count=1, validate_args=None):
        self.logits, self.total_count = promote_shapes(logits, total_count)
        batch_shape = lax.broadcast_shapes(jnp.shape(logits), jnp.shape(total_count))
        super(BinomialLogits, self).__init__(batch_shape=batch_shape, validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        return binomial(key, self.probs, n=self.total_count, shape=sample_shape + self.batch_shape)

    @validate_sample
    def log_prob(self, value):
        log_factorial_n = gammaln(self.total_count + 1)
github pyro-ppl / numpyro / numpyro / distributions / continuous.py View on Github external
    @validate_sample
    def log_prob(self, value):
        return self._dirichlet.log_prob(jnp.stack([value, 1. - value], -1))

    @property
    def mean(self):
        return self.concentration1 / (self.concentration1 + self.concentration0)

    @property
    def variance(self):
        total = self.concentration1 + self.concentration0
        return self.concentration1 * self.concentration0 / (total ** 2 * (total + 1))


class Cauchy(Distribution):
    arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
    support = constraints.real
    reparametrized_params = ['loc', 'scale']

    def __init__(self, loc=0., scale=1., validate_args=None):
        self.loc, self.scale = promote_shapes(loc, scale)
        batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
        super(Cauchy, self).__init__(batch_shape=batch_shape, validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        eps = random.cauchy(key, shape=sample_shape + self.batch_shape)
        return self.loc + eps * self.scale

    @validate_sample
    def log_prob(self, value):
        return - jnp.log(jnp.pi) - jnp.log(self.scale) - jnp.log1p(((value - self.loc) / self.scale) ** 2)