How to use the jax.numpy.shape function in jax

To help you get started, we’ve selected a few jax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / numpyro / distributions / continuous.py View on Github external
def _batch_mahalanobis(bL, bx):
    if bL.shape[:-1] == bx.shape:
        # no need to use the below optimization procedure
        solve_bL_bx = solve_triangular(bL, bx[..., None], lower=True).squeeze(-1)
        return jnp.sum(jnp.square(solve_bL_bx), -1)

    # NB: The following procedure handles the case: bL.shape = (i, 1, n, n), bx.shape = (i, j, n)
    # because we don't want to broadcast bL to the shape (i, j, n, n).

    # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
    # we are going to make bx have shape (..., 1, j,  i, 1, n) to apply batched tril_solve
    sample_ndim = bx.ndim - bL.ndim + 1  # size of sample_shape
    out_shape = jnp.shape(bx)[:-1]  # shape of output
    # Reshape bx with the shape (..., 1, i, j, 1, n)
    bx_new_shape = out_shape[:sample_ndim]
    for (sL, sx) in zip(bL.shape[:-2], out_shape[sample_ndim:]):
        bx_new_shape += (sx // sL, sL)
    bx_new_shape += (-1,)
    bx = jnp.reshape(bx, bx_new_shape)
    # Permute bx to make it have shape (..., 1, j, i, 1, n)
    permute_dims = (tuple(range(sample_ndim))
                    + tuple(range(sample_ndim, bx.ndim - 1, 2))
                    + tuple(range(sample_ndim + 1, bx.ndim - 1, 2))
                    + (bx.ndim - 1,))
    bx = jnp.transpose(bx, permute_dims)

    # reshape to (-1, i, 1, n)
    xt = jnp.reshape(bx, (-1,) + bL.shape[:-1])
    # permute to (i, 1, n, -1)
github pyro-ppl / numpyro / test / test_distributions.py View on Github external
x = biject_to(transform.domain)(random.normal(rng_key, shape))
    y = transform(x)

    # test codomain
    assert_array_equal(transform.codomain(y), jnp.ones(batch_shape))

    # test inv
    z = transform.inv(y)
    assert_allclose(x, z, atol=1e-6, rtol=1e-6)

    # test domain
    assert_array_equal(transform.domain(z), jnp.ones(batch_shape))

    # test log_abs_det_jacobian
    actual = transform.log_abs_det_jacobian(x, y)
    assert jnp.shape(actual) == batch_shape
    if len(shape) == transform.event_dim:
        if len(event_shape) == 1:
            expected = np.linalg.slogdet(jax.jacobian(transform)(x))[1]
            inv_expected = np.linalg.slogdet(jax.jacobian(transform.inv)(y))[1]
        else:
            expected = jnp.log(jnp.abs(grad(transform)(x)))
            inv_expected = jnp.log(jnp.abs(grad(transform.inv)(y)))

        assert_allclose(actual, expected, atol=1e-6)
        assert_allclose(actual, -inv_expected, atol=1e-6)
github pyro-ppl / numpyro / numpyro / distributions / util.py View on Github external
def _binomial(key, p, n, shape):
    shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n))
    # reshape to map over axis 0
    p = jnp.reshape(jnp.broadcast_to(p, shape), -1)
    n = jnp.reshape(jnp.broadcast_to(n, shape), -1)
    key = random.split(key, jnp.size(p))
    if xla_bridge.get_backend().platform == 'cpu':
        ret = lax.map(lambda x: _binomial_dispatch(*x),
                      (key, p, n))
    else:
        ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n)
    return jnp.reshape(ret, shape)
github pyro-ppl / numpyro / numpyro / distributions / util.py View on Github external
def _multinomial(key, p, n, n_max, shape=()):
    if jnp.shape(n) != jnp.shape(p)[:-1]:
        broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
        n = jnp.broadcast_to(n, broadcast_shape)
        p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
    shape = shape or p.shape[:-1]
    # get indices from categorical distribution then gather the result
    indices = categorical(key, p, (n_max,) + shape)
    # mask out values when counts is heterogeneous
    if jnp.ndim(n) > 0:
        mask = promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
        mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
        excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1), jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))], -1)
    else:
        mask = 1
        excess = 0
    # NB: we transpose to move batch shape to the front
    indices_2D = (jnp.reshape(indices * mask, (n_max, -1,))).T
github pyro-ppl / numpyro / numpyro / distributions / util.py View on Github external
def von_mises_centered(key, concentration, shape=(), dtype=jnp.float64):
    """ Compute centered von Mises samples using rejection sampling from [1] with wrapped Cauchy proposal.

        *** References ***
        [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986;
            Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf


        :param key: random number generator key
        :param concentration: concentration of distribution
        :param shape: shape of samples
        :param dtype: float precesions for choosing correct s cutfoff
        :return: centered samples from von Mises
    """
    shape = shape or jnp.shape(concentration)
    dtype = canonicalize_dtype(dtype)
    concentration = lax.convert_element_type(concentration, dtype)
    concentration = jnp.broadcast_to(concentration, shape)
    return _von_mises_centered(key, concentration, shape, dtype)
github pyro-ppl / numpyro / numpyro / distributions / discrete.py View on Github external
def __init__(self, logits=None, validate_args=None):
        self.logits = logits
        super(BernoulliLogits, self).__init__(batch_shape=jnp.shape(self.logits), validate_args=validate_args)
github pyro-ppl / numpyro / numpyro / distributions / discrete.py View on Github external
def __init__(self, logits, total_count=1, validate_args=None):
        if jnp.ndim(logits) < 1:
            raise ValueError("`logits` parameter must be at least one-dimensional.")
        batch_shape = lax.broadcast_shapes(jnp.shape(logits)[:-1], jnp.shape(total_count))
        self.logits = promote_shapes(logits, shape=batch_shape + jnp.shape(logits)[-1:])[0]
        self.total_count = promote_shapes(total_count, shape=batch_shape)[0]
        super(MultinomialLogits, self).__init__(batch_shape=batch_shape,
                                                event_shape=jnp.shape(self.logits)[-1:],
                                                validate_args=validate_args)
github pyro-ppl / numpyro / numpyro / infer / autoguide.py View on Github external
            return tree_map(lambda x: jnp.reshape(x, sample_shape + jnp.shape(x)[1:]),
                            unpacked_samples)
github pyro-ppl / numpyro / numpyro / distributions / continuous.py View on Github external
def __init__(self, loc, cov_factor, cov_diag, validate_args=None):
        if jnp.ndim(loc) < 1:
            raise ValueError("`loc` must be at least one-dimensional.")
        event_shape = jnp.shape(loc)[-1:]
        if jnp.ndim(cov_factor) < 2:
            raise ValueError("`cov_factor` must be at least two-dimensional, "
                             "with optional leading batch dimensions")
        if jnp.shape(cov_factor)[-2:-1] != event_shape:
            raise ValueError("`cov_factor` must be a batch of matrices with shape {} x m"
                             .format(event_shape[0]))
        if jnp.shape(cov_diag)[-1:] != event_shape:
            raise ValueError("`cov_diag` must be a batch of vectors with shape {}".format(self.event_shape))

        loc, cov_factor, cov_diag = promote_shapes(loc[..., jnp.newaxis], cov_factor, cov_diag[..., jnp.newaxis])
        batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(cov_factor), jnp.shape(cov_diag))[:-2]
        self.loc = jnp.broadcast_to(loc[..., 0], batch_shape + event_shape)
        self.cov_factor = cov_factor
        cov_diag = cov_diag[..., 0]
        self.cov_diag = cov_diag
        self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag)
        super(LowRankMultivariateNormal, self).__init__(
            batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args
            )
github pyro-ppl / numpyro / numpyro / distributions / distribution.py View on Github external
def __init__(self, log_factor, validate_args=None):
        batch_shape = jnp.shape(log_factor)
        event_shape = (0,)  # This satisfies .size == 0.
        self.log_factor = log_factor
        super(Unit, self).__init__(batch_shape, event_shape, validate_args=validate_args)