# How to use the jax.numpy.shape function in jax

## To help you get started, we’ve selected a few jax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

pyro-ppl / numpyro / numpyro / distributions / continuous.py View on Github
``````def _batch_mahalanobis(bL, bx):
if bL.shape[:-1] == bx.shape:
# no need to use the below optimization procedure
solve_bL_bx = solve_triangular(bL, bx[..., None], lower=True).squeeze(-1)
return jnp.sum(jnp.square(solve_bL_bx), -1)

# NB: The following procedure handles the case: bL.shape = (i, 1, n, n), bx.shape = (i, j, n)
# because we don't want to broadcast bL to the shape (i, j, n, n).

# Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
# we are going to make bx have shape (..., 1, j,  i, 1, n) to apply batched tril_solve
sample_ndim = bx.ndim - bL.ndim + 1  # size of sample_shape
out_shape = jnp.shape(bx)[:-1]  # shape of output
# Reshape bx with the shape (..., 1, i, j, 1, n)
bx_new_shape = out_shape[:sample_ndim]
for (sL, sx) in zip(bL.shape[:-2], out_shape[sample_ndim:]):
bx_new_shape += (sx // sL, sL)
bx_new_shape += (-1,)
bx = jnp.reshape(bx, bx_new_shape)
# Permute bx to make it have shape (..., 1, j, i, 1, n)
permute_dims = (tuple(range(sample_ndim))
+ tuple(range(sample_ndim, bx.ndim - 1, 2))
+ tuple(range(sample_ndim + 1, bx.ndim - 1, 2))
+ (bx.ndim - 1,))
bx = jnp.transpose(bx, permute_dims)

# reshape to (-1, i, 1, n)
xt = jnp.reshape(bx, (-1,) + bL.shape[:-1])
# permute to (i, 1, n, -1)``````
pyro-ppl / numpyro / test / test_distributions.py View on Github
``````x = biject_to(transform.domain)(random.normal(rng_key, shape))
y = transform(x)

# test codomain
assert_array_equal(transform.codomain(y), jnp.ones(batch_shape))

# test inv
z = transform.inv(y)
assert_allclose(x, z, atol=1e-6, rtol=1e-6)

# test domain
assert_array_equal(transform.domain(z), jnp.ones(batch_shape))

# test log_abs_det_jacobian
actual = transform.log_abs_det_jacobian(x, y)
assert jnp.shape(actual) == batch_shape
if len(shape) == transform.event_dim:
if len(event_shape) == 1:
expected = np.linalg.slogdet(jax.jacobian(transform)(x))[1]
inv_expected = np.linalg.slogdet(jax.jacobian(transform.inv)(y))[1]
else:

assert_allclose(actual, expected, atol=1e-6)
assert_allclose(actual, -inv_expected, atol=1e-6)``````
pyro-ppl / numpyro / numpyro / distributions / util.py View on Github
``````def _binomial(key, p, n, shape):
shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n))
# reshape to map over axis 0
key = random.split(key, jnp.size(p))
if xla_bridge.get_backend().platform == 'cpu':
ret = lax.map(lambda x: _binomial_dispatch(*x),
(key, p, n))
else:
ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n)
return jnp.reshape(ret, shape)``````
pyro-ppl / numpyro / numpyro / distributions / util.py View on Github
``````def _multinomial(key, p, n, n_max, shape=()):
if jnp.shape(n) != jnp.shape(p)[:-1]:
shape = shape or p.shape[:-1]
# get indices from categorical distribution then gather the result
indices = categorical(key, p, (n_max,) + shape)
# mask out values when counts is heterogeneous
if jnp.ndim(n) &gt; 0:
mask = promote_shapes(jnp.arange(n_max) &lt; jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1), jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))], -1)
else:
excess = 0
# NB: we transpose to move batch shape to the front
indices_2D = (jnp.reshape(indices * mask, (n_max, -1,))).T``````
pyro-ppl / numpyro / numpyro / distributions / util.py View on Github
``````def von_mises_centered(key, concentration, shape=(), dtype=jnp.float64):
""" Compute centered von Mises samples using rejection sampling from [1] with wrapped Cauchy proposal.

*** References ***
[1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986;
Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf

:param key: random number generator key
:param concentration: concentration of distribution
:param shape: shape of samples
:param dtype: float precesions for choosing correct s cutfoff
:return: centered samples from von Mises
"""
shape = shape or jnp.shape(concentration)
dtype = canonicalize_dtype(dtype)
concentration = lax.convert_element_type(concentration, dtype)
return _von_mises_centered(key, concentration, shape, dtype)``````
pyro-ppl / numpyro / numpyro / distributions / discrete.py View on Github
``````def __init__(self, logits=None, validate_args=None):
self.logits = logits
super(BernoulliLogits, self).__init__(batch_shape=jnp.shape(self.logits), validate_args=validate_args)``````
pyro-ppl / numpyro / numpyro / distributions / discrete.py View on Github
``````def __init__(self, logits, total_count=1, validate_args=None):
if jnp.ndim(logits) &lt; 1:
raise ValueError("`logits` parameter must be at least one-dimensional.")
self.logits = promote_shapes(logits, shape=batch_shape + jnp.shape(logits)[-1:])[0]
self.total_count = promote_shapes(total_count, shape=batch_shape)[0]
super(MultinomialLogits, self).__init__(batch_shape=batch_shape,
event_shape=jnp.shape(self.logits)[-1:],
validate_args=validate_args)``````
pyro-ppl / numpyro / numpyro / infer / autoguide.py View on Github
``````            return tree_map(lambda x: jnp.reshape(x, sample_shape + jnp.shape(x)[1:]),
unpacked_samples)``````
pyro-ppl / numpyro / numpyro / distributions / continuous.py View on Github
``````def __init__(self, loc, cov_factor, cov_diag, validate_args=None):
if jnp.ndim(loc) &lt; 1:
raise ValueError("`loc` must be at least one-dimensional.")
event_shape = jnp.shape(loc)[-1:]
if jnp.ndim(cov_factor) &lt; 2:
raise ValueError("`cov_factor` must be at least two-dimensional, "
if jnp.shape(cov_factor)[-2:-1] != event_shape:
raise ValueError("`cov_factor` must be a batch of matrices with shape {} x m"
.format(event_shape[0]))
if jnp.shape(cov_diag)[-1:] != event_shape:
raise ValueError("`cov_diag` must be a batch of vectors with shape {}".format(self.event_shape))

loc, cov_factor, cov_diag = promote_shapes(loc[..., jnp.newaxis], cov_factor, cov_diag[..., jnp.newaxis])
self.loc = jnp.broadcast_to(loc[..., 0], batch_shape + event_shape)
self.cov_factor = cov_factor
cov_diag = cov_diag[..., 0]
self.cov_diag = cov_diag
self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag)
super(LowRankMultivariateNormal, self).__init__(
batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args
)``````
pyro-ppl / numpyro / numpyro / distributions / distribution.py View on Github
``````def __init__(self, log_factor, validate_args=None):
batch_shape = jnp.shape(log_factor)
event_shape = (0,)  # This satisfies .size == 0.
self.log_factor = log_factor
super(Unit, self).__init__(batch_shape, event_shape, validate_args=validate_args)``````

## jax

Differentiate, compile, and transform Numpy code.

Apache-2.0