Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def lexicon_iterator(path: str,
vocab_source: Dict[str, int],
vocab_target: Dict[str, int]) -> Generator[Tuple[int, int, float], None, None]:
"""
Yields lines from a translation table of format: src, trg, logprob.
:param path: Path to lexicon file.
:param vocab_source: Source vocabulary.
:param vocab_target: Target vocabulary.
:return: Generator returning tuples (src_id, trg_id, prob).
"""
assert C.UNK_SYMBOL in vocab_source
assert C.UNK_SYMBOL in vocab_target
src_unk_id = vocab_source[C.UNK_SYMBOL]
trg_unk_id = vocab_target[C.UNK_SYMBOL]
with smart_open(path) as fin:
for line in fin:
src, trg, logprob = line.rstrip("\n").split("\t")
prob = np.exp(float(logprob))
src_id = vocab_source.get(src, src_unk_id)
trg_id = vocab_target.get(trg, trg_unk_id)
yield src_id, trg_id, prob
def parse(path):
if path is None or path == "-":
return sys.stdin
else:
return data_io.smart_open(path)
translator = inference.Translator(context=self.context,
ensemble_mode=self.ensemble_mode,
bucket_source_width=self.bucket_width_source,
length_penalty=inference.LengthPenalty(self.length_penalty_alpha, self.length_penalty_beta),
brevity_penalty=inference.BrevityPenalty(weight=0.0),
beam_prune=0.0,
beam_search_stop='all',
nbest_size=self.nbest_size,
models=models,
source_vocabs=source_vocabs,
target_vocab=target_vocab,
restrict_lexicon=None,
store_beam=False)
trans_wall_time = 0.0
translations = []
with data_io.smart_open(output_name, 'w') as output:
handler = sockeye.output_handler.StringOutputHandler(output)
tic = time.time()
trans_inputs = [] # type: List[inference.TranslatorInput]
for i, inputs in enumerate(self.inputs_sentences):
trans_inputs.append(sockeye.inference.make_input_from_multiple_strings(i, inputs))
trans_outputs = translator.translate(trans_inputs)
trans_wall_time = time.time() - tic
for trans_input, trans_output in zip(trans_inputs, trans_outputs):
handler.handle(trans_input, trans_output)
translations.append(trans_output.translation)
avg_time = trans_wall_time / len(self.target_sentences)
# TODO(fhieber): eventually add more metrics (METEOR etc.)
return {C.BLEU_VAL: evaluate.raw_corpus_bleu(hypotheses=translations,
references=self.target_sentences,
offset=0.01),
yield inference.make_input_from_json_string(sentence_id=sentence_id,
json_string=line,
translator=translator)
else:
yield inference.make_input_from_factored_string(sentence_id=sentence_id,
factored_string=line,
translator=translator)
else:
input_factors = [] if input_factors is None else input_factors
inputs = [input_file] + input_factors
if not input_is_json:
check_condition(translator.num_source_factors == len(inputs),
"Model(s) require %d factors, but %d given (through --input and --input-factors)." % (
translator.num_source_factors, len(inputs)))
with ExitStack() as exit_stack:
streams = [exit_stack.enter_context(data_io.smart_open(i)) for i in inputs] # pylint: disable=no-member
for sentence_id, inputs in enumerate(zip(*streams), 1):
if input_is_json:
yield inference.make_input_from_json_string(sentence_id=sentence_id,
json_string=inputs[0],
translator=translator)
else:
yield inference.make_input_from_multiple_strings(sentence_id=sentence_id, strings=list(inputs))
def get_output_handler(output_type: str,
output_fname: Optional[str] = None,
sure_align_threshold: float = 1.0) -> 'OutputHandler':
"""
:param output_type: Type of output handler.
:param output_fname: Output filename. If none sys.stdout is used.
:param sure_align_threshold: Threshold to consider an alignment link as 'sure'.
:raises: ValueError for unknown output_type.
:return: Output handler.
"""
output_stream = sys.stdout if output_fname is None else data_io.smart_open(output_fname, mode='w')
if output_type == C.OUTPUT_HANDLER_TRANSLATION:
return StringOutputHandler(output_stream)
elif output_type == C.OUTPUT_HANDLER_SCORE:
return ScoreOutputHandler(output_stream)
elif output_type == C.OUTPUT_HANDLER_PAIR_WITH_SCORE:
return PairWithScoreOutputHandler(output_stream)
elif output_type == C.OUTPUT_HANDLER_TRANSLATION_WITH_SCORE:
return StringWithScoreOutputHandler(output_stream)
elif output_type == C.OUTPUT_HANDLER_TRANSLATION_WITH_ALIGNMENTS:
return StringWithAlignmentsOutputHandler(output_stream, sure_align_threshold)
elif output_type == C.OUTPUT_HANDLER_TRANSLATION_WITH_ALIGNMENT_MATRIX:
return StringWithAlignmentMatrixOutputHandler(output_stream)
elif output_type == C.OUTPUT_HANDLER_BENCHMARK:
return BenchmarkOutputHandler(output_stream)
elif output_type == C.OUTPUT_HANDLER_ALIGN_PLOT:
return AlignPlotHandler(plot_prefix="align" if output_fname is None else output_fname)
random_seed: int = 42) -> None:
self.context = context
self.max_input_len = max_input_len
self.max_output_length_num_stds = max_output_length_num_stds
self.ensemble_mode = ensemble_mode
self.beam_size = beam_size
self.nbest_size = nbest_size
self.batch_size = batch_size
self.bucket_width_source = bucket_width_source
self.length_penalty_alpha = length_penalty_alpha
self.length_penalty_beta = length_penalty_beta
self.softmax_temperature = softmax_temperature
self.model = model
with ExitStack() as exit_stack:
inputs_fins = [exit_stack.enter_context(data_io.smart_open(f)) for f in inputs] # pylint: disable=no-member
references_fin = exit_stack.enter_context(data_io.smart_open(references)) # pylint: disable=no-member
inputs_sentences = [f.readlines() for f in inputs_fins]
target_sentences = references_fin.readlines()
utils.check_condition(all(len(l) == len(target_sentences) for l in inputs_sentences),
"Sentences differ in length")
if sample_size <= 0:
sample_size = len(inputs_sentences[0])
if sample_size < len(inputs_sentences[0]):
self.target_sentences, *self.inputs_sentences = parallel_subsample(
[target_sentences] + inputs_sentences, sample_size, random_seed)
else:
self.inputs_sentences, self.target_sentences = inputs_sentences, target_sentences