How to use the jaxlib.xla_client.register_custom_call_target function in jaxlib

To help you get started, we’ve selected a few jaxlib examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google / jax / jaxlib / cusolver.py View on Github external
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import operator

import numpy as np
from six.moves import reduce

from jaxlib import xla_client

try:
  from jaxlib import cublas_kernels
  for _name, _value in cublas_kernels.registrations().items():
    xla_client.register_custom_call_target(_name, _value, platform="gpu")
except ImportError:
  pass

try:
  from jaxlib import cusolver_kernels
  for _name, _value in cusolver_kernels.registrations().items():
    xla_client.register_custom_call_target(_name, _value, platform="gpu")
except ImportError:
  pass


_Shape = xla_client.Shape


def _real_type(dtype):
  """Returns the real equivalent of 'dtype'."""