How to use the jaxlib.cublas_kernels.build_trsm_batched_descriptor function in jaxlib

To help you get started, we’ve selected a few jaxlib examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google / jax / jaxlib / View on Github external
m, n = dims[-2:]
  batch_dims = tuple(dims[:-2])
  num_bd = len(batch_dims)
  batch = _prod(batch_dims)
  k = m if left_side else n

  a_shape = c.GetShape(a)
  if (batch_dims + (k, k) != a_shape.dimensions() or
      a_shape.element_type() != dtype):
    raise ValueError("Argument mismatch for trsm, got {} and {}".format(
      a_shape, b_shape))

  if conj_a and not trans_a:
    raise NotImplementedError("Conjugation without transposition not supported")

  lwork, opaque = cublas_kernels.build_trsm_batched_descriptor(
    np.dtype(dtype), batch, m, n, left_side, lower, trans_a, conj_a, diag)
  layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
  out = c.CustomCall(
      operands=(a, b),
          _Shape.array_shape(dtype, b_shape.dimensions(), layout),
          _Shape.array_shape(np.dtype(np.int8), (lwork,), (0,)),
          _Shape.array_shape(np.dtype(np.int8), (lwork,), (0,)))),
          _Shape.array_shape(dtype, a_shape.dimensions(), layout),
          _Shape.array_shape(dtype, b_shape.dimensions(), layout),
  return c.GetTupleElement(out, 0)