How to use the torchgan.trainer.Trainer function in torchgan

To help you get started, we’ve selected a few torchgan examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github torchgan / torchgan / tests / torchgan / test_trainer.py View on Github external
"discriminator": {
                "name": ACGANDiscriminator,
                "args": {"num_classes": 10, "in_channels": 1, "step_channels": 4},
                "optimizer": {
                    "name": Adam,
                    "args": {"lr": 0.0002, "betas": (0.5, 0.999)},
                },
            },
        }
        losses_list = [
            MinimaxGeneratorLoss(),
            MinimaxDiscriminatorLoss(),
            AuxiliaryClassifierGeneratorLoss(),
            AuxiliaryClassifierDiscriminatorLoss(),
        ]
        trainer = Trainer(
            network_params,
            losses_list,
            sample_size=1,
            epochs=1,
            device=torch.device("cpu"),
        )
        trainer(mnist_dataloader())
github torchgan / torchgan / tests / torchgan / test_trainer.py View on Github external
"optimizer": {
                    "name": Adam,
                    "args": {"lr": 0.0002, "betas": (0.5, 0.999)},
                },
            },
            "discriminator": {
                "name": ConditionalGANDiscriminator,
                "args": {"num_classes": 10, "in_channels": 1, "step_channels": 4},
                "optimizer": {
                    "name": Adam,
                    "args": {"lr": 0.0002, "betas": (0.5, 0.999)},
                },
            },
        }
        losses_list = [MinimaxGeneratorLoss(), MinimaxDiscriminatorLoss()]
        trainer = Trainer(
            network_params,
            losses_list,
            sample_size=1,
            epochs=1,
            device=torch.device("cpu"),
        )
        trainer(mnist_dataloader())
github torchgan / torchgan / tests / torchgan / test_trainer.py View on Github external
"optimizer": {
                    "name": Adam,
                    "args": {"lr": 0.0002, "betas": (0.5, 0.999)},
                },
            },
            "discriminator": {
                "name": DCGANDiscriminator,
                "args": {"in_channels": 1, "step_channels": 4},
                "optimizer": {
                    "name": Adam,
                    "args": {"lr": 0.0002, "betas": (0.5, 0.999)},
                },
            },
        }
        losses_list = [MinimaxGeneratorLoss(), MinimaxDiscriminatorLoss()]
        trainer = Trainer(
            network_params,
            losses_list,
            sample_size=1,
            epochs=1,
            device=torch.device("cpu"),
        )
        trainer(mnist_dataloader())
github torchgan / model-zoo / gman / gman.py View on Github external
trainer = ParallelTrainer(
            network_configuration,
            losses,
            args.list_gpus,
            epochs=args.epochs,
            sample_size=args.sample_size,
            checkpoints=args.checkpoint,
            retain_checkpoints=1,
            recon=args.reconstructions,
        )
    else:
        if args.cpu == 1:
            device = torch.device("cpu")
        else:
            device = torch.device("cuda:0")
        trainer = Trainer(
            network_configuration,
            losses,
            device=device,
            epochs=args.epochs,
            sample_size=args.sample_size,
            checkpoints=args.checkpoint,
            retain_checkpoints=1,
            recon=args.reconstructions,
        )

    train_dataset = dataset(
        root=args.data_dir,
        train=True,
        download=True,
        transform=transformations)
github torchgan / model-zoo / binarygan / binarygan.py View on Github external
trainer = ParallelTrainer(
            network_config,
            losses_list,
            args.list_gpus,
            epochs=args.epochs,
            sample_size=args.sample_size,
            checkpoints=args.checkpoint,
            retain_checkpoints=1,
            recon=args.reconstructions,
        )
    else:
        if args.cpu == 1:
            device = torch.device("cpu")
        else:
            device = torch.device("cuda:0")
        trainer = Trainer(
            network_config,
            losses_list,
            device=device,
            epochs=args.epochs,
            sample_size=args.sample_size,
            checkpoints=args.checkpoint,
            retain_checkpoints=1,
            recon=args.reconstructions,
        )

    # Transforms to get Binarized MNIST
    dataset = dsets.MNIST(
        root=args.data_dir,
        train=True,
        transform=transforms.Compose([
            transforms.Resize((32, 32)),