Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_backends_installed():
"""
This test will fail if some of backends are not installed or can't be imported
Other tests will just work and only test installed backends.
"""
from . import skip_cupy
errors = []
for backend_type in AbstractBackend.__subclasses__():
if skip_cupy and backend_type.framework_name == 'cupy':
continue
try:
# instantiate
backend_type()
except Exception as e:
errors.append(e)
assert len(errors) == 0, errors
def get_backend(tensor):
"""
Takes a correct backend (e.g. numpy backend if tensor is numpy.ndarray) for a tensor.
If needed, imports package and creates backend
"""
for framework_name, backend in _backends.items():
if backend.is_appropriate_type(tensor):
return backend
# Find backend subclasses recursively
backend_subclasses = []
backends = AbstractBackend.__subclasses__()
while backends:
backend = backends.pop()
backends += backend.__subclasses__()
backend_subclasses.append(backend)
for BackendSubclass in backend_subclasses:
if _debug_importing:
print('Testing for subclass of ', BackendSubclass)
if BackendSubclass.framework_name not in _backends:
# check that module was already imported. Otherwise it can't be imported
if BackendSubclass.framework_name in sys.modules:
if _debug_importing:
print('Imported backend for ', BackendSubclass.framework_name)
backend = BackendSubclass()
_backends[backend.framework_name] = backend
if backend.is_appropriate_type(tensor):