How to use the torchgan.logging.visualize.Visualize function in torchgan

To help you get started, we’ve selected a few torchgan examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github torchgan / torchgan / torchgan / logging / visualize.py View on Github external
running_generator_loss = (
            trainer.loss_information["generator_losses"]
            / trainer.loss_information["generator_iters"]
        )
        running_discriminator_loss = (
            trainer.loss_information["discriminator_losses"]
            / trainer.loss_information["discriminator_iters"]
        )
        running_losses = {
            "Running Discriminator Loss": running_discriminator_loss,
            "Running Generator Loss": running_generator_loss,
        }
        super(LossVisualize, self).__call__(running_losses, **kwargs)


class MetricVisualize(Visualize):
    r"""This class provides the Visualizations for Metrics.

    Args:
        visualize_list (list, optional): List of the functions needed for visualization.
        visdom_port (int, optional): Port to log using ``visdom``. The visdom server needs to be
            manually started at this port else an error will be thrown and the code will crash.
            This is ignored if ``VISDOM_LOGGING`` is ``0``.
        log_dir (str, optional): Directory where TensorboardX should store the logs. This is
            ignored if ``TENSORBOARD_LOGGING`` is ``0``.
        writer (tensorboardX.SummaryWriter, optonal): Send a `SummaryWriter` if you
            don't want to start a new SummaryWriter.
    """

    def log_tensorboard(self):
        r"""Tensorboard logging function. This function logs the values of the individual metrics.
        """
github torchgan / torchgan / torchgan / logging / visualize.py View on Github external
*args,
        lock_console=False,
        lock_tensorboard=False,
        lock_visdom=False,
        **kwargs
    ):
        if not lock_console and CONSOLE_LOGGING == 1:
            self.log_console(*args, **kwargs)
        if not lock_tensorboard and TENSORBOARD_LOGGING == 1:
            self.log_tensorboard(*args, **kwargs)
        if not lock_visdom and VISDOM_LOGGING == 1:
            self.log_visdom(*args, **kwargs)
        self.step_update()


class LossVisualize(Visualize):
    r"""This class provides the Visualizations for Generator and Discriminator Losses.

    Args:
        visualize_list (list, optional): List of the functions needed for visualization.
        visdom_port (int, optional): Port to log using ``visdom``. The visdom server needs to be
            manually started at this port else an error will be thrown and the code will crash.
            This is ignored if ``VISDOM_LOGGING`` is ``0``.
        log_dir (str, optional): Directory where TensorboardX should store the logs. This is
            ignored if ``TENSORBOARD_LOGGING`` is ``0``.
        writer (tensorboardX.SummaryWriter, optonal): Send a `SummaryWriter` if you
            don't want to start a new SummaryWriter.
    """

    def log_tensorboard(self, running_losses):
        r"""Tensorboard logging function. This function logs the following:
github torchgan / torchgan / torchgan / logging / visualize.py View on Github external
model.zero_grad()

    def report_end_epoch(self):
        r"""Prints to the console at the end of the epoch.
        """
        if CONSOLE_LOGGING == 1:
            for key, val in self.logs.items():
                print("{} Mean Gradients : {}".format(key, sum(val) / len(val)))

    def __call__(self, trainer, **kwargs):
        for name in trainer.model_names:
            super(GradientVisualize, self).__call__(name, **kwargs)
            self.logs[name].append(0.0)


class ImageVisualize(Visualize):
    r"""This class provides the Logging for the Images.

    Args:
        trainer (torchgan.trainer.Trainer): The base trainer used for training.
        visdom_port (int, optional): Port to log using ``visdom``. The visdom server needs to be
            manually started at this port else an error will be thrown and the code will crash.
            This is ignored if ``VISDOM_LOGGING`` is ``0``.
        log_dir (str, optional): Directory where TensorboardX should store the logs. This is
            ignored if ``TENSORBOARD_LOGGING`` is ``0``.
        writer (tensorboardX.SummaryWriter, optonal): Send a `SummaryWriter` if you
            don't want to start a new SummaryWriter.
        test_noise (torch.Tensor, optional): If provided then it will be used as the noise for image
            sampling.
        nrow (int, optional): Number of rows in which the image is to be stored.
    """
github torchgan / torchgan / torchgan / logging / visualize.py View on Github external
print("{} : {}".format(name, val[-1]))

    def log_visdom(self):
        r"""Visdom logging function. This function logs the values of the individual metrics.
        """
        for name, value in self.logs.items():
            self.vis.line(
                [value[-1]],
                [self.step],
                win=name,
                update="append",
                opts=dict(title=name, xlabel="Time Step", ylabel="Metric Value"),
            )


class GradientVisualize(Visualize):
    r"""This class provides the Visualizations for the Gradients.

    Args:
        visualize_list (list, optional): List of the functions needed for visualization.
        visdom_port (int, optional): Port to log using ``visdom``. The visdom server needs to be
            manually started at this port else an error will be thrown and the code will crash.
            This is ignored if ``VISDOM_LOGGING`` is ``0``.
        log_dir (str, optional): Directory where TensorboardX should store the logs. This is
            ignored if ``TENSORBOARD_LOGGING`` is ``0``.
        writer (tensorboardX.SummaryWriter, optonal): Send a `SummaryWriter` if you
            don't want to start a new SummaryWriter.
    """

    def __init__(self, visualize_list, visdom_port=8097, log_dir=None, writer=None):
        if visualize_list is None or len(visualize_list) == 0:
            raise Exception("Gradient Visualizer requires list of model names")