Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _forward(self, xs, ilens, ys=None, olens=None, spembs=None, is_inference=False):
# forward encoder
x_masks = self._source_mask(ilens)
hs, _ = self.encoder(xs, x_masks) # (B, Tmax, adim)
# integrate speaker embedding
if self.spk_embed_dim is not None:
hs = self._integrate_with_spk_embed(hs, spembs)
# forward duration predictor and length regulator
d_masks = make_pad_mask(ilens).to(xs.device)
if is_inference:
d_outs = self.duration_predictor.inference(hs, d_masks) # (B, Tmax)
hs = self.length_regulator(hs, d_outs, ilens) # (B, Lmax, adim)
else:
with torch.no_grad():
ds = self.duration_calculator(xs, ilens, ys, olens, spembs) # (B, Tmax)
d_outs = self.duration_predictor(hs, d_masks) # (B, Tmax)
hs = self.length_regulator(hs, ds, ilens) # (B, Lmax, adim)
# forward decoder
if olens is not None:
if self.reduction_factor > 1:
olens_in = olens.new([olen // self.reduction_factor for olen in olens])
else:
olens_in = olens
h_masks = self._source_mask(olens_in)
"""
text = text[:, : text_lengths.max()] # for data-parallel
speech = speech[:, : speech_lengths.max()] # for data-parallel
batch_size = text.size(0)
# Add eos at the last of sequence
xs = F.pad(text, [0, 1], "constant", 0.0)
for i, l in enumerate(text_lengths):
xs[i, l] = self.eos
ilens = text_lengths + 1
ys = speech
olens = speech_lengths
# make labels for stop prediction
labels = make_pad_mask(olens).to(ys.device, ys.dtype)
# calculate tacotron2 outputs
hs, hlens = self.enc(xs, ilens)
if self.spk_embed_dim is not None:
spembs = (
F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1)
)
hs = torch.cat([hs, spembs], dim=-1)
after_outs, before_outs, logits, att_ws = self.dec(hs, hlens, ys)
# modify mod part of groundtruth
if self.reduction_factor > 1:
olens = olens.new(
[olen - olen % self.reduction_factor for olen in olens]
)
max_out = max(olens)
t = F.pad(text, [0, 1], "constant", self.ignore_id)
for i, l in enumerate(text_lengths):
t[i, l] = self.sos
x_lengths = text_lengths + 1
# 2. Forward Language model
# x: (Batch, Length) -> y: (Batch, Length, NVocab)
y, _ = self.lm(x, None)
# 3. Calc negative log likelihood
# nll: (BxL,)
nll = F.cross_entropy(
y.view(-1, y.shape[-1]), t.view(-1), reduction="none"
)
# nll: (BxL,) -> (BxL,)
nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0)
# nll: (BxL,) -> (B, L)
nll = nll.view(batch_size, -1)
return nll, x_lengths
:rtype: torch.Tensor
:return: attention loss value
:rtype: torch.Tensor
:return: accuracy in attention decoder
:rtype: float
"""
# 0. Extract target language ID
# src_lang_ids = None
tgt_lang_ids = None
if self.multilingual:
tgt_lang_ids = ys_pad[:, 0:1]
ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining
# 1. forward encoder
xs_pad = xs_pad[:, :max(ilens)] # for data parallel
src_mask = (~make_pad_mask(ilens.tolist())).to(xs_pad.device).unsqueeze(-2)
hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
self.hs_pad = hs_pad
# 2. forward decoder
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
# replace with target language ID
if self.replace_sos:
ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1)
ys_mask = target_mask(ys_in_pad, self.ignore_id)
pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
self.pred_pad = pred_pad
pred_pad_asr, pred_pad_mt = None, None
# 3. compute attention loss
loss_asr, loss_mt = 0.0, 0.0
loss_att = self.criterion(pred_pad, ys_out_pad)
self.mean = self.mean.to(x.device, x.dtype)
self.std = self.std.to(x.device, x.dtype)
mask = make_pad_mask(ilens, x, 1)
if x.is_leaf and x.requires_grad:
x = x.masked_fill(mask, 0.0)
else:
x.masked_fill_(mask, 0.0)
if norm_vars:
x *= self.std
# feat: (B, T, D)
if norm_means:
x += self.mean
x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
return x, ilens
ilens: (B,)
norm_means:
norm_vars:
eps:
"""
if ilens is None:
ilens = x.new_full([x.size(0)], x.size(1))
ilens_ = ilens.to(x.device, x.dtype).view(
-1, *[1 for _ in range(x.dim() - 1)]
)
# Zero padding
if x.is_leaf and x.requires_grad:
x = x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
else:
x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
# mean: (B, 1, D)
mean = x.sum(dim=1, keepdim=True) / ilens_
if norm_means:
x -= mean
if norm_vars:
var = x.pow(2).sum(dim=1, keepdim=True) / ilens_
std = torch.clamp(var.sqrt(), min=eps)
x = x / std.sqrt()
return x, ilens
else:
if norm_vars:
y = x - mean
y.masked_fill_(make_pad_mask(ilens, y, 1), 0.0)
var = y.pow(2).sum(dim=1, keepdim=True) / ilens_
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor prev_state: batch of previous encoder hidden states (?, ...)
:return: batch of hidden state sequences (B, Tmax, eprojs)
:rtype: torch.Tensor
"""
if prev_states is None:
prev_states = [None] * len(self.enc)
assert len(prev_states) == len(self.enc)
current_states = []
for module, prev_state in zip(self.enc, prev_states):
xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state)
current_states.append(states)
# make mask to remove bias value in padded part
mask = to_device(self, make_pad_mask(ilens).unsqueeze(-1))
return xs_pad.masked_fill(mask, 0.0), ilens, current_states
def forward(self, x: torch.Tensor, ilens: torch.LongTensor) \
-> Tuple[torch.Tensor, torch.LongTensor]:
# feat: (B, T, D)
if self.norm_means:
x += self.bias.type_as(x)
x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
if self.norm_vars:
x *= self.scale.type_as(x)
return x, ilens