How to use the espnet.nets.pytorch_backend.nets_utils.to_device function in espnet

To help you get started, we’ve selected a few espnet examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github espnet / espnet / espnet / nets / pytorch_backend / rnn / decoders_transducer.py View on Github external
if rnnlm:
            kept_hyps = [{'score': 0.0, 'yseq': [self.blank], 'z_prev': z_list,
                          'c_prev': c_list, 'lm_state': None}]
        else:
            kept_hyps = [{'score': 0.0, 'yseq': [self.blank], 'z_prev': z_list,
                          'c_prev': c_list}]

        for i, hi in enumerate(h):
            hyps = kept_hyps
            kept_hyps = []

            while True:
                new_hyp = max(hyps, key=lambda x: x['score'])
                hyps.remove(new_hyp)

                vy = to_device(self, torch.full((1, 1), new_hyp['yseq'][-1], dtype=torch.long))
                ey = self.dropout_embed(self.embed(vy))

                y, (z_list, c_list) = self.rnn_forward(ey[0], (new_hyp['z_prev'],
                                                               new_hyp['c_prev']))

                ytu = F.log_softmax(self.joint(hi, y[0]), dim=0)

                if rnnlm:
                    rnnlm_state, rnnlm_scores = rnnlm.predict(new_hyp['lm_state'], vy[0])

                for k in six.moves.range(self.odim):
                    beam_hyp = {'score': new_hyp['score'] + float(ytu[k]),
                                'yseq': new_hyp['yseq'][:],
                                'z_prev': new_hyp['z_prev'],
                                'c_prev': new_hyp['c_prev']}
                    if rnnlm:
github espnet / espnet / espnet / lm / pytorch_backend / extlm.py View on Github external
w = to_device(self, torch.LongTensor([node[1]]))
                else:  # this node is not a word end, which means 
                    w = self.var_word_unk
                # update wordlm state and log-prob vector
                wlm_state, z_wlm = self.wordlm(wlm_state, w)
                wlm_logprobs = F.log_softmax(z_wlm, dim=1)
                new_node = self.lexroot  # move to the tree root
                clm_logprob = 0.
            elif node is not None and xi in node[0]:  # intra-word transition
                new_node = node[0][xi]
                clm_logprob += log_y[0, xi]
            elif self.open_vocab:  # if no path in the tree, enter open-vocabulary mode
                new_node = None
                clm_logprob += log_y[0, xi]
            else:  # if open_vocab flag is disabled, return 0 probabilities
                log_y = to_device(self, torch.full((1, self.subword_dict_size), self.logzero))
                return (clm_state, wlm_state, wlm_logprobs, None, log_y, 0.), log_y

            clm_state, z_clm = self.subwordlm(clm_state, x)
            log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight

        # apply word-level probabilies for  and  labels
        if xi != self.space:
            if new_node is not None and new_node[1] >= 0:  # if new node is word end
                wlm_logprob = wlm_logprobs[:, new_node[1]] - clm_logprob
            else:
                wlm_logprob = wlm_logprobs[:, self.word_unk] + self.log_oov_penalty
            log_y[:, self.space] = wlm_logprob
            log_y[:, self.eos] = wlm_logprob
        else:
            log_y[:, self.space] = self.logzero
            log_y[:, self.eos] = self.logzero
github espnet / espnet / espnet / nets / pytorch_backend / mt_decoders.py View on Github external
max_hlen = int(max(hlens))
        if recog_args.maxlenratio == 0:
            maxlen = max_hlen
        else:
            maxlen = max(1, int(recog_args.maxlenratio * max_hlen))
        minlen = int(recog_args.minlenratio * max_hlen)
        logging.info('max output length: ' + str(maxlen))
        logging.info('min output length: ' + str(minlen))

        # initialization
        c_prev = [to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
        z_prev = [to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
        c_list = [to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
        z_list = [to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
        vscores = to_device(self, torch.zeros(batch, beam))

        a_prev = None
        rnnlm_prev = None

        self.att.reset()  # reset pre-computation of h

        # yseq = [[self.sos] for _ in six.moves.range(n_bb)]
        if recog_args.sos:
            logging.info('sos index: ' + str(char_list.index(recog_args.sos)))
            logging.info('sos mark: ' + recog_args.sos)
            yseq = [[char_list.index(recog_args.sos)] for _ in six.moves.range(n_bb)]
        else:
            logging.info('sos index: ' + str(self.sos))
            logging.info('sos mark: ' + char_list[self.sos])
            yseq = [[self.sos] for _ in six.moves.range(n_bb)]
        accum_odim_ids = [self.sos for _ in six.moves.range(n_bb)]
github espnet / espnet / espnet / nets / pytorch_backend / e2e_asr_transducer.py View on Github external
x (ndarray): input acoustic feature (T, D)
            recog_args (namespace): argument Namespace containing options
            char_list (list): list of characters
            rnnlm (torch.nn.Module): language model module

        Returns:
           y (list): n-best decoding results

        """
        prev = self.training
        self.eval()
        ilens = [x.shape[0]]

        # subsample frame
        x = x[::self.subsample[0], :]
        h = to_device(self, to_torch_tensor(x).float())
        # make a utt list (1) to use the same interface for encoder
        hs = h.contiguous().unsqueeze(0)

        # 0. Frontend
        if self.frontend is not None:
            enhanced, hlens, mask = self.frontend(hs, ilens)
            hs, hlens = self.feature_transform(enhanced, hlens)
        else:
            hs, hlens = hs, ilens

        # 1. Encoder
        h, _, _ = self.enc(hs, hlens)

        # 2. Decoder
        if recog_args.beam_size == 1:
            y = self.dec.recognize(h[0], recog_args)
github espnet / espnet / espnet / nets / pytorch_backend / rnn / attentions.py View on Github external
self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)]

        if dec_z is None:
            dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
        else:
            dec_z = dec_z.view(batch, self.dunits)

        c = []
        w = []
        for h in six.moves.range(self.aheads):
            e = self.gvec[h](torch.tanh(
                self.pre_compute_k[h] + self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k))).squeeze(2)

            # NOTE consider zero padding when compute w.
            if self.mask is None:
                self.mask = to_device(self, make_pad_mask(enc_hs_len))
            e.masked_fill_(self.mask, -float('inf'))
            w += [F.softmax(self.scaling * e, dim=1)]

            # weighted sum over flames
            # utt x hdim
            # NOTE use bmm instead of sum(*)
            c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)]

        # concat all of c
        c = self.mlp_o(torch.cat(c, dim=1))

        return c, w
github espnet / espnet / espnet / nets / pytorch_backend / rnn / decoders_transducer.py View on Github external
ey = torch.cat((eys[:, i, :], att_c), dim=1)
            y, (z_list, c_list) = self.rnn_forward(ey, (z_list, c_list))
            z_all.append(y)

        h_dec = torch.stack(z_all, dim=1)

        h_enc = hs_pad.unsqueeze(2)
        h_dec = h_dec.unsqueeze(1)

        z = self.joint(h_enc, h_dec)
        y = pad_list(ys, self.blank).type(torch.int32)

        z_len = to_device(self, torch.IntTensor(hlens))
        y_len = to_device(self, torch.IntTensor([_y.size(0) for _y in ys]))

        loss = to_device(self, self.rnnt_loss(z, y, z_len, y_len))

        return loss
github espnet / espnet / espnet / lm / pytorch_backend / extlm.py View on Github external
def forward(self, state, x):
        # update state with input label x
        if state is None:  # make initial states and cumlative probability vector
            self.var_word_eos = to_device(self, self.var_word_eos)
            self.var_word_unk = to_device(self, self.var_word_eos)
            self.zero_tensor = to_device(self, self.zero_tensor)
            wlm_state, z_wlm = self.wordlm(None, self.var_word_eos)
            cumsum_probs = torch.cumsum(F.softmax(z_wlm, dim=1), dim=1)
            new_node = self.lexroot
            xi = self.space
        else:
            wlm_state, cumsum_probs, node = state
            xi = int(x)
            if xi == self.space:  # inter-word transition
                if node is not None and node[1] >= 0:  # check if the node is word end
                    w = to_device(self, torch.LongTensor([node[1]]))
                else:  # this node is not a word end, which means 
                    w = self.var_word_unk
                # update wordlm state and cumlative probability vector
                wlm_state, z_wlm = self.wordlm(wlm_state, w)
github espnet / espnet / espnet / nets / pytorch_backend / e2e_mt.py View on Github external
:param ndarray x: input source text feature (B, T, D)
        :param Namespace trans_args: argument Namespace containing options
        :param list char_list: list of characters
        :param torch.nn.Module rnnlm: language model module
        :return: N-best decoding results
        :rtype: list
        """
        prev = self.training
        self.eval()

        # 1. encoder
        # make a utt list (1) to use the same interface for encoder
        if self.multilingual:
            ilen = [len(x[0][1:])]
            h = to_device(self, torch.from_numpy(np.fromiter(map(int, x[0][1:]), dtype=np.int64)))
        else:
            ilen = [len(x[0])]
            h = to_device(self, torch.from_numpy(np.fromiter(map(int, x[0]), dtype=np.int64)))
        hs, _, _ = self.enc(self.dropout(self.embed(h.unsqueeze(0))), ilen)

        # 2. decoder
        # decode the first utterance
        y = self.dec.recognize_beam(hs[0], None, trans_args, char_list, rnnlm)

        if prev:
            self.train()
        return y
github espnet / espnet / espnet / nets / pytorch_backend / mt_decoders.py View on Github external
for (i, y_hat), y_true in zip(enumerate(ys_hat.detach().cpu().numpy()),
                                          ys_true.detach().cpu().numpy()):
                if i == MAX_DECODER_OUTPUT:
                    break
                idx_hat = np.argmax(y_hat[y_true != self.ignore_id], axis=1)
                idx_true = y_true[y_true != self.ignore_id]
                seq_hat = [self.char_list[int(idx)] for idx in idx_hat]
                seq_true = [self.char_list[int(idx)] for idx in idx_true]
                seq_hat = "".join(seq_hat)
                seq_true = "".join(seq_true)
                logging.info("groundtruth[%d]: " % i + seq_true)
                logging.info("prediction [%d]: " % i + seq_hat)

        if self.labeldist is not None:
            if self.vlabeldist is None:
                self.vlabeldist = to_device(self, torch.from_numpy(self.labeldist))
            loss_reg = - torch.sum((F.log_softmax(y_all, dim=1) * self.vlabeldist).view(-1), dim=0) / len(ys_in)
            self.loss = (1. - self.lsm_weight) * self.loss + self.lsm_weight * loss_reg

        return self.loss, acc, ppl
github espnet / espnet / espnet / nets / pytorch_backend / rnn / decoders_transducer.py View on Github external
att_c, att_w = self.att[0](hs_pad, hlens, self.dropout_dec[0](z_list[0]), att_w)

            ey = torch.cat((eys[:, i, :], att_c), dim=1)
            y, (z_list, c_list) = self.rnn_forward(ey, (z_list, c_list))
            z_all.append(y)

        h_dec = torch.stack(z_all, dim=1)

        h_enc = hs_pad.unsqueeze(2)
        h_dec = h_dec.unsqueeze(1)

        z = self.joint(h_enc, h_dec)
        y = pad_list(ys, self.blank).type(torch.int32)

        z_len = to_device(self, torch.IntTensor(hlens))
        y_len = to_device(self, torch.IntTensor([_y.size(0) for _y in ys]))

        loss = to_device(self, self.rnnt_loss(z, y, z_len, y_len))

        return loss