How to use the fairseq.utils.get_incremental_state function in fairseq

To help you get started, we’ve selected a few fairseq examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github freewym / espresso / fairseq / modules / lightweight_convolution.py View on Github external
def _get_input_buffer(self, incremental_state):
        return utils.get_incremental_state(self, incremental_state, 'input_buffer')
github lancopku / Prime / fairseq / modules / multihead_attention.py View on Github external
def _get_input_buffer(self, incremental_state):
        return utils.get_incremental_state(
            self,
            incremental_state,
            'attn_state',
        ) or {}
github pytorch / translate / pytorch_translate / rnn.py View on Github external
encoder_outs,
            final_hidden,
            final_cell,
            src_lengths,
            src_tokens,
            _,
        ) = encoder_out

        # embed tokens
        x = self.embed_tokens(input_tokens)
        x = F.dropout(x, p=self.dropout_in, training=self.training)
        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # initialize previous states (or get from cache during incremental generation)
        cached_state = utils.get_incremental_state(
            self, incremental_state, "cached_state"
        )
        input_feed = None
        if cached_state is not None:
            prev_hiddens, prev_cells, input_feed = cached_state
        else:
            # first time step, initialize previous states
            init_prev_states = self._init_prev_states(encoder_out)
            prev_hiddens = []
            prev_cells = []

            # init_prev_states may or may not include initial attention context
            for (h, c) in zip(init_prev_states[0::2], init_prev_states[1::2]):
                prev_hiddens.append(h)
                prev_cells.append(c)
            if self.attention.context_dim:
github StillKeepTry / Transformer-PyTorch / Papers / Double Path Networks for Sequence to Sequence Learning / dualpath.py View on Github external
def _get_input_buffer(self, incremental_state, name):
        return utils.get_incremental_state(self, incremental_state, name)
github StillKeepTry / Transformer-PyTorch / fairseq / models / dpn.py View on Github external
def _get_input_buffer(self, incremental_state, name):
        return utils.get_incremental_state(self, incremental_state, name)
github freewym / espresso / fairseq / modules / dynamicconv_layer / dynamicconv_layer.py View on Github external
def _get_input_buffer(self, incremental_state):
        return utils.get_incremental_state(self, incremental_state, 'input_buffer')
github freewym / espresso / espresso / models / speech_lstm.py View on Github external
encoder_outs = encoder_out[0]
            srclen = encoder_outs.size(0)

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
        bsz, seqlen = prev_output_tokens.size()

        # embed tokens
        x = self.embed_tokens(prev_output_tokens)
        x = F.dropout(x, p=self.dropout_in, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # initialize previous states (or get from cache during incremental generation)
        cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
        if cached_state is not None:
            prev_hiddens, prev_cells, input_feed = cached_state
        else:
            num_layers = len(self.layers)
            prev_hiddens = [x.new_zeros(bsz, self.hidden_size) for i in range(num_layers)]
            prev_cells = [x.new_zeros(bsz, self.hidden_size) for i in range(num_layers)]
            input_feed = x.new_zeros(bsz, self.encoder_output_units) \
                if self.attention is not None else None

        if self.attention is not None:
            attn_scores = x.new_zeros(srclen, seqlen, bsz)
        outs = []
        for j in range(seqlen):
            # input feeding: concatenate context vector from previous time step
            input = torch.cat((x[j, :, :], input_feed), dim=1) \
                if input_feed is not None else x[j, :, :]
github freewym / espresso / fairseq / modules / dynamic_convolution.py View on Github external
def _get_input_buffer(self, incremental_state):
        return utils.get_incremental_state(self, incremental_state, 'input_buffer')
github freewym / espresso / espresso / models / speech_fconv.py View on Github external
def masked_copy_incremental_state(self, incremental_state, another_state, mask):
        state = utils.get_incremental_state(self, incremental_state, 'encoder_out')
        if state is None:
            assert another_state is None
            return

        def mask_copy_state(state, another_state):
            if isinstance(state, list):
                assert isinstance(another_state, list) and len(state) == len(another_state)
                return [
                    mask_copy_state(state_i, another_state_i)
                    for state_i, another_state_i in zip(state, another_state)
                ]
            if state is not None:
                assert state.size(0) == mask.size(0) and another_state is not None and \
                    state.size() == another_state.size()
                for _ in range(1, len(state.size())):
                    mask_unsqueezed = mask.unsqueeze(-1)
github freewym / espresso / fairseq / models / lstm.py View on Github external
def reorder_incremental_state(self, incremental_state, new_order):
        super().reorder_incremental_state(incremental_state, new_order)
        cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
        if cached_state is None:
            return

        def reorder_state(state):
            if isinstance(state, list):
                return [reorder_state(state_i) for state_i in state]
            return state.index_select(0, new_order)

        new_state = tuple(map(reorder_state, cached_state))
        utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)