Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_io_args(test_params, expected_params):
_test_args(test_params, expected_params, arguments.add_training_io_args)
Check if we should resume a broken training run.
:param args: Arguments as returned by argparse.
:param output_folder: Main output folder for the model.
:return: Flag signaling if we are resuming training and the directory with
the training status.
"""
resume_training = False
training_state_dir = os.path.join(output_folder, C.TRAINING_STATE_DIRNAME)
if os.path.exists(output_folder):
if args.overwrite_output:
logger.info("Removing existing output folder %s.", output_folder)
shutil.rmtree(output_folder)
os.makedirs(output_folder)
elif os.path.exists(training_state_dir):
old_args = vars(arguments.load_args(os.path.join(output_folder, C.ARGS_STATE_NAME)))
arg_diffs = _dict_difference(vars(args), old_args) | _dict_difference(old_args, vars(args))
# Remove args that may differ without affecting the training.
arg_diffs -= set(C.ARGS_MAY_DIFFER)
# allow different device-ids provided their total count is the same
if 'device_ids' in arg_diffs and len(old_args['device_ids']) == len(vars(args)['device_ids']):
arg_diffs.discard('device_ids')
if not arg_diffs:
resume_training = True
else:
# We do not have the logger yet
logger.error("Mismatch in arguments for training continuation.")
logger.error("Differing arguments: %s.", ", ".join(arg_diffs))
sys.exit(1)
elif os.path.exists(os.path.join(output_folder, C.PARAMS_BEST_NAME)):
logger.error("Refusing to overwrite model folder %s as it seems to contain a trained model.", output_folder)
sys.exit(1)
def main():
params = arguments.ConfigArgumentParser(description='Train Sockeye sequence-to-sequence models.')
arguments.add_train_cli_args(params)
args = params.parse_args()
train(args)
def main():
from . import arguments
params = argparse.ArgumentParser(description='CLI to build source and target vocab(s).')
arguments.add_build_vocab_args(params)
arguments.add_logging_args(params)
args = params.parse_args()
prepare_vocab(args)
def main():
params = arguments.ConfigArgumentParser(description='Train Sockeye sequence-to-sequence models.')
arguments.add_train_cli_args(params)
args = params.parse_args()
train(args)
def main():
"""
Commandline interface to extract parameters.
"""
setup_main_logger(console=True, file_logging=False)
params = argparse.ArgumentParser(description="Extract specific parameters.")
arguments.add_extract_args(params)
args = params.parse_args()
extract_parameters(args)
def main():
params = arguments.ConfigArgumentParser(description='Train Sockeye sequence-to-sequence models.')
arguments.add_train_cli_args(params)
args = params.parse_args()
train(args)
def main():
params = arguments.ConfigArgumentParser(description='Image Captioning CLI')
arguments_image.add_image_caption_cli_args(params)
args = params.parse_args()
caption(args)
def main():
params = argparse.ArgumentParser(description='Preprocesses and shards training data.')
arguments.add_prepare_data_cli_args(params)
args = params.parse_args()
prepare_data(args)
def main():
params = argparse.ArgumentParser(description='Evaluate translations by calculating metrics with '
'respect to a reference set. If multiple hypotheses files are given'
'the mean and standard deviation of the metrics are reported.')
arguments.add_evaluate_args(params)
arguments.add_logging_args(params)
args = params.parse_args()
setup_main_logger(file_logging=False)
if args.quiet:
logger.setLevel(logging.ERROR)
utils.check_condition(args.offset >= 0, "Offset should be non-negative.")
log_sockeye_version(logger)
logger.info("Command: %s", " ".join(sys.argv))
logger.info("Arguments: %s", args)
references = [' '.join(e) for e in data_io.read_content(args.references)]
all_hypotheses = [[h.strip() for h in hypotheses] for hypotheses in args.hypotheses]
if not args.not_strict: