How to use the dragon.vm.torch.ops.modules.base.BaseModule function in dragon

To help you get started, we’ve selected a few dragon examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / modules / array.py View on Github external
# along with the software. If not, See,
#
#      
#
# ------------------------------------------------------------

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from dragon.vm.torch.autograd import no_grad
from dragon.vm.torch.tensor import _ReferenceTensor
from dragon.vm.torch.ops.modules.base import BaseModule


class Indexing(BaseModule):
    """This module imports the *CropOp* from backend.

    Arbitrary length of starts and sizes will be take,
    and the resulting memory is deep copied.

    """
    def __init__(self, key, dev, **kwargs):
        super(Indexing, self).__init__(key, dev, **kwargs)
        self.nstarts = kwargs.get('nstarts', 0)
        self.nsizes = kwargs.get('nsizes', 0)
        self.register_op()

    def register_op(self):
        self.op_meta = {
            'op_type': 'Crop',
            'arguments': {
github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / modules / control_flow.py View on Github external
def register_op(self):
        self.op_meta = {
            'op_type': 'Compare',
            'arguments': {
                'operation': self.operation,
                'to_uint8': True,
            }}

    def forward(self, x1, x2, y):
        inputs = [x1, x2]; self.unify_devices(inputs)
        outputs = [y] if y else [self.register_output()]
        return self.run(inputs, outputs)


class Assign(BaseModule):
    """This module imports the *AssignOp* from backend.

    Arbitrary length of starts and sizes will be take.

    """
    def __init__(self, key, dev, **kwargs):
        super(Assign, self).__init__(key, dev, **kwargs)
        self.nstarts = kwargs.get('nstarts', 0)
        self.nsizes = kwargs.get('nsizes', 0)
        self.register_op()

    def register_op(self):
        self.op_meta = {
            'op_type': 'Assign',
            'arguments': {
                'starts_desc': [
github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / modules / shape.py View on Github external
'arguments': {
                'dtype': self.dtype,
                'value': float(self.value),
                'dims_desc': [d for d in self.shape] if self.n_dim > 0 else None,
            }
        }

    def forward(self, x, shape):
        outputs = [x]; self.unify_devices(outputs)
        if shape is not None:
            for ix, d in enumerate(shape):
                self.set_argument_i(self.shape[ix], d)
        return self.run([], outputs)


class Reshape(BaseModule):
    def __init__(self, key, ctx, **kwargs):
        super(Reshape, self).__init__(key, ctx, **kwargs)
        self.n_dim = kwargs.get('n_dim', 0)
        self.register_arguments()
        self.register_op()

    def register_arguments(self):
         self.dims = [self.register_argument('dims[{}]'.format(i))
                for i in range(self.n_dim)]

    def register_op(self):
        self.op_meta = {
            'op_type': 'Reshape',
            'n_inputs': 1, 'n_outputs': 1,
            'arguments': {
                'dims_desc': [d for d in self.dims]
github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / modules / array.py View on Github external
self.op_meta = {
            'op_type': 'Reduce',
            'arguments': {
                'operation': self.operation,
                'axes': [self.dim] if self.dim is not None else None,
                'keep_dims': self.keepdim,
            },
        }

    def forward(self, x, y):
        inputs = [x]; self.unify_devices(inputs)
        outputs = [y] if y else [self.register_output()]
        return self.run(inputs, outputs)


class ArgReduce(BaseModule):
    def __init__(self, key, dev, **kwargs):
        super(ArgReduce, self).__init__(key, dev, **kwargs)
        self.operation = kwargs.get('operation', 'ARGMAX')
        self.axis = kwargs.get('axis', None)
        self.keepdim = kwargs.get('keepdim', True)
        self.topk = kwargs.get('topk', 1)
        self.register_op()

    def register_op(self):
        self.op_meta = {
            'op_type': 'ArgReduce',
            'arguments': {
                'operation': self.operation
                    if 'ARG' in self.operation \
                    else 'ARG' + self.operation,
                'axis': self.axis if self.axis else 2147483647,
github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / modules / array.py View on Github external
def register_op(self):
        self.op_meta = {
            'op_type': 'ChannelShuffle',
            'arguments': {
                'axis': self.axis,
                'group': self.group,
            },
        }

    def forward(self, x, y):
        inputs = [x]; self.unify_devices(inputs)
        outputs = [y] if y else [self.register_output()]
        return self.run(inputs, outputs)


class Repeat(BaseModule):
    def __init__(self, key, dev, **kwargs):
        super(Repeat, self).__init__(key, dev, **kwargs)
        self.ntimes = kwargs.get('ntimes', 0)
        self.register_op()

    def register_op(self):
        self.op_meta = {
            'op_type': 'Tile',
            'arguments': {
                'multiples_desc': [
                    '${{HANDLE}}/multiples[{}]'.format(n)
                        for n in range(self.ntimes)
                ],
            },
        }
github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / modules / array.py View on Github external
else:
            if y:
                if not isinstance(y, (tuple, list)):
                    raise TypeError('Excepted outputs as a tuple or list, got {}.'.format(type(y)))
                if len(y) != 2:
                    raise ValueError('Excepted 2 outputs, got {}.'.format(len(y)))
                outputs = [y[1], y[0]]
            else: outputs = [self.register_output(), self.register_output()]
            returns = self.run(inputs, outputs)
            # Return values only
            if self.axis is None: return returns[1]
            # Return values and indices
            return returns[1], returns[0]


class Reshape(BaseModule):
    def __init__(self, key, dev, **kwargs):
        super(Reshape, self).__init__(key, dev, **kwargs)
        self.ndim = kwargs.get('ndim', 0)
        self.register_op()

    def register_op(self):
        self.op_meta = {
            'op_type': 'Reshape',
            'arguments': {
                'dims_desc': [
                    '${{HANDLE}}/dims[{}]'.format(n)
                        for n in range(self.ndim)
                ],
            },
        }
github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / modules / array.py View on Github external
},
        }

    def update_args(self, A, starts, sizes):
        for i, e in enumerate(starts):
            self.set_arg_i64('{}/starts[{}]'.format(A, i), e)
            self.set_arg_i64('{}/sizes[{}]'.format(A, i), sizes[i])

    def forward(self, x, starts, sizes):
        inputs = [x]; self.unify_devices(inputs)
        outputs = [self.register_output()]
        callback = lambda A: self.update_args(A, starts, sizes)
        return self.run(inputs, outputs, callback=callback)


class Concat(BaseModule):
    """This module imports the *ConcatOp* from backend.

    Concatenate the inputs along the given axis.

    """
    def __init__(self, key, dev, **kwargs):
        super(Concat, self).__init__(key, dev, **kwargs)
        self.axis = kwargs.get('axis', 0)
        self.register_op()

    def register_op(self):
        self.op_meta = {
            'op_type': 'Concat',
            'arguments': {
                'axis': self.axis
            },
github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / modules / arithmetic.py View on Github external
class Maximum(BaseModule):
    def __init__(self, key, dev, **kwargs):
        super(Maximum, self).__init__(key, dev, **kwargs)
        self.register_op()

    def register_op(self):
        self.op_meta = {'op_type': 'Maximum', 'arguments': {}}

    def forward(self, x1, x2, y):
        inputs = [x1, x2]; self.unify_devices(inputs)
        outputs = [y] if y else [self.register_output()]
        return self.run(inputs, outputs)


class Minimum(BaseModule):
    def __init__(self, key, dev, **kwargs):
        super(Minimum, self).__init__(key, dev, **kwargs)
        self.register_op()

    def register_op(self):
        self.op_meta = {'op_type': 'Minimum', 'arguments': {}}

    def forward(self, x1, x2, y):
        inputs = [x1, x2]; self.unify_devices(inputs)
        outputs = [y] if y else [self.register_output()]
        return self.run(inputs, outputs)


class Clamp(BaseModule):
    def __init__(self, key, dev, **kwargs):
        super(Clamp, self).__init__(key, dev, **kwargs)
github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / modules / reduce.py View on Github external
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
#      
#
# ------------------------------------------------------------

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from dragon.vm.torch.ops.modules.base import BaseModule


class Reduce(BaseModule):
    def __init__(self, key, ctx, **kwargs):
        super(Reduce, self).__init__(key, ctx, **kwargs)
        self.operation = kwargs.get('operation', 'SUM')
        self.dim = kwargs.get('dim', None)
        self.keepdim = kwargs.get('keepdim', True)
        self.register_arguments()
        self.register_op()

    def register_arguments(self):
        """No Arguments for reduce op.

        Mutable ``axis`` and ``keep_dims`` is non-trivial for backend,
        we simply hash them in the persistent key.

        """
        pass
github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / modules / array.py View on Github external
def register_op(self):
        self.op_meta = {
            'op_type': 'Stack',
            'arguments': {
                'axis': self.axis
            },
        }

    def forward(self, xs, y):
        inputs = xs; self.unify_devices(inputs)
        outputs = [y] if y else [self.register_output()]
        return self.run(inputs, outputs)


class Chunk(BaseModule):
    """This module imports the *SliceOp* from backend.

    Slice the inputs into several parts along the given axis.

    """
    def __init__(self, key, dev, **kwargs):
        super(Chunk, self).__init__(key, dev, **kwargs)
        self.axis = kwargs.get('axis', 0)
        self.chunks = kwargs.get('chunks', 1)
        self.register_op()

    def register_op(self):
        self.op_meta = {
            'op_type': 'Slice',
            'arguments': {
                'axis': self.axis,