Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
from ctranslate2.converters import utils
from ctranslate2.converters.converter import Converter
from ctranslate2.specs import common_spec
from ctranslate2.specs import transformer_spec
class OpenNMTPyConverter(Converter):
"""Converts models generated by OpenNMT-py."""
def __init__(self, model_path):
self._model_path = model_path
def _save_vocabulary(self, vocab, output_path):
with open(output_path, "wb") as output_file:
for word in vocab.itos:
word = word.encode("utf-8")
output_file.write(word)
output_file.write(b"\n")
def _load(self, model_spec):
import torch
checkpoint = torch.load(self._model_path, map_location="cpu")
variables = checkpoint["model"]
checkpoint = tf.train.latest_checkpoint(model_path)
else:
checkpoint = model_path
reader = tf.train.load_checkpoint(checkpoint)
variables = {
name:reader.get_tensor(name)
for name in six.iterkeys(reader.get_variable_to_shape_map())}
if os.path.basename(checkpoint).startswith("ckpt"):
model_version = 2
variables = {
name.replace("/.ATTRIBUTES/VARIABLE_VALUE", ""):value
for name, value in six.iteritems(variables)}
return model_version, variables, src_vocab, tgt_vocab
class OpenNMTTFConverter(Converter):
"""Converts models generated by OpenNMT-tf."""
def __init__(self, model_path, src_vocab=None, tgt_vocab=None):
self._model_path = model_path
self._src_vocab = src_vocab
self._tgt_vocab = tgt_vocab
def _load(self, model_spec):
version, variables, src_vocab, tgt_vocab = load_model(
self._model_path,
src_vocab=self._src_vocab,
tgt_vocab=self._tgt_vocab)
if isinstance(model_spec, transformer_spec.TransformerSpec):
if version == 2:
set_transformer_spec_v2(model_spec, variables)
else: