How to use the fairseq.models.register_model_architecture function in fairseq

To help you get started, we’ve selected a few fairseq examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github zhawe01 / fairseq-gec / fairseq / models / transformer.py View on Github external
@register_model_architecture('transformer', 'transformer_vaswani_wmt_en_fr_big')
def transformer_vaswani_wmt_en_fr_big(args):
    args.dropout = getattr(args, 'dropout', 0.1)
    transformer_vaswani_wmt_en_de_big(args)
github freewym / espresso / fairseq / models / masked_lm.py View on Github external
@register_model_architecture('masked_lm', 'xlm_base')
def xlm_architecture(args):
    args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
    args.share_encoder_input_output_embed = getattr(
        args, 'share_encoder_input_output_embed', True)
    args.no_token_positional_embeddings = getattr(
        args, 'no_token_positional_embeddings', False)
    args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', True)
    args.num_segment = getattr(args, 'num_segment', 1)

    args.encoder_layers = getattr(args, 'encoder_layers', 6)

    args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
    args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
    args.bias_kv = getattr(args, 'bias_kv', False)
    args.zero_attn = getattr(args, 'zero_attn', False)
github freewym / espresso / fairseq / models / nat / levenshtein_transformer.py View on Github external
@register_model_architecture(
    "levenshtein_transformer", "levenshtein_transformer_wmt_en_de"
)
def levenshtein_transformer_wmt_en_de(args):
    levenshtein_base_architecture(args)
github freewym / espresso / espresso / models / speech_lstm.py View on Github external
@register_model_architecture('lstm_lm', 'lstm_lm_wsj')
def lstm_lm_wsj(args):
    base_lm_architecture(args)
github StillKeepTry / Transformer-PyTorch / fairseq / models / fconv.py View on Github external
@register_model_architecture('fconv', 'fconv')
def base_architecture(args):
    args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
    args.encoder_layers = getattr(args, 'encoder_layers', '[(512, 3)] * 20')
    args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
    args.decoder_layers = getattr(args, 'decoder_layers', '[(512, 3)] * 20')
    args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256)
    args.decoder_attention = getattr(args, 'decoder_attention', 'True')
    args.share_input_output_embed = getattr(args, 'share_input_output_embed', False)
github freewym / espresso / fairseq / models / nat / nonautoregressive_transformer.py View on Github external
@register_model_architecture(
    "nonautoregressive_transformer", "nonautoregressive_transformer"
)
def base_architecture(args):
    args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
    args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
    args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
    args.encoder_layers = getattr(args, "encoder_layers", 6)
    args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
    args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
    args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
    args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
    args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
    args.decoder_ffn_embed_dim = getattr(
        args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
    )
    args.decoder_layers = getattr(args, "decoder_layers", 6)
github freewym / espresso / espresso / models / speech_transformer.py View on Github external
@register_model_architecture('speech_transformer', 'speech_transformer')
def base_architecture(args):
    args.encoder_conv_channels = getattr(
        args, 'encoder_conv_channels', '[64, 64, 128, 128]',
    )
    args.encoder_conv_kernel_sizes = getattr(
        args, 'encoder_conv_kernel_sizes', '[(3, 3), (3, 3), (3, 3), (3, 3)]',
    )
    args.encoder_conv_strides = getattr(
        args, 'encoder_conv_strides', '[(1, 1), (2, 2), (1, 1), (2, 2)]',
    )
    args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
    args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 512)
    args.encoder_layers = getattr(args, 'encoder_layers', 6)
    args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
    args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
    args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
github krasserm / fairseq-image-captioning / model / caption.py View on Github external
@register_model_architecture('default-captioning-model', 'default-captioning-arch')
def default_captioning_arch(args):
    args.encoder_layers = getattr(args, 'encoder_layers', 3)
github StillKeepTry / Transformer-PyTorch / fairseq / models / fconv.py View on Github external
@register_model_architecture('fconv', 'fconv_wmt_en_de')
def fconv_wmt_en_de(args):
    base_architecture(args)
    convs = '[(512, 3)] * 9'       # first 9 layers have 512 units
    convs += ' + [(1024, 3)] * 4'  # next 4 layers have 1024 units
    convs += ' + [(2048, 1)] * 2'  # final 2 layers use 1x1 convolutions
    args.encoder_embed_dim = 768
    args.encoder_layers = convs
    args.decoder_embed_dim = 768
    args.decoder_layers = convs
    args.decoder_out_embed_dim = 512
github freewym / espresso / fairseq / models / lightconv_lm.py View on Github external
@register_model_architecture('lightconv_lm', 'lightconv_lm_gbw')
def lightconv_lm_gbw(args):
    args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
    args.dropout = getattr(args, 'dropout', 0.1)
    args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
    args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 4096)
    args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16)
    base_lm_architecture(args)