How to use the torchgan.losses.loss.GeneratorLoss function in torchgan

To help you get started, we’ve selected a few torchgan examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github torchgan / torchgan / torchgan / losses / aaeloss.py View on Github external
import torch
import torch.nn.functional as F
from .loss import GeneratorLoss, DiscriminatorLoss
from ..models import AdversarialAutoEncodingGenerator
from .functional import minimax_generator_loss, minimax_discriminator_loss

__all__ = ['AdversarialAutoEncoderGeneratorLoss', 'AdversarialAutoEncoderDiscriminatorLoss']

class AdversarialAutoEncoderGeneratorLoss(GeneratorLoss):
    def __init__(self, recon_weight=0.999, gen_weight=0.001, reduction='mean', override_train_ops=None):
        super(AdversarialAutoEncoderGeneratorLoss, self).__init__(reduction, override_train_ops)
        self.gen_weight = gen_weight
        self.recon_weight = recon_weight

    def forward(self, real_inputs, gen_outputs, dgz):
        return self.recon_weight * F.mse_loss(real_inputs, gen_outputs) +\
            self.gen_weight + minimax_generator_loss(dgz, reduction=self.reduction)

    def train_ops(self, generator, discriminator, optimizer_generator, real_inputs,
                  device, batch_size, labels=None):
        if self.override_train_ops is not None:
            return self.override_train_ops(self, generator, discriminator, optimizer_generator,
                   real_inputs, device, labels)
        else:
            if isinstance(generator, AdversarialAutoEncodingGenerator):
github torchgan / torchgan / torchgan / losses / leastsquares.py View on Github external
import torch

from .functional import least_squares_discriminator_loss, least_squares_generator_loss
from .loss import DiscriminatorLoss, GeneratorLoss

__all__ = ["LeastSquaresGeneratorLoss", "LeastSquaresDiscriminatorLoss"]


class LeastSquaresGeneratorLoss(GeneratorLoss):
    r"""Least Squares GAN generator loss from `"Least Squares Generative Adversarial Networks
    by Mao et. al." `_ paper

    The loss can be described as

    .. math:: L(G) = \frac{(D(G(z)) - c)^2}{2}

    where

    - :math:`G` : Generator
    - :math:`D` : Disrciminator
    - :math:`c` : target generator label
    - :math:`z` : A sample from the noise prior

    Args:
        reduction (str, optional): Specifies the reduction to apply to the output.
github torchgan / torchgan / torchgan / trainer / base_trainer.py View on Github external
**self._get_arguments(self.loss_arg_maps[name])
                    )
                    loss_logs.logs[name].append(cur_loss)
                    if type(cur_loss) is tuple:
                        lgen, ldis, gen_iter, dis_iter = (
                            lgen + cur_loss[0],
                            ldis + cur_loss[1],
                            gen_iter + 1,
                            dis_iter + 1,
                        )
                    else:
                        # NOTE(avik-pal): We assume that it is a Discriminator Loss by default.
                        ldis, dis_iter = ldis + cur_loss, dis_iter + 1
                for model_name in self.model_names:
                    grad_logs.update_grads(model_name, getattr(self, model_name))
            elif isinstance(loss, GeneratorLoss):
                if self.loss_information["discriminator_iters"] % self.ncritic == 0:
                    cur_loss = loss.train_ops(
                        **self._get_arguments(self.loss_arg_maps[name])
                    )
                    loss_logs.logs[name].append(cur_loss)
                    lgen, gen_iter = lgen + cur_loss, gen_iter + 1
                for model_name in self.model_names:
                    model = getattr(self, model_name)
                    if isinstance(model, Generator):
                        grad_logs.update_grads(model_name, model)
            elif isinstance(loss, DiscriminatorLoss):
                if self.loss_information["generator_iters"] % self.ngen == 0:
                    cur_loss = loss.train_ops(
                        **self._get_arguments(self.loss_arg_maps[name])
                    )
                    loss_logs.logs[name].append(cur_loss)
github torchgan / torchgan / torchgan / losses / wasserstein.py View on Github external
from .functional import (
    wasserstein_discriminator_loss,
    wasserstein_generator_loss,
    wasserstein_gradient_penalty,
)
from .loss import DiscriminatorLoss, GeneratorLoss

__all__ = [
    "WassersteinGeneratorLoss",
    "WassersteinDiscriminatorLoss",
    "WassersteinGradientPenalty",
]


class WassersteinGeneratorLoss(GeneratorLoss):
    r"""Wasserstein GAN generator loss from
    `"Wasserstein GAN by Arjovsky et. al." `_ paper

    The loss can be described as:

    .. math:: L(G) = -f(G(z))

    where

    - :math:`G` : Generator
    - :math:`f` : Critic/Discriminator
    - :math:`z` : A sample from the noise prior

    Args:
        reduction (str, optional): Specifies the reduction to apply to the output.
            If ``none`` no reduction will be applied. If ``mean`` the mean of the output.
github torchgan / torchgan / torchgan / losses / historical.py View on Github external
import torch

from ..utils import reduce
from .loss import DiscriminatorLoss, GeneratorLoss

__all__ = ["HistoricalAverageGeneratorLoss", "HistoricalAverageDiscriminatorLoss"]


class HistoricalAverageGeneratorLoss(GeneratorLoss):
    r"""Historical Average Generator Loss from
    `"Improved Techniques for Training GANs
    by Salimans et. al." `_ paper

    The loss can be described as

    .. math:: || \vtheta - \frac{1}{t} \sum_{i=1}^t \vtheta[i] ||^2

    where

    - :math:`G` : Generator
    - :math: `\vtheta[i]` : Generator Parameters at Past Timestep :math: `i`

    Args:
        reduction (str, optional): Specifies the reduction to apply to the output.
            If ``none`` no reduction will be applied. If ``mean`` the outputs are averaged over batch size.
github torchgan / torchgan / torchgan / losses / energybased.py View on Github external
from ..models import AutoEncodingDiscriminator
from .functional import (
    energy_based_discriminator_loss,
    energy_based_generator_loss,
    energy_based_pulling_away_term,
)
from .loss import DiscriminatorLoss, GeneratorLoss

__all__ = [
    "EnergyBasedGeneratorLoss",
    "EnergyBasedDiscriminatorLoss",
    "EnergyBasedPullingAwayTerm",
]


class EnergyBasedGeneratorLoss(GeneratorLoss):
    r"""Energy Based GAN generator loss from `"Energy Based Generative Adversarial Network
    by Zhao et. al." `_ paper.

    The loss can be described as:

    .. math:: L(G) = D(G(z))

    where

    - :math:`G` : Generator
    - :math:`D` : Discriminator
    - :math:`z` : A sample from the noise prior

    Args:
        reduction (str, optional): Specifies the reduction to apply to the output.
            If ``none`` no reduction will be applied. If ``mean`` the outputs are averaged over batch size.
github torchgan / torchgan / torchgan / losses / minimax.py View on Github external
import torch

from .functional import minimax_discriminator_loss, minimax_generator_loss
from .loss import DiscriminatorLoss, GeneratorLoss

__all__ = ["MinimaxGeneratorLoss", "MinimaxDiscriminatorLoss"]


class MinimaxGeneratorLoss(GeneratorLoss):
    r"""Minimax game generator loss from the original GAN paper `"Generative Adversarial Networks
    by Goodfellow et. al." `_

    The loss can be described as:

    .. math:: L(G) = log(1 - D(G(z)))

    The nonsaturating heuristic is also supported:

    .. math:: L(G) = -log(D(G(z)))

    where

    - :math:`G` : Generator
    - :math:`D` : Discriminator
    - :math:`z` : A sample from the noise prior
github torchgan / torchgan / torchgan / losses / energybased.py View on Github external
if isinstance(discriminator, AutoEncodingDiscriminator):
                setattr(discriminator, "embeddings", False)
            loss = super(EnergyBasedGeneratorLoss, self).train_ops(
                generator,
                discriminator,
                optimizer_generator,
                device,
                batch_size,
                labels,
            )
            if isinstance(discriminator, AutoEncodingDiscriminator):
                setattr(discriminator, "embeddings", True)
            return loss


class EnergyBasedPullingAwayTerm(GeneratorLoss):
    r"""Energy Based Pulling Away Term from `"Energy Based Generative Adversarial Network
    by Zhao et. al." `_ paper.

    The loss can be described as:

    .. math:: f_{PT}(S) = \frac{1}{N(N-1)}\sum_i\sum_{j \neq i}\bigg(\frac{S_i^T S_j}{||S_i||\ ||S_j||}\bigg)^2

    where

    - :math:`S` : The feature output from the encoder for generated images
    - :math:`N` : Batch Size of the Input

    Args:
        pt_ratio (float, optional): The weight given to the pulling away term.
        override_train_ops (function, optional): A function is passed to this argument,
            if the default ``train_ops`` is not to be used.
github torchgan / torchgan / torchgan / losses / featurematching.py View on Github external
import torch
import torch.nn.functional as F

from ..utils import reduce
from .loss import DiscriminatorLoss, GeneratorLoss

__all__ = ["FeatureMatchingGeneratorLoss"]


class FeatureMatchingGeneratorLoss(GeneratorLoss):
    r"""Feature Matching Generator loss from
    `"Improved Training of GANs by Salimans et. al." `_ paper

    The loss can be described as:

    .. math:: L(G) = ||f(x)-f(G(z))||_2

    where

    - :math:`G` : Generator
    - :math:`f` : An intermediate activation from the discriminator
    - :math:`z` : A sample from the noise prior

    Args:
        reduction (str, optional): Specifies the reduction to apply to the output.
            If ``none`` no reduction will be applied. If ``mean`` the outputs are averaged over batch size.
github torchgan / torchgan / torchgan / losses / boundaryequilibrium.py View on Github external
import torch

from .functional import (
    boundary_equilibrium_discriminator_loss,
    boundary_equilibrium_generator_loss,
)
from .loss import DiscriminatorLoss, GeneratorLoss

__all__ = ["BoundaryEquilibriumGeneratorLoss", "BoundaryEquilibriumDiscriminatorLoss"]


class BoundaryEquilibriumGeneratorLoss(GeneratorLoss):
    r"""Boundary Equilibrium GAN generator loss from
    `"BEGAN : Boundary Equilibrium Generative Adversarial Networks
    by Berthelot et. al." `_ paper

    The loss can be described as

    .. math:: L(G) = D(G(z))

    where

    - :math:`G` : Generator
    - :math:`D` : Discriminator

    Args:
        reduction (str, optional): Specifies the reduction to apply to the output.
            If ``none`` no reduction will be applied. If ``mean`` the outputs are averaged over batch size.