How to use the tensorly.backend.core.Backend function in tensorly

To help you get started, we’ve selected a few tensorly examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github tensorly / tensorly / tensorly / backend / pytorch_backend.py View on Github external
message = ('Impossible to import PyTorch.\n'
               'To use TensorLy with the PyTorch backend, '
               'you must first install PyTorch!')
    raise ImportError(message) from error

if LooseVersion(torch.__version__) < LooseVersion('0.4.0'):
    raise ImportError('You are using version=%r of PyTorch.'
                      'Please update to "0.4.0" or higher.'
                      % torch.__version__)

import numpy as np

from .core import Backend


class PyTorchBackend(Backend):
    backend_name = 'pytorch'

    @staticmethod
    def context(tensor):
        return {'dtype': tensor.dtype,
                'device': tensor.device,
                'requires_grad': tensor.requires_grad}

    @staticmethod
    def tensor(data, dtype=torch.float32, device='cpu', requires_grad=False):
        if isinstance(data, np.ndarray):
            data = data.copy()
        return torch.tensor(data, dtype=dtype, device=device,
                            requires_grad=requires_grad)

    @staticmethod
github tensorly / tensorly / tensorly / backend / __init__.py View on Github external
where = dispatch(Backend.where)
clip = dispatch(Backend.clip)
max = dispatch(Backend.max)
min = dispatch(Backend.min)
argmax = dispatch(Backend.argmax)
argmin = dispatch(Backend.argmin)
all = dispatch(Backend.all)
mean = dispatch(Backend.mean)
sum = dispatch(Backend.sum)
prod = dispatch(Backend.prod)
sign = dispatch(Backend.sign)
abs = dispatch(Backend.abs)
sqrt = dispatch(Backend.sqrt)
norm = dispatch(Backend.norm)
dot = dispatch(Backend.dot)
kron = dispatch(Backend.kron)
solve = dispatch(Backend.solve)
qr = dispatch(Backend.qr)
kr = dispatch(Backend.kr)
partial_svd = dispatch(Backend.partial_svd)


# Initialise the backend to the default one
initialize_backend()
override_module_dispatch(__name__, 
                         _get_backend_method,
                         _get_backend_dir)
github tensorly / tensorly / tensorly / backend / __init__.py View on Github external
zeros_like = dispatch(Backend.zeros_like)
eye = dispatch(Backend.eye)
where = dispatch(Backend.where)
clip = dispatch(Backend.clip)
max = dispatch(Backend.max)
min = dispatch(Backend.min)
argmax = dispatch(Backend.argmax)
argmin = dispatch(Backend.argmin)
all = dispatch(Backend.all)
mean = dispatch(Backend.mean)
sum = dispatch(Backend.sum)
prod = dispatch(Backend.prod)
sign = dispatch(Backend.sign)
abs = dispatch(Backend.abs)
sqrt = dispatch(Backend.sqrt)
norm = dispatch(Backend.norm)
dot = dispatch(Backend.dot)
kron = dispatch(Backend.kron)
solve = dispatch(Backend.solve)
qr = dispatch(Backend.qr)
kr = dispatch(Backend.kr)
partial_svd = dispatch(Backend.partial_svd)


# Initialise the backend to the default one
initialize_backend()
override_module_dispatch(__name__, 
                         _get_backend_method,
                         _get_backend_dir)
github tensorly / tensorly / tensorly / backend / __init__.py View on Github external
if 'self' in sig.parameters:
        parameters = [v for k, v in sig.parameters.items() if k != 'self']
        sig = sig.replace(parameters=parameters)
    inner.__signature__ = sig

    return inner

# Generic methods, exposed as part of the public API
context = dispatch(Backend.context)
tensor = dispatch(Backend.tensor)
is_tensor = dispatch(Backend.is_tensor)
shape = dispatch(Backend.shape)
ndim = dispatch(Backend.ndim)
to_numpy = dispatch(Backend.to_numpy)
copy = dispatch(Backend.copy)
concatenate = dispatch(Backend.concatenate)
stack = dispatch(Backend.stack)
reshape = dispatch(Backend.reshape)
transpose = dispatch(Backend.transpose)
moveaxis = dispatch(Backend.moveaxis)
arange = dispatch(Backend.arange)
ones = dispatch(Backend.ones)
zeros = dispatch(Backend.zeros)
zeros_like = dispatch(Backend.zeros_like)
eye = dispatch(Backend.eye)
where = dispatch(Backend.where)
clip = dispatch(Backend.clip)
max = dispatch(Backend.max)
min = dispatch(Backend.min)
argmax = dispatch(Backend.argmax)
argmin = dispatch(Backend.argmin)
all = dispatch(Backend.all)
github tensorly / tensorly / tensorly / backend / __init__.py View on Github external
transpose = dispatch(Backend.transpose)
moveaxis = dispatch(Backend.moveaxis)
arange = dispatch(Backend.arange)
ones = dispatch(Backend.ones)
zeros = dispatch(Backend.zeros)
zeros_like = dispatch(Backend.zeros_like)
eye = dispatch(Backend.eye)
where = dispatch(Backend.where)
clip = dispatch(Backend.clip)
max = dispatch(Backend.max)
min = dispatch(Backend.min)
argmax = dispatch(Backend.argmax)
argmin = dispatch(Backend.argmin)
all = dispatch(Backend.all)
mean = dispatch(Backend.mean)
sum = dispatch(Backend.sum)
prod = dispatch(Backend.prod)
sign = dispatch(Backend.sign)
abs = dispatch(Backend.abs)
sqrt = dispatch(Backend.sqrt)
norm = dispatch(Backend.norm)
dot = dispatch(Backend.dot)
kron = dispatch(Backend.kron)
solve = dispatch(Backend.solve)
qr = dispatch(Backend.qr)
kr = dispatch(Backend.kr)
partial_svd = dispatch(Backend.partial_svd)


# Initialise the backend to the default one
initialize_backend()
override_module_dispatch(__name__,
github tensorly / tensorly / tensorly / backend / __init__.py View on Github external
# Generic methods, exposed as part of the public API
context = dispatch(Backend.context)
tensor = dispatch(Backend.tensor)
is_tensor = dispatch(Backend.is_tensor)
shape = dispatch(Backend.shape)
ndim = dispatch(Backend.ndim)
to_numpy = dispatch(Backend.to_numpy)
copy = dispatch(Backend.copy)
concatenate = dispatch(Backend.concatenate)
stack = dispatch(Backend.stack)
reshape = dispatch(Backend.reshape)
transpose = dispatch(Backend.transpose)
moveaxis = dispatch(Backend.moveaxis)
arange = dispatch(Backend.arange)
ones = dispatch(Backend.ones)
zeros = dispatch(Backend.zeros)
zeros_like = dispatch(Backend.zeros_like)
eye = dispatch(Backend.eye)
where = dispatch(Backend.where)
clip = dispatch(Backend.clip)
max = dispatch(Backend.max)
min = dispatch(Backend.min)
argmax = dispatch(Backend.argmax)
argmin = dispatch(Backend.argmin)
all = dispatch(Backend.all)
mean = dispatch(Backend.mean)
sum = dispatch(Backend.sum)
prod = dispatch(Backend.prod)
sign = dispatch(Backend.sign)
abs = dispatch(Backend.abs)
sqrt = dispatch(Backend.sqrt)
norm = dispatch(Backend.norm)
github tensorly / tensorly / tensorly / backend / __init__.py View on Github external
return inner

# Generic methods, exposed as part of the public API
context = dispatch(Backend.context)
tensor = dispatch(Backend.tensor)
is_tensor = dispatch(Backend.is_tensor)
shape = dispatch(Backend.shape)
ndim = dispatch(Backend.ndim)
to_numpy = dispatch(Backend.to_numpy)
copy = dispatch(Backend.copy)
concatenate = dispatch(Backend.concatenate)
stack = dispatch(Backend.stack)
reshape = dispatch(Backend.reshape)
transpose = dispatch(Backend.transpose)
moveaxis = dispatch(Backend.moveaxis)
arange = dispatch(Backend.arange)
ones = dispatch(Backend.ones)
zeros = dispatch(Backend.zeros)
zeros_like = dispatch(Backend.zeros_like)
eye = dispatch(Backend.eye)
where = dispatch(Backend.where)
clip = dispatch(Backend.clip)
max = dispatch(Backend.max)
min = dispatch(Backend.min)
argmax = dispatch(Backend.argmax)
argmin = dispatch(Backend.argmin)
all = dispatch(Backend.all)
mean = dispatch(Backend.mean)
sum = dispatch(Backend.sum)
prod = dispatch(Backend.prod)
sign = dispatch(Backend.sign)
github tensorly / tensorly / tensorly / backend / __init__.py View on Github external
clip = dispatch(Backend.clip)
max = dispatch(Backend.max)
min = dispatch(Backend.min)
argmax = dispatch(Backend.argmax)
argmin = dispatch(Backend.argmin)
all = dispatch(Backend.all)
mean = dispatch(Backend.mean)
sum = dispatch(Backend.sum)
prod = dispatch(Backend.prod)
sign = dispatch(Backend.sign)
abs = dispatch(Backend.abs)
sqrt = dispatch(Backend.sqrt)
norm = dispatch(Backend.norm)
dot = dispatch(Backend.dot)
kron = dispatch(Backend.kron)
solve = dispatch(Backend.solve)
qr = dispatch(Backend.qr)
kr = dispatch(Backend.kr)
partial_svd = dispatch(Backend.partial_svd)


# Initialise the backend to the default one
initialize_backend()
override_module_dispatch(__name__, 
                         _get_backend_method,
                         _get_backend_dir)
github tensorly / tensorly / tensorly / backend / __init__.py View on Github external
try:
            setattr(inner, attr, getattr(method, attr))
        except AttributeError:
            pass

    sig = inspect.signature(method)
    if 'self' in sig.parameters:
        parameters = [v for k, v in sig.parameters.items() if k != 'self']
        sig = sig.replace(parameters=parameters)
    inner.__signature__ = sig

    return inner

# Generic methods, exposed as part of the public API
context = dispatch(Backend.context)
tensor = dispatch(Backend.tensor)
is_tensor = dispatch(Backend.is_tensor)
shape = dispatch(Backend.shape)
ndim = dispatch(Backend.ndim)
to_numpy = dispatch(Backend.to_numpy)
copy = dispatch(Backend.copy)
concatenate = dispatch(Backend.concatenate)
stack = dispatch(Backend.stack)
reshape = dispatch(Backend.reshape)
transpose = dispatch(Backend.transpose)
moveaxis = dispatch(Backend.moveaxis)
arange = dispatch(Backend.arange)
ones = dispatch(Backend.ones)
zeros = dispatch(Backend.zeros)
zeros_like = dispatch(Backend.zeros_like)
eye = dispatch(Backend.eye)
where = dispatch(Backend.where)
github tensorly / tensorly / tensorly / backend / __init__.py View on Github external
max = dispatch(Backend.max)
min = dispatch(Backend.min)
argmax = dispatch(Backend.argmax)
argmin = dispatch(Backend.argmin)
all = dispatch(Backend.all)
mean = dispatch(Backend.mean)
sum = dispatch(Backend.sum)
prod = dispatch(Backend.prod)
sign = dispatch(Backend.sign)
abs = dispatch(Backend.abs)
sqrt = dispatch(Backend.sqrt)
norm = dispatch(Backend.norm)
dot = dispatch(Backend.dot)
kron = dispatch(Backend.kron)
solve = dispatch(Backend.solve)
qr = dispatch(Backend.qr)
kr = dispatch(Backend.kr)
partial_svd = dispatch(Backend.partial_svd)


# Initialise the backend to the default one
initialize_backend()
override_module_dispatch(__name__, 
                         _get_backend_method,
                         _get_backend_dir)