How to use the fairseq.options.parse_args_and_arch function in fairseq

To help you get started, we’ve selected a few fairseq examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github freewym / espresso / multiprocessing_train.py View on Github external
(rank, original_trace) = self.error_queue.get()
        self.error_queue.put((rank, original_trace))
        os.kill(os.getpid(), signal.SIGUSR1)

    def signal_handler(self, signalnum, stackframe):
        for pid in self.children_pids:
            os.kill(pid, signal.SIGINT)  # kill children processes
        (rank, original_trace) = self.error_queue.get()
        msg = "\n\n-- Tracebacks above this line can probably be ignored --\n\n"
        msg += original_trace
        raise Exception(msg)


if __name__ == '__main__':
    parser = options.get_training_parser()
    args = options.parse_args_and_arch(parser)
    main(args)
github freewym / espresso / examples / noisychannel / rerank_tune.py View on Github external
def cli_main():
    parser = rerank_options.get_tuning_parser()
    args = options.parse_args_and_arch(parser)

    random_search(args)
github freewym / espresso / examples / noisychannel / rerank_utils.py View on Github external
"--srcdict", cur_lm_dict,
                               "--destdir", preprocess_dir]
        preprocess_parser = options.get_preprocessing_parser()
        input_args = preprocess_parser.parse_args(preprocess_lm_param)
        preprocess.main(input_args)

        eval_lm_param = [preprocess_dir,
                         "--path", cur_language_model,
                         "--output-word-probs",
                         "--batch-size", str(batch_size),
                         "--max-tokens", "1024",
                         "--sample-break-mode", "eos",
                         "--gen-subset", "train"]

        eval_lm_parser = options.get_eval_lm_parser()
        input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)

        with open(lm_score_file, 'w') as f:
            with redirect_stdout(f):
                eval_lm.main(input_args)
github freewym / espresso / examples / noisychannel / rerank_generate.py View on Github external
param1 = [args.data,
                      "--path", args.gen_model,
                      "--shard-id", str(args.shard_id),
                      "--num-shards", str(args.num_shards),
                      "--nbest", str(args.num_rescore),
                      "--batch-size", str(args.batch_size),
                      "--beam", str(args.num_rescore),
                      "--max-sentences", str(args.num_rescore),
                      "--gen-subset", args.gen_subset,
                      "--source-lang", args.source_lang,
                      "--target-lang", args.target_lang]
            if args.sampling:
                param1 += ["--sampling"]

            gen_parser = options.get_generation_parser()
            input_args = options.parse_args_and_arch(gen_parser, param1)

            print(input_args)
            with open(predictions_bpe_file, 'w') as f:
                with redirect_stdout(f):
                    generate.main(input_args)

    gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe,
                                                  nbest=using_nbest, prefix_len=args.prefix_len,
                                                  target_prefix_frac=args.target_prefix_frac)

    if args.diff_bpe:
        rerank_utils.write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo,
                                       gen_output.no_bpe_target, pre_gen+"/source_gen_bpe."+args.source_lang,
                                       pre_gen+"/target_gen_bpe."+args.target_lang,
                                       pre_gen+"/reference_gen_bpe."+args.target_lang)
        bitext_bpe = args.rescore_bpe_code
github ecchochan / roberta-squad / fairseq_train_embed.py View on Github external
def cli_main():
    parser = options.get_training_parser()
##############################################################################
##############################################################################
####
####   Added an argument
####
##############################################################################
##############################################################################
    parser.add_argument('--lr_decay', default=1, type=float, 
                        help='Learning rate decay factor, 1.0 = no decay')
    parser.add_argument('--lr_decay_layers', default=24, type=int, 
                        help='Number of layers for learning rate decay')
    parser.add_argument('--freeze_transformer', dest='freeze_transformer', action='store_true',
                        help='Whether to freeze the weights in transformer')
    args = options.parse_args_and_arch(parser)

    if args.distributed_init_method is None:
        distributed_utils.infer_init_method(args)

    if args.distributed_init_method is not None:
        # distributed training
        if torch.cuda.device_count() > 1 and not args.distributed_no_spawn:
            start_rank = args.distributed_rank
            args.distributed_rank = None  # assign automatically
            torch.multiprocessing.spawn(
                fn=distributed_main,
                args=(args, start_rank),
                nprocs=torch.cuda.device_count(),
            )
        else:
            distributed_main(args.device_id, args)
github KelleyYin / Cross-lingual-Summarization / Teacher-Student / multiprocessing_train.py View on Github external
(rank, original_trace) = self.error_queue.get()
        self.error_queue.put((rank, original_trace))
        os.kill(os.getpid(), signal.SIGUSR1)

    def signal_handler(self, signalnum, stackframe):
        for pid in self.children_pids:
            os.kill(pid, signal.SIGINT)  # kill children processes
        (rank, original_trace) = self.error_queue.get()
        msg = "\n\n-- Tracebacks above this line can probably be ignored --\n\n"
        msg += original_trace
        raise Exception(msg)


if __name__ == '__main__':
    parser = options.get_training_parser()
    args = options.parse_args_and_arch(parser)
    main(args)
github ecchochan / roberta-squad / fairseq_train2.py View on Github external
def cli_main():
    parser = options.get_training_parser()
	##############################################################################
	##############################################################################
	####
	####   Added an argument
	####
	##############################################################################
	##############################################################################
    parser.add_argument('--lr_decay', default=1, type=float, 
                        help='Learning rate decay factor, 1.0 = no decay')
    parser.add_argument('--lr_decay_layers', default=24, type=int, 
                        help='Number of layers for learning rate decay')
    args = options.parse_args_and_arch(parser)

    if args.distributed_init_method is None:
        distributed_utils.infer_init_method(args)

    if args.distributed_init_method is not None:
        # distributed training
        if torch.cuda.device_count() > 1 and not args.distributed_no_spawn:
            start_rank = args.distributed_rank
            args.distributed_rank = None  # assign automatically
            torch.multiprocessing.spawn(
                fn=distributed_main,
                args=(args, start_rank),
                nprocs=torch.cuda.device_count(),
            )
        else:
            distributed_main(args.device_id, args)
github krasserm / fairseq-image-captioning / demo.py View on Github external
def cli_main():
    parser = options.get_generation_parser(interactive=True, default_task='captioning')
    args = options.parse_args_and_arch(parser)
    main(args)
github StillKeepTry / Transformer-PyTorch / train.py View on Github external
from singleprocess_train import main as singleprocess_main


def main(args):
    if args.distributed_port > 0 \
            or args.distributed_init_method is not None:
        distributed_main(args)
    elif args.distributed_world_size > 1:
        multiprocessing_main(args)
    else:
        singleprocess_main(args)


if __name__ == '__main__':
    parser = options.get_training_parser()
    args = options.parse_args_and_arch(parser)
    main(args)