How to use the jaxlib.cusolver_kernels.registrations function in jaxlib

To help you get started, we’ve selected a few jaxlib examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google / jax / jaxlib / cusolver.py View on Github external
import numpy as np
from six.moves import reduce

from jaxlib import xla_client

try:
  from jaxlib import cublas_kernels
  for _name, _value in cublas_kernels.registrations().items():
    xla_client.register_custom_call_target(_name, _value, platform="gpu")
except ImportError:
  pass

try:
  from jaxlib import cusolver_kernels
  for _name, _value in cusolver_kernels.registrations().items():
    xla_client.register_custom_call_target(_name, _value, platform="gpu")
except ImportError:
  pass


_Shape = xla_client.Shape


def _real_type(dtype):
  """Returns the real equivalent of 'dtype'."""
  if dtype == np.float32:
    return np.float32
  elif dtype == np.float64:
    return np.float64
  elif dtype == np.complex64:
    return np.float32