Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
"""
# (batch, length, num_hidden * 2)
source_hidden = mx.sym.FullyConnected(data=source,
weight=self.s2h_weight,
bias=self.s2h_bias,
num_hidden=self.num_hidden * 2,
flatten=False,
name="%ssource_hidden_fc" % self.prefix)
# split keys and values
# (batch, length, num_hidden)
# pylint: disable=unbalanced-tuple-unpacking
keys, values = mx.sym.split(data=source_hidden, num_outputs=2, axis=2)
# (batch*heads, length, num_hidden/head)
keys = layers.split_heads(keys, self.num_hidden_per_head, self.heads)
values = layers.split_heads(values, self.num_hidden_per_head, self.heads)
def attend(att_input: AttentionInput, att_state: AttentionState) -> AttentionState:
"""
Returns updated attention state given attention input and current attention state.
:param att_input: Attention input as returned by make_input().
:param att_state: Current attention state
:return: Updated attention state.
"""
# (batch, num_hidden)
query = mx.sym.FullyConnected(data=att_input.query,
weight=self.t2h_weight, bias=self.t2h_bias,
num_hidden=self.num_hidden, name="%squery_hidden_fc" % self.prefix)
# (batch, length, heads, num_hidden/head)
query = mx.sym.reshape(query, shape=(0, 1, self.heads, self.num_hidden_per_head))
# (batch, heads, num_hidden/head, length)
def _create_layer_configs(default_kwargs_filler, parsed_layers: List[Dict]) -> Tuple[List[layers.LayerConfig], bool]:
source_attention_present = False
layer_configs = []
for layer in parsed_layers:
name = layer['name']
# TODO: can we simplify this? Maybe have LayerConfigs register themselves
if name == 'ff':
layer_configs.append(_fill_and_create(default_kwargs_filler,
name, layers.FeedForwardLayerConfig, layer['params']))
elif name == 'linear':
layer_configs.append(_fill_and_create(default_kwargs_filler,
name, layers.LinearLayerConfig, layer['params']))
elif name == 'id':
layer_configs.append(_fill_and_create(default_kwargs_filler,
name, layers.IdentityLayerConfig, layer['params']))
elif name == 'mh_dot_att':
layer_configs.append(_fill_and_create(default_kwargs_filler,
name, layers.MultiHeadSourceAttentionLayerConfig, layer['params']))
source_attention_present = True
elif name == 'mh_dot_self_att':
layer_configs.append(_fill_and_create(default_kwargs_filler,
name, layers.MultiHeadSelfAttentionLayerConfig, layer['params']))
elif name == 'cnn':
layer_configs.append(_fill_and_create(default_kwargs_filler,
name, convolution.ConvolutionalLayerConfig, layer['params']))
self.pre_ff = TransformerProcessBlock(sequence=config.preprocess_sequence,
dropout=config.dropout_prepost,
prefix="ff_pre_")
self.ff = TransformerFeedForward(num_hidden=config.feed_forward_num_hidden,
num_model=config.model_size,
act_type=config.act_type,
dropout=config.dropout_act,
prefix="ff_")
self.post_ff = TransformerProcessBlock(sequence=config.postprocess_sequence,
dropout=config.dropout_prepost,
prefix="ff_post_")
self.lhuc = None
if config.use_lhuc:
self.lhuc = layers.LHUC(config.model_size)
num_embed=config.num_embed,
max_seq_len=config.max_seq_len_target,
fixed_pos_embed_scale_up_input=False,
fixed_pos_embed_scale_down_positions=True,
prefix=C.TARGET_POSITIONAL_EMBEDDING_PREFIX)
self.layers = [convolution.ConvolutionBlock(
config.cnn_config,
pad_type='left',
prefix="%s%d_" % (prefix, i)) for i in range(config.num_layers)]
if self.config.project_qkv:
self.attention_layers = [layers.ProjectedDotAttention("%s%d_" % (prefix, i),
self.config.cnn_config.num_hidden)
for i in range(config.num_layers)]
else:
self.attention_layers = [layers.PlainDotAttention() for _ in range(config.num_layers)] # type: ignore
self.i2h_weight = mx.sym.Variable('%si2h_weight' % prefix)
# source & target embeddings
embed_weight_source, embed_weight_target, out_weight_target = self._get_embed_weights(self.prefix)
if isinstance(self.config.config_embed_source, encoder.PassThroughEmbeddingConfig):
self.embedding_source = encoder.PassThroughEmbedding(self.config.config_embed_source) # type: encoder.Encoder
else:
self.embedding_source = encoder.Embedding(self.config.config_embed_source,
prefix=self.prefix + C.SOURCE_EMBEDDING_PREFIX,
embed_weight=embed_weight_source,
is_source=True) # type: encoder.Encoder
self.embedding_target = encoder.Embedding(self.config.config_embed_target,
prefix=self.prefix + C.TARGET_EMBEDDING_PREFIX,
embed_weight=embed_weight_target)
# output layer
self.output_layer = layers.OutputLayer(hidden_size=self.decoder.get_num_hidden(),
vocab_size=self.config.vocab_target_size - self.config.num_pointers,
weight=out_weight_target,
weight_normalization=self.config.weight_normalization,
prefix=self.prefix + C.DEFAULT_OUTPUT_LAYER_PREFIX)
# create length ratio prediction layer(s)
self.length_ratio = None
if self.config.config_length_task is not None:
if self.config.config_length_task.weight > 0.0:
self.length_ratio = layers.LengthRatio(hidden_size=self.encoder.get_num_hidden(),
num_layers=self.config.config_length_task.num_layers,
prefix=self.prefix + C.LENRATIOS_OUTPUT_LAYER_PREFIX)
else:
logger.warning("Auxiliary length task requested, but its loss weight is zero -- this will have no effect.")
self.params = None # type: Optional[Dict]
def __init__(self,
config: TransformerConfig,
prefix: str) -> None:
super().__init__(prefix=prefix)
with self.name_scope():
self.pre_self_attention = TransformerProcessBlock(sequence=config.preprocess_sequence,
dropout=config.dropout_prepost,
prefix="att_self_pre_")
self.self_attention = layers.MultiHeadSelfAttention(depth_att=config.model_size,
heads=config.attention_heads,
depth_out=config.model_size,
dropout=config.dropout_attention,
prefix="att_self_")
self.post_self_attention = TransformerProcessBlock(sequence=config.postprocess_sequence,
dropout=config.dropout_prepost,
prefix="att_self_post_")
self.pre_enc_attention = TransformerProcessBlock(sequence=config.preprocess_sequence,
dropout=config.dropout_prepost,
prefix="att_enc_pre_")
self.enc_attention = layers.MultiHeadAttention(depth_att=config.model_size,
heads=config.attention_heads,
depth_out=config.model_size,
dropout=config.dropout_attention,
prefix="att_enc_")
prefix="att_self_post_")
self.pre_ff = TransformerProcessBlock(sequence=config.preprocess_sequence,
dropout=config.dropout_prepost,
prefix="ff_pre_")
self.ff = TransformerFeedForward(num_hidden=config.feed_forward_num_hidden,
num_model=config.model_size,
act_type=config.act_type,
dropout=config.dropout_act,
prefix="ff_")
self.post_ff = TransformerProcessBlock(sequence=config.postprocess_sequence,
dropout=config.dropout_prepost,
prefix="ff_post_")
self.lhuc = None
if config.use_lhuc:
self.lhuc = layers.LHUC(config.model_size)
# (batch, length, heads, num_hidden/head)
query = mx.sym.reshape(query, shape=(0, 1, self.heads, self.num_hidden_per_head))
# (batch, heads, num_hidden/head, length)
query = mx.sym.transpose(query, axes=(0, 2, 3, 1))
# (batch * heads, num_hidden/head, 1)
query = mx.sym.reshape(query, shape=(-3, self.num_hidden_per_head, 1))
# scale dot product
query = query * (self.num_hidden_per_head ** -0.5)
# (batch*heads, length, num_hidden/head) X (batch*heads, num_hidden/head, 1)
# -> (batch*heads, length, 1)
attention_scores = mx.sym.batch_dot(lhs=keys, rhs=query, name="%sdot" % self.prefix)
# (batch*heads, 1)
lengths = layers.broadcast_to_heads(mx.sym, source_length, self.heads, ndim=1, fold_heads=True)
# context: (batch*heads, num_hidden/head)
# attention_probs: (batch*heads, length)
context, attention_probs = get_context_and_attention_probs(values, lengths, attention_scores, self.dtype)
# combine heads
# (batch*heads, 1, num_hidden/head)
context = mx.sym.expand_dims(context, axis=1)
# (batch, 1, num_hidden)
context = layers.combine_heads(mx.sym, context, self.num_hidden_per_head, heads=self.heads)
# (batch, num_hidden)
context = mx.sym.reshape(context, shape=(-3, -1))
# (batch, heads, length)
attention_probs = mx.sym.reshape(data=attention_probs, shape=(-4, -1, self.heads, source_seq_len))
# just average over distributions
self.attention_num_hidden = num_hidden
# input (encoder) to hidden
self.att_e2h_weight = mx.sym.Variable("%se2h_weight" % self.prefix)
# input (query) to hidden
self.att_q2h_weight = mx.sym.Variable("%sq2h_weight" % self.prefix)
# hidden to score
self.att_h2s_weight = mx.sym.Variable("%sh2s_weight" % self.prefix)
# coverage
self.coverage = None # type: Optional[coverage.Coverage]
# dynamic source (coverage) weights and settings
# input (coverage) to hidden
self.att_c2h_weight = None
# layer normalization
self._ln = None
if layer_normalization:
self._ln = layers.LayerNormalization(prefix="%snorm" % self.prefix)
init = mx.sym.broadcast_div(mx.sym.sum(source_masked, axis=1, keepdims=False),
mx.sym.expand_dims(source_encoded_length, axis=1))
else:
raise ValueError("Unknown decoder state init type '%s'" % self.config.state_init)
init = mx.sym.FullyConnected(data=init,
num_hidden=init_num_hidden,
weight=self.init_ws[state_idx],
bias=self.init_bs[state_idx],
name="%senc2decinit_%d" % (self.prefix, state_idx))
if self.config.layer_normalization:
init = self.init_norms[state_idx](init)
init = mx.sym.Activation(data=init, act_type="tanh",
name="%senc2dec_inittanh_%d" % (self.prefix, state_idx))
if self.config.state_init_lhuc:
lhuc = layers.LHUC(init_num_hidden, prefix="%senc2decinit_%d_" % (self.prefix, state_idx))
init = lhuc(init)
layer_states.append(init)
return RecurrentDecoderState(hidden, layer_states)