How to use the einops._backends.AbstractBackend.__subclasses__ function in einops

To help you get started, we’ve selected a few einops examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github arogozhnikov / einops / tests / test_other.py View on Github external
def test_backends_installed():
    """
    This test will fail if some of backends are not installed or can't be imported
    Other tests will just work and only test installed backends.
    """
    from . import skip_cupy
    errors = []
    for backend_type in AbstractBackend.__subclasses__():
        if skip_cupy and backend_type.framework_name == 'cupy':
            continue
        try:
            # instantiate
            backend_type()
        except Exception as e:
            errors.append(e)
    assert len(errors) == 0, errors
github arogozhnikov / einops / einops / _backends.py View on Github external
def get_backend(tensor):
    """
    Takes a correct backend (e.g. numpy backend if tensor is numpy.ndarray) for a tensor.
    If needed, imports package and creates backend
    """
    for framework_name, backend in _backends.items():
        if backend.is_appropriate_type(tensor):
            return backend

    # Find backend subclasses recursively
    backend_subclasses = []
    backends = AbstractBackend.__subclasses__()
    while backends:
        backend = backends.pop()
        backends += backend.__subclasses__()
        backend_subclasses.append(backend)

    for BackendSubclass in backend_subclasses:
        if _debug_importing:
            print('Testing for subclass of ', BackendSubclass)
        if BackendSubclass.framework_name not in _backends:
            # check that module was already imported. Otherwise it can't be imported
            if BackendSubclass.framework_name in sys.modules:
                if _debug_importing:
                    print('Imported backend for ', BackendSubclass.framework_name)
                backend = BackendSubclass()
                _backends[backend.framework_name] = backend
                if backend.is_appropriate_type(tensor):