How to use the jaxlib.cusolver_kernels function in jaxlib

To help you get started, we’ve selected a few jaxlib examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google / jax / jaxlib / cusolver.py View on Github external
_Shape.array_shape(np.dtype(singular_vals_dtype), (min(m, n),), (0,)),
            _Shape.array_shape(dtype, (n, n), (1, 0)),
            _Shape.array_shape(dtype, (m, m), (1, 0)),
            _Shape.array_shape(np.dtype(np.int32), (), ()),
            _Shape.array_shape(dtype, (lwork,), (0,)),
        )),
        operand_shapes_with_layout=(
            _Shape.array_shape(dtype, (m, n), (1, 0)),
        ),
        opaque=opaque)
    s = c.GetTupleElement(out, 1)
    vt = c.GetTupleElement(out, 2)
    u = c.GetTupleElement(out, 3)
    info = c.GetTupleElement(out, 4)
  else:
    lwork, opaque = cusolver_kernels.build_gesvd_descriptor(
        np.dtype(dtype), b, m, n, compute_uv, full_matrices)

    out = c.CustomCall(
        b"cusolver_gesvd",
        operands=(a,),
        shape_with_layout=_Shape.tuple_shape((
            _Shape.array_shape(dtype, (m, n), (0, 1)),
            _Shape.array_shape(np.dtype(singular_vals_dtype), (min(m, n),), (0,)),
            _Shape.array_shape(dtype, (m, m), (0, 1)),
            _Shape.array_shape(dtype, (n, n), (0, 1)),
            _Shape.array_shape(np.dtype(np.int32), (), ()),
            _Shape.array_shape(dtype, (lwork,), (0,)),
        )),
        operand_shapes_with_layout=(
            _Shape.array_shape(dtype, (m, n), (0, 1)),
        ),