Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
running_generator_loss = (
trainer.loss_information["generator_losses"]
/ trainer.loss_information["generator_iters"]
)
running_discriminator_loss = (
trainer.loss_information["discriminator_losses"]
/ trainer.loss_information["discriminator_iters"]
)
running_losses = {
"Running Discriminator Loss": running_discriminator_loss,
"Running Generator Loss": running_generator_loss,
}
super(LossVisualize, self).__call__(running_losses, **kwargs)
class MetricVisualize(Visualize):
r"""This class provides the Visualizations for Metrics.
Args:
visualize_list (list, optional): List of the functions needed for visualization.
visdom_port (int, optional): Port to log using ``visdom``. The visdom server needs to be
manually started at this port else an error will be thrown and the code will crash.
This is ignored if ``VISDOM_LOGGING`` is ``0``.
log_dir (str, optional): Directory where TensorboardX should store the logs. This is
ignored if ``TENSORBOARD_LOGGING`` is ``0``.
writer (tensorboardX.SummaryWriter, optonal): Send a `SummaryWriter` if you
don't want to start a new SummaryWriter.
"""
def log_tensorboard(self):
r"""Tensorboard logging function. This function logs the values of the individual metrics.
"""
*args,
lock_console=False,
lock_tensorboard=False,
lock_visdom=False,
**kwargs
):
if not lock_console and CONSOLE_LOGGING == 1:
self.log_console(*args, **kwargs)
if not lock_tensorboard and TENSORBOARD_LOGGING == 1:
self.log_tensorboard(*args, **kwargs)
if not lock_visdom and VISDOM_LOGGING == 1:
self.log_visdom(*args, **kwargs)
self.step_update()
class LossVisualize(Visualize):
r"""This class provides the Visualizations for Generator and Discriminator Losses.
Args:
visualize_list (list, optional): List of the functions needed for visualization.
visdom_port (int, optional): Port to log using ``visdom``. The visdom server needs to be
manually started at this port else an error will be thrown and the code will crash.
This is ignored if ``VISDOM_LOGGING`` is ``0``.
log_dir (str, optional): Directory where TensorboardX should store the logs. This is
ignored if ``TENSORBOARD_LOGGING`` is ``0``.
writer (tensorboardX.SummaryWriter, optonal): Send a `SummaryWriter` if you
don't want to start a new SummaryWriter.
"""
def log_tensorboard(self, running_losses):
r"""Tensorboard logging function. This function logs the following:
model.zero_grad()
def report_end_epoch(self):
r"""Prints to the console at the end of the epoch.
"""
if CONSOLE_LOGGING == 1:
for key, val in self.logs.items():
print("{} Mean Gradients : {}".format(key, sum(val) / len(val)))
def __call__(self, trainer, **kwargs):
for name in trainer.model_names:
super(GradientVisualize, self).__call__(name, **kwargs)
self.logs[name].append(0.0)
class ImageVisualize(Visualize):
r"""This class provides the Logging for the Images.
Args:
trainer (torchgan.trainer.Trainer): The base trainer used for training.
visdom_port (int, optional): Port to log using ``visdom``. The visdom server needs to be
manually started at this port else an error will be thrown and the code will crash.
This is ignored if ``VISDOM_LOGGING`` is ``0``.
log_dir (str, optional): Directory where TensorboardX should store the logs. This is
ignored if ``TENSORBOARD_LOGGING`` is ``0``.
writer (tensorboardX.SummaryWriter, optonal): Send a `SummaryWriter` if you
don't want to start a new SummaryWriter.
test_noise (torch.Tensor, optional): If provided then it will be used as the noise for image
sampling.
nrow (int, optional): Number of rows in which the image is to be stored.
"""
print("{} : {}".format(name, val[-1]))
def log_visdom(self):
r"""Visdom logging function. This function logs the values of the individual metrics.
"""
for name, value in self.logs.items():
self.vis.line(
[value[-1]],
[self.step],
win=name,
update="append",
opts=dict(title=name, xlabel="Time Step", ylabel="Metric Value"),
)
class GradientVisualize(Visualize):
r"""This class provides the Visualizations for the Gradients.
Args:
visualize_list (list, optional): List of the functions needed for visualization.
visdom_port (int, optional): Port to log using ``visdom``. The visdom server needs to be
manually started at this port else an error will be thrown and the code will crash.
This is ignored if ``VISDOM_LOGGING`` is ``0``.
log_dir (str, optional): Directory where TensorboardX should store the logs. This is
ignored if ``TENSORBOARD_LOGGING`` is ``0``.
writer (tensorboardX.SummaryWriter, optonal): Send a `SummaryWriter` if you
don't want to start a new SummaryWriter.
"""
def __init__(self, visualize_list, visdom_port=8097, log_dir=None, writer=None):
if visualize_list is None or len(visualize_list) == 0:
raise Exception("Gradient Visualizer requires list of model names")