How to use the ctranslate2.converters.converter.Converter function in ctranslate2

To help you get started, we’ve selected a few ctranslate2 examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github OpenNMT / CTranslate2 / python / ctranslate2 / converters / opennmt_py.py View on Github external
from ctranslate2.converters import utils
from ctranslate2.converters.converter import Converter
from ctranslate2.specs import common_spec
from ctranslate2.specs import transformer_spec


class OpenNMTPyConverter(Converter):
    """Converts models generated by OpenNMT-py."""

    def __init__(self, model_path):
        self._model_path = model_path

    def _save_vocabulary(self, vocab, output_path):
        with open(output_path, "wb") as output_file:
            for word in vocab.itos:
                word = word.encode("utf-8")
                output_file.write(word)
                output_file.write(b"\n")

    def _load(self, model_spec):
        import torch
        checkpoint = torch.load(self._model_path, map_location="cpu")
        variables = checkpoint["model"]
github OpenNMT / CTranslate2 / python / ctranslate2 / converters / opennmt_tf.py View on Github external
checkpoint = tf.train.latest_checkpoint(model_path)
        else:
            checkpoint = model_path
        reader = tf.train.load_checkpoint(checkpoint)
        variables = {
            name:reader.get_tensor(name)
            for name in six.iterkeys(reader.get_variable_to_shape_map())}
        if os.path.basename(checkpoint).startswith("ckpt"):
            model_version = 2
            variables = {
                name.replace("/.ATTRIBUTES/VARIABLE_VALUE", ""):value
                for name, value in six.iteritems(variables)}
    return model_version, variables, src_vocab, tgt_vocab


class OpenNMTTFConverter(Converter):
    """Converts models generated by OpenNMT-tf."""

    def __init__(self, model_path, src_vocab=None, tgt_vocab=None):
        self._model_path = model_path
        self._src_vocab = src_vocab
        self._tgt_vocab = tgt_vocab

    def _load(self, model_spec):
        version, variables, src_vocab, tgt_vocab = load_model(
            self._model_path,
            src_vocab=self._src_vocab,
            tgt_vocab=self._tgt_vocab)
        if isinstance(model_spec, transformer_spec.TransformerSpec):
            if version == 2:
                set_transformer_spec_v2(model_spec, variables)
            else: