Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
utils.check_condition(vocab.are_identical(target_vocab, model_target_vocab),
"Prepared data and resumed model target vocabs do not match.")
check_condition(data_config.num_source_factors == len(validation_sources),
'Training and validation data must have the same number of factors, but found %d and %d.' % (
data_config.num_source_factors, len(validation_sources)))
return train_iter, validation_iter, data_config, source_vocabs, target_vocab
else:
utils.check_condition(args.prepared_data is None and args.source is not None and args.target is not None,
either_raw_or_prepared_error_msg)
if resume_training:
# Load the existing vocabs created when starting the training run.
source_vocabs = vocab.load_source_vocabs(output_folder)
target_vocab = vocab.load_target_vocab(output_folder)
# Recover the vocabulary path from the data info file:
data_info = cast(data_io.DataInfo, Config.load(os.path.join(output_folder, C.DATA_INFO)))
source_vocab_paths = data_info.source_vocabs
target_vocab_path = data_info.target_vocab
edge_vocab = vocab.load_or_create_vocab(args.source_graphs, args.edge_vocab, num_words_source,
word_min_count_source,
args.pad_vocab_to_multiple_of)
else:
# GRN: AMR Generation
# Load or create vocabs
source_factor_vocab_paths = [args.source_factor_vocabs[i] if i < len(args.source_factor_vocabs)
else None for i in range(len(args.source_factors))]
prepared_data_dir=args.prepared_data,
validation_sources=validation_sources,
validation_target=validation_target,
shared_vocab=shared_vocab,
batch_size=args.batch_size,
batch_by_words=batch_by_words,
batch_num_devices=batch_num_devices)
check_condition(args.source_factors_combine == C.SOURCE_FACTORS_COMBINE_SUM \
or len(source_vocabs) == len(args.source_factors_num_embed) + 1,
"Data was prepared with %d source factors, but only provided %d source factor dimensions." % (
len(source_vocabs), len(args.source_factors_num_embed) + 1))
if resume_training:
# resuming training. Making sure the vocabs in the model and in the prepared data match up
model_source_vocabs = vocab.load_source_vocabs(output_folder)
for i, (v, mv) in enumerate(zip(source_vocabs, model_source_vocabs)):
utils.check_condition(vocab.are_identical(v, mv),
"Prepared data and resumed model source vocab %d do not match." % i)
model_target_vocab = vocab.load_target_vocab(output_folder)
utils.check_condition(vocab.are_identical(target_vocab, model_target_vocab),
"Prepared data and resumed model target vocabs do not match.")
check_condition(data_config.num_source_factors == len(validation_sources),
'Training and validation data must have the same number of factors, but found %d and %d.' % (
data_config.num_source_factors, len(validation_sources)))
return train_iter, validation_iter, data_config, source_vocabs, target_vocab
else:
utils.check_condition(args.prepared_data is None and args.source is not None and args.target is not None,
either_raw_or_prepared_error_msg)
:param model_folder: Output folder.
:return: The scoring data iterator as well as the source and target vocabularies.
"""
model_config = model.SockeyeModel.load_config(os.path.join(args.model, C.CONFIG_NAME))
if args.max_seq_len is None:
max_seq_len_source = model_config.config_data.max_seq_len_source
max_seq_len_target = model_config.config_data.max_seq_len_target
else:
max_seq_len_source, max_seq_len_target = args.max_seq_len
batch_num_devices = 1 if args.use_cpu else sum(-di if di < 0 else 1 for di in args.device_ids)
# Load the existing vocabs created when starting the training run.
source_vocabs = vocab.load_source_vocabs(model_folder)
target_vocab = vocab.load_target_vocab(model_folder)
sources = [args.source] + args.source_factors
sources = [str(os.path.abspath(source)) for source in sources]
score_iter = data_io.get_scoring_data_iters(
sources=sources,
target=os.path.abspath(args.target),
source_vocabs=source_vocabs,
target_vocab=target_vocab,
batch_size=args.batch_size,
batch_num_devices=batch_num_devices,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target)
return score_iter, source_vocabs, target_vocab, model_config
def inspect(args):
setup_main_logger(console=True, file_logging=False)
global logger
logger = logging.getLogger('inspect')
log_sockeye_version(logger)
logger.info("Inspecting top-k lexicon at \"%s\"", args.lexicon)
vocab_source = vocab.load_source_vocabs(args.model)[0]
vocab_target = vocab.vocab_from_json(os.path.join(args.model, C.VOCAB_TRG_NAME))
vocab_target_inv = vocab.reverse_vocab(vocab_target)
lexicon = TopKLexicon(vocab_source, vocab_target)
lexicon.load(args.lexicon, args.k)
logger.info("Reading from STDIN...")
for line in sys.stdin:
tokens = list(get_tokens(line))
if not tokens:
continue
ids = tokens2ids(tokens, vocab_source)
print("Input: n=%d" % len(tokens), " ".join("%s(%d)" % (tok, i) for tok, i in zip(tokens, ids)))
trg_ids = lexicon.get_trg_ids(np.array(ids))
tokens_trg = [vocab_target_inv.get(trg_id, C.UNK_SYMBOL) for trg_id in trg_ids]
print("Output: n=%d" % len(tokens_trg), " ".join("%s(%d)" % (tok, i) for tok, i in zip(tokens_trg, trg_ids)))
print()
prepared_data_dir=args.prepared_data,
validation_sources=validation_sources,
validation_target=str(os.path.abspath(args.validation_target)),
shared_vocab=shared_vocab,
batch_size=args.batch_size,
batch_by_words=batch_by_words,
batch_num_devices=batch_num_devices,
fill_up=args.fill_up)
check_condition(len(source_vocabs) == len(args.source_factors_num_embed) + 1,
"Data was prepared with %d source factors, but only provided %d source factor dimensions." % (
len(source_vocabs), len(args.source_factors_num_embed) + 1))
if resume_training:
# resuming training. Making sure the vocabs in the model and in the prepared data match up
model_source_vocabs = vocab.load_source_vocabs(output_folder)
for i, (v, mv) in enumerate(zip(source_vocabs, model_source_vocabs)):
utils.check_condition(vocab.are_identical(v, mv),
"Prepared data and resumed model source vocab %d do not match." % i)
model_target_vocab = vocab.load_target_vocab(output_folder)
utils.check_condition(vocab.are_identical(target_vocab, model_target_vocab),
"Prepared data and resumed model target vocabs do not match.")
check_condition(data_config.num_source_factors == len(validation_sources),
'Training and validation data must have the same number of factors, but found %d and %d.' % (
data_config.num_source_factors, len(validation_sources)))
return train_iter, validation_iter, data_config, source_vocabs, target_vocab
else:
utils.check_condition(args.prepared_data is None and args.source is not None and args.target is not None,
either_raw_or_prepared_error_msg)