How to use the jaxlib.cusolver_kernels.build_syevd_descriptor function in jaxlib

To help you get started, we’ve selected a few jaxlib examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google / jax / jaxlib / cusolver.py View on Github external
dims = a_shape.dimensions()
  assert len(dims) >= 2
  m, n = dims[-2:]
  assert m == n
  batch_dims = tuple(dims[:-2])
  num_bd = len(batch_dims)
  batch = _prod(batch_dims)
  layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))

  if n <= 32:
    kernel = b"cusolver_syevj"
    lwork, opaque = cusolver_kernels.build_syevj_descriptor(
        np.dtype(dtype), lower, batch, n)
  else:
    kernel = b"cusolver_syevd"
    lwork, opaque = cusolver_kernels.build_syevd_descriptor(
        np.dtype(dtype), lower, batch, n)
  eigvals_type = _real_type(dtype)

  out = c.CustomCall(
      kernel,
      operands=(a,),
      shape_with_layout=_Shape.tuple_shape((
          _Shape.array_shape(dtype, dims, layout),
          _Shape.array_shape(
              np.dtype(eigvals_type), batch_dims + (n,),
              tuple(range(num_bd, -1, -1))),
          _Shape.array_shape(
              np.dtype(np.int32), batch_dims,
              tuple(range(num_bd - 1, -1, -1))),
          _Shape.array_shape(dtype, (lwork,), (0,))
      )),