How to use torchgan - 10 common examples

To help you get started, we’ve selected a few torchgan examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github torchgan / torchgan / tests / torchgan / test_trainer.py View on Github external
"discriminator": {
                "name": ACGANDiscriminator,
                "args": {"num_classes": 10, "in_channels": 1, "step_channels": 4},
                "optimizer": {
                    "name": Adam,
                    "args": {"lr": 0.0002, "betas": (0.5, 0.999)},
                },
            },
        }
        losses_list = [
            MinimaxGeneratorLoss(),
            MinimaxDiscriminatorLoss(),
            AuxiliaryClassifierGeneratorLoss(),
            AuxiliaryClassifierDiscriminatorLoss(),
        ]
        trainer = Trainer(
            network_params,
            losses_list,
            sample_size=1,
            epochs=1,
            device=torch.device("cpu"),
        )
        trainer(mnist_dataloader())
github torchgan / torchgan / tests / torchgan / test_trainer.py View on Github external
"optimizer": {
                    "name": Adam,
                    "args": {"lr": 0.0002, "betas": (0.5, 0.999)},
                },
            },
            "discriminator": {
                "name": ConditionalGANDiscriminator,
                "args": {"num_classes": 10, "in_channels": 1, "step_channels": 4},
                "optimizer": {
                    "name": Adam,
                    "args": {"lr": 0.0002, "betas": (0.5, 0.999)},
                },
            },
        }
        losses_list = [MinimaxGeneratorLoss(), MinimaxDiscriminatorLoss()]
        trainer = Trainer(
            network_params,
            losses_list,
            sample_size=1,
            epochs=1,
            device=torch.device("cpu"),
        )
        trainer(mnist_dataloader())
github torchgan / torchgan / tests / torchgan / test_trainer.py View on Github external
"optimizer": {
                    "name": Adam,
                    "args": {"lr": 0.0002, "betas": (0.5, 0.999)},
                },
            },
            "discriminator": {
                "name": DCGANDiscriminator,
                "args": {"in_channels": 1, "step_channels": 4},
                "optimizer": {
                    "name": Adam,
                    "args": {"lr": 0.0002, "betas": (0.5, 0.999)},
                },
            },
        }
        losses_list = [MinimaxGeneratorLoss(), MinimaxDiscriminatorLoss()]
        trainer = Trainer(
            network_params,
            losses_list,
            sample_size=1,
            epochs=1,
            device=torch.device("cpu"),
        )
        trainer(mnist_dataloader())
github torchgan / torchgan / torchgan / losses / energybased.py View on Github external
)
            if not generator.label_type == "none":
                raise Exception("EBGAN PT supports models which donot require labels")
            if not discriminator.embeddings:
                raise Exception("EBGAN PT requires the embeddings for loss computation")
            noise = torch.randn(batch_size, generator.encoding_dims, device=device)
            optimizer_generator.zero_grad()
            fake = generator(noise)
            d_hid, dgz = discriminator(fake)
            loss = self.forward(dgz, d_hid)
            loss.backward()
            optimizer_generator.step()
            return loss.item()


class EnergyBasedDiscriminatorLoss(DiscriminatorLoss):
    r"""Energy Based GAN generator loss from `"Energy Based Generative Adversarial Network
    by Zhao et. al." `_ paper

    The loss can be described as:

    .. math:: L(D) = D(x) + max(0, m - D(G(z)))

    where

    - :math:`G` : Generator
    - :math:`D` : Discriminator
    - :math:`m` : Margin Hyperparameter
    - :math:`z` : A sample from the noise prior

    .. note::
        The convergence of EBGAN is highly sensitive to hyperparameters. The ``margin``
github torchgan / torchgan / torchgan / losses / aaeloss.py View on Github external
device, batch_size, labels=None):
        if self.override_train_ops is not None:
            return self.override_train_ops(self, generator, discriminator, optimizer_generator,
                   real_inputs, device, labels)
        else:
            if isinstance(generator, AdversarialAutoEncodingGenerator):
                setattr(generator, "embeddings", False)
            recon, encodings = generator(real_inputs)
            optimizer_generator.zero_grad()
            dgz = discriminator(encodings)
            loss = self.forward(real_inputs, recon, dgz)
            loss.backward()
            optimizer_generator.step()
            return loss.item()

class AdversarialAutoEncoderDiscriminatorLoss(DiscriminatorLoss):
    def forward(self, dx, dgz):
        return minimax_discriminator_loss(dx, dgz)

    def train_ops(self, generator, discriminator, optimizer_discriminator, real_inputs,
                  device, batch_size, labels=None):
        if self.override_train_ops is not None:
            return self.override_train_ops(self, generator, discriminator, optimizer_discriminator,
                   real_inputs, device, labels)
        else:
            if isinstance(generator, AdversarialAutoEncodingGenerator):
                setattr(generator, "embeddings", True)
            encodings = generator(real_inputs).detach()
            noise = torch.randn(real_inputs.size(0), generator.encoding_dims, device=device)
            optimizer_discriminator.zero_grad()
            dx = discriminator(noise)
            dgz = discriminator(encodings)
github torchgan / torchgan / torchgan / losses / minimax.py View on Github external
def forward(self, dgz):
        r"""Computes the loss for the given input.

        Args:
            dgz (torch.Tensor) : Output of the Discriminator with generated data. It must have the
                                 dimensions (N, \*) where \* means any number of additional
                                 dimensions.

        Returns:
            scalar if reduction is applied else Tensor with dimensions (N, \*).
        """
        return minimax_generator_loss(dgz, self.nonsaturating, self.reduction)


class MinimaxDiscriminatorLoss(DiscriminatorLoss):
    r"""Minimax game discriminator loss from the original GAN paper `"Generative Adversarial Networks
    by Goodfellow et. al." `_

    The loss can be described as:

    .. math:: L(D) = -[log(D(x)) + log(1 - D(G(z)))]

    where

    - :math:`G` : Generator
    - :math:`D` : Discriminator
    - :math:`x` : A sample from the data distribution
    - :math:`z` : A sample from the noise prior

    Args:
        label_smoothing (float, optional): The factor by which the labels (1 in this case) needs
github torchgan / torchgan / torchgan / trainer / base_trainer.py View on Github external
``train_iter_custom``.

        .. warning::
            This function is needed in this exact state for the Trainer to work correctly. So it is
            highly recommended that this function is not changed even if the ``Trainer`` is subclassed.

        Returns:
            An NTuple of the ``generator loss``, ``discriminator loss``, ``number of times the generator
            was trained`` and the ``number of times the discriminator was trained``.
        """
        self.train_iter_custom()
        ldis, lgen, dis_iter, gen_iter = 0.0, 0.0, 0, 0
        loss_logs = self.logger.get_loss_viz()
        grad_logs = self.logger.get_grad_viz()
        for name, loss in self.losses.items():
            if isinstance(loss, GeneratorLoss) and isinstance(loss, DiscriminatorLoss):
                # NOTE(avik-pal): In most cases this loss is meant to optimize the Discriminator
                #                 but we might need to think of a better solution
                if self.loss_information["generator_iters"] % self.ngen == 0:
                    cur_loss = loss.train_ops(
                        **self._get_arguments(self.loss_arg_maps[name])
                    )
                    loss_logs.logs[name].append(cur_loss)
                    if type(cur_loss) is tuple:
                        lgen, ldis, gen_iter, dis_iter = (
                            lgen + cur_loss[0],
                            ldis + cur_loss[1],
                            gen_iter + 1,
                            dis_iter + 1,
                        )
                    else:
                        # NOTE(avik-pal): We assume that it is a Discriminator Loss by default.
github torchgan / torchgan / torchgan / losses / mutualinfo.py View on Github external
import torch

from .functional import mutual_information_penalty
from .loss import DiscriminatorLoss, GeneratorLoss

__all__ = ["MutualInformationPenalty"]


class MutualInformationPenalty(GeneratorLoss, DiscriminatorLoss):
    r"""Mutual Information Penalty as defined in
    `"InfoGAN : Interpretable Representation Learning by Information Maximising Generative Adversarial Nets
    by Chen et. al." `_ paper

    The loss is the variational lower bound of the mutual information between
    the latent codes and the generator distribution and is defined as

    .. math:: L(G,Q) = log(Q|x)

    where

    - :math:`x` is drawn from the generator distribution G(z,c)
    - :math:`c` drawn from the latent code prior :math:`P(c)`

    Args:
        lambd (float, optional): The scaling factor for the loss.
github torchgan / torchgan / torchgan / losses / auxclassifier.py View on Github external
label_gen = torch.randint(
                0, generator.num_classes, (batch_size,), device=device
            )
            fake = generator(noise, label_gen)
        cgz = discriminator(fake, mode="classifier")
        if generator.label_type == "required":
            loss = self.forward(cgz, labels)
        else:
            label_gen = label_gen.type(torch.LongTensor).to(device)
            loss = self.forward(cgz, label_gen)
        loss.backward()
        optimizer_generator.step()
        return loss.item()


class AuxiliaryClassifierDiscriminatorLoss(DiscriminatorLoss):
    r"""Auxiliary Classifier GAN (ACGAN) loss based on a from
    `"Conditional Image Synthesis With Auxiliary Classifier GANs
    by Odena et. al. " `_ paper

    Args:
       reduction (str, optional): Specifies the reduction to apply to the output.
            If ``none`` no reduction will be applied. If ``mean`` the outputs are averaged over batch size.
            If ``sum`` the elements of the output are summed.
       override_train_ops (function, optional): A function is passed to this argument,
            if the default ``train_ops`` is not to be used.
    """

    def forward(self, logits, labels):
        return auxiliary_classification_loss(logits, labels, self.reduction)

    def train_ops(
github torchgan / torchgan / torchgan / trainer / base_trainer.py View on Github external
# NOTE(avik-pal): We assume that it is a Discriminator Loss by default.
                        ldis, dis_iter = ldis + cur_loss, dis_iter + 1
                for model_name in self.model_names:
                    grad_logs.update_grads(model_name, getattr(self, model_name))
            elif isinstance(loss, GeneratorLoss):
                if self.loss_information["discriminator_iters"] % self.ncritic == 0:
                    cur_loss = loss.train_ops(
                        **self._get_arguments(self.loss_arg_maps[name])
                    )
                    loss_logs.logs[name].append(cur_loss)
                    lgen, gen_iter = lgen + cur_loss, gen_iter + 1
                for model_name in self.model_names:
                    model = getattr(self, model_name)
                    if isinstance(model, Generator):
                        grad_logs.update_grads(model_name, model)
            elif isinstance(loss, DiscriminatorLoss):
                if self.loss_information["generator_iters"] % self.ngen == 0:
                    cur_loss = loss.train_ops(
                        **self._get_arguments(self.loss_arg_maps[name])
                    )
                    loss_logs.logs[name].append(cur_loss)
                    ldis, dis_iter = ldis + cur_loss, dis_iter + 1
                for model_name in self.model_names:
                    model = getattr(self, model_name)
                    if isinstance(model, Discriminator):
                        grad_logs.update_grads(model_name, model)
        return lgen, ldis, gen_iter, dis_iter