How to use the dragon.core.workspace function in dragon

To help you get started, we’ve selected a few dragon examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github seetaresearch / Dragon / Dragon / python / dragon / core / helper.py View on Github external
def get_index_and_name(cls, prefix='Op'):
        name = _workspace.GetDummyName(prefix, domain='Operator')
        try:
            _, op_idx = name.split('_')
        except:
            name = _workspace.GetDummyName(prefix, domain='Operator')
            _, op_idx = name.split('_')
        return int(op_idx), name
github seetaresearch / Dragon / Dragon / python / dragon / updaters.py View on Github external
def register_in_workspace(self):
        if not self._registered:
            for k, v in self._defaults.items():
                _workspace.FeedTensor(
                    self._slot + "/" + k, v,
                        dtype='float32', force_cpu=True)
            self._registered = True
            if self._verbose:
                print('---------------------------------------------------------')
                print('Optimizer: {}, Using config:'.format(self.type(True)))
                pprint.pprint(self._defaults)
                print('---------------------------------------------------------')
github seetaresearch / Dragon / Dragon / python / dragon / vm / caffe / net.py View on Github external
----------
        The implementation of `ForwardBackward(net.cpp, L85)`_.

        """
        if hasattr(self, '_function'): return self._function

        for loss in self.losses:
            for var in self.trainable_variables:
                _Grad(loss, var)

        self._function = _Function(
            outputs=[self.blobs[key].data
                for key in self.outputs])

        if hasattr(self, '_model'):
            _workspace.Restore(self._model, format='caffe')

        return self._function
github seetaresearch / Dragon / Dragon / python / dragon / core / tensor.py View on Github external
def name(self, value):
        if value != '':
            self._name = _workspace.GetDummyName(
                _scope.get_default_name_scope() + value
                    if value else 'Tensor', domain='Tensor')
        else:
            # Set it manually for same cases
            self._name = value
github seetaresearch / Dragon / examples / Seg-FCN / surgery.py View on Github external
print 'dropping', p
            continue
        for i in range(len(net.params[p])):
            if i > (len(new_net.params[p]) - 1):
                print 'dropping', p, i
                break
            print 'copying', p, i
            net_param = ws.FetchTensor(net.params[p][i].data)
            new_net_param = ws.FetchTensor(new_net.params[p][i].data)
            name = new_net.params[p][i].data._name
            if net_param.shape != new_net_param.shape:
                print 'coercing', p, i, 'from', net_param.shape, 'to', new_net_param.shape
            else:
                pass
            new_net_param.flat = new_net_param.flat
            ws.FeedTensor(name, new_net_param)
github seetaresearch / Dragon / Dragon / python / dragon / vm / tensorflow / client / session.py View on Github external
def run(self, feed_dict=None):
        for i, func in enumerate(self.functions):
            if i == 0 and feed_dict is not None:
                for tensor, value in feed_dict.items():
                    _workspace.FeedTensor(tensor, value)
            func(return_outputs=False)
github seetaresearch / Dragon / Dragon / python / dragon / vm / torch / ops / modules / base.py View on Github external
def set_arg_i64(self, name, value):
        _workspace.get_default_workspace() \
            .FeedTensor(
                name,
                numpy.array(value, 'int64'),
                self._arg_dev,
            )