How to use the poutyne.torch_apply function in Poutyne

To help you get started, we’ve selected a few Poutyne examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github GRAAL-Research / poutyne / tests / test_utils.py View on Github external
def test_apply_on_dict(self):
        my_dict = {}
        for k in ['a', 'b', 'c']:
            my_dict[k] = MagicMock(spec=torch.Tensor)
        torch_apply(my_dict, lambda t: t.cpu())
        self._test_method_calls(list(my_dict.values()))
github GRAAL-Research / poutyne / tests / test_utils.py View on Github external
def test_apply_on_object_with_no_tensor(self):
        my_obj = {'a': 5, 'b': 3.141592, 'c': {'d': [1, 2, 3]}}
        ret = torch_apply(my_obj, lambda t: t.cpu())
        self.assertEqual(ret, my_obj)
        self.assertFalse(ret is my_obj)
github GRAAL-Research / poutyne / tests / test_utils.py View on Github external
def test_apply_on_recursive_list(self):
        my_list = [MagicMock(spec=torch.Tensor) for _ in range(2)]
        my_list.append([MagicMock(spec=torch.Tensor) for _ in range(3)])
        my_list += [MagicMock(spec=torch.Tensor) for _ in range(1)]
        torch_apply(my_list, lambda t: t.cpu())
        self._test_method_calls(my_list[:2] + my_list[2] + my_list[3:])
github GRAAL-Research / poutyne / tests / test_utils.py View on Github external
def test_apply_on_recursive_tuple(self):
        my_tuple = tuple(MagicMock(spec=torch.Tensor) for _ in range(2))
        my_tuple += (tuple(MagicMock(spec=torch.Tensor) for _ in range(3)), )
        my_tuple += tuple(MagicMock(spec=torch.Tensor) for _ in range(1))
        torch_apply(my_tuple, lambda t: t.cpu())
        self._test_method_calls(my_tuple[:2] + my_tuple[2] + my_tuple[3:])
github GRAAL-Research / poutyne / tests / test_utils.py View on Github external
def test_apply_on_list(self):
        my_list = [MagicMock(spec=torch.Tensor) for _ in range(10)]
        torch_apply(my_list, lambda t: t.cpu())
        self._test_method_calls(my_list)
github GRAAL-Research / poutyne / tests / test_utils.py View on Github external
def test_apply_on_tuple(self):
        my_tuple = tuple(MagicMock(spec=torch.Tensor) for _ in range(10))
        torch_apply(my_tuple, lambda t: t.cpu())
        self._test_method_calls(my_tuple)
github GRAAL-Research / poutyne / tests / test_utils.py View on Github external
def test_apply_on_recursive_data_structure(self):
        my_obj = {
            'a': [MagicMock(spec=torch.Tensor) for _ in range(3)],
            'b': tuple(MagicMock(spec=torch.Tensor) for _ in range(2)),
            'c': {
                'd': [MagicMock(spec=torch.Tensor) for _ in range(3)]
            },
            'e': MagicMock(spec=torch.Tensor)
        }
        torch_apply(my_obj, lambda t: t.cpu())
        self._test_method_calls(my_obj['a'] + list(my_obj['b']) + my_obj['c']['d'] + [my_obj['e']])
github GRAAL-Research / poutyne / tests / test_utils.py View on Github external
def test_apply_with_replacement_to_no_tensor(self):
        my_obj = [MagicMock(spec=torch.Tensor)]
        ret = torch_apply(my_obj, lambda t: 123)
        self.assertEqual(ret, [123])
github GRAAL-Research / poutyne / tests / test_utils.py View on Github external
def test_apply_on_recursive_dict(self):
        my_dict = {}
        my_dict['a'] = MagicMock(spec=torch.Tensor)
        my_dict['b'] = {}
        for k in ['c', 'd']:
            my_dict['b'][k] = MagicMock(spec=torch.Tensor)
        torch_apply(my_dict, lambda t: t.cpu())
        self._test_method_calls([my_dict['a'], *my_dict['b'].values()])