Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def get_num_threads(self):
"""Return the maximum number of threads available to use"""
pass # pragma: no cover
@abstractmethod
def set_num_threads(self, num_threads):
"""Set the maximum number of threads to use"""
pass # pragma: no cover
@abstractmethod
def _get_extra_info(self):
"""Add additional module specific information"""
pass # pragma: no cover
class _OpenBLASModule(_Module):
"""Module class for OpenBLAS"""
def get_version(self):
# None means OpenBLAS is not loaded or version < 0.3.4, since OpenBLAS
# did not expose its version before that.
get_config = getattr(self._dynlib, "openblas_get_config",
lambda: None)
get_config.restype = ctypes.c_char_p
config = get_config().split()
if config[0] == b"OpenBLAS":
return config[1].decode("utf-8")
return None
def get_num_threads(self):
get_func = getattr(self._dynlib, "openblas_get_num_threads",
lambda: None)
return get_func()
lambda num_threads: None)
return set_func(num_threads)
def _get_extra_info(self):
self.threading_layer = self.get_threading_layer()
def get_threading_layer(self):
"""Return the threading layer of BLIS"""
if self._dynlib.bli_info_get_enable_openmp():
return "openmp"
elif self._dynlib.bli_info_get_enable_pthreads():
return "pthreads"
return "disabled"
class _MKLModule(_Module):
"""Module class for MKL"""
def get_version(self):
res = ctypes.create_string_buffer(200)
self._dynlib.mkl_get_version_string(res, 200)
version = res.value.decode("utf-8")
group = re.search(r"Version ([^ ]+) ", version)
if group is not None:
version = group.groups()[0]
return version.strip()
def get_num_threads(self):
get_func = getattr(self._dynlib, "MKL_Get_Max_Threads", lambda: None)
return get_func()
def set_num_threads(self, num_threads):
def _get_extra_info(self):
self.threading_layer = self.get_threading_layer()
def get_threading_layer(self):
"""Return the threading layer of MKL"""
# The function mkl_set_threading_layer returns the current threading
# layer. Calling it with an invalid threading layer allows us to safely
# get the threading layer
set_threading_layer = getattr(self._dynlib, "MKL_Set_Threading_Layer",
lambda layer: -1)
layer_map = {0: "intel", 1: "sequential", 2: "pgi",
3: "gnu", 4: "tbb", -1: "not specified"}
return layer_map[set_threading_layer(-1)]
class _OpenMPModule(_Module):
"""Module class for OpenMP"""
def get_version(self):
# There is no way to get the version number programmatically in OpenMP.
return None
def get_num_threads(self):
get_func = getattr(self._dynlib, "omp_get_max_threads", lambda: None)
return get_func()
def set_num_threads(self, num_threads):
set_func = getattr(self._dynlib, "omp_set_num_threads",
lambda num_threads: None)
return set_func(num_threads)
def _get_extra_info(self):
pass
return set_func(num_threads)
def _get_extra_info(self):
self.threading_layer = self.get_threading_layer()
def get_threading_layer(self):
"""Return the threading layer of OpenBLAS"""
threading_layer = self._dynlib.openblas_get_parallel()
if threading_layer == 2:
return "openmp"
elif threading_layer == 1:
return "pthreads"
return "disabled"
class _BLISModule(_Module):
"""Module class for BLIS"""
def get_version(self):
get_version_ = getattr(self._dynlib, "bli_info_get_version_str",
lambda: None)
get_version_.restype = ctypes.c_char_p
return get_version_().decode("utf-8")
def get_num_threads(self):
get_func = getattr(self._dynlib, "bli_thread_get_num_threads",
lambda: None)
num_threads = get_func()
# by default BLIS is single-threaded and get_num_threads
# returns -1. We map it to 1 for consistency with other libraries.
return 1 if num_threads == -1 else num_threads
def set_num_threads(self, num_threads):