How to use the dragon.vm.torch.ops.primitive.WrapScalar function in dragon

To help you get started, we’ve selected a few dragon examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / builtin.py View on Github external
def _fundamental(input, value, op='Add', out=None):
    if not isinstance(value, Tensor):
        value = WrapScalar(value, input.dtype, input.device)
    dev = MakeDevice(inputs=[input, value])
    key = '{}/{}'.format(op, dev)
    module = get_module(Fundamental, key, dev, op_type=op)
    return module.forward(input, value, out)
github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / builtin.py View on Github external
def _rfundamental(input, value, op='RAdd', out=None):
    if not isinstance(value, Tensor):
        value = WrapScalar(value, input.dtype, input.device)
    dev = MakeDevice(inputs=[input, value])
    key = '{}/{}'.format(op, dev)
    module = get_module(Fundamental, key, dev, op_type=op)
    return module.forward(value, input, out)
github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / arithmetic.py View on Github external
def _fundamental(input, value, op='Add', out=None):
    if not isinstance(value, Tensor):
        value = WrapScalar(value, input.dtype, input._ctx)
    ctx = MakeContext(inputs=[input, value])
    key = 'torch.ops.{}/{}:{}'.format(op.lower(), ctx[0], ctx[1])
    module = get_module(Fundamental, key, ctx, op_type=op)
    return module.forward(input, value, out)
github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / arithmetic.py View on Github external
def _maximum(input, other, out=None):
    if not isinstance(input, Tensor):
        input = WrapScalar(input, other.dtype, other._ctx)
    elif not isinstance(other, Tensor):
        other = WrapScalar(other, input.dtype, input._ctx)
    ctx = MakeContext(inputs=[input])
    key = 'torch.ops.maximum/{}:{}'.format(ctx[0], ctx[1])
    module = get_module(Maximum, key, ctx)
    return module.forward(input, other, out)
github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / arithmetic.py View on Github external
def _minimum(input, other, out=None):
    if not isinstance(input, Tensor):
        input = WrapScalar(input, other.dtype, other._ctx)
    elif not isinstance(other, Tensor):
        other = WrapScalar(other, input.dtype, input._ctx)
    ctx = MakeContext(inputs=[input])
    key = 'torch.ops.minimum/{}:{}'.format(ctx[0], ctx[1])
    module = get_module(Minimum, key, ctx)
    return module.forward(input, other, out)
github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / builtin.py View on Github external
def _compare(input, other, operation, out=None):
    if not isinstance(other, Tensor):
        other = WrapScalar(other, input.dtype, input.device)
    dev = MakeDevice(inputs=[input, other])
    key = 'Compare/{}/{}'.format(operation, dev)
    module = get_module(Compare, key, dev, operation=operation)
    return module.forward(input, other, out)
github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / builtin.py View on Github external
The input tensor.
    other : dragon.vm.torch.Tensor or number
        The input tensor.
    out : dragon.vm.torch.Tensor, optional
        The output tensor.

    Returns
    -------
    dragon.vm.torch.Tensor
        The output tensor.

    """
    if not isinstance(input, Tensor):
        input = WrapScalar(input, other.dtype, other.device)
    elif not isinstance(other, Tensor):
        other = WrapScalar(other, input.dtype, input.device)
    dev = MakeDevice(inputs=[input])
    key = 'Minimum/{}'.format(dev)
    module = get_module(Minimum, key, dev)
    return module.forward(input, other, out)