How to use the ctranslate2.specs.transformer_spec function in ctranslate2

To help you get started, we’ve selected a few ctranslate2 examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github OpenNMT / CTranslate2 / python / ctranslate2 / converters / opennmt_tf.py View on Github external
def _load(self, model_spec):
        version, variables, src_vocab, tgt_vocab = load_model(
            self._model_path,
            src_vocab=self._src_vocab,
            tgt_vocab=self._tgt_vocab)
        if isinstance(model_spec, transformer_spec.TransformerSpec):
            if version == 2:
                set_transformer_spec_v2(model_spec, variables)
            else:
                set_transformer_spec(model_spec, variables)
        else:
            raise NotImplementedError()
        return src_vocab, tgt_vocab
github OpenNMT / CTranslate2 / python / ctranslate2 / specs / catalog.py View on Github external
"""Catalog of model specifications."""

from ctranslate2.specs import transformer_spec


class TransformerBase(transformer_spec.TransformerSpec):
    def __init__(self):
        super(TransformerBase, self).__init__(6, 8)

class TransformerBig(transformer_spec.TransformerSpec):
    def __init__(self):
        super(TransformerBig, self).__init__(6, 16)
github OpenNMT / CTranslate2 / python / ctranslate2 / converters / opennmt_py.py View on Github external
def _load(self, model_spec):
        import torch
        checkpoint = torch.load(self._model_path, map_location="cpu")
        variables = checkpoint["model"]
        variables["generator.weight"] = checkpoint["generator"]["0.weight"]
        variables["generator.bias"] = checkpoint["generator"]["0.bias"]
        if isinstance(model_spec, transformer_spec.TransformerSpec):
            set_transformer_spec(model_spec, variables)
        else:
            raise NotImplementedError()
        vocab = checkpoint["vocab"]
        if isinstance(vocab, dict) and "src" in vocab:
            return vocab["src"].fields[0][1].vocab, vocab["tgt"].fields[0][1].vocab
        else:
            # Compatibility with older models.
            return vocab[0][1], vocab[1][1]