Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def forward(self, c_dis, c_cont, dist_dis, dist_cont):
r"""Computes the loss for the given input.
Args:
c_dis (int): The discrete latent code sampled from the prior.
c_cont (int): The continuous latent code sampled from the prior.
dist_dis (torch.distributions.Distribution): The auxilliary distribution :math:`Q(c|x)` over the
discrete latent code output by the discriminator.
dist_cont (torch.distributions.Distribution): The auxilliary distribution :math:`Q(c|x)` over the
continuous latent code output by the discriminator.
Returns:
scalar if reduction is applied else Tensor with dimensions (N, \*).
"""
return mutual_information_penalty(
c_dis, c_cont, dist_dis, dist_cont, reduction=self.reduction
)