Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
z_all = torch.stack(z_all, dim=1).view(batch * olength, -1)
# compute loss
y_all = self.output(z_all)
if LooseVersion(torch.__version__) < LooseVersion('1.0'):
reduction_str = 'elementwise_mean'
else:
reduction_str = 'mean'
self.loss = F.cross_entropy(y_all, ys_out_pad.view(-1),
ignore_index=self.ignore_id,
reduction=reduction_str)
# compute perplexity
ppl = math.exp(self.loss.item())
# -1: eos, which is removed in the loss computation
self.loss *= (np.mean([len(x) for x in ys_in]) - 1)
acc = th_accuracy(y_all, ys_out_pad, ignore_label=self.ignore_id)
logging.info('att loss:' + ''.join(str(self.loss.item()).split('\n')))
# show predicted character sequence for debug
if self.verbose > 0 and self.char_list is not None:
ys_hat = y_all.view(batch, olength, -1)
ys_true = ys_out_pad
for (i, y_hat), y_true in zip(enumerate(ys_hat.detach().cpu().numpy()),
ys_true.detach().cpu().numpy()):
if i == MAX_DECODER_OUTPUT:
break
idx_hat = np.argmax(y_hat[y_true != self.ignore_id], axis=1)
idx_true = y_true[y_true != self.ignore_id]
seq_hat = [self.char_list[int(idx)] for idx in idx_hat]
seq_true = [self.char_list[int(idx)] for idx in idx_true]
seq_hat = "".join(seq_hat)
seq_true = "".join(seq_true)
"""
# 1. forward encoder
xs_pad = xs_pad[:, :max(ilens)] # for data parallel
src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2)
hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
self.hs_pad = hs_pad
# 2. forward decoder
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_mask = target_mask(ys_in_pad, self.ignore_id)
pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
self.pred_pad = pred_pad
# 3. compute attention loss
loss_att = self.criterion(pred_pad, ys_out_pad)
self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad,
ignore_label=self.ignore_id)
# TODO(karita) show predicted text
# TODO(karita) calculate these stats
cer_ctc = None
if self.mtlalpha == 0.0:
loss_ctc = None
else:
batch_size = xs_pad.size(0)
hs_len = hs_mask.view(batch_size, -1).sum(1)
loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad)
if self.error_calculator is not None:
ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
# 5. compute cer/wer
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
ys_in_pad, ys_out_pad = add_sos_eos(
ys_pad, self.sos, self.eos, self.ignore_id
)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, _ = self.decoder(
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
)
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad)
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_out_pad,
ignore_label=self.ignore_id,
)
# Compute cer/wer using attention-decoder
if self.error_calculator is None:
cer_att, wer_att = None, None
else:
ys_hat = decoder_out.argmax(dim=-1)
cer_att, wer_att = self.error_calculator(
ys_hat.cpu(), ys_pad.cpu()
)
return loss_att, acc_att, cer_att, wer_att
if self.mt_weight > 0:
# forward MT encoder
ilens_mt = torch.sum(ys_pad_src != self.ignore_id, dim=1).cpu().numpy()
# NOTE: ys_pad_src is padded with -1
ys_src = [y[y != self.ignore_id] for y in ys_pad_src] # parse padded ys_src
ys_zero_pad_src = pad_list(ys_src, self.pad) # re-pad with zero
ys_zero_pad_src = ys_zero_pad_src[:, :max(ilens_mt)] # for data parallel
src_mask_mt = (~make_pad_mask(ilens_mt.tolist())).to(ys_zero_pad_src.device).unsqueeze(-2)
# ys_zero_pad_src, ys_pad = self.target_forcing(ys_zero_pad_src, ys_pad)
hs_pad_mt, hs_mask_mt = self.encoder_mt(ys_zero_pad_src, src_mask_mt)
# forward MT decoder
pred_pad_mt, _ = self.decoder(ys_in_pad, ys_mask, hs_pad_mt, hs_mask_mt)
# compute loss
loss_mt = self.criterion(pred_pad_mt, ys_out_pad)
self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad,
ignore_label=self.ignore_id)
if pred_pad_asr is not None:
self.acc_asr = th_accuracy(pred_pad_asr.view(-1, self.odim), ys_out_pad_asr,
ignore_label=self.ignore_id)
else:
self.acc_asr = 0.0
if pred_pad_mt is not None:
self.acc_mt = th_accuracy(pred_pad_mt.view(-1, self.odim), ys_out_pad,
ignore_label=self.ignore_id)
else:
self.acc_mt = 0.0
# TODO(karita) show predicted text
# TODO(karita) calculate these stats
cer_ctc = None
if self.mtlalpha == 0.0 or self.asr_weight == 0:
# ys_zero_pad_src, ys_pad = self.target_forcing(ys_zero_pad_src, ys_pad)
hs_pad_mt, hs_mask_mt = self.encoder_mt(ys_zero_pad_src, src_mask_mt)
# forward MT decoder
pred_pad_mt, _ = self.decoder(ys_in_pad, ys_mask, hs_pad_mt, hs_mask_mt)
# compute loss
loss_mt = self.criterion(pred_pad_mt, ys_out_pad)
self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad,
ignore_label=self.ignore_id)
if pred_pad_asr is not None:
self.acc_asr = th_accuracy(pred_pad_asr.view(-1, self.odim), ys_out_pad_asr,
ignore_label=self.ignore_id)
else:
self.acc_asr = 0.0
if pred_pad_mt is not None:
self.acc_mt = th_accuracy(pred_pad_mt.view(-1, self.odim), ys_out_pad,
ignore_label=self.ignore_id)
else:
self.acc_mt = 0.0
# TODO(karita) show predicted text
# TODO(karita) calculate these stats
cer_ctc = None
if self.mtlalpha == 0.0 or self.asr_weight == 0:
loss_ctc = 0.0
else:
batch_size = xs_pad.size(0)
hs_len = hs_mask.view(batch_size, -1).sum(1)
loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad_src)
if self.error_calculator is not None:
ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad_src.cpu(), is_ctc=True)
# NOTE: ys_pad_src is padded with -1
ys_src = [y[y != self.ignore_id] for y in ys_pad_src] # parse padded ys_src
ys_zero_pad_src = pad_list(ys_src, self.pad) # re-pad with zero
ys_zero_pad_src = ys_zero_pad_src[:, :max(ilens_mt)] # for data parallel
src_mask_mt = (~make_pad_mask(ilens_mt.tolist())).to(ys_zero_pad_src.device).unsqueeze(-2)
# ys_zero_pad_src, ys_pad = self.target_forcing(ys_zero_pad_src, ys_pad)
hs_pad_mt, hs_mask_mt = self.encoder_mt(ys_zero_pad_src, src_mask_mt)
# forward MT decoder
pred_pad_mt, _ = self.decoder(ys_in_pad, ys_mask, hs_pad_mt, hs_mask_mt)
# compute loss
loss_mt = self.criterion(pred_pad_mt, ys_out_pad)
self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad,
ignore_label=self.ignore_id)
if pred_pad_asr is not None:
self.acc_asr = th_accuracy(pred_pad_asr.view(-1, self.odim), ys_out_pad_asr,
ignore_label=self.ignore_id)
else:
self.acc_asr = 0.0
if pred_pad_mt is not None:
self.acc_mt = th_accuracy(pred_pad_mt.view(-1, self.odim), ys_out_pad,
ignore_label=self.ignore_id)
else:
self.acc_mt = 0.0
# TODO(karita) show predicted text
# TODO(karita) calculate these stats
cer_ctc = None
if self.mtlalpha == 0.0 or self.asr_weight == 0:
loss_ctc = 0.0
else:
batch_size = xs_pad.size(0)
# 1. forward encoder
xs_pad = xs_pad[:, :max(ilens)] # for data parallel
src_mask = (~make_pad_mask(ilens.tolist())).to(xs_pad.device).unsqueeze(-2)
xs_pad, ys_pad = self.target_forcing(xs_pad, ys_pad)
hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
self.hs_pad = hs_pad
# 2. forward decoder
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_mask = target_mask(ys_in_pad, self.ignore_id)
pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
self.pred_pad = pred_pad
# 3. compute attention loss
loss = self.criterion(pred_pad, ys_out_pad)
self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad,
ignore_label=self.ignore_id)
# TODO(karita) show predicted text
# TODO(karita) calculate these stats
# 5. compute bleu
if self.training or self.error_calculator is None:
bleu = 0.0
else:
ys_hat = pred_pad.argmax(dim=-1)
bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
# copyied from e2e_mt
self.loss = loss
loss_data = float(self.loss)