Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

blockformer #1504

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions examples/aishell/s0/conf/train_blockformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# network architecture
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: true
cnn_module_kernel: 15
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'
use_se_module: true
se_module_channel: 12 # the same number with encoder blocks

# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
input_layer: 'rel_embed'
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
use_se_module: true
se_module_channel: 6 # the same number with decoder blocks

# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

dataset_conf:
filter_conf:
max_length: 40960
min_length: 0
token_max_length: 200
token_min_length: 1
resample_conf:
resample_rate: 16000
speed_perturb: true
fbank_conf:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 0.1
spec_aug: true
spec_aug_conf:
num_t_mask: 2
num_f_mask: 2
max_t: 50
max_f: 10
shuffle: true
shuffle_conf:
shuffle_size: 1500
sort: false
sort_conf:
sort_size: 500 # sort_size should be less than shuffle_size
batch_conf:
batch_type: 'static' # static or dynamic
batch_size: 16

grad_clip: 5
accum_grad: 4
max_epoch: 360
log_interval: 100

optim: adam
optim_conf:
lr: 0.002
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 50000
81 changes: 69 additions & 12 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
from typeguard import check_argument_types

from wenet.transformer.attention import MultiHeadedAttention
from wenet.transformer.attention import RelPositionMultiHeadedAttention
from wenet.transformer.decoder_layer import DecoderLayer
from wenet.transformer.embedding import PositionalEncoding
from wenet.transformer.embedding import RelPositionalEncoding
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.transformer.se_layer import SELayer
from wenet.utils.mask import (subsequent_mask, make_pad_mask)


Expand Down Expand Up @@ -61,6 +64,8 @@ def __init__(
use_output_layer: bool = True,
normalize_before: bool = True,
concat_after: bool = False,
use_se_module: bool = False,
se_module_channel: int = 0
):
assert check_argument_types()
super().__init__()
Expand All @@ -71,19 +76,28 @@ def __init__(
torch.nn.Embedding(vocab_size, attention_dim),
PositionalEncoding(attention_dim, positional_dropout_rate),
)
elif input_layer == "rel_embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(vocab_size, attention_dim),
RelPositionalEncoding(attention_dim, positional_dropout_rate),
)
else:
raise ValueError(f"only 'embed' is supported: {input_layer}")

self.normalize_before = normalize_before
self.input_layer = input_layer
self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
self.use_output_layer = use_output_layer
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
self.num_blocks = num_blocks
self.decoders = torch.nn.ModuleList([
DecoderLayer(
attention_dim,
MultiHeadedAttention(attention_heads, attention_dim,
self_attention_dropout_rate),
RelPositionMultiHeadedAttention(attention_heads, attention_dim,
self_attention_dropout_rate) \
if input_layer == "rel_embed" else \
MultiHeadedAttention(attention_heads, attention_dim,
self_attention_dropout_rate),
MultiHeadedAttention(attention_heads, attention_dim,
src_attention_dropout_rate),
PositionwiseFeedForward(attention_dim, linear_units,
Expand All @@ -93,6 +107,8 @@ def __init__(
concat_after,
) for _ in range(self.num_blocks)
])
self.use_se_module = use_se_module
self.se_class = SELayer(se_module_channel)

def forward(
self,
Expand Down Expand Up @@ -130,10 +146,40 @@ def forward(
device=tgt_mask.device).unsqueeze(0)
# tgt_mask: (B, L, L)
tgt_mask = tgt_mask & m
x, _ = self.embed(tgt)
for layer in self.decoders:
x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory,
memory_mask)
x, pos_emb = self.embed(tgt)
if self.use_se_module:
x_list = []
for layer in self.decoders:
if self.input_layer == "rel_embed":
x, tgt_mask, memory, memory_mask = layer(x,
tgt_mask,
memory,
memory_mask,
pos_emb)
else:
x, tgt_mask, memory, memory_mask = layer(x,
tgt_mask,
memory,
memory_mask,
torch.empty(0))
x_list.append(x)
x_list = torch.stack(x_list).transpose(0, 1)
x_se_output = self.se_class(x_list)
x = torch.sum(x_se_output, dim=1)
else:
for layer in self.decoders:
if self.input_layer == "rel_embed":
x, tgt_mask, memory, memory_mask = layer(x,
tgt_mask,
memory,
memory_mask,
pos_emb)
else:
x, tgt_mask, memory, memory_mask = layer(x,
tgt_mask,
memory,
memory_mask,
torch.empty(0))
if self.normalize_before:
x = self.after_norm(x)
if self.use_output_layer:
Expand Down Expand Up @@ -163,18 +209,27 @@ def forward_one_step(
y, cache: NN output value and cache per `self.decoders`.
y.shape` is (batch, maxlen_out, token)
"""
x, _ = self.embed(tgt)
x, pos_emb = self.embed(tgt)
new_cache = []
for i, decoder in enumerate(self.decoders):
if cache is None:
c = None
else:
c = cache[i]
x, tgt_mask, memory, memory_mask = decoder(x,
tgt_mask,
memory,
memory_mask,
cache=c)
if self.input_layer == "rel_embed":
x, tgt_mask, memory, memory_mask = decoder(x,
tgt_mask,
memory,
memory_mask,
pos_emb,
cache=c)
else:
x, tgt_mask, memory, memory_mask = decoder(x,
tgt_mask,
memory,
memory_mask,
torch.empty(0),
cache=c)
new_cache.append(x)
if self.normalize_before:
y = self.after_norm(x[:, -1])
Expand Down Expand Up @@ -222,6 +277,8 @@ def __init__(
use_output_layer: bool = True,
normalize_before: bool = True,
concat_after: bool = False,
use_se_module: bool = False,
se_module_channel: int = 0
):

assert check_argument_types()
Expand Down
5 changes: 3 additions & 2 deletions wenet/transformer/decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def forward(
tgt_mask: torch.Tensor,
memory: torch.Tensor,
memory_mask: torch.Tensor,
pos_emb: torch.Tensor,
cache: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute decoded features.
Expand Down Expand Up @@ -117,11 +118,11 @@ def forward(

if self.concat_after:
tgt_concat = torch.cat(
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]), dim=-1)
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask, pos_emb)[0]), dim=-1)
x = residual + self.concat_linear1(tgt_concat)
else:
x = residual + self.dropout(
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask, pos_emb)[0])
if not self.normalize_before:
x = self.norm1(x)

Expand Down
24 changes: 21 additions & 3 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from wenet.transformer.subsampling import Conv2dSubsampling6
from wenet.transformer.subsampling import Conv2dSubsampling8
from wenet.transformer.subsampling import LinearNoSubsampling
from wenet.transformer.se_layer import SELayer
from wenet.utils.common import get_activation
from wenet.utils.mask import make_pad_mask
from wenet.utils.mask import add_optional_chunk_mask
Expand All @@ -57,6 +58,8 @@ def __init__(
use_dynamic_chunk: bool = False,
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
use_se_module: bool = False,
se_module_channel: int = 0
):
"""
Args:
Expand Down Expand Up @@ -127,6 +130,8 @@ def __init__(
self.static_chunk_size = static_chunk_size
self.use_dynamic_chunk = use_dynamic_chunk
self.use_dynamic_left_chunk = use_dynamic_left_chunk
self.use_se_module = use_se_module
self.se_class = SELayer(se_module_channel)

def output_size(self) -> int:
return self._output_size
Expand Down Expand Up @@ -169,8 +174,18 @@ def forward(
decoding_chunk_size,
self.static_chunk_size,
num_decoding_left_chunks)
for layer in self.encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
if self.use_se_module:
xs_list = []
for layer in self.encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
xs_list.append(xs)
xs_list = torch.stack(xs_list).transpose(0, 1)
xs_se_output = self.se_class(xs_list)
xs = torch.sum(xs_se_output, dim=1)
else:
for layer in self.encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)

if self.normalize_before:
xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
Expand Down Expand Up @@ -397,6 +412,8 @@ def __init__(
cnn_module_kernel: int = 15,
causal: bool = False,
cnn_module_norm: str = "batch_norm",
use_se_module: bool = False,
se_module_channel: int = 0
):
"""Construct ConformerEncoder

Expand All @@ -420,7 +437,8 @@ def __init__(
positional_dropout_rate, attention_dropout_rate,
input_layer, pos_enc_layer_type, normalize_before,
concat_after, static_chunk_size, use_dynamic_chunk,
global_cmvn, use_dynamic_left_chunk)
global_cmvn, use_dynamic_left_chunk, use_se_module,
se_module_channel)
activation = get_activation(activation_type)

# self-attention module definition
Expand Down
37 changes: 37 additions & 0 deletions wenet/transformer/se_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) 2022 Mininglamp Com (Liuwei Wei, Xiaoming Ren)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)


"""Squeeze-and-Excitation layer definition."""

import torch


class SELayer(torch.nn.Module):
def __init__(self, channel: int, reduction: int = 1):
super().__init__()
self.avg_pool = torch.nn.AdaptiveAvgPool2d(1)
self.fc = torch.nn.Sequential(
torch.nn.Linear(channel, channel // reduction, bias=False),
torch.nn.ReLU(inplace=True),
torch.nn.Linear(channel // reduction, channel, bias=False),
torch.nn.Sigmoid()
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avg_pool over T and D dim should consider pad_mask ?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for your remind, we will update pad_mask to the code and retrain it .

y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)