Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# callback which prints training loss and perplexity once in a while
train_callback = nemo.core.SimpleLossLoggerCallback(
tensors=[train_loss],
print_func=lambda x: print("Loss: {:.3f}".format(x[0].item())),
get_tb_values=lambda x: [["loss", x[0]]],
tb_writer=nf.tb_writer)
eval_callback = nemo.core.EvaluatorCallback(
eval_tensors=eval_tensors,
user_iter_callback=eval_iter_callback,
user_epochs_done_callback=eval_epochs_done_callback,
eval_step=steps_per_epoch,
tb_writer=nf.tb_writer)
ckpt_callback = nemo.core.CheckpointCallback(folder=nf.checkpoint_dir,
epoch_freq=args.save_epoch_freq,
step_freq=args.save_step_freq)
# define learning rate decay policy
lr_policy_fn = get_lr_policy(args.lr_policy,
total_steps=args.num_epochs * steps_per_epoch,
warmup_ratio=args.lr_warmup_proportion)
config_path = f'{nf.checkpoint_dir}/bert-config.json'
if not os.path.exists(config_path):
bert_model.config.to_json_file(config_path)
# define and launch training algorithm (optimizer)
nf.train(tensors_to_optimize=[train_loss],
lr_policy=lr_policy_fn,
callbacks=[train_callback, eval_callback, ckpt_callback],
gate_target=gate_target,
target_len=spec_target_len,
seq_len=audio_len)
# Callbacks needed to print info to console and Tensorboard
train_callback = nemo.core.SimpleLossLoggerCallback(
tensors=[loss_t, spec_target, mel_postnet, gate, gate_target,
alignments],
print_func=lambda x: print(f"Loss: {x[0].data}"),
log_to_tb_func=partial(
tacotron2_log_to_tb_func, log_images=True,
log_images_freq=log_freq),
tb_writer=neural_factory.tb_writer,
)
chpt_callback = nemo.core.CheckpointCallback(
folder=neural_factory.checkpoint_dir,
step_freq=checkpoint_save_freq)
callbacks = [train_callback, chpt_callback]
return loss_t, callbacks, steps_per_epoch
targets=transcript_t,
input_length=encoded_len_t,
target_length=transcript_len_t)
# Callbacks needed to print info to console and Tensorboard
train_callback = nemo.core.SimpleLossLoggerCallback(
tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
print_func=partial(
monitor_asr_train_progress,
labels=vocab,
logger=logger),
get_tb_values=lambda x: [("loss", x[0])],
tb_writer=neural_factory.tb_writer,
)
chpt_callback = nemo.core.CheckpointCallback(
folder=neural_factory.checkpoint_dir,
load_from_folder=args.load_dir,
step_freq=args.checkpoint_save_freq)
callbacks = [train_callback, chpt_callback]
# assemble eval DAGs
for i, eval_dl in enumerate(data_layers_eval):
audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e = \
eval_dl()
processed_signal_e, p_length_e = data_preprocessor(
input_signal=audio_signal_e,
length=a_sig_length_e)
encoded_e, encoded_len_e = jasper_encoder(
audio_signal=processed_signal_e,
length=p_length_e)
tensors=[train_loss],
step_freq=100,
print_func=lambda x: str(x[0].item()),
get_tb_values=lambda x: [["loss", x[0]]],
tb_writer=nf.tb_writer)
# callback which calculates evaluation loss
eval_callback = nemo.core.EvaluatorCallback(
eval_tensors=[eval_loss],
user_iter_callback=eval_iter_callback,
user_epochs_done_callback=eval_epochs_done_callback,
eval_step=args.eval_freq,
tb_writer=nf.tb_writer)
# callback which saves checkpoints once in a while
callback_ckpt = nemo.core.CheckpointCallback(
folder=nf.checkpoint_dir,
epoch_freq=args.save_epoch_freq,
step_freq=args.save_step_freq,
checkpoints_to_keep=-1)
# define learning rate decay policy
lr_policy_fn = CosineAnnealing(args.max_steps, warmup_steps=args.warmup_steps)
# define and launch training algorithm (optimizer)
max_num_epochs = 0 if args.interactive else args.num_epochs
callbacks = [callback_ckpt]
if not args.interactive:
callbacks.extend([train_callback, eval_callback])
neural_factory.logger.info(f"Grad Penalty: {grad_p}")
def get_tb_name_value(tensors):
g_loss = tensors[0]
return [["G_LOSS", g_loss]]
logger_callback = nemo.core.SimpleLossLoggerCallback(
tensors=[generator_loss, interpolated_loss, real_loss, grad_penalty],
print_func=print_losses,
get_tb_values=get_tb_name_value,
step_freq=500,
tb_writer=neural_factory.tb_writer)
checkpoint_callback = nemo.core.CheckpointCallback(
folder=neural_factory.checkpoint_dir, step_freq=1000)
tensors_to_optimize = [
(optimizer_D, losses_D),
(optimizer_D, losses_D),
(optimizer_D, losses_D),
(optimizer_G, losses_G),
]
neural_factory.train(
tensors_to_optimize=tensors_to_optimize,
callbacks=[eval_callback, logger_callback, checkpoint_callback],
optimization_params={"num_epochs": args.num_epochs})
# Create trainer and execute training action
train_callback = nemo.core.SimpleLossLoggerCallback(
tensors=train_tensors,
print_func=lambda x: print("Loss: {:.3f}".format(x[0].item())),
get_tb_values=lambda x: [["loss", x[0]]],
tb_writer=nf.tb_writer)
eval_callback = nemo.core.EvaluatorCallback(
eval_tensors=eval_tensors,
user_iter_callback=lambda x, y: eval_iter_callback(x, y),
user_epochs_done_callback=lambda x:
eval_epochs_done_callback(x, label_ids, f'{nf.work_dir}/graphs'),
tb_writer=nf.tb_writer,
eval_step=steps_per_epoch)
ckpt_callback = nemo.core.CheckpointCallback(
folder=nf.checkpoint_dir,
epoch_freq=args.save_epoch_freq,
step_freq=args.save_step_freq)
lr_policy_fn = get_lr_policy(args.lr_policy,
total_steps=args.num_epochs * steps_per_epoch,
warmup_ratio=args.lr_warmup_proportion)
nf.train(tensors_to_optimize=[train_loss],
callbacks=[train_callback, eval_callback, ckpt_callback],
lr_policy=lr_policy_fn,
optimizer=args.optimizer_kind,
optimization_params={"num_epochs": args.num_epochs,
"lr": args.lr})
target_length=transcript_len_t)
# create train callbacks
train_callback = nemo.core.SimpleLossLoggerCallback(
tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
print_func=partial(
monitor_asr_train_progress,
labels=vocab,
logger=neural_factory.logger),
get_tb_values=lambda x: [["loss", x[0]]],
tb_writer=neural_factory.tb_writer)
callbacks = [train_callback]
if args.checkpoint_dir or args.load_dir:
chpt_callback = nemo.core.CheckpointCallback(
folder=args.checkpoint_dir,
load_from_folder=args.load_dir,
step_freq=args.checkpoint_save_freq)
callbacks.append(chpt_callback)
# assemble eval DAGs
for i, eval_dl in enumerate(data_layers_eval):
audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e = \
eval_dl()
processed_signal_e, p_length_e = data_preprocessor(
input_signal=audio_signal_e,
length=a_sig_length_e)
encoded_e, encoded_len_e = encoder(
audio_signal=processed_signal_e,
target_length=transcript_len_t)
# Callbacks needed to print info to console and Tensorboard
train_callback = nemo.core.SimpleLossLoggerCallback(
tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
print_func=partial(
monitor_asr_train_progress,
labels=vocab,
eval_metric='CER',
logger=logger),
step_freq=args.train_eval_freq,
get_tb_values=lambda x: [("loss", x[0])],
tb_writer=neural_factory.tb_writer,
)
chpt_callback = nemo.core.CheckpointCallback(
folder=neural_factory.checkpoint_dir,
step_freq=args.checkpoint_save_freq)
callbacks = [train_callback, chpt_callback]
# assemble eval DAGs
for i, eval_dl in enumerate(data_layers_eval):
audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e = \
eval_dl()
processed_signal_e, p_length_e = data_preprocessor(
input_signal=audio_signal_e,
length=a_sig_length_e)
encoded_e, encoded_len_e = jasper_encoder(
audio_signal=processed_signal_e,
length=p_length_e)
log_probs_e = jasper_decoder(encoder_output=encoded_e)
loss_t = waveglow_loss(
z=z,
log_s_list=log_s_list,
log_det_W_list=log_det_W_list)
# Callbacks needed to print info to console and Tensorboard
train_callback = nemo.core.SimpleLossLoggerCallback(
tensors=[loss_t, z, spec_target, spec_target_len],
print_func=lambda x: print(f"Loss: {x[0].data}"),
log_to_tb_func=partial(
waveglow_log_to_tb_func,
log_images=False),
tb_writer=neural_factory.tb_writer,
)
chpt_callback = nemo.core.CheckpointCallback(
folder=neural_factory.checkpoint_dir,
step_freq=checkpoint_save_freq)
callbacks = [train_callback, chpt_callback]
return loss_t, callbacks, steps_per_epoch