Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
if rnnlm:
kept_hyps = [{'score': 0.0, 'yseq': [self.blank], 'z_prev': z_list,
'c_prev': c_list, 'lm_state': None}]
else:
kept_hyps = [{'score': 0.0, 'yseq': [self.blank], 'z_prev': z_list,
'c_prev': c_list}]
for i, hi in enumerate(h):
hyps = kept_hyps
kept_hyps = []
while True:
new_hyp = max(hyps, key=lambda x: x['score'])
hyps.remove(new_hyp)
vy = to_device(self, torch.full((1, 1), new_hyp['yseq'][-1], dtype=torch.long))
ey = self.dropout_embed(self.embed(vy))
y, (z_list, c_list) = self.rnn_forward(ey[0], (new_hyp['z_prev'],
new_hyp['c_prev']))
ytu = F.log_softmax(self.joint(hi, y[0]), dim=0)
if rnnlm:
rnnlm_state, rnnlm_scores = rnnlm.predict(new_hyp['lm_state'], vy[0])
for k in six.moves.range(self.odim):
beam_hyp = {'score': new_hyp['score'] + float(ytu[k]),
'yseq': new_hyp['yseq'][:],
'z_prev': new_hyp['z_prev'],
'c_prev': new_hyp['c_prev']}
if rnnlm:
w = to_device(self, torch.LongTensor([node[1]]))
else: # this node is not a word end, which means
w = self.var_word_unk
# update wordlm state and log-prob vector
wlm_state, z_wlm = self.wordlm(wlm_state, w)
wlm_logprobs = F.log_softmax(z_wlm, dim=1)
new_node = self.lexroot # move to the tree root
clm_logprob = 0.
elif node is not None and xi in node[0]: # intra-word transition
new_node = node[0][xi]
clm_logprob += log_y[0, xi]
elif self.open_vocab: # if no path in the tree, enter open-vocabulary mode
new_node = None
clm_logprob += log_y[0, xi]
else: # if open_vocab flag is disabled, return 0 probabilities
log_y = to_device(self, torch.full((1, self.subword_dict_size), self.logzero))
return (clm_state, wlm_state, wlm_logprobs, None, log_y, 0.), log_y
clm_state, z_clm = self.subwordlm(clm_state, x)
log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight
# apply word-level probabilies for and labels
if xi != self.space:
if new_node is not None and new_node[1] >= 0: # if new node is word end
wlm_logprob = wlm_logprobs[:, new_node[1]] - clm_logprob
else:
wlm_logprob = wlm_logprobs[:, self.word_unk] + self.log_oov_penalty
log_y[:, self.space] = wlm_logprob
log_y[:, self.eos] = wlm_logprob
else:
log_y[:, self.space] = self.logzero
log_y[:, self.eos] = self.logzero
max_hlen = int(max(hlens))
if recog_args.maxlenratio == 0:
maxlen = max_hlen
else:
maxlen = max(1, int(recog_args.maxlenratio * max_hlen))
minlen = int(recog_args.minlenratio * max_hlen)
logging.info('max output length: ' + str(maxlen))
logging.info('min output length: ' + str(minlen))
# initialization
c_prev = [to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
z_prev = [to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
c_list = [to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
z_list = [to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
vscores = to_device(self, torch.zeros(batch, beam))
a_prev = None
rnnlm_prev = None
self.att.reset() # reset pre-computation of h
# yseq = [[self.sos] for _ in six.moves.range(n_bb)]
if recog_args.sos:
logging.info('sos index: ' + str(char_list.index(recog_args.sos)))
logging.info('sos mark: ' + recog_args.sos)
yseq = [[char_list.index(recog_args.sos)] for _ in six.moves.range(n_bb)]
else:
logging.info('sos index: ' + str(self.sos))
logging.info('sos mark: ' + char_list[self.sos])
yseq = [[self.sos] for _ in six.moves.range(n_bb)]
accum_odim_ids = [self.sos for _ in six.moves.range(n_bb)]
x (ndarray): input acoustic feature (T, D)
recog_args (namespace): argument Namespace containing options
char_list (list): list of characters
rnnlm (torch.nn.Module): language model module
Returns:
y (list): n-best decoding results
"""
prev = self.training
self.eval()
ilens = [x.shape[0]]
# subsample frame
x = x[::self.subsample[0], :]
h = to_device(self, to_torch_tensor(x).float())
# make a utt list (1) to use the same interface for encoder
hs = h.contiguous().unsqueeze(0)
# 0. Frontend
if self.frontend is not None:
enhanced, hlens, mask = self.frontend(hs, ilens)
hs, hlens = self.feature_transform(enhanced, hlens)
else:
hs, hlens = hs, ilens
# 1. Encoder
h, _, _ = self.enc(hs, hlens)
# 2. Decoder
if recog_args.beam_size == 1:
y = self.dec.recognize(h[0], recog_args)
self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)]
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
c = []
w = []
for h in six.moves.range(self.aheads):
e = self.gvec[h](torch.tanh(
self.pre_compute_k[h] + self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k))).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
w += [F.softmax(self.scaling * e, dim=1)]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)]
# concat all of c
c = self.mlp_o(torch.cat(c, dim=1))
return c, w
ey = torch.cat((eys[:, i, :], att_c), dim=1)
y, (z_list, c_list) = self.rnn_forward(ey, (z_list, c_list))
z_all.append(y)
h_dec = torch.stack(z_all, dim=1)
h_enc = hs_pad.unsqueeze(2)
h_dec = h_dec.unsqueeze(1)
z = self.joint(h_enc, h_dec)
y = pad_list(ys, self.blank).type(torch.int32)
z_len = to_device(self, torch.IntTensor(hlens))
y_len = to_device(self, torch.IntTensor([_y.size(0) for _y in ys]))
loss = to_device(self, self.rnnt_loss(z, y, z_len, y_len))
return loss
def forward(self, state, x):
# update state with input label x
if state is None: # make initial states and cumlative probability vector
self.var_word_eos = to_device(self, self.var_word_eos)
self.var_word_unk = to_device(self, self.var_word_eos)
self.zero_tensor = to_device(self, self.zero_tensor)
wlm_state, z_wlm = self.wordlm(None, self.var_word_eos)
cumsum_probs = torch.cumsum(F.softmax(z_wlm, dim=1), dim=1)
new_node = self.lexroot
xi = self.space
else:
wlm_state, cumsum_probs, node = state
xi = int(x)
if xi == self.space: # inter-word transition
if node is not None and node[1] >= 0: # check if the node is word end
w = to_device(self, torch.LongTensor([node[1]]))
else: # this node is not a word end, which means
w = self.var_word_unk
# update wordlm state and cumlative probability vector
wlm_state, z_wlm = self.wordlm(wlm_state, w)
:param ndarray x: input source text feature (B, T, D)
:param Namespace trans_args: argument Namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
prev = self.training
self.eval()
# 1. encoder
# make a utt list (1) to use the same interface for encoder
if self.multilingual:
ilen = [len(x[0][1:])]
h = to_device(self, torch.from_numpy(np.fromiter(map(int, x[0][1:]), dtype=np.int64)))
else:
ilen = [len(x[0])]
h = to_device(self, torch.from_numpy(np.fromiter(map(int, x[0]), dtype=np.int64)))
hs, _, _ = self.enc(self.dropout(self.embed(h.unsqueeze(0))), ilen)
# 2. decoder
# decode the first utterance
y = self.dec.recognize_beam(hs[0], None, trans_args, char_list, rnnlm)
if prev:
self.train()
return y
for (i, y_hat), y_true in zip(enumerate(ys_hat.detach().cpu().numpy()),
ys_true.detach().cpu().numpy()):
if i == MAX_DECODER_OUTPUT:
break
idx_hat = np.argmax(y_hat[y_true != self.ignore_id], axis=1)
idx_true = y_true[y_true != self.ignore_id]
seq_hat = [self.char_list[int(idx)] for idx in idx_hat]
seq_true = [self.char_list[int(idx)] for idx in idx_true]
seq_hat = "".join(seq_hat)
seq_true = "".join(seq_true)
logging.info("groundtruth[%d]: " % i + seq_true)
logging.info("prediction [%d]: " % i + seq_hat)
if self.labeldist is not None:
if self.vlabeldist is None:
self.vlabeldist = to_device(self, torch.from_numpy(self.labeldist))
loss_reg = - torch.sum((F.log_softmax(y_all, dim=1) * self.vlabeldist).view(-1), dim=0) / len(ys_in)
self.loss = (1. - self.lsm_weight) * self.loss + self.lsm_weight * loss_reg
return self.loss, acc, ppl
att_c, att_w = self.att[0](hs_pad, hlens, self.dropout_dec[0](z_list[0]), att_w)
ey = torch.cat((eys[:, i, :], att_c), dim=1)
y, (z_list, c_list) = self.rnn_forward(ey, (z_list, c_list))
z_all.append(y)
h_dec = torch.stack(z_all, dim=1)
h_enc = hs_pad.unsqueeze(2)
h_dec = h_dec.unsqueeze(1)
z = self.joint(h_enc, h_dec)
y = pad_list(ys, self.blank).type(torch.int32)
z_len = to_device(self, torch.IntTensor(hlens))
y_len = to_device(self, torch.IntTensor([_y.size(0) for _y in ys]))
loss = to_device(self, self.rnnt_loss(z, y, z_len, y_len))
return loss