Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
check_condition(os.path.exists(config_file),
"Could not find data config %s. Are you sure %s is a directory created with "
"python -m sockeye.prepare_data?" % (config_file, prepared_data_dir))
config_data = cast(DataConfig, DataConfig.load(config_file))
shard_fnames = [os.path.join(prepared_data_dir,
C.SHARD_NAME % shard_idx) for shard_idx in range(data_info.num_shards)]
for shard_fname in shard_fnames:
check_condition(os.path.exists(shard_fname), "Shard %s does not exist." % shard_fname)
check_condition(shared_vocab == data_info.shared_vocab, "Shared vocabulary settings need to match these "
"of the prepared data (e.g. for weight tying). "
"Specify or omit %s consistently when training "
"and preparing the data." % C.VOCAB_ARG_SHARED_VOCAB)
source_vocabs = vocab.load_source_vocabs(prepared_data_dir)
target_vocab = vocab.load_target_vocab(prepared_data_dir)
check_condition(len(source_vocabs) == len(data_info.sources),
"Wrong number of source vocabularies. Found %d, need %d." % (len(source_vocabs),
len(data_info.sources)))
buckets = config_data.data_statistics.buckets
max_seq_len_source = config_data.max_seq_len_source
max_seq_len_target = config_data.max_seq_len_target
bucket_batch_sizes = define_bucket_batch_sizes(buckets,
batch_size,
batch_by_words,
batch_num_devices,
config_data.data_statistics.average_len_target_per_bucket)
config_data.data_statistics.log(bucket_batch_sizes)
:return: The scoring data iterator as well as the source and target vocabularies.
"""
model_config = model.SockeyeModel.load_config(os.path.join(args.model, C.CONFIG_NAME))
if args.max_seq_len is None:
max_seq_len_source = model_config.config_data.max_seq_len_source
max_seq_len_target = model_config.config_data.max_seq_len_target
else:
max_seq_len_source, max_seq_len_target = args.max_seq_len
batch_num_devices = 1 if args.use_cpu else sum(-di if di < 0 else 1 for di in args.device_ids)
# Load the existing vocabs created when starting the training run.
source_vocabs = vocab.load_source_vocabs(model_folder)
target_vocab = vocab.load_target_vocab(model_folder)
sources = [args.source] + args.source_factors
sources = [str(os.path.abspath(source)) for source in sources]
score_iter = data_io.get_scoring_data_iters(
sources=sources,
target=os.path.abspath(args.target),
source_vocabs=source_vocabs,
target_vocab=target_vocab,
batch_size=args.batch_size,
batch_num_devices=batch_num_devices,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target)
return score_iter, source_vocabs, target_vocab, model_config
batch_size=args.batch_size,
batch_by_words=batch_by_words,
batch_num_devices=batch_num_devices)
check_condition(args.source_factors_combine == C.SOURCE_FACTORS_COMBINE_SUM \
or len(source_vocabs) == len(args.source_factors_num_embed) + 1,
"Data was prepared with %d source factors, but only provided %d source factor dimensions." % (
len(source_vocabs), len(args.source_factors_num_embed) + 1))
if resume_training:
# resuming training. Making sure the vocabs in the model and in the prepared data match up
model_source_vocabs = vocab.load_source_vocabs(output_folder)
for i, (v, mv) in enumerate(zip(source_vocabs, model_source_vocabs)):
utils.check_condition(vocab.are_identical(v, mv),
"Prepared data and resumed model source vocab %d do not match." % i)
model_target_vocab = vocab.load_target_vocab(output_folder)
utils.check_condition(vocab.are_identical(target_vocab, model_target_vocab),
"Prepared data and resumed model target vocabs do not match.")
check_condition(data_config.num_source_factors == len(validation_sources),
'Training and validation data must have the same number of factors, but found %d and %d.' % (
data_config.num_source_factors, len(validation_sources)))
return train_iter, validation_iter, data_config, source_vocabs, target_vocab
else:
utils.check_condition(args.prepared_data is None and args.source is not None and args.target is not None,
either_raw_or_prepared_error_msg)
if resume_training:
# Load the existing vocabs created when starting the training run.
source_vocabs = vocab.load_source_vocabs(output_folder)
batch_size=args.batch_size,
batch_by_words=batch_by_words,
batch_num_devices=batch_num_devices,
fill_up=args.fill_up)
check_condition(len(source_vocabs) == len(args.source_factors_num_embed) + 1,
"Data was prepared with %d source factors, but only provided %d source factor dimensions." % (
len(source_vocabs), len(args.source_factors_num_embed) + 1))
if resume_training:
# resuming training. Making sure the vocabs in the model and in the prepared data match up
model_source_vocabs = vocab.load_source_vocabs(output_folder)
for i, (v, mv) in enumerate(zip(source_vocabs, model_source_vocabs)):
utils.check_condition(vocab.are_identical(v, mv),
"Prepared data and resumed model source vocab %d do not match." % i)
model_target_vocab = vocab.load_target_vocab(output_folder)
utils.check_condition(vocab.are_identical(target_vocab, model_target_vocab),
"Prepared data and resumed model target vocabs do not match.")
check_condition(data_config.num_source_factors == len(validation_sources),
'Training and validation data must have the same number of factors, but found %d and %d.' % (
data_config.num_source_factors, len(validation_sources)))
return train_iter, validation_iter, data_config, source_vocabs, target_vocab
else:
utils.check_condition(args.prepared_data is None and args.source is not None and args.target is not None,
either_raw_or_prepared_error_msg)
if resume_training:
# Load the existing vocabs created when starting the training run.
source_vocabs = vocab.load_source_vocabs(output_folder)
target_vocabs = [] # type: List[vocab.Vocab]
if checkpoints is None:
checkpoints = [None] * len(model_folders)
else:
utils.check_condition(len(checkpoints) == len(model_folders), "Must provide checkpoints for each model")
skip_softmax = False
# performance tweak: skip softmax for a single model, decoding with beam size 1, when not sampling and no scores are required in output.
if len(model_folders) == 1 and beam_size == 1 and not output_scores and not sampling:
skip_softmax = True
logger.info("Enabled skipping softmax for a single model and greedy decoding.")
for model_folder, checkpoint in zip(model_folders, checkpoints):
model_source_vocabs = vocab.load_source_vocabs(model_folder)
model_target_vocab = vocab.load_target_vocab(model_folder)
source_vocabs.append(model_source_vocabs)
target_vocabs.append(model_target_vocab)
model_version = utils.load_version(os.path.join(model_folder, C.VERSION_NAME))
logger.info("Model version: %s", model_version)
utils.check_version(model_version)
model_config = model.SockeyeModel.load_config(os.path.join(model_folder, C.CONFIG_NAME))
logger.info("Disabling dropout layers for performance reasons")
model_config.disable_dropout()
if override_dtype is not None:
model_config.config_encoder.dtype = override_dtype
model_config.config_decoder.dtype = override_dtype
if override_dtype == C.DTYPE_FP16:
logger.warning('Experimental feature \'override_dtype=float16\' has been used. '
"Prepared data and resumed model target vocabs do not match.")
check_condition(data_config.num_source_factors == len(validation_sources),
'Training and validation data must have the same number of factors, but found %d and %d.' % (
data_config.num_source_factors, len(validation_sources)))
return train_iter, validation_iter, data_config, source_vocabs, target_vocab
else:
utils.check_condition(args.prepared_data is None and args.source is not None and args.target is not None,
either_raw_or_prepared_error_msg)
if resume_training:
# Load the existing vocabs created when starting the training run.
source_vocabs = vocab.load_source_vocabs(output_folder)
target_vocab = vocab.load_target_vocab(output_folder)
# Recover the vocabulary path from the data info file:
data_info = cast(data_io.DataInfo, Config.load(os.path.join(output_folder, C.DATA_INFO)))
source_vocab_paths = data_info.source_vocabs
target_vocab_path = data_info.target_vocab
else:
# Load or create vocabs
source_factor_vocab_paths = [args.source_factor_vocabs[i] if i < len(args.source_factor_vocabs)
else None for i in range(len(args.source_factors))]
source_vocab_paths = [args.source_vocab] + source_factor_vocab_paths
target_vocab_path = args.target_vocab
num_pointers = max_seq_len_source if args.attention_based_copying else 0
source_vocabs, target_vocab = vocab.load_or_create_vocabs(
source_paths=[args.source] + args.source_factors,
target_path=args.target,