How to use the crypten.init function in crypten

To help you get started, we’ve selected a few crypten examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github facebookresearch / CrypTen / test / test_arithmetic.py View on Github external
def setUp(self):
        super().setUp()
        # We don't want the main process (rank -1) to initialize the communcator
        if self.rank >= 0:
            crypten.init()
github facebookresearch / CrypTen / test / test_common.py View on Github external
for float in [False, True]:
            if float:
                fpe = FixedPointEncoder(precision_bits=16)
            else:
                fpe = FixedPointEncoder(precision_bits=0)
            tensor = get_test_tensor(float=float)
            decoded = fpe.decode(fpe.encode(tensor))
            self._check(
                decoded,
                tensor,
                "Encoding/decoding a %s failed." % "float" if float else "long",
            )

        # Make sure encoding a subclass of CrypTensor is a no-op
        crypten.mpc.set_default_provider(crypten.mpc.provider.TrustedFirstParty)
        crypten.init()

        tensor = get_test_tensor(float=True)
        encrypted_tensor = crypten.cryptensor(tensor)
        encrypted_tensor = fpe.encode(encrypted_tensor)
        self._check(
            encrypted_tensor.get_plain_text(),
            tensor,
            "Encoding an EncryptedTensor failed.",
        )

        # Try a few other types.
        fpe = FixedPointEncoder(precision_bits=0)
        for dtype in [torch.uint8, torch.int8, torch.int16]:
            tensor = torch.zeros(5, dtype=dtype).random_()
            decoded = fpe.decode(fpe.encode(tensor)).type(dtype)
            self._check(decoded, tensor, "Encoding/decoding a %s failed." % dtype)
github facebookresearch / CrypTen / test / test_crypten.py View on Github external
def setUp(self):
        super().setUp()
        if self.rank >= 0:
            crypten.init()
            crypten.set_default_backend(crypten.mpc)
github facebookresearch / CrypTen / test / test_tensorboard.py View on Github external
def setUp(self):
        super().setUp()
        if self.rank >= 0:
            crypten.init()
github facebookresearch / CrypTen / test / test_binary.py View on Github external
def setUp(self):
        super().setUp()
        # We don't want the main process (rank -1) to initialize the communcator
        if self.rank >= 0:
            crypten.init()
github facebookresearch / CrypTen / test / benchmark_mpc.py View on Github external
def setUp(self):
        super().setUp()
        # We don't want the main process (rank -1) to initialize the communcator
        if self.rank == self.MAIN_PROCESS_RANK:
            return

        crypten.init()

        torch.manual_seed(0)

        self.sizes = [(1, 8), (1, 16), (1, 32)]

        self.int_tensors = [
            get_random_test_tensor(size=size, is_float=False) for size in self.sizes
        ]
        self.int_operands = [
            (
                get_random_test_tensor(size=size, is_float=False),
                get_random_test_tensor(size=size, is_float=False),
            )
            for size in self.sizes
        ]
        self.float_tensors = [
github facebookresearch / CrypTen / examples / mpc_linear_svm / mpc_linear_svm.py View on Github external
def run_mpc_linear_svm(
    epochs=50, examples=50, features=100, lr=0.5, skip_plaintext=False
):
    crypten.init()

    # Set random seed for reproducibility
    torch.manual_seed(1)

    # Initialize x, y, w, b
    x = torch.randn(features, examples)
    w_true = torch.randn(1, features)
    b_true = torch.randn(1)
    y = w_true.matmul(x) + b_true
    y = y.sign()

    if not skip_plaintext:
        logging.info("==================")
        logging.info("PyTorch Training")
        logging.info("==================")
        w_torch, b_torch = train_linear_svm(x, y, lr=lr, print_time=True)
github facebookresearch / CrypTen / examples / bandits / launcher.py View on Github external
def _run_experiment(args):
    if args.plaintext:
        import plain_contextual_bandits as bandits
    else:
        import private_contextual_bandits as bandits

    learner_func = build_learner(args, bandits, download_mnist)
    import crypten

    crypten.init()
    learner_func()
github facebookresearch / CrypTen / examples / tfe_benchmarks / tfe_benchmarks.py View on Github external
start_epoch=0,
    batch_size=256,
    lr=0.01,
    momentum=0.9,
    weight_decay=1e-6,
    print_freq=10,
    resume="",
    evaluate=True,
    seed=None,
    skip_plaintext=False,
    save_checkpoint_dir="/tmp/tfe_benchmarks",
    save_modelbest_dir="/tmp/tfe_benchmarks_best",
    context_manager=None,
    mnist_dir=None,
):
    crypten.init()

    if seed is not None:
        random.seed(seed)
        torch.manual_seed(seed)

    # create model
    model = create_benchmark_model(network)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(
        model.parameters(), lr, momentum=momentum, weight_decay=weight_decay
    )

    # optionally resume from a checkpoint
github facebookresearch / CrypTen / examples / multiprocess_launcher.py View on Github external
def _run_process(cls, rank, world_size, env, run_process_fn, fn_args):
        for env_key, env_value in env.items():
            os.environ[env_key] = env_value
        os.environ["RANK"] = str(rank)
        orig_logging_level = logging.getLogger().level
        logging.getLogger().setLevel(logging.INFO)
        crypten.init()
        logging.getLogger().setLevel(orig_logging_level)
        if fn_args is None:
            run_process_fn()
        else:
            run_process_fn(fn_args)