diff --git a/scripts/machine_translation/gnmt.py b/scripts/machine_translation/gnmt.py index cf3c82aa13..7fb7d744d5 100644 --- a/scripts/machine_translation/gnmt.py +++ b/scripts/machine_translation/gnmt.py @@ -22,7 +22,8 @@ from mxnet.gluon import nn, rnn from mxnet.gluon.block import HybridBlock from gluonnlp.model.seq2seq_encoder_decoder import Seq2SeqEncoder, Seq2SeqDecoder, \ - Seq2SeqOneStepDecoder, _get_attention_cell, _get_cell_type, _nested_sequence_last + Seq2SeqOneStepDecoder, _get_cell_type, _nested_sequence_last +from gluonnlp.model.attention_cell import _get_attention_cell class GNMTEncoder(Seq2SeqEncoder): diff --git a/src/gluonnlp/model/attention_cell.py b/src/gluonnlp/model/attention_cell.py index 6701020f69..53d4a04b2a 100644 --- a/src/gluonnlp/model/attention_cell.py +++ b/src/gluonnlp/model/attention_cell.py @@ -89,8 +89,9 @@ class AttentionCell(HybridBlock): out = cell(query, key, value, mask) """ - def __init__(self, prefix=None, params=None): + def __init__(self, unnormalized_score=False, prefix=None, params=None): self._dtype = np.float32 + self._unnormalized_score = unnormalized_score super(AttentionCell, self).__init__(prefix=prefix, params=params) def cast(self, dtype): @@ -117,6 +118,10 @@ def _compute_weight(self, F, query, key, mask=None): att_weights : Symbol or NDArray For single-head attention, Shape (batch_size, query_length, memory_length) For multi-head attention, Shape (batch_size, num_heads, query_length, memory_length) + att_score : Symbol or NDArray + unnormalized weight matrix. + For single-head attention, Shape (batch_size, query_length, memory_length) + For multi-head attention, Shape (batch_size, num_heads, query_length, memory_length) """ raise NotImplementedError @@ -172,9 +177,9 @@ def __call__(self, query, key, value=None, mask=None): # pylint: disable=argume def hybrid_forward(self, F, query, key, value=None, mask=None): # pylint: disable=arguments-differ if value is None: value = key - att_weights = self._compute_weight(F, query, key, mask) + att_weights, att_score = self._compute_weight(F, query, key, mask) context_vec = self._read_by_weight(F, att_weights, value) - return context_vec, att_weights + return context_vec, att_score if self._unnormalized_score else att_weights class MultiHeadAttentionCell(AttentionCell): @@ -206,17 +211,22 @@ class MultiHeadAttentionCell(AttentionCell): Initializer of the weights. bias_initializer : str or `Initializer`, default 'zeros' Initializer of the bias. + unnormalized_score: bool, default False + Whether to return an unnormalized weight matrix prefix : str or None, default None See document of `Block`. params : str or None, default None See document of `Block`. """ def __init__(self, base_cell, query_units, key_units, value_units, num_heads, use_bias=True, - weight_initializer=None, bias_initializer='zeros', prefix=None, params=None): - super(MultiHeadAttentionCell, self).__init__(prefix=prefix, params=params) + weight_initializer=None, bias_initializer='zeros', + unnormalized_score=False, prefix=None, params=None): + super(MultiHeadAttentionCell, + self).__init__(unnormalized_score=unnormalized_score, prefix=prefix, params=params) self._base_cell = base_cell self._num_heads = num_heads self._use_bias = use_bias + self._unnormalized_score = unnormalized_score units = {'query': query_units, 'key': key_units, 'value': value_units} for name, unit in units.items(): if unit % self._num_heads != 0: @@ -275,8 +285,11 @@ def _compute_weight(self, F, query, key, mask=None): mask = F.broadcast_axis(F.expand_dims(mask, axis=1), axis=1, size=self._num_heads)\ .reshape(shape=(-1, 0, 0), reverse=True) - att_weights = self._base_cell._compute_weight(F, query, key, mask) - return att_weights.reshape(shape=(-1, self._num_heads, 0, 0), reverse=True) + att_weights, att_score = self._base_cell._compute_weight(F, query, key, mask) + return att_weights.reshape(shape=(-1, self._num_heads, 0, 0), + reverse=True), att_score.reshape( + shape=(-1, self._num_heads, 0, 0), + reverse=True) def _read_by_weight(self, F, att_weights, value): att_weights = att_weights.reshape(shape=(-1, 0, 0), reverse=True) @@ -319,6 +332,8 @@ class MLPAttentionCell(AttentionCell): Initializer of the weights. bias_initializer : str or `Initializer`, default 'zeros' Initializer of the bias. + unnormalized_score: bool, default False + Whether to return an unnormalized weight matrix prefix : str or None, default None See document of `Block`. params : ParameterDict or None, default None @@ -326,7 +341,8 @@ class MLPAttentionCell(AttentionCell): """ def __init__(self, units, act=nn.Activation('tanh'), normalized=False, dropout=0.0, - weight_initializer=None, bias_initializer='zeros', prefix=None, params=None): + weight_initializer=None, bias_initializer='zeros', + unnormalized_score=False, prefix=None, params=None): # Define a temporary class to implement the normalized version # TODO(sxjscience) Find a better solution class _NormalizedScoreProj(HybridBlock): @@ -346,11 +362,15 @@ def hybrid_forward(self, F, x, g, v): # pylint: disable=arguments-differ flatten=False, name='fwd') return out - super(MLPAttentionCell, self).__init__(prefix=prefix, params=params) + super(MLPAttentionCell, + self).__init__(unnormalized_score=unnormalized_score, + prefix=prefix, + params=params) self._units = units self._act = act self._normalized = normalized self._dropout = dropout + self._unnormalized_score = unnormalized_score with self.name_scope(): self._dropout_layer = nn.Dropout(dropout) self._query_mid_layer = nn.Dense(units=self._units, flatten=False, use_bias=True, @@ -388,7 +408,7 @@ def _compute_weight(self, F, query, key, mask=None): if mask is not None: att_weights = att_weights * mask att_weights = self._dropout_layer(att_weights) - return att_weights + return att_weights, att_score class DotProductAttentionCell(AttentionCell): @@ -443,6 +463,8 @@ class DotProductAttentionCell(AttentionCell): Initializer of the weights bias_initializer : str or `Initializer`, default 'zeros' Initializer of the bias + unnormalized_score: bool, default False + Whether to return an unnormalized weight matrix prefix : str or None, default None See document of `Block`. params : str or None, default None @@ -450,14 +472,18 @@ class DotProductAttentionCell(AttentionCell): """ def __init__(self, units=None, luong_style=False, scaled=True, normalized=False, use_bias=True, dropout=0.0, weight_initializer=None, bias_initializer='zeros', - prefix=None, params=None): - super(DotProductAttentionCell, self).__init__(prefix=prefix, params=params) + unnormalized_score=False, prefix=None, params=None): + super(DotProductAttentionCell, + self).__init__(unnormalized_score=unnormalized_score, + prefix=prefix, + params=params) self._units = units self._scaled = scaled self._normalized = normalized self._use_bias = use_bias self._luong_style = luong_style self._dropout = dropout + self._unnormalized_score = unnormalized_score if self._luong_style: assert units is not None, 'Luong style attention is not available without explicitly ' \ 'setting the units' @@ -503,4 +529,76 @@ def _compute_weight(self, F, query, key, mask=None): if mask is not None: att_weights = att_weights * mask att_weights = self._dropout_layer(att_weights) - return att_weights + return att_weights, att_score + + +def _get_attention_cell(attention_cell, + units=None, + scaled=True, + num_heads=None, + use_bias=False, + dropout=0.0, + unnormalized_score=False): + """ + + Parameters + ---------- + attention_cell : AttentionCell or str + units : int or None + + Returns + ------- + attention_cell : AttentionCell + """ + if isinstance(attention_cell, str): + if attention_cell == 'scaled_luong': + return DotProductAttentionCell(units=units, + scaled=True, + normalized=False, + use_bias=use_bias, + dropout=dropout, + luong_style=True) + elif attention_cell == 'scaled_dot': + return DotProductAttentionCell(units=units, + scaled=True, + normalized=False, + use_bias=use_bias, + dropout=dropout, + luong_style=False) + elif attention_cell == 'dot': + return DotProductAttentionCell(units=units, + scaled=False, + normalized=False, + use_bias=use_bias, + dropout=dropout, + luong_style=False) + elif attention_cell == 'cosine': + return DotProductAttentionCell(units=units, + scaled=False, + use_bias=use_bias, + dropout=dropout, + normalized=True) + elif attention_cell == 'mlp': + return MLPAttentionCell(units=units, normalized=False) + elif attention_cell == 'normed_mlp': + return MLPAttentionCell(units=units, normalized=True) + elif attention_cell == 'multi_head': + base_cell = DotProductAttentionCell( + scaled=scaled, + dropout=dropout, + unnormalized_score=unnormalized_score) + return MultiHeadAttentionCell( + base_cell=base_cell, + query_units=units, + use_bias=use_bias, + key_units=units, + value_units=units, + num_heads=num_heads, + unnormalized_score=unnormalized_score) + else: + raise NotImplementedError + else: + assert isinstance(attention_cell, AttentionCell),\ + 'attention_cell must be either string or AttentionCell. Received attention_cell={}'\ + .format(attention_cell) + return attention_cell diff --git a/src/gluonnlp/model/seq2seq_encoder_decoder.py b/src/gluonnlp/model/seq2seq_encoder_decoder.py index fc2c7a1da9..2f988d8020 100644 --- a/src/gluonnlp/model/seq2seq_encoder_decoder.py +++ b/src/gluonnlp/model/seq2seq_encoder_decoder.py @@ -23,9 +23,6 @@ from mxnet.gluon import rnn from mxnet.gluon.block import Block -from .attention_cell import (AttentionCell, DotProductAttentionCell, - MLPAttentionCell, MultiHeadAttentionCell) - def _get_cell_type(cell_type): """Get the object type of the cell by parsing the input @@ -53,51 +50,6 @@ def _get_cell_type(cell_type): else: return cell_type - -def _get_attention_cell(attention_cell, units=None, - scaled=True, num_heads=None, - use_bias=False, dropout=0.0): - """ - - Parameters - ---------- - attention_cell : AttentionCell or str - units : int or None - - Returns - ------- - attention_cell : AttentionCell - """ - if isinstance(attention_cell, str): - if attention_cell == 'scaled_luong': - return DotProductAttentionCell(units=units, scaled=True, normalized=False, - use_bias=use_bias, dropout=dropout, luong_style=True) - elif attention_cell == 'scaled_dot': - return DotProductAttentionCell(units=units, scaled=True, normalized=False, - use_bias=use_bias, dropout=dropout, luong_style=False) - elif attention_cell == 'dot': - return DotProductAttentionCell(units=units, scaled=False, normalized=False, - use_bias=use_bias, dropout=dropout, luong_style=False) - elif attention_cell == 'cosine': - return DotProductAttentionCell(units=units, scaled=False, use_bias=use_bias, - dropout=dropout, normalized=True) - elif attention_cell == 'mlp': - return MLPAttentionCell(units=units, normalized=False) - elif attention_cell == 'normed_mlp': - return MLPAttentionCell(units=units, normalized=True) - elif attention_cell == 'multi_head': - base_cell = DotProductAttentionCell(scaled=scaled, dropout=dropout) - return MultiHeadAttentionCell(base_cell=base_cell, query_units=units, use_bias=use_bias, - key_units=units, value_units=units, num_heads=num_heads) - else: - raise NotImplementedError - else: - assert isinstance(attention_cell, AttentionCell),\ - 'attention_cell must be either string or AttentionCell. Received attention_cell={}'\ - .format(attention_cell) - return attention_cell - - def _nested_sequence_last(data, valid_length): """ diff --git a/src/gluonnlp/model/transformer.py b/src/gluonnlp/model/transformer.py index 7afc1f432d..301427d38c 100644 --- a/src/gluonnlp/model/transformer.py +++ b/src/gluonnlp/model/transformer.py @@ -34,8 +34,8 @@ from ..utils.parallel import Parallelizable from .block import GELU from .seq2seq_encoder_decoder import (Seq2SeqDecoder, Seq2SeqEncoder, - Seq2SeqOneStepDecoder, - _get_attention_cell) + Seq2SeqOneStepDecoder) +from .attention_cell import _get_attention_cell from .translation import NMTModel from .utils import _load_pretrained_params, _load_vocab