How to use the syft.TensorBase function in syft

To help you get started, we’ve selected a few syft examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github OpenMined / PySyft / tests / test_tensor.py View on Github external
def test_masked_scatter_braodcasting_1(self):
        t = TensorBase(np.ones((2, 3)))
        source = TensorBase([1, 2, 3, 4, 5, 6])
        mask = TensorBase([0, 1, 0])
        t.masked_scatter_(mask, source)
        self.assertTrue(np.array_equal(t, TensorBase([[1, 1, 1], [1, 2, 1]])))
github OpenMined / PySyft / tests / test_tensor.py View on Github external
def test_scatter_numerical_3(self):
        t = TensorBase(np.zeros((3, 5)))
        idx = TensorBase(np.array([[0, 0, 0, 0, 0]]))
        src = TensorBase(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]))
        dim = 0
        t.scatter_(dim=dim, index=idx, src=src)
        self.assertTrue(np.array_equal(t.data, np.array([[1, 2, 3, 4, 5], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])))
github OpenMined / PySyft / tests / test_tensor.py View on Github external
def test_log_1p(self):
        t1 = TensorBase(np.array([1, 2, 3]))
        self.assertTrue(np.allclose((t1.log1p()).data, [0.69314718, 1.09861229, 1.38629436]))
github OpenMined / PySyft / tests / test_tensor.py View on Github external
def test_lerp_(self):
        t1 = TensorBase(np.array([1, 2, 3, 4]))
        t2 = TensorBase(np.array([3, 4, 5, 6]))
        weight = 0.5
        t1.lerp_(t2, weight)
        self.assertTrue(np.array_equal(t1.data, [2, 3, 4, 5]))
github OpenMined / PySyft / tests / test_tensor.py View on Github external
def testMode_axis_col(self):
        t1 = TensorBase([[1, 2, 3, 4, 5, 1, 1, 1, 1, 1], [1, 2, 3, 4, 4, 5, 6, 7, 8, 1]])
        self.assertTrue(t1.mode(axis=0), np.array([[[1, 2, 3, 4, 4, 1, 1, 1, 1, 1]], [[2, 2, 2, 2, 1, 1, 1, 1, 1, 2]]]))
github OpenMined / PySyft / tests / test_tensor.py View on Github external
def test_half_1(self):
        t1 = TensorBase(np.array([2, 3, 4]))
        self.assertTrue(np.alltrue(t1.half() == np.array([2, 3, 4]).astype('float16')))
github OpenMined / PySyft / tests / test_tensor.py View on Github external
def test_index_add_(self):
        t1 = TensorBase(np.array([[0, 0, 0], [1, 1, 1], [1, 1, 1]]))
        t2 = TensorBase(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))

        expected_0 = TensorBase(np.array([[1, 2, 3], [8, 9, 10], [5, 6, 7]]))
        t1.index_add_(0, [0, 2, 1], t2)
        self.assertEqual(expected_0, t1)

        t1 = TensorBase(np.array([[0, 0, 0], [1, 1, 1], [1, 1, 1]]))
        expected_1 = TensorBase(np.array([[1, 3, 2], [5, 7, 6], [8, 10, 9]]))
        t1.index_add_(1, [0, 2, 1], t2)
        self.assertEqual(expected_1, t1)

        with pytest.raises(TypeError):
            t1.index_add_(0, [1.0, 2, 2], t2)
        with pytest.raises(IndexError):
            t1.index_add_(0, [0, 1, 2], TensorBase([1, 2]))
        with pytest.raises(ValueError):
github OpenMined / PySyft / tests / test_tensor.py View on Github external
def test_topK(self):
        t1 = TensorBase(np.array([[900, 800, 1000, 2000, 5, 10, 20, 40, 50], [10, 11, 12, 13, 5, 6, 7, 8, 9], [30, 40, 50, 10, 8, 1, 2, 3, 4]]))
        t2 = t1.topk(3, largest=True)
        self.assertTrue(np.array_equal(t2.data, np.array([[900, 1000, 2000], [11, 12, 13], [30, 40, 50]])))
github OpenMined / PySyft / tests / test_tensor.py View on Github external
def test_random_(self):
        np.random.seed(0)
        t1 = TensorBase(np.zeros(4))
        t1.random_(low=0, high=5, size=4)
        self.assertTrue(np.array_equal(t1.data, np.array([4, 0, 3, 3])))
github OpenMined / PySyft / tests / test_math.py View on Github external
def test_matmul_2d_identity(self):
        t1 = TensorBase(np.array([[1, 0],
                                  [0, 1]]))
        t2 = TensorBase(np.array([[5.8, 6.5],
                                  [7.8, 8.9]]))
        self.assertTrue(syft.equal(syft.matmul(t1, t2), [[5.8, 6.5],
                                                         [7.8, 8.9]]))