How to use the dragon.vm.torch.ops.factory.get_module function in dragon

To help you get started, we’ve selected a few dragon examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / builtin.py View on Github external
def _compare(input, other, operation, out=None):
    if not isinstance(other, Tensor):
        other = WrapScalar(other, input.dtype, input.device)
    dev = MakeDevice(inputs=[input, other])
    key = 'Compare/{}/{}'.format(operation, dev)
    module = get_module(Compare, key, dev, operation=operation)
    return module.forward(input, other, out)
github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / builtin.py View on Github external
w : dragon.vm.torch.Tensor
        The w.
    bias : dragon.vm.torch.Tensor, optional
        The bias.
    transW : boolean
        Whether to transpose the ``w``.

    Returns
    -------
    dragon.vm.torch.Tensor
        The output tensor.

    """
    dev = MakeDevice(inputs=[x, w] + ([bias] if bias else []))
    key = 'FullyConnected/{}/transW:{}'.format(dev, transW)
    module = get_module(FullyConnected, key, dev, transW=transW)
    return module.forward(x, w, bias, out)
github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / vision.py View on Github external
def roi_align(feature, rois, pooled_h, pooled_w,
              spatial_scale, sampling_ratio=2):
    ctx = MakeContext(inputs=[feature])
    key = 'torch.ops.roi_align/{}:{}/pool_h:{}/pool_w:{}/' \
          'spatial_scale:{}/sampling_ratio:{}'.format(
        ctx[0], ctx[1], pooled_h, pooled_w, spatial_scale, sampling_ratio)
    module = get_module(RoIAlign, key, ctx, pooled_h=pooled_h,
        pooled_w=pooled_w, spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)
    return module.forward(feature, rois)
github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / builtin.py View on Github external
def _fundamental(input, value, op='Add', out=None):
    if not isinstance(value, Tensor):
        value = WrapScalar(value, input.dtype, input.device)
    dev = MakeDevice(inputs=[input, value])
    key = '{}/{}'.format(op, dev)
    module = get_module(Fundamental, key, dev, op_type=op)
    return module.forward(input, value, out)
github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / builtin.py View on Github external
def _uniform(input, shape, low, high):
    dev = MakeDevice(inputs=[input]); ndim = len(shape)
    key = 'Uniform/{}/dtype:{}/ndim:{}/low:{}/high:{}'.format(
        dev, input.dtype, ndim, float(low), float(high))
    module = get_module(
        RandomUniform, key, dev,
        ndim=ndim,
        low=low,
        high=high,
        dtype=input.dtype,
    )
    return module.forward(input, shape)
github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / arithmetic.py View on Github external
def _maximum(input, other, out=None):
    if not isinstance(input, Tensor):
        input = WrapScalar(input, other.dtype, other._ctx)
    elif not isinstance(other, Tensor):
        other = WrapScalar(other, input.dtype, input._ctx)
    ctx = MakeContext(inputs=[input])
    key = 'torch.ops.maximum/{}:{}'.format(ctx[0], ctx[1])
    module = get_module(Maximum, key, ctx)
    return module.forward(input, other, out)
github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / arithmetic.py View on Github external
def _log(input, out=None):
    ctx = MakeContext(inputs=[input])
    key = 'torch.ops.log/{}:{}'.format(ctx[0], ctx[1])
    module = get_module(Log, key, ctx)
    return module.forward(input, out)
github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / arithmetic.py View on Github external
def _fundamental(input, value, op='Add', out=None):
    if not isinstance(value, Tensor):
        value = WrapScalar(value, input.dtype, input._ctx)
    ctx = MakeContext(inputs=[input, value])
    key = 'torch.ops.{}/{}:{}'.format(op.lower(), ctx[0], ctx[1])
    module = get_module(Fundamental, key, ctx, op_type=op)
    return module.forward(input, value, out)
github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / builtin.py View on Github external
def _fill(input, shape, value):
    dev = MakeDevice(inputs=[input]); ndim = len(shape)
    key = 'Fill/{}/dtype:{}/ndim:{}/value:{}' \
        .format(dev, input.dtype, ndim, value)
    module = get_module(
        Fill, key, dev,
        ndim=ndim,
        value=value,
        dtype=input.dtype,
    )
    return module.forward(input, shape)