How to use the sockeye.arguments function in sockeye

To help you get started, we’ve selected a few sockeye examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github awslabs / sockeye / test / unit / test_arguments.py View on Github external
def test_io_args(test_params, expected_params):
    _test_args(test_params, expected_params, arguments.add_training_io_args)
github awslabs / sockeye / sockeye / train.py View on Github external
Check if we should resume a broken training run.

    :param args: Arguments as returned by argparse.
    :param output_folder: Main output folder for the model.
    :return: Flag signaling if we are resuming training and the directory with
        the training status.
    """
    resume_training = False
    training_state_dir = os.path.join(output_folder, C.TRAINING_STATE_DIRNAME)
    if os.path.exists(output_folder):
        if args.overwrite_output:
            logger.info("Removing existing output folder %s.", output_folder)
            shutil.rmtree(output_folder)
            os.makedirs(output_folder)
        elif os.path.exists(training_state_dir):
            old_args = vars(arguments.load_args(os.path.join(output_folder, C.ARGS_STATE_NAME)))
            arg_diffs = _dict_difference(vars(args), old_args) | _dict_difference(old_args, vars(args))
            # Remove args that may differ without affecting the training.
            arg_diffs -= set(C.ARGS_MAY_DIFFER)
            # allow different device-ids provided their total count is the same
            if 'device_ids' in arg_diffs and len(old_args['device_ids']) == len(vars(args)['device_ids']):
                arg_diffs.discard('device_ids')
            if not arg_diffs:
                resume_training = True
            else:
                # We do not have the logger yet
                logger.error("Mismatch in arguments for training continuation.")
                logger.error("Differing arguments: %s.", ", ".join(arg_diffs))
                sys.exit(1)
        elif os.path.exists(os.path.join(output_folder, C.PARAMS_BEST_NAME)):
            logger.error("Refusing to overwrite model folder %s as it seems to contain a trained model.", output_folder)
            sys.exit(1)
github awslabs / sockeye / sockeye / train.py View on Github external
def main():
    params = arguments.ConfigArgumentParser(description='Train Sockeye sequence-to-sequence models.')
    arguments.add_train_cli_args(params)
    args = params.parse_args()
    train(args)
github awslabs / sockeye / sockeye / vocab.py View on Github external
def main():
    from . import arguments
    params = argparse.ArgumentParser(description='CLI to build source and target vocab(s).')
    arguments.add_build_vocab_args(params)
    arguments.add_logging_args(params)
    args = params.parse_args()
    prepare_vocab(args)
github Cartus / DCGCN / sockeye / train.py View on Github external
def main():
    params = arguments.ConfigArgumentParser(description='Train Sockeye sequence-to-sequence models.')
    arguments.add_train_cli_args(params)
    args = params.parse_args()
    train(args)
github awslabs / sockeye / sockeye / extract_parameters.py View on Github external
def main():
    """
    Commandline interface to extract parameters.
    """
    setup_main_logger(console=True, file_logging=False)
    params = argparse.ArgumentParser(description="Extract specific parameters.")
    arguments.add_extract_args(params)
    args = params.parse_args()
    extract_parameters(args)
github awslabs / sockeye / sockeye / train.py View on Github external
def main():
    params = arguments.ConfigArgumentParser(description='Train Sockeye sequence-to-sequence models.')
    arguments.add_train_cli_args(params)
    args = params.parse_args()
    train(args)
github awslabs / sockeye / sockeye / image_captioning / captioner.py View on Github external
def main():
    params = arguments.ConfigArgumentParser(description='Image Captioning CLI')
    arguments_image.add_image_caption_cli_args(params)
    args = params.parse_args()
    caption(args)
github awslabs / sockeye / sockeye / prepare_data.py View on Github external
def main():
    params = argparse.ArgumentParser(description='Preprocesses and shards training data.')
    arguments.add_prepare_data_cli_args(params)
    args = params.parse_args()
    prepare_data(args)
github awslabs / sockeye / sockeye / evaluate.py View on Github external
def main():
    params = argparse.ArgumentParser(description='Evaluate translations by calculating metrics with '
                                                 'respect to a reference set. If multiple hypotheses files are given'
                                                 'the mean and standard deviation of the metrics are reported.')
    arguments.add_evaluate_args(params)
    arguments.add_logging_args(params)
    args = params.parse_args()
    setup_main_logger(file_logging=False)

    if args.quiet:
        logger.setLevel(logging.ERROR)

    utils.check_condition(args.offset >= 0, "Offset should be non-negative.")
    log_sockeye_version(logger)

    logger.info("Command: %s", " ".join(sys.argv))
    logger.info("Arguments: %s", args)

    references = [' '.join(e) for e in data_io.read_content(args.references)]
    all_hypotheses = [[h.strip() for h in hypotheses] for hypotheses in args.hypotheses]
    if not args.not_strict: