# How to use the jax.lax function in jax

## To help you get started, we’ve selected a few jax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

google / jax / polysimp.py View on Github
``````from jax.interpreters import polysimp
from jax import lax
from jax import make_jaxpr
import jax.numpy as np

import jax.linear_util as lu

f = lambda x: x + x * x * x + 3 * x + 4 * x * x * x

print make_jaxpr(f)(2)
print f(2)

print make_jaxpr(polysimp.polysimp(lu.wrap_init(f)).call_wrapped)((2,))
print polysimp.polysimp(lu.wrap_init(f)).call_wrapped((2,))

import numpy as onp``````
team-ocean / veros / veros / core / operators.py View on Github
``````update_add = update_add_numpy
update_multiply = update_multiply_numpy
at = Index()
solve_tridiagonal = solve_tridiagonal_numpy
scan = scan_numpy

elif runtime_settings.backend == 'jax':
import jax
import jax.numpy
numpy = jax.numpy
update = jax.ops.index_update
update_multiply = update_multiply_jax
at = jax.ops.index
solve_tridiagonal = solve_tridiagonal_jax
scan = jax.lax.scan

else:
raise ValueError()``````
tensorflow / probability / discussion / fun_mcmc / tf_on_jax.py View on Github
``````def _while_loop(cond, body, loop_vars, **kwargs):  # pylint: disable=missing-docstring
del kwargs

# JAX doesn't do the automatic unwrapping of variables.
def cond_wrapper(loop_vars):
return cond(*loop_vars)

def body_wrapper(loop_vars):
return body(*loop_vars)

return lax.while_loop(cond_wrapper, body_wrapper, loop_vars)``````
google / jax / jax / initial_style.py View on Github
``````def _index_arrays(i, aval, xs):
if isinstance(aval, core.AbstractTuple):
return core.pack(map(partial(_index_arrays, i), aval, xs))
else:
return lax.dynamic_index_in_dim(xs, i, keepdims=False)``````
google / jax / jax / numpy / lax_numpy.py View on Github
``````bitwise_not = _one_to_one_unop(onp.bitwise_not, lax.bitwise_not)
negative = _one_to_one_unop(onp.negative, lax.neg)
positive = _one_to_one_unop(onp.positive, lambda x: x)
sign = _one_to_one_unop(onp.sign, lax.sign)

floor = _one_to_one_unop(onp.floor, lax.floor, True)
ceil = _one_to_one_unop(onp.ceil, lax.ceil, True)
exp = _one_to_one_unop(onp.exp, lax.exp, True)
log = _one_to_one_unop(onp.log, lax.log, True)
expm1 = _one_to_one_unop(onp.expm1, lax.expm1, True)
log1p = _one_to_one_unop(onp.log1p, lax.log1p, True)
sin = _one_to_one_unop(onp.sin, lax.sin, True)
cos = _one_to_one_unop(onp.cos, lax.cos, True)
tan = _one_to_one_unop(onp.tan, lax.tan, True)
arcsin = _one_to_one_unop(onp.arcsin, lax.asin, True)
arccos = _one_to_one_unop(onp.arccos, lax.acos, True)
arctan = _one_to_one_unop(onp.arctan, lax.atan, True)
sinh = _one_to_one_unop(onp.sinh, lax.sinh, True)
cosh = _one_to_one_unop(onp.cosh, lax.cosh, True)
tanh = _one_to_one_unop(onp.tanh, lax.tanh, True)
sqrt = _one_to_one_unop(onp.sqrt, lax.sqrt, True)

bitwise_and = _one_to_one_binop(onp.bitwise_and, lax.bitwise_and)
bitwise_or = _one_to_one_binop(onp.bitwise_or, lax.bitwise_or)
bitwise_xor = _one_to_one_binop(onp.bitwise_xor, lax.bitwise_xor)
right_shift = _one_to_one_binop(onp.right_shift, lax.shift_right_arithmetic)
left_shift = _one_to_one_binop(onp.left_shift, lax.shift_left)
equal = _one_to_one_binop(onp.equal, lax.eq)
multiply = _maybe_bool_binop(onp.multiply, lax.mul, lax.bitwise_and)
not_equal = _one_to_one_binop(onp.not_equal, lax.ne)``````
google / trax / trax / layers / research / efficient_attention.py View on Github
``````"""Masks out elements attending to self.

Args:
N: number of query positions
M: number of key positions
k: position of the initial query element

Returns:
N x M mask, where 1.0 indicates that attention is not allowed.
"""
x = jax.lax.tie_in(k, np.arange(N, dtype=np.int32))
y = jax.lax.tie_in(k, np.arange(M, dtype=np.int32))
x, shape=(N, M), broadcast_dimensions=(0,)) + k),
google / jax / jax / scipy / stats / uniform.py View on Github
``````def logpdf(x, loc=0, scale=1):
x, loc, scale = _promote_args_inexact("uniform.logpdf", x, loc, scale)
log_probs = lax.neg(lax.log(scale))
lax.lt(x, loc)),
-inf, log_probs)``````
tensorflow / tensor2tensor / tensor2tensor / trax / layers / attention.py View on Github
``````query_slice = jax.lax.dynamic_slice_in_dim(
query, q_loop_idx, q_loop_stride, axis=-2)

if do_backprop:
ct_slice = jax.lax.dynamic_slice_in_dim(
ct, q_loop_idx, q_loop_stride, axis=-2)
out_slice, partial_ct = forward_and_vjp_slice(
query_slice, q_loop_idx, key, value, ct_slice)
query_ct_accum = jax.lax.dynamic_update_slice_in_dim(
query_ct_accum, partial_ct[0], q_loop_idx, axis=-2)
key_ct_accum = key_ct_accum + partial_ct[1]
value_ct_accum = value_ct_accum + partial_ct[2]
else:
out_slice = forward_slice(query_slice, q_loop_idx, key, value)

out_accum = jax.lax.dynamic_update_slice_in_dim(
out_accum, out_slice, q_loop_idx, axis=-2)
q_loop_idx = q_loop_idx + q_loop_stride

if do_backprop:
return (q_loop_idx, out_accum,
query_ct_accum, key_ct_accum, value_ct_accum)
else:
return (q_loop_idx, out_accum)``````
pyro-ppl / numpyro / numpyro / distributions / util.py View on Github
``````def vec_to_tril_matrix(t, diagonal=0):
# NB: the following formula only works for diagonal &lt;= 0
n = round((math.sqrt(1 + 8 * t.shape[-1]) - 1) / 2) - diagonal
n2 = n * n
idx = jnp.reshape(jnp.arange(n2), (n, n))[jnp.tril_indices(n, diagonal)]
x = lax.scatter_add(jnp.zeros(t.shape[:-1] + (n2,)), jnp.expand_dims(idx, axis=-1), t,
lax.ScatterDimensionNumbers(update_window_dims=range(t.ndim - 1),
inserted_window_dims=(t.ndim - 1,),
scatter_dims_to_operand_dims=(t.ndim - 1,)))
return jnp.reshape(x, x.shape[:-1] + (n, n))``````
JuliusKunze / jaxnet / examples / wavenet.py View on Github
``````|             |-> (*) -|
input -|-> [filter] -|        |-> 1x1 conv -|
|                                    |-> (+) -> dense output
|------------------------------------|

Where `[gate]` and `[filter]` are causal convolutions with a
non-linear activation at the output
"""
gated = Sequential(Conv1D(dilation_channels, (filter_width,),
dilation=(dilation,)), sigmoid)(inputs)
filtered = Sequential(Conv1D(dilation_channels, (filter_width,),
dilation=(dilation,)), np.tanh)(inputs)
p = gated * filtered
# Add the transformed output of the resblock to the sliced input:
sliced_inputs = lax.dynamic_slice(
inputs, [0, inputs.shape[1] - out.shape[1], 0],
[inputs.shape[0], out.shape[1], inputs.shape[2]])
new_out = sum(out, sliced_inputs)
skip = Conv1D(residual_channels, (1,), padding='SAME')(skip_slice(p, output_width))
return new_out, skip``````

## jax

Differentiate, compile, and transform Numpy code.

Apache-2.0