Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
train_max_length = 30
dev_line_count = 20
dev_max_length = 30
expected_mean = 1.0
expected_std = 0.0
test_line_count = 20
test_line_count_empty = 0
test_max_length = 30
batch_size = 5
with tmp_digits_dataset("tmp_corpus",
train_line_count, train_line_count_empty, train_max_length - C.SPACE_FOR_XOS,
dev_line_count, dev_max_length - C.SPACE_FOR_XOS,
test_line_count, test_line_count_empty,
test_max_length - C.SPACE_FOR_XOS) as data:
# tmp common vocab
vcb = vocab.build_from_paths([data['train_source'], data['train_target']])
train_iter, val_iter, config_data, data_info = data_io.get_training_data_iters(
sources=[data['train_source']],
target=data['train_target'],
validation_sources=[data['dev_source']],
validation_target=data['dev_target'],
source_vocabs=[vcb],
target_vocab=vcb,
source_vocab_paths=[None],
target_vocab_path=None,
shared_vocab=True,
batch_size=batch_size,
batch_by_words=False,
batch_num_devices=1,
max_seq_len_source=train_max_length,
max_seq_len_target=train_max_length,
def __init__(self,
logdir: str,
source_vocab: Optional[vocab.Vocab] = None,
target_vocab: Optional[vocab.Vocab] = None) -> None:
self.logdir = logdir
self.source_labels = vocab.get_ordered_tokens_from_vocab(source_vocab) if source_vocab is not None else None
self.target_labels = vocab.get_ordered_tokens_from_vocab(target_vocab) if target_vocab is not None else None
try:
import mxboard
logger.info("Logging training events for Tensorboard at '%s'", self.logdir)
self.sw = mxboard.SummaryWriter(logdir=self.logdir, flush_secs=60, verbose=False)
except ImportError:
logger.info("mxboard not found. Consider 'pip install mxboard' to log events to Tensorboard.")
self.sw = None
:param decoder_return_logit_inputs: Model decoders return inputs to logit computation instead of softmax over target
vocabulary. Used when logits/softmax are handled separately.
:param cache_output_layer_w_b: Models cache weights and biases for logit computation as NumPy arrays (used with
restrict lexicon).
:param source_image_size: Size of the image to resize to. Used only for the image-text models
:param forced_max_output_len: An optional overwrite of the maximum out length.
:return: List of models, target vocabulary, source factor vocabularies.
"""
models = [] # type: List[ImageInferenceModel]
target_vocabs = [] # type: List[vocab.Vocab]
if checkpoints is None:
checkpoints = [None] * len(model_folders)
for model_folder, checkpoint in zip(model_folders, checkpoints):
target_vocabs.append(vocab.vocab_from_json(os.path.join(model_folder, C.VOCAB_TRG_NAME)))
model_version = utils.load_version(os.path.join(model_folder, C.VERSION_NAME))
logger.info("Model version: %s", model_version)
utils.check_version(model_version)
model_config = model.SockeyeModel.load_config(os.path.join(model_folder, C.CONFIG_NAME))
if checkpoint is None:
params_fname = os.path.join(model_folder, C.PARAMS_BEST_NAME)
else:
params_fname = os.path.join(model_folder, C.PARAMS_NAME % checkpoint)
inference_model = ImageInferenceModel(config=model_config,
params_fname=params_fname,
context=context,
beam_size=beam_size,
softmax_temperature=softmax_temperature,
def __init__(self,
source_vocab: vocab.Vocab,
target_vocab: vocab.Vocab,
window_size: int = 20,
min_word_length: int = 2):
self.source_vocab = vocab.reverse_vocab(source_vocab)
self.target_vocab = vocab.reverse_vocab(target_vocab)
self.vocab_offset = len(target_vocab)
self.window_size = window_size
self.min_word_length = min_word_length
self.banned_words = [C.EOS_SYMBOL, C.BOS_SYMBOL, C.PAD_SYMBOL]
self.num_pointed = 0
self.num_total = 0
logger.info("Pointer networks: window_size=%d min_word_len=%d", self.window_size, self.min_word_length)
source_vocabs: List[vocab.Vocab],
target_vocab: vocab.Vocab,
nbest_size: int = 1,
restrict_lexicon: Optional[lexicon.TopKLexicon] = None,
avoid_list: Optional[str] = None,
store_beam: bool = False,
strip_unknown_words: bool = False,
skip_topk: bool = False,
sample: int = None) -> None:
self.context = context
self.length_penalty = length_penalty
self.beam_prune = beam_prune
self.beam_search_stop = beam_search_stop
self.source_vocabs = source_vocabs
self.vocab_target = target_vocab
self.vocab_target_inv = vocab.reverse_vocab(self.vocab_target)
self.restrict_lexicon = restrict_lexicon
self.store_beam = store_beam
self.start_id = self.vocab_target[C.BOS_SYMBOL]
assert C.PAD_ID == 0, "pad id should be 0"
self.stop_ids = {self.vocab_target[C.EOS_SYMBOL], C.PAD_ID} # type: Set[int]
self.strip_ids = self.stop_ids.copy() # ids to strip from the output
self.unk_id = self.vocab_target[C.UNK_SYMBOL]
if strip_unknown_words:
self.strip_ids.add(self.unk_id)
self.models = models
utils.check_condition(all(models[0].source_with_eos == m.source_with_eos for m in models),
"The source_with_eos property must match across models.")
self.source_with_eos = models[0].source_with_eos
self.interpolation_func = self._get_interpolation_func(ensemble_mode)
self.beam_size = self.models[0].beam_size
self.nbest_size = nbest_size
"size that is a multiple of %d." % len(context))
logger.info("Training Device(s): %s", ", ".join(str(c) for c in context))
train_iter, eval_iter, config_data, source_vocabs, target_vocab = create_data_iters_and_vocabs(
args=args,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
shared_vocab=use_shared_vocab(args),
resume_training=resume_training,
output_folder=output_folder)
max_seq_len_source = config_data.max_seq_len_source
max_seq_len_target = config_data.max_seq_len_target
# Dump the vocabularies if we're just starting up
if not resume_training:
vocab.save_source_vocabs(source_vocabs, output_folder)
vocab.save_target_vocab(target_vocab, output_folder)
source_vocab_sizes = [len(v) for v in source_vocabs]
target_vocab_size = len(target_vocab)
logger.info('Vocabulary sizes: source=[%s] target=%d',
'|'.join([str(size) for size in source_vocab_sizes]),
target_vocab_size)
model_config = create_model_config(args=args,
source_vocab_sizes=source_vocab_sizes, target_vocab_size=target_vocab_size,
max_seq_len_source=max_seq_len_source, max_seq_len_target=max_seq_len_target,
config_data=config_data)
model_config.freeze()
training_model = create_training_model(config=model_config,
context=context,
def __init__(self,
logdir: str,
source_vocab: Optional[vocab.Vocab] = None,
target_vocab: Optional[vocab.Vocab] = None) -> None:
self.logdir = logdir
self.source_labels = vocab.get_ordered_tokens_from_vocab(source_vocab) if source_vocab is not None else None
self.target_labels = vocab.get_ordered_tokens_from_vocab(target_vocab) if target_vocab is not None else None
try:
import mxboard
logger.info("Logging training events for Tensorboard at '%s'", self.logdir)
self.sw = mxboard.SummaryWriter(logdir=self.logdir, flush_secs=60, verbose=False)
except ImportError:
logger.info("mxboard not found. Consider 'pip install mxboard' to log events to Tensorboard.")
self.sw = None
logger.info("Training Device(s): %s", ", ".join(str(c) for c in context))
train_iter, eval_iter, config_data, source_vocabs, target_vocab, edge_vocab = create_data_iters_and_vocabs(args=args,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
shared_vocab=use_shared_vocab(args),
resume_training=resume_training,
output_folder=output_folder)
max_seq_len_source = config_data.max_seq_len_source
max_seq_len_target = config_data.max_seq_len_target
vocab_edge_size = len(edge_vocab)
# Dump the vocabularies if we're just starting up
if not resume_training:
vocab.save_source_vocabs(source_vocabs, output_folder)
vocab.save_target_vocab(target_vocab, output_folder)
source_vocab_sizes = [len(v) for v in source_vocabs]
target_vocab_size = len(target_vocab)
logger.info('Vocabulary sizes: source=[%s] target=%d',
'|'.join([str(size) for size in source_vocab_sizes]),
target_vocab_size)
model_config = create_model_config(args=args,
source_vocab_sizes=source_vocab_sizes, target_vocab_size=target_vocab_size,
edge_vocab_size=vocab_edge_size,
max_seq_len_source=max_seq_len_source, max_seq_len_target=max_seq_len_target,
config_data=config_data)
model_config.freeze()
training_model = create_training_model(config=model_config,
context=context,
_, num_words_target = args.num_words
num_words_target = num_words_target if num_words_target > 0 else None
_, word_min_count_target = args.word_min_count
batch_num_devices = 1 if args.use_cpu else sum(-di if di < 0 else 1 for di in args.device_ids)
batch_by_words = args.batch_type == C.BATCH_TYPE_WORD
either_raw_or_prepared_error_msg = "Either specify a raw training corpus with %s or a preprocessed corpus " \
"with %s." % (C.TRAINING_ARG_TARGET,
C.TRAINING_ARG_PREPARED_DATA)
# Note: ignore args.prepared_data for the moment
utils.check_condition(args.prepared_data is None and args.target is not None,
either_raw_or_prepared_error_msg)
if resume_training:
# Load the existing vocab created when starting the training run.
target_vocab = vocab.vocab_from_json(os.path.join(output_folder, C.VOCAB_TRG_NAME))
# Recover the vocabulary path from the existing config file:
data_info = cast(data_io.DataInfo, Config.load(os.path.join(output_folder, C.DATA_INFO)))
target_vocab_path = data_info.target_vocab
else:
# Load vocab:
target_vocab_path = args.target_vocab
# Note: We do not care about the source vocab for images, that is why some inputs are mocked
target_vocab = vocab.load_or_create_vocab(data=args.target,
vocab_path=target_vocab_path,
num_words=num_words_target,
word_min_count=word_min_count_target)
train_iter, validation_iter, config_data, data_info = data_io_image.get_training_image_text_data_iters(
source_root=args.source_root,
source=os.path.abspath(args.source),
config_file = os.path.join(prepared_data_dir, C.DATA_CONFIG)
check_condition(os.path.exists(config_file),
"Could not find data config %s. Are you sure %s is a directory created with "
"python -m sockeye.prepare_data?" % (config_file, prepared_data_dir))
config_data = cast(DataConfig, DataConfig.load(config_file))
shard_fnames = [os.path.join(prepared_data_dir,
C.SHARD_NAME % shard_idx) for shard_idx in range(data_info.num_shards)]
for shard_fname in shard_fnames:
check_condition(os.path.exists(shard_fname), "Shard %s does not exist." % shard_fname)
check_condition(shared_vocab == data_info.shared_vocab, "Shared vocabulary settings need to match these "
"of the prepared data (e.g. for weight tying). "
"Specify or omit %s consistently when training "
"and preparing the data." % C.VOCAB_ARG_SHARED_VOCAB)
source_vocabs = vocab.load_source_vocabs(prepared_data_dir)
target_vocab = vocab.load_target_vocab(prepared_data_dir)
check_condition(len(source_vocabs) == len(data_info.sources),
"Wrong number of source vocabularies. Found %d, need %d." % (len(source_vocabs),
len(data_info.sources)))
buckets = config_data.data_statistics.buckets
max_seq_len_source = config_data.max_seq_len_source
max_seq_len_target = config_data.max_seq_len_target
bucket_batch_sizes = define_bucket_batch_sizes(buckets,
batch_size,
batch_by_words,
batch_num_devices,
config_data.data_statistics.average_len_target_per_bucket)