How to use the crypten.communicator.get function in crypten

To help you get started, we’ve selected a few crypten examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github facebookresearch / CrypTen / test / test_nn.py View on Github external
if isinstance(encr_module, crypten.nn.Graph):
                                for encr_node in encr_module.modules():
                                    if hasattr(encr_node, key):
                                        encr_param = getattr(encr_node, key)
                                        break

                            # or get it from the crypten Module directly:
                            else:
                                encr_param = getattr(encr_module, key)

                            # compare with reference:
                            # NOTE: Because some parameters are initialized randomly
                            # with different values on each process, we only want to
                            # check that they are consistent with source parameter value
                            reference = getattr(module, key)
                            src_reference = comm.get().broadcast(reference, src=0)
                            msg = "parameter %s in %s incorrect" % (key, module_name)
                            if not encrypted:
                                encr_param = crypten.cryptensor(encr_param)
                            self._check(encr_param, src_reference, msg)

                # compare model outputs:
                self.assertTrue(encr_module.training, "training value incorrect")
                reference = module(input)
                encr_output = encr_module(encr_input)
                self._check(encr_output, reference, "%s forward failed" % module_name)

                # test backward pass:
                reference.backward(torch.ones(reference.size()))
                encr_output.backward()
                if wrap:  # you cannot get input gradients on MPCTensor inputs
                    self._check(
github facebookresearch / CrypTen / crypten / __init__.py View on Github external
# integer here so other parties cannot guess its value.

    # We sometimes get here from a forked process, which causes all parties
    # to have the same RNG state. Reset the seed to make sure RNG streams
    # are different in all the parties. We use numpy's random here since
    # setting its seed to None will produce different seeds even from
    # forked processes.
    import numpy

    numpy.random.seed(seed=None)
    next_seed = torch.tensor(numpy.random.randint(-2 ** 63, 2 ** 63 - 1, (1,)))
    prev_seed = torch.LongTensor([0])  # placeholder

    # Send random seed to next party, receive random seed from prev party
    world_size = comm.get().get_world_size()
    rank = comm.get().get_rank()
    if world_size >= 2:  # Otherwise sending seeds will segfault.
        next_rank = (rank + 1) % world_size
        prev_rank = (next_rank - 2) % world_size

        req0 = comm.get().isend(tensor=next_seed, dst=next_rank)
        req1 = comm.get().irecv(tensor=prev_seed, src=prev_rank)

        req0.wait()
        req1.wait()
    else:
        prev_seed = next_seed

    # Seed Generators
    comm.get().g0.manual_seed(next_seed.item())
    comm.get().g1.manual_seed(prev_seed.item())
github facebookresearch / CrypTen / crypten / __init__.py View on Github external
# NOTE: Chosen seed can be any number, but we choose as a random 64-bit
    # integer here so other parties cannot guess its value.

    # We sometimes get here from a forked process, which causes all parties
    # to have the same RNG state. Reset the seed to make sure RNG streams
    # are different in all the parties. We use numpy's random here since
    # setting its seed to None will produce different seeds even from
    # forked processes.
    import numpy

    numpy.random.seed(seed=None)
    next_seed = torch.tensor(numpy.random.randint(-2 ** 63, 2 ** 63 - 1, (1,)))
    prev_seed = torch.LongTensor([0])  # placeholder

    # Send random seed to next party, receive random seed from prev party
    world_size = comm.get().get_world_size()
    rank = comm.get().get_rank()
    if world_size >= 2:  # Otherwise sending seeds will segfault.
        next_rank = (rank + 1) % world_size
        prev_rank = (next_rank - 2) % world_size

        req0 = comm.get().isend(tensor=next_seed, dst=next_rank)
        req1 = comm.get().irecv(tensor=prev_seed, src=prev_rank)

        req0.wait()
        req1.wait()
    else:
        prev_seed = next_seed

    # Seed Generators
    comm.get().g0.manual_seed(next_seed.item())
    comm.get().g1.manual_seed(prev_seed.item())
github facebookresearch / CrypTen / examples / tfe_benchmarks / tfe_benchmarks.py View on Github external
def save_checkpoint(
    state, is_best, filename="checkpoint.pth.tar", model_best="model_best.pth.tar"
):
    # TODO: use crypten.save() in future.
    rank = comm.get().get_rank()
    # only save for process rank = 0
    if rank == 0:
        torch.save(state, filename)
        if is_best:
            shutil.copyfile(filename, model_best)
github facebookresearch / CrypTen / crypten / mpc / primitives / binary.py View on Github external
def reveal(self):
        """Get plaintext without any downscaling"""
        shares = comm.get().all_gather(self.share)
        result = shares[0]
        for x in shares[1:]:
            result = result ^ x
        return result
github facebookresearch / CrypTen / crypten / mpc / provider / ttp_provider.py View on Github external
def generate_additive_triple(size0, size1, op, *args, **kwargs):
        """Generate multiplicative triples of given sizes"""
        generator = TTPClient.get().generator

        a = generate_random_ring_element(size0, generator=generator)
        b = generate_random_ring_element(size1, generator=generator)
        if comm.get().get_rank() == 0:
            # Request c from TTP
            c = TTPClient.get().ttp_request(
                "additive", size0, size1, op, *args, **kwargs
            )
        else:
            # TODO: Compute size without executing computation
            c_size = getattr(torch, op)(a, b, *args, **kwargs).size()
            c = generate_random_ring_element(c_size, generator=generator)

        a = ArithmeticSharedTensor.from_shares(a, precision=0)
        b = ArithmeticSharedTensor.from_shares(b, precision=0)
        c = ArithmeticSharedTensor.from_shares(c, precision=0)
        return a, b, c
github facebookresearch / CrypTen / crypten / mpc / primitives / beaver.py View on Github external
[theta_x] = theta_z + [beta_xr] - [theta_r] - [eta_xr]

    Where [theta_i] is the wraps for a variable i
          [beta_ij] is the differential wraps for variables i and j
          [eta_ij]  is the plaintext wraps for variables i and j

    Note: Since [eta_xr] = 0 with probability 1 - |x| / Q for modulus Q, we
    can make the assumption that [eta_xr] = 0 with high probability.
    """
    provider = crypten.mpc.get_default_provider()
    r, theta_r = provider.wrap_rng(x.size())
    beta_xr = theta_r.clone()
    beta_xr._tensor = count_wraps([x._tensor, r._tensor])

    z = x + r
    theta_z = comm.get().gather(z._tensor, 0)
    theta_x = beta_xr - theta_r

    # TODO: Incorporate eta_xr
    if x.rank == 0:
        theta_z = count_wraps(theta_z)
        theta_x._tensor += theta_z
    return theta_x