Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
setup_main_logger(file_logging=True,
console=not args.quiet, path=os.path.join(output_folder, C.LOG_NAME))
utils.log_basic_info(args)
with open(os.path.join(output_folder, C.ARGS_STATE_NAME), "w") as fp:
json.dump(vars(args), fp)
max_seq_len_source, max_seq_len_target = args.max_seq_len
# The maximum length is the length before we add the BOS/EOS symbols
max_seq_len_source = max_seq_len_source + C.SPACE_FOR_XOS
max_seq_len_target = max_seq_len_target + C.SPACE_FOR_XOS
logger.info("Adjusting maximum length to reserve space for a BOS/EOS marker. New maximum length: (%d, %d)",
max_seq_len_source, max_seq_len_target)
with ExitStack() as exit_stack:
context = utils.determine_context(device_ids=args.device_ids,
use_cpu=args.use_cpu,
disable_device_locking=args.disable_device_locking,
lock_dir=args.lock_dir,
exit_stack=exit_stack)
if args.batch_type == C.BATCH_TYPE_SENTENCE:
check_condition(args.batch_size % len(context) == 0, "When using multiple devices the batch size must be "
"divisible by the number of devices. Choose a batch "
"size that is a multiple of %d." % len(context))
logger.info("Training Device(s): %s", ", ".join(str(c) for c in context))
# Read feature size
if args.image_preextracted_features:
_, args.source_image_size = read_feature_shape(args.source_root)
train_iter, eval_iter, config_data, target_vocab = create_data_iters_and_vocab(
args=args,
:return: A CheckpointDecoder if --decode-and-evaluate != 0, else None.
"""
sample_size = args.decode_and_evaluate
if args.optimized_metric == C.BLEU and sample_size == 0:
logger.info("You chose BLEU as the optimized metric, will turn on BLEU monitoring during training. "
"To control how many validation sentences are used for calculating bleu use "
"the --decode-and-evaluate argument.")
sample_size = -1
if sample_size == 0:
return None
if args.use_cpu or args.decode_and_evaluate_use_cpu:
context = mx.cpu()
elif args.decode_and_evaluate_device_id is not None:
context = utils.determine_context(device_ids=[args.decode_and_evaluate_device_id],
use_cpu=False,
disable_device_locking=args.disable_device_locking,
lock_dir=args.lock_dir,
exit_stack=exit_stack)[0]
else:
# default decode context is the last training device
context = train_context[-1]
return checkpoint_decoder.CheckpointDecoder(context=context,
inputs=[args.validation_source] + args.validation_source_factors,
references=args.validation_target,
model=args.output,
sample_size=sample_size)
:return: A CheckpointDecoder if --decode-and-evaluate != 0, else None.
"""
sample_size = args.decode_and_evaluate
if args.optimized_metric == C.BLEU and sample_size == 0:
logger.info("You chose BLEU as the optimized metric, will turn on BLEU monitoring during training. "
"To control how many validation sentences are used for calculating bleu use "
"the --decode-and-evaluate argument.")
sample_size = -1
if sample_size == 0:
return None
if args.use_cpu or args.decode_and_evaluate_use_cpu:
context = mx.cpu()
elif args.decode_and_evaluate_device_id is not None:
context = utils.determine_context(device_ids=args.decode_and_evaluate_device_id,
use_cpu=False,
disable_device_locking=args.disable_device_locking,
lock_dir=args.lock_dir,
exit_stack=exit_stack)[0]
else:
# default decode context is the last training device
context = train_context[-1]
return checkpoint_decoder.CheckpointDecoder(context=context,
inputs=[args.validation_source] + args.validation_source_factors,
graph=args.val_source_graphs,
references=args.validation_target,
model=args.output,
edge_vocab=edge_vocab,
sample_size=sample_size)
:return: A CheckpointDecoder if --decode-and-evaluate != 0, else None.
"""
sample_size = args.decode_and_evaluate
if args.optimized_metric == C.BLEU and sample_size == 0:
logger.info("You chose BLEU as the optimized metric, will turn on BLEU monitoring during training. "
"To control how many validation sentences are used for calculating bleu use "
"the --decode-and-evaluate argument.")
sample_size = -1
if sample_size == 0:
return None
if args.use_cpu or args.decode_and_evaluate_use_cpu:
context = mx.cpu()
elif args.decode_and_evaluate_device_id is not None:
context = utils.determine_context(device_ids=args.decode_and_evaluate_device_id,
use_cpu=False,
disable_device_locking=args.disable_device_locking,
lock_dir=args.lock_dir,
exit_stack=exit_stack)[0]
else:
# default decode context is the last training device
context = train_context[-1]
return checkpoint_decoder.CheckpointDecoderImageModel(context=context,
inputs=[args.validation_source] + args.validation_source_factors,
references=args.validation_target,
model=args.output,
sample_size=sample_size,
source_image_size=args.source_image_size,
image_root=args.validation_source_root,
max_output_length=args.max_output_length,
logger.warning("'--checkpoint-frequency' is deprecated, and will be removed in the future. Please use '--checkpoint-interval'")
utils.log_basic_info(args)
arguments.save_args(args, os.path.join(output_folder, C.ARGS_STATE_NAME))
max_seq_len_source, max_seq_len_target = args.max_seq_len
# The maximum length is the length before we add the BOS/EOS symbols
max_seq_len_source = max_seq_len_source + C.SPACE_FOR_XOS
max_seq_len_target = max_seq_len_target + C.SPACE_FOR_XOS
logger.info("Adjusting maximum length to reserve space for a BOS/EOS marker. New maximum length: (%d, %d)",
max_seq_len_source, max_seq_len_target)
check_condition(args.length_task is not None or C.LENRATIO_MSE not in args.metrics,
"%s metrics requires enabling length ratio prediction with --length-task." % C.LENRATIO_MSE)
with ExitStack() as exit_stack:
context = utils.determine_context(device_ids=args.device_ids,
use_cpu=args.use_cpu,
disable_device_locking=args.disable_device_locking,
lock_dir=args.lock_dir,
exit_stack=exit_stack)
if args.batch_type == C.BATCH_TYPE_SENTENCE:
check_condition(args.batch_size % len(context) == 0, "When using multiple devices the batch size must be "
"divisible by the number of devices. Choose a batch "
"size that is a multiple of %d." % len(context))
logger.info("Training Device(s): %s", ", ".join(str(c) for c in context))
train_iter, eval_iter, config_data, source_vocabs, target_vocab = create_data_iters_and_vocabs(
args=args,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
shared_vocab=use_shared_vocab(args),
resume_training=resume_training,
setup_main_logger(file_logging=False, level=args.loglevel)
log_basic_info(args)
if args.nbest_size > 1:
if args.output_type != C.OUTPUT_HANDLER_JSON:
logger.warning("For nbest translation, you must specify `--output-type '%s'; overriding your setting of '%s'.",
C.OUTPUT_HANDLER_JSON, args.output_type)
args.output_type = C.OUTPUT_HANDLER_JSON
output_handler = get_output_handler(args.output_type,
args.output,
args.sure_align_threshold)
with ExitStack() as exit_stack:
check_condition(len(args.device_ids) == 1, "translate only supports single device for now")
context = determine_context(device_ids=args.device_ids,
use_cpu=args.use_cpu,
disable_device_locking=args.disable_device_locking,
lock_dir=args.lock_dir,
exit_stack=exit_stack)[0]
logger.info("Translate Device: %s", context)
models, source_vocabs, target_vocab = inference.load_models(
context=context,
max_input_len=args.max_input_len,
beam_size=args.beam_size,
batch_size=args.batch_size,
model_folders=args.models,
checkpoints=args.checkpoints,
softmax_temperature=args.softmax_temperature,
max_output_length_num_stds=args.max_output_length_num_stds,
decoder_return_logit_inputs=args.restrict_lexicon is not None,
image_root = os.path.abspath(args.image_root)
output_root = os.path.abspath(args.output_root)
output_file = os.path.abspath(args.output)
size_out_file = os.path.join(output_root, "image_feature_sizes.pkl")
if os.path.exists(output_root):
logger.info("Overwriting provided path {}.".format(output_root))
else:
os.makedirs(output_root)
# read image list file
image_list = read_list_file(args.input)
# Get pretrained net module (already bind)
with ExitStack() as exit_stack:
check_condition(len(args.device_ids) == 1, "extract_features only supports single device for now")
context = determine_context(device_ids=args.device_ids,
use_cpu=args.use_cpu,
disable_device_locking=args.disable_device_locking,
lock_dir=args.lock_dir,
exit_stack=exit_stack)[0]
module, _ = get_pretrained_net(args, context)
# Extract features
with open(output_file, "w") as fout:
for i, im in enumerate(batching(image_list, args.batch_size)):
logger.info("Processing batch {}/{}".format(i + 1, int(np.ceil(len(image_list) / args.batch_size))))
# TODO: enable caching to reuse features and resume computation
feats, out_names = extract_features_forward(im, module,
image_root,
output_root,
args.batch_size,
args.source_image_size,