How to use crypten - 10 common examples

To help you get started, we’ve selected a few crypten examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github facebookresearch / CrypTen / test / test_arithmetic.py View on Github external
"""Test convolution of encrypted tensor with public/private tensors."""
        for kernel_type in [lambda x: x, ArithmeticSharedTensor]:
            for matrix_width in range(2, 5):
                for kernel_width in range(1, matrix_width):
                    for padding in range(kernel_width // 2 + 1):
                        matrix_size = (5, matrix_width)
                        matrix = get_random_test_tensor(size=matrix_size, is_float=True)

                        kernel_size = (kernel_width, kernel_width)
                        kernel = get_random_test_tensor(size=kernel_size, is_float=True)

                        matrix = matrix.unsqueeze(0).unsqueeze(0)
                        kernel = kernel.unsqueeze(0).unsqueeze(0)

                        reference = F.conv2d(matrix, kernel, padding=padding)
                        encrypted_matrix = ArithmeticSharedTensor(matrix)
                        encrypted_kernel = kernel_type(kernel)
                        with self.benchmark(
                            kernel_type=kernel_type.__name__, matrix_width=matrix_width
                        ) as bench:
                            for _ in bench.iters:
                                encrypted_conv = encrypted_matrix.conv2d(
                                    encrypted_kernel, padding=padding
                                )

                        self._check(encrypted_conv, reference, "conv2d failed")
github facebookresearch / CrypTen / crypten / mpc / mpc.py View on Github external
if isinstance(p, float) and int(p) == p:
            p = int(p)

        if not isinstance(p, int):
            raise TypeError(
                "pow must take an integer exponent. For non-integer powers, use"
                " pos_pow with positive-valued base."
            )
        if p < -1:
            return self.reciprocal(**kwargs).pow(-p)
        elif p == -1:
            return self.reciprocal(**kwargs)
        elif p == 0:
            # Note: This returns 0 ** 0 -> 1 when inputs have zeros.
            # This is consistent with PyTorch's pow function.
            return MPCTensor(torch.ones(self.size()))
        elif p == 1:
            return self.clone()
        elif p == 2:
            return self.square()
        elif p % 2 == 0:
            return self.square().pow(p // 2)
        else:
            return self.square().mul_(self).pow((p - 1) // 2)
github facebookresearch / CrypTen / crypten / gradients.py View on Github external
grad_kernel.size(1) // batch_size,
            grad_kernel.size(2),
            grad_kernel.size(3),
        )
        grad_kernel = (
            grad_kernel.sum(dim=0)
            .view(in_channels, out_channels, grad_kernel.size(2), grad_kernel.size(3))
            .transpose(0, 1)
        )
        grad_kernel = grad_kernel.narrow(2, 0, kernel_size_y)
        grad_kernel = grad_kernel.narrow(3, 0, kernel_size_x)
        return (grad_input, grad_kernel)


@register_function("batchnorm")
class AutogradBatchNorm(AutogradFunction):
    @staticmethod
    def forward(
        ctx,
        input,
        running_mean=None,
        running_var=None,
        training=False,
        eps=1e-05,
        momentum=0.1,
    ):
        """
        Computes forward step of batch norm by normalizing x
            and returning weight * x_norm + bias.

        Running mean and var are computed over the `C` dimension for an input
        of size `(N, C, +)`.
github facebookresearch / CrypTen / crypten / mpc / primitives / arithmetic.py View on Github external
def div_(self, y):
        """Divide two tensors element-wise"""
        # TODO: Add test coverage for this code path (next 4 lines)
        if isinstance(y, float) and int(y) == y:
            y = int(y)
        if is_float_tensor(y) and y.frac().eq(0).all():
            y = y.long()

        if isinstance(y, int) or is_int_tensor(y):
            # Truncate protocol for dividing by public integers:
            if comm.get().get_world_size() > 2:
                wraps = self.wraps()
                self.share /= y
                # NOTE: The multiplication here must be split into two parts
                # to avoid long out-of-bounds when y <= 2 since (2 ** 63) is
                # larger than the largest long integer.
                self -= wraps * 4 * (int(2 ** 62) // y)
            else:
                self.share /= y
            return self

        # Otherwise multiply by reciprocal
        if isinstance(y, float):
            y = torch.FloatTensor([y])

        assert is_float_tensor(y), "Unsupported type for div_: %s" % type(y)
        return self.mul_(y.reciprocal())
github facebookresearch / CrypTen / crypten / mpc / primitives / arithmetic.py View on Github external
if comm.get().get_world_size() > 2:
                wraps = self.wraps()
                self.share /= y
                # NOTE: The multiplication here must be split into two parts
                # to avoid long out-of-bounds when y <= 2 since (2 ** 63) is
                # larger than the largest long integer.
                self -= wraps * 4 * (int(2 ** 62) // y)
            else:
                self.share /= y
            return self

        # Otherwise multiply by reciprocal
        if isinstance(y, float):
            y = torch.FloatTensor([y])

        assert is_float_tensor(y), "Unsupported type for div_: %s" % type(y)
        return self.mul_(y.reciprocal())
github facebookresearch / CrypTen / crypten / mpc / primitives / arithmetic.py View on Github external
if comm.get().get_world_size() > 2:
                wraps = self.wraps()
                self.share /= y
                # NOTE: The multiplication here must be split into two parts
                # to avoid long out-of-bounds when y <= 2 since (2 ** 63) is
                # larger than the largest long integer.
                self -= wraps * 4 * (int(2 ** 62) // y)
            else:
                self.share /= y
            return self

        # Otherwise multiply by reciprocal
        if isinstance(y, float):
            y = torch.FloatTensor([y])

        assert is_float_tensor(y), "Unsupported type for div_: %s" % type(y)
        return self.mul_(y.reciprocal())
github facebookresearch / CrypTen / crypten / mpc / primitives / arithmetic.py View on Github external
def div_(self, y):
        """Divide two tensors element-wise"""
        # TODO: Add test coverage for this code path (next 4 lines)
        if isinstance(y, float) and int(y) == y:
            y = int(y)
        if is_float_tensor(y) and y.frac().eq(0).all():
            y = y.long()

        if isinstance(y, int) or is_int_tensor(y):
            # Truncate protocol for dividing by public integers:
            if comm.get().get_world_size() > 2:
                wraps = self.wraps()
                self.share /= y
                # NOTE: The multiplication here must be split into two parts
                # to avoid long out-of-bounds when y <= 2 since (2 ** 63) is
                # larger than the largest long integer.
                self -= wraps * 4 * (int(2 ** 62) // y)
            else:
                self.share /= y
            return self

        # Otherwise multiply by reciprocal
github facebookresearch / CrypTen / test / test_common.py View on Github external
else:
                fpe = FixedPointEncoder(precision_bits=0)
            tensor = get_test_tensor(float=float)
            decoded = fpe.decode(fpe.encode(tensor))
            self._check(
                decoded,
                tensor,
                "Encoding/decoding a %s failed." % "float" if float else "long",
            )

        # Make sure encoding a subclass of CrypTensor is a no-op
        crypten.mpc.set_default_provider(crypten.mpc.provider.TrustedFirstParty)
        crypten.init()

        tensor = get_test_tensor(float=True)
        encrypted_tensor = crypten.cryptensor(tensor)
        encrypted_tensor = fpe.encode(encrypted_tensor)
        self._check(
            encrypted_tensor.get_plain_text(),
            tensor,
            "Encoding an EncryptedTensor failed.",
        )

        # Try a few other types.
        fpe = FixedPointEncoder(precision_bits=0)
        for dtype in [torch.uint8, torch.int8, torch.int16]:
            tensor = torch.zeros(5, dtype=dtype).random_()
            decoded = fpe.decode(fpe.encode(tensor)).type(dtype)
            self._check(decoded, tensor, "Encoding/decoding a %s failed." % dtype)
github facebookresearch / CrypTen / test / test_arithmetic.py View on Github external
def setUp(self):
        super().setUp()
        # We don't want the main process (rank -1) to initialize the communcator
        if self.rank >= 0:
            crypten.init()
github facebookresearch / CrypTen / test / test_common.py View on Github external
for float in [False, True]:
            if float:
                fpe = FixedPointEncoder(precision_bits=16)
            else:
                fpe = FixedPointEncoder(precision_bits=0)
            tensor = get_test_tensor(float=float)
            decoded = fpe.decode(fpe.encode(tensor))
            self._check(
                decoded,
                tensor,
                "Encoding/decoding a %s failed." % "float" if float else "long",
            )

        # Make sure encoding a subclass of CrypTensor is a no-op
        crypten.mpc.set_default_provider(crypten.mpc.provider.TrustedFirstParty)
        crypten.init()

        tensor = get_test_tensor(float=True)
        encrypted_tensor = crypten.cryptensor(tensor)
        encrypted_tensor = fpe.encode(encrypted_tensor)
        self._check(
            encrypted_tensor.get_plain_text(),
            tensor,
            "Encoding an EncryptedTensor failed.",
        )

        # Try a few other types.
        fpe = FixedPointEncoder(precision_bits=0)
        for dtype in [torch.uint8, torch.int8, torch.int16]:
            tensor = torch.zeros(5, dtype=dtype).random_()
            decoded = fpe.decode(fpe.encode(tensor)).type(dtype)
            self._check(decoded, tensor, "Encoding/decoding a %s failed." % dtype)