Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _load(self, model_spec):
version, variables, src_vocab, tgt_vocab = load_model(
self._model_path,
src_vocab=self._src_vocab,
tgt_vocab=self._tgt_vocab)
if isinstance(model_spec, transformer_spec.TransformerSpec):
if version == 2:
set_transformer_spec_v2(model_spec, variables)
else:
set_transformer_spec(model_spec, variables)
else:
raise NotImplementedError()
return src_vocab, tgt_vocab
"""Catalog of model specifications."""
from ctranslate2.specs import transformer_spec
class TransformerBase(transformer_spec.TransformerSpec):
def __init__(self):
super(TransformerBase, self).__init__(6, 8)
class TransformerBig(transformer_spec.TransformerSpec):
def __init__(self):
super(TransformerBig, self).__init__(6, 16)
def _load(self, model_spec):
import torch
checkpoint = torch.load(self._model_path, map_location="cpu")
variables = checkpoint["model"]
variables["generator.weight"] = checkpoint["generator"]["0.weight"]
variables["generator.bias"] = checkpoint["generator"]["0.bias"]
if isinstance(model_spec, transformer_spec.TransformerSpec):
set_transformer_spec(model_spec, variables)
else:
raise NotImplementedError()
vocab = checkpoint["vocab"]
if isinstance(vocab, dict) and "src" in vocab:
return vocab["src"].fields[0][1].vocab, vocab["tgt"].fields[0][1].vocab
else:
# Compatibility with older models.
return vocab[0][1], vocab[1][1]