Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[Refactor]Add a switch for attention to return an unnormalized weight matrix. Move _get_attention_cell function position #1007

Open
wants to merge 5 commits into
base: v0.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion scripts/machine_translation/gnmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from mxnet.gluon import nn, rnn
from mxnet.gluon.block import HybridBlock
from gluonnlp.model.seq2seq_encoder_decoder import Seq2SeqEncoder, Seq2SeqDecoder, \
Seq2SeqOneStepDecoder, _get_attention_cell, _get_cell_type, _nested_sequence_last
Seq2SeqOneStepDecoder, _get_cell_type, _nested_sequence_last
from gluonnlp.model.attention_cell import _get_attention_cell


class GNMTEncoder(Seq2SeqEncoder):
Expand Down
124 changes: 111 additions & 13 deletions src/gluonnlp/model/attention_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ class AttentionCell(HybridBlock):
out = cell(query, key, value, mask)

"""
def __init__(self, prefix=None, params=None):
def __init__(self, unnormalized_score=False, prefix=None, params=None):
self._dtype = np.float32
self._unnormalized_score = unnormalized_score
super(AttentionCell, self).__init__(prefix=prefix, params=params)

def cast(self, dtype):
Expand All @@ -117,6 +118,10 @@ def _compute_weight(self, F, query, key, mask=None):
att_weights : Symbol or NDArray
For single-head attention, Shape (batch_size, query_length, memory_length)
For multi-head attention, Shape (batch_size, num_heads, query_length, memory_length)
att_score : Symbol or NDArray
unnormalized weight matrix.
For single-head attention, Shape (batch_size, query_length, memory_length)
For multi-head attention, Shape (batch_size, num_heads, query_length, memory_length)
"""
raise NotImplementedError

Expand Down Expand Up @@ -172,9 +177,9 @@ def __call__(self, query, key, value=None, mask=None): # pylint: disable=argume
def hybrid_forward(self, F, query, key, value=None, mask=None): # pylint: disable=arguments-differ
if value is None:
value = key
att_weights = self._compute_weight(F, query, key, mask)
att_weights, att_score = self._compute_weight(F, query, key, mask)
context_vec = self._read_by_weight(F, att_weights, value)
return context_vec, att_weights
return context_vec, att_score if self._unnormalized_score else att_weights


class MultiHeadAttentionCell(AttentionCell):
Expand Down Expand Up @@ -206,17 +211,22 @@ class MultiHeadAttentionCell(AttentionCell):
Initializer of the weights.
bias_initializer : str or `Initializer`, default 'zeros'
Initializer of the bias.
unnormalized_score: bool, default False
Whether to return an unnormalized weight matrix
prefix : str or None, default None
See document of `Block`.
params : str or None, default None
See document of `Block`.
"""
def __init__(self, base_cell, query_units, key_units, value_units, num_heads, use_bias=True,
weight_initializer=None, bias_initializer='zeros', prefix=None, params=None):
super(MultiHeadAttentionCell, self).__init__(prefix=prefix, params=params)
weight_initializer=None, bias_initializer='zeros',
unnormalized_score=False, prefix=None, params=None):
super(MultiHeadAttentionCell,
self).__init__(unnormalized_score=unnormalized_score, prefix=prefix, params=params)
self._base_cell = base_cell
self._num_heads = num_heads
self._use_bias = use_bias
self._unnormalized_score = unnormalized_score
units = {'query': query_units, 'key': key_units, 'value': value_units}
for name, unit in units.items():
if unit % self._num_heads != 0:
Expand Down Expand Up @@ -275,8 +285,11 @@ def _compute_weight(self, F, query, key, mask=None):
mask = F.broadcast_axis(F.expand_dims(mask, axis=1),
axis=1, size=self._num_heads)\
.reshape(shape=(-1, 0, 0), reverse=True)
att_weights = self._base_cell._compute_weight(F, query, key, mask)
return att_weights.reshape(shape=(-1, self._num_heads, 0, 0), reverse=True)
att_weights, att_score = self._base_cell._compute_weight(F, query, key, mask)
return att_weights.reshape(shape=(-1, self._num_heads, 0, 0),
reverse=True), att_score.reshape(
shape=(-1, self._num_heads, 0, 0),
reverse=True)

def _read_by_weight(self, F, att_weights, value):
att_weights = att_weights.reshape(shape=(-1, 0, 0), reverse=True)
Expand Down Expand Up @@ -319,14 +332,17 @@ class MLPAttentionCell(AttentionCell):
Initializer of the weights.
bias_initializer : str or `Initializer`, default 'zeros'
Initializer of the bias.
unnormalized_score: bool, default False
Whether to return an unnormalized weight matrix
prefix : str or None, default None
See document of `Block`.
params : ParameterDict or None, default None
See document of `Block`.
"""

def __init__(self, units, act=nn.Activation('tanh'), normalized=False, dropout=0.0,
weight_initializer=None, bias_initializer='zeros', prefix=None, params=None):
weight_initializer=None, bias_initializer='zeros',
unnormalized_score=False, prefix=None, params=None):
# Define a temporary class to implement the normalized version
# TODO(sxjscience) Find a better solution
class _NormalizedScoreProj(HybridBlock):
Expand All @@ -346,11 +362,15 @@ def hybrid_forward(self, F, x, g, v): # pylint: disable=arguments-differ
flatten=False, name='fwd')
return out

super(MLPAttentionCell, self).__init__(prefix=prefix, params=params)
super(MLPAttentionCell,
self).__init__(unnormalized_score=unnormalized_score,
prefix=prefix,
params=params)
self._units = units
self._act = act
self._normalized = normalized
self._dropout = dropout
self._unnormalized_score = unnormalized_score
with self.name_scope():
self._dropout_layer = nn.Dropout(dropout)
self._query_mid_layer = nn.Dense(units=self._units, flatten=False, use_bias=True,
Expand Down Expand Up @@ -388,7 +408,7 @@ def _compute_weight(self, F, query, key, mask=None):
if mask is not None:
att_weights = att_weights * mask
att_weights = self._dropout_layer(att_weights)
return att_weights
return att_weights, att_score


class DotProductAttentionCell(AttentionCell):
Expand Down Expand Up @@ -443,21 +463,27 @@ class DotProductAttentionCell(AttentionCell):
Initializer of the weights
bias_initializer : str or `Initializer`, default 'zeros'
Initializer of the bias
unnormalized_score: bool, default False
Whether to return an unnormalized weight matrix
prefix : str or None, default None
See document of `Block`.
params : str or None, default None
See document of `Block`.
"""
def __init__(self, units=None, luong_style=False, scaled=True, normalized=False, use_bias=True,
dropout=0.0, weight_initializer=None, bias_initializer='zeros',
prefix=None, params=None):
super(DotProductAttentionCell, self).__init__(prefix=prefix, params=params)
unnormalized_score=False, prefix=None, params=None):
super(DotProductAttentionCell,
self).__init__(unnormalized_score=unnormalized_score,
prefix=prefix,
params=params)
self._units = units
self._scaled = scaled
self._normalized = normalized
self._use_bias = use_bias
self._luong_style = luong_style
self._dropout = dropout
self._unnormalized_score = unnormalized_score
if self._luong_style:
assert units is not None, 'Luong style attention is not available without explicitly ' \
'setting the units'
Expand Down Expand Up @@ -503,4 +529,76 @@ def _compute_weight(self, F, query, key, mask=None):
if mask is not None:
att_weights = att_weights * mask
att_weights = self._dropout_layer(att_weights)
return att_weights
return att_weights, att_score


def _get_attention_cell(attention_cell,
units=None,
scaled=True,
num_heads=None,
use_bias=False,
dropout=0.0,
unnormalized_score=False):
"""

Parameters
----------
attention_cell : AttentionCell or str
units : int or None

Returns
-------
attention_cell : AttentionCell
"""
if isinstance(attention_cell, str):
if attention_cell == 'scaled_luong':
return DotProductAttentionCell(units=units,
scaled=True,
normalized=False,
use_bias=use_bias,
dropout=dropout,
luong_style=True)
elif attention_cell == 'scaled_dot':
return DotProductAttentionCell(units=units,
scaled=True,
normalized=False,
use_bias=use_bias,
dropout=dropout,
luong_style=False)
elif attention_cell == 'dot':
return DotProductAttentionCell(units=units,
scaled=False,
normalized=False,
use_bias=use_bias,
dropout=dropout,
luong_style=False)
elif attention_cell == 'cosine':
return DotProductAttentionCell(units=units,
scaled=False,
use_bias=use_bias,
dropout=dropout,
normalized=True)
elif attention_cell == 'mlp':
return MLPAttentionCell(units=units, normalized=False)
elif attention_cell == 'normed_mlp':
return MLPAttentionCell(units=units, normalized=True)
elif attention_cell == 'multi_head':
base_cell = DotProductAttentionCell(
scaled=scaled,
dropout=dropout,
unnormalized_score=unnormalized_score)
return MultiHeadAttentionCell(
base_cell=base_cell,
query_units=units,
use_bias=use_bias,
key_units=units,
value_units=units,
num_heads=num_heads,
unnormalized_score=unnormalized_score)
else:
raise NotImplementedError
else:
assert isinstance(attention_cell, AttentionCell),\
'attention_cell must be either string or AttentionCell. Received attention_cell={}'\
.format(attention_cell)
return attention_cell
48 changes: 0 additions & 48 deletions src/gluonnlp/model/seq2seq_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@
from mxnet.gluon import rnn
from mxnet.gluon.block import Block

from .attention_cell import (AttentionCell, DotProductAttentionCell,
MLPAttentionCell, MultiHeadAttentionCell)


def _get_cell_type(cell_type):
"""Get the object type of the cell by parsing the input
Expand Down Expand Up @@ -53,51 +50,6 @@ def _get_cell_type(cell_type):
else:
return cell_type


def _get_attention_cell(attention_cell, units=None,
scaled=True, num_heads=None,
use_bias=False, dropout=0.0):
"""

Parameters
----------
attention_cell : AttentionCell or str
units : int or None

Returns
-------
attention_cell : AttentionCell
"""
if isinstance(attention_cell, str):
if attention_cell == 'scaled_luong':
return DotProductAttentionCell(units=units, scaled=True, normalized=False,
use_bias=use_bias, dropout=dropout, luong_style=True)
elif attention_cell == 'scaled_dot':
return DotProductAttentionCell(units=units, scaled=True, normalized=False,
use_bias=use_bias, dropout=dropout, luong_style=False)
elif attention_cell == 'dot':
return DotProductAttentionCell(units=units, scaled=False, normalized=False,
use_bias=use_bias, dropout=dropout, luong_style=False)
elif attention_cell == 'cosine':
return DotProductAttentionCell(units=units, scaled=False, use_bias=use_bias,
dropout=dropout, normalized=True)
elif attention_cell == 'mlp':
return MLPAttentionCell(units=units, normalized=False)
elif attention_cell == 'normed_mlp':
return MLPAttentionCell(units=units, normalized=True)
elif attention_cell == 'multi_head':
base_cell = DotProductAttentionCell(scaled=scaled, dropout=dropout)
return MultiHeadAttentionCell(base_cell=base_cell, query_units=units, use_bias=use_bias,
key_units=units, value_units=units, num_heads=num_heads)
else:
raise NotImplementedError
else:
assert isinstance(attention_cell, AttentionCell),\
'attention_cell must be either string or AttentionCell. Received attention_cell={}'\
.format(attention_cell)
return attention_cell


def _nested_sequence_last(data, valid_length):
"""

Expand Down
4 changes: 2 additions & 2 deletions src/gluonnlp/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
from ..utils.parallel import Parallelizable
from .block import GELU
from .seq2seq_encoder_decoder import (Seq2SeqDecoder, Seq2SeqEncoder,
Seq2SeqOneStepDecoder,
_get_attention_cell)
Seq2SeqOneStepDecoder)
from .attention_cell import _get_attention_cell
from .translation import NMTModel
from .utils import _load_pretrained_params, _load_vocab

Expand Down