How to use the jaxlib.xla_client.Shape function in jaxlib

To help you get started, we’ve selected a few jaxlib examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google / jax / jaxlib / cusolver.py View on Github external
try:
  from jaxlib import cublas_kernels
  for _name, _value in cublas_kernels.registrations().items():
    xla_client.register_custom_call_target(_name, _value, platform="gpu")
except ImportError:
  pass

try:
  from jaxlib import cusolver_kernels
  for _name, _value in cusolver_kernels.registrations().items():
    xla_client.register_custom_call_target(_name, _value, platform="gpu")
except ImportError:
  pass


_Shape = xla_client.Shape


def _real_type(dtype):
  """Returns the real equivalent of 'dtype'."""
  if dtype == np.float32:
    return np.float32
  elif dtype == np.float64:
    return np.float64
  elif dtype == np.complex64:
    return np.float32
  elif dtype == np.complex128:
    return np.float64
  else:
    raise NotImplementedError("Unsupported dtype {}".format(dtype))

_prod = lambda xs: reduce(operator.mul, xs, 1)