How to use the jax.util.partial function in jax

To help you get started, we’ve selected a few jax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google / jax / jax / test_util.py View on Github external
def rand_default(scale=3):
  randn = npr.RandomState(0).randn
  return partial(_rand_dtype, randn, scale=scale)
github google / jax / jax / test_util.py View on Github external
tree_all(tree_multimap(_assert_numpy_allclose, xs, ys))


def check_close(xs, ys, atol=None, rtol=None):
  assert_close = partial(_assert_numpy_close, atol=atol, rtol=rtol)
  tree_all(tree_multimap(assert_close, xs, ys))


def inner_prod(xs, ys):
  def contract(x, y):
    return onp.real(onp.dot(onp.conj(x).reshape(-1), y.reshape(-1)))
  return tree_reduce(onp.add, tree_multimap(contract, xs, ys))


add = partial(tree_multimap, lambda x, y: onp.add(x, y, dtype=_dtype(x)))
sub = partial(tree_multimap, lambda x, y: onp.subtract(x, y, dtype=_dtype(x)))
conj = partial(tree_map, lambda x: onp.conj(x, dtype=_dtype(x)))

def scalar_mul(xs, a):
  return tree_map(lambda x: onp.multiply(x, a, dtype=_dtype(x)), xs)


def rand_like(rng, x):
  shape = onp.shape(x)
  dtype = _dtype(x)
  randn = lambda: onp.asarray(rng.randn(*shape), dtype=dtype)
  if dtypes.issubdtype(dtype, onp.complexfloating):
    return randn() + dtype.type(1.0j) * randn()
  else:
    return randn()
github google / jax / jax / test_util.py View on Github external
def check_close(xs, ys, atol=None, rtol=None):
  assert_close = partial(_assert_numpy_close, atol=atol, rtol=rtol)
  tree_all(tree_multimap(assert_close, xs, ys))
github google / jax / jax / interpreters / batching.py View on Github external
def defvectorized(prim):
  primitive_batchers[prim] = partial(vectorized_batcher, prim)
github google / jax / jax / numpy / lax_numpy.py View on Github external
@partial(jit, static_argnums=(1, 2))
def _pad(array, pad_width, mode, constant_values):
  array = asarray(array)
  nd = ndim(array)
  pad_width = onp.broadcast_to(onp.asarray(pad_width), (nd, 2))
  if any(pad_width < 0):
    raise ValueError("index can't contain negative values")

  if mode == "constant":
    constant_values = broadcast_to(asarray(constant_values), (nd, 2))
    constant_values = lax.convert_element_type(constant_values, array.dtype)
    for i in xrange(nd):
      widths = [(0, 0, 0)] * nd
      widths[i] = (pad_width[i, 0], 0, 0)
      array = lax.pad(array, constant_values[i, 0], widths)
      widths[i] = (0, pad_width[i, 1], 0)
      array = lax.pad(array, constant_values[i, 1], widths)
github JuliusKunze / jaxnet / examples / pixelcnn.py View on Github external
def PixelCNNPP(nr_resnet=5, nr_filters=160, nr_logistic_mix=10, dropout_p=.5):
    Resnet = partial(GatedResnet, dropout_p=dropout_p)
    ResnetDown = partial(Resnet, Conv=DownShiftedConv)
    ResnetDownRight = partial(Resnet, Conv=DownRightShiftedConv)

    ConvDown = partial(DownShiftedConv, out_chan=nr_filters)
    ConvDownRight = partial(DownRightShiftedConv, out_chan=nr_filters)

    HalveDown = partial(ConvDown, strides=(2, 2))
    HalveDownRight = partial(ConvDownRight, strides=(2, 2))

    DoubleDown = partial(DownShiftedConvTranspose, out_chan=nr_filters, strides=(2, 2))
    DoubleDownRight = partial(DownRightShiftedConvTranspose, out_chan=nr_filters, strides=(2, 2))

    def ResnetUpBlock():
        @parametrized
        def resnet_up_block(us, uls):
            for _ in range(nr_resnet):
                us.append(ResnetDown()(us[-1]))
                uls.append(ResnetDownRight()(uls[-1], us[-1]))

            return us, uls

        return resnet_up_block

    def ResnetDownBlock(nr_resnet):
        @parametrized
        def resnet_down_block(u, ul, us, uls):
            us = us.copy()
github google / jax / jax / lax.py View on Github external
if bdim is not None:
    window_dimensions = \
        window_dimensions[:bdim] + (1,) + window_dimensions[bdim:]
    window_strides = window_strides[:bdim] + (1,) + window_strides[bdim:]

  operand = _reduce_window_max(
      operand, window_dimensions, window_strides, padding)

  return operand, 0

_reduce_window_max_translation_rule = partial(
    _reduce_window_chooser_translation_rule, max_p, _get_max_identity)
reduce_window_max_p = standard_primitive(
    _common_reduce_window_shape_rule, _input_dtype, 'reduce_window_max',
    _reduce_window_max_translation_rule)
ad.defjvp(reduce_window_max_p, partial(_reduce_window_chooser_jvp_rule, max_p))
batching.primitive_batchers[reduce_window_max_p] = _reduce_window_max_batch_rule

_reduce_window_min_translation_rule = partial(
    _reduce_window_chooser_translation_rule, min_p, _get_min_identity)
reduce_window_min_p = standard_primitive(
    _common_reduce_window_shape_rule, _input_dtype, 'reduce_window_min',
    _reduce_window_min_translation_rule)
ad.defjvp(reduce_window_min_p, partial(_reduce_window_chooser_jvp_rule, min_p))


def _select_and_scatter_shape_rule(
    operand, source, init_value, select_jaxpr, select_consts, scatter_jaxpr,
    scatter_consts, window_dimensions, window_strides, padding):
  _check_shapelike("select_and_scatter", "window_dimensions", window_dimensions)
  _check_shapelike("select_and_scatter", "window_strides", window_strides)
  if len(window_dimensions) != len(window_strides):
github google / jax / jax / numpy / lax_numpy.py View on Github external
  @partial(jit, static_argnums=(1, 2))
  def _cumulative_reduction(a, axis, dtype):
    if axis is None or isscalar(a):
      a = ravel(a)
      axis = 0

    a_shape = list(shape(a))
    num_dims = len(a_shape)

    if axis < 0:
      axis = axis + num_dims
    if axis < 0 or axis >= num_dims:
      raise ValueError(
          "axis {} is out of bounds for array of dimension {}".format(
              axis, num_dims))

    if squash_nan:
github google / jax / jax / lax / lax_control_flow.py View on Github external
const_dims, init_dims = split_list(dims, [cond_nconsts + body_nconsts])
  new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
                else x for x, d in zip(consts, const_dims)]
  new_init = [batching.broadcast(x, size, 0) if now_bat and not was_bat
              else batching.moveaxis(x, d, 0) if now_bat else x
              for x, d, was_bat, now_bat in zip(init, init_dims, init_bat, carry_bat)]

  outs = while_p.bind(*(new_consts + new_init),
                      cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_batched,
                      body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched)
  out_bdims = [0 if b else batching.not_mapped for b in carry_bat]
  return outs, out_bdims

while_p = lax.Primitive('while')
while_p.multiple_results = True
while_p.def_impl(partial(xla.apply_primitive, while_p))
while_p.def_abstract_eval(_while_loop_abstract_eval)
xla.initial_style_translations[while_p] = _while_loop_translation_rule
batching.primitive_batchers[while_p] = _while_loop_batching_rule


### cond

def cond(pred, true_operand, true_fun, false_operand, false_fun):
  """Conditionally apply ``true_fun`` or ``false_fun``.

  Has equivalent semantics to this Python implementation::

    def cond(pred, true_operand, true_fun, false_operand, false_fun):
      if pred:
        return true_fun(true_operand)
      else:
github google / jax / jax / lax.py View on Github external
def PmapPrimitive(name):
  prim = Primitive(name)
  prim.def_impl(partial(unbound_name_error, name))
  prim.def_abstract_eval(lambda x, *args, **kwargs: x)  # default
  return prim