Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def create_parallel_sentence_iter(source_sentences, target_sentences, max_len, batch_size, batch_by_words):
buckets = sockeye.data_io.define_parallel_buckets(max_len, max_len, 10)
batch_num_devices = 1
eos = 0
pad = 1
unk = 2
bucket_iterator = sockeye.data_io.ParallelBucketSentenceIter(source_sentences,
target_sentences,
buckets,
batch_size,
batch_by_words,
batch_num_devices,
eos, pad, unk)
return bucket_iterator
batch_by_words = args.batch_type == C.BATCH_TYPE_WORD
validation_sources = [args.validation_source] + args.validation_source_factors
validation_sources = [str(os.path.abspath(source)) for source in validation_sources]
either_raw_or_prepared_error_msg = "Either specify a raw training corpus with %s and %s or a preprocessed corpus " \
"with %s." % (C.TRAINING_ARG_SOURCE,
C.TRAINING_ARG_TARGET,
C.TRAINING_ARG_PREPARED_DATA)
if args.prepared_data is not None:
utils.check_condition(args.source is None and args.target is None, either_raw_or_prepared_error_msg)
if not resume_training:
utils.check_condition(args.source_vocab is None and args.target_vocab is None,
"You are using a prepared data folder, which is tied to a vocabulary. "
"To change it you need to rerun data preparation with a different vocabulary.")
train_iter, validation_iter, data_config, source_vocabs, target_vocab = data_io.get_prepared_data_iters(
prepared_data_dir=args.prepared_data,
validation_sources=validation_sources,
validation_target=str(os.path.abspath(args.validation_target)),
shared_vocab=shared_vocab,
batch_size=args.batch_size,
batch_by_words=batch_by_words,
batch_num_devices=batch_num_devices,
fill_up=args.fill_up)
check_condition(len(source_vocabs) == len(args.source_factors_num_embed) + 1,
"Data was prepared with %d source factors, but only provided %d source factor dimensions." % (
len(source_vocabs), len(args.source_factors_num_embed) + 1))
if resume_training:
# resuming training. Making sure the vocabs in the model and in the prepared data match up
model_source_vocabs = vocab.load_source_vocabs(output_folder)
logger.info("Adjusting maximum length to reserve space for a BOS/EOS marker. New maximum length: (%d, %d)",
max_seq_len_source, max_seq_len_target)
source_vocabs, target_vocab = vocab.load_or_create_vocabs(
source_paths=source_paths,
target_path=args.target,
source_vocab_paths=source_vocab_paths,
target_vocab_path=args.target_vocab,
shared_vocab=args.shared_vocab,
num_words_source=num_words_source,
word_min_count_source=word_min_count_source,
num_words_target=num_words_target,
word_min_count_target=word_min_count_target,
pad_to_multiple_of=args.pad_vocab_to_multiple_of)
data_io.prepare_data(source_fnames=source_paths,
target_fname=args.target,
source_vocabs=source_vocabs,
target_vocab=target_vocab,
source_vocab_paths=source_vocab_paths,
target_vocab_path=args.target_vocab,
shared_vocab=args.shared_vocab,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
bucketing=bucketing,
bucket_width=bucket_width,
samples_per_shard=samples_per_shard,
min_num_shards=minimum_num_shards,
output_prefix=output_folder)
:param trans_inputs: List of TranslatorInputs.
:return NDArray of source ids (shape=(batch_size, bucket_key, num_factors)),
bucket key, list of raw constraint lists, and list of phrases to avoid,
and an NDArray of maximum output lengths.
"""
batch_size = len(trans_inputs)
bucket_key = data_io.get_bucket(max(len(inp.tokens) for inp in trans_inputs), self.buckets_source)
source = mx.nd.zeros((batch_size, bucket_key, self.num_source_factors), ctx=self.context)
raw_constraints = [None] * batch_size # type: List[Optional[constrained.RawConstraintList]]
raw_avoid_list = [None] * batch_size # type: List[Optional[constrained.RawConstraintList]]
max_output_lengths = [] # type: List[int]
for j, trans_input in enumerate(trans_inputs):
num_tokens = len(trans_input)
max_output_lengths.append(self.models[0].get_max_output_length(data_io.get_bucket(num_tokens, self.buckets_source)))
source[j, :num_tokens, 0] = data_io.tokens2ids(trans_input.tokens, self.source_vocabs[0])
factors = trans_input.factors if trans_input.factors is not None else []
num_factors = 1 + len(factors)
if num_factors != self.num_source_factors:
logger.warning("Input %d factors, but model(s) expect %d", num_factors,
self.num_source_factors)
for i, factor in enumerate(factors[:self.num_source_factors - 1], start=1):
# fill in as many factors as there are tokens
source[j, :num_tokens, i] = data_io.tokens2ids(factor, self.source_vocabs[i])[:num_tokens]
if trans_input.constraints is not None:
raw_constraints[j] = [data_io.tokens2ids(phrase, self.vocab_target) for phrase in
trans_input.constraints]
self.context = context
self.max_input_len = max_input_len
self.max_output_length_num_stds = max_output_length_num_stds
self.ensemble_mode = ensemble_mode
self.beam_size = beam_size
self.nbest_size = nbest_size
self.batch_size = batch_size
self.bucket_width_source = bucket_width_source
self.length_penalty_alpha = length_penalty_alpha
self.length_penalty_beta = length_penalty_beta
self.softmax_temperature = softmax_temperature
self.model = model
with ExitStack() as exit_stack:
inputs_fins = [exit_stack.enter_context(data_io.smart_open(f)) for f in inputs] # pylint: disable=no-member
references_fin = exit_stack.enter_context(data_io.smart_open(references)) # pylint: disable=no-member
inputs_sentences = [f.readlines() for f in inputs_fins]
target_sentences = references_fin.readlines()
utils.check_condition(all(len(l) == len(target_sentences) for l in inputs_sentences),
"Sentences differ in length")
if sample_size <= 0:
sample_size = len(inputs_sentences[0])
if sample_size < len(inputs_sentences[0]):
self.target_sentences, *self.inputs_sentences = parallel_subsample(
[target_sentences] + inputs_sentences, sample_size, random_seed)
else:
self.inputs_sentences, self.target_sentences = inputs_sentences, target_sentences
if sample_size < self.batch_size:
utils.check_condition(len(self.models) == 1 and self.beam_size == 1,
"Skipping softmax cannot be enabled for ensembles or beam sizes > 1.")
self.skip_topk = skip_topk
if self.skip_topk:
utils.check_condition(self.beam_size == 1, "skip_topk has no effect if beam size is larger than 1")
utils.check_condition(len(self.models) == 1, "skip_topk has no effect for decoding with more than 1 model")
self.sample = sample
utils.check_condition(not self.sample or self.restrict_lexicon is None,
"Sampling is not available when working with a restricted lexicon.")
# after models are loaded we ensured that they agree on max_input_length, max_output_length and batch size
self._max_input_length = self.models[0].max_input_length
if bucket_source_width > 0:
self.buckets_source = data_io.define_buckets(self._max_input_length, step=bucket_source_width)
else:
self.buckets_source = [self._max_input_length]
self._update_scores = UpdateScores()
self._update_scores.initialize(ctx=self.context)
self._update_scores.hybridize(static_alloc=True, static_shape=True)
# Vocabulary selection leads to different vocabulary sizes across requests. Hence, we cannot use a
# statically-shaped HybridBlock for the topk operation in this case; resorting to imperative topk
# function in this case.
if not self.restrict_lexicon:
if self.skip_topk:
self._top = Top1() # type: mx.gluon.HybridBlock
elif self.sample is not None:
self._top = SampleK(k=self.beam_size,
n=self.sample,
validation_sources = [args.validation_source] + args.validation_source_factors
validation_sources = [str(os.path.abspath(source)) for source in validation_sources]
validation_target = str(os.path.abspath(args.validation_target))
either_raw_or_prepared_error_msg = "Either specify a raw training corpus with %s and %s or a preprocessed corpus " \
"with %s." % (C.TRAINING_ARG_SOURCE,
C.TRAINING_ARG_TARGET,
C.TRAINING_ARG_PREPARED_DATA)
if args.prepared_data is not None:
utils.check_condition(args.source is None and args.target is None, either_raw_or_prepared_error_msg)
if not resume_training:
utils.check_condition(args.source_vocab is None and args.target_vocab is None,
"You are using a prepared data folder, which is tied to a vocabulary. "
"To change it you need to rerun data preparation with a different vocabulary.")
train_iter, validation_iter, data_config, source_vocabs, target_vocab = data_io.get_prepared_data_iters(
prepared_data_dir=args.prepared_data,
validation_sources=validation_sources,
validation_target=validation_target,
shared_vocab=shared_vocab,
batch_size=args.batch_size,
batch_by_words=batch_by_words,
batch_num_devices=batch_num_devices)
check_condition(args.source_factors_combine == C.SOURCE_FACTORS_COMBINE_SUM \
or len(source_vocabs) == len(args.source_factors_num_embed) + 1,
"Data was prepared with %d source factors, but only provided %d source factor dimensions." % (
len(source_vocabs), len(args.source_factors_num_embed) + 1))
if resume_training:
# resuming training. Making sure the vocabs in the model and in the prepared data match up
model_source_vocabs = vocab.load_source_vocabs(output_folder)
num_tokens = len(trans_input)
max_output_lengths.append(self.models[0].get_max_output_length(data_io.get_bucket(num_tokens, self.buckets_source)))
source[j, :num_tokens, 0] = data_io.tokens2ids(trans_input.tokens, self.source_vocabs[0])
factors = trans_input.factors if trans_input.factors is not None else []
num_factors = 1 + len(factors)
if num_factors != self.num_source_factors:
logger.warning("Input %d factors, but model(s) expect %d", num_factors,
self.num_source_factors)
for i, factor in enumerate(factors[:self.num_source_factors - 1], start=1):
# fill in as many factors as there are tokens
source[j, :num_tokens, i] = data_io.tokens2ids(factor, self.source_vocabs[i])[:num_tokens]
if trans_input.constraints is not None:
raw_constraints[j] = [data_io.tokens2ids(phrase, self.vocab_target) for phrase in
trans_input.constraints]
if trans_input.avoid_list is not None:
raw_avoid_list[j] = [data_io.tokens2ids(phrase, self.vocab_target) for phrase in
trans_input.avoid_list]
if any(self.unk_id in phrase for phrase in raw_avoid_list[j]):
logger.warning("Sentence %s: %s was found in the list of phrases to avoid; "
"this may indicate improper preprocessing.", trans_input.sentence_id, C.UNK_SYMBOL)
return source, bucket_key, raw_constraints, raw_avoid_list, mx.nd.array(max_output_lengths, ctx=self.context, dtype='int32')
pad_to_multiple_of=args.pad_vocab_to_multiple_of,
num_pointers=num_pointers)
check_condition(args.source_factors_combine == C.SOURCE_FACTORS_COMBINE_SUM \
or len(args.source_factors) == len(args.source_factors_num_embed),
"Number of source factor data (%d) differs from provided source factor dimensions (%d)" % (
len(args.source_factors), len(args.source_factors_num_embed)))
sources = [args.source] + args.source_factors
sources = [str(os.path.abspath(source)) for source in sources]
check_condition(len(sources) == len(validation_sources),
'Training and validation data must have the same number of factors, but found %d and %d.' % (
len(source_vocabs), len(validation_sources)))
train_iter, validation_iter, config_data, data_info = data_io.get_training_data_iters(
sources=sources,
target=os.path.abspath(args.target),
validation_sources=validation_sources,
validation_target=validation_target,
source_vocabs=source_vocabs,
target_vocab=target_vocab,
source_vocab_paths=source_vocab_paths,
target_vocab_path=target_vocab_path,
shared_vocab=shared_vocab,
batch_size=args.batch_size,
batch_by_words=batch_by_words,
batch_num_devices=batch_num_devices,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
bucketing=not args.no_bucketing,
bucket_width=args.bucket_width)