From 7ef65a9744e626fe9d88e973dc5a959cbfe9d391 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 26 Jun 2023 13:26:53 -0700 Subject: [PATCH] add ability to use t5 relative positional bias, addressing https://github.com/lucidrains/soundstorm-pytorch/issues/15 --- setup.py | 2 +- soundstorm_pytorch/attend.py | 3 +- soundstorm_pytorch/soundstorm.py | 83 +++++++++++++++++++++++++++++--- 3 files changed, 78 insertions(+), 10 deletions(-) diff --git a/setup.py b/setup.py index b143937..2f3e79d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'soundstorm-pytorch', packages = find_packages(exclude=[]), - version = '0.0.17', + version = '0.0.18', license='MIT', description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch', author = 'Phil Wang', diff --git a/soundstorm_pytorch/attend.py b/soundstorm_pytorch/attend.py index ddfebb8..27cb93e 100644 --- a/soundstorm_pytorch/attend.py +++ b/soundstorm_pytorch/attend.py @@ -132,7 +132,8 @@ def forward(self, q, k, v, mask = None, attn_bias = None): kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d' if self.flash: - return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias) + assert not exists(attn_bias) + return self.flash_attn(q, k, v, mask = mask) # similarity diff --git a/soundstorm_pytorch/soundstorm.py b/soundstorm_pytorch/soundstorm.py index 9b174bd..d1ad559 100644 --- a/soundstorm_pytorch/soundstorm.py +++ b/soundstorm_pytorch/soundstorm.py @@ -126,6 +126,64 @@ def rotate_half(x): def apply_rotary_pos_emb(pos, t): return (t * pos.cos()) + (rotate_half(t) * pos.sin()) +# t5 relative positional bias + +class T5RelativePositionBias(nn.Module): + def __init__( + self, + scale = 1., + num_buckets = 32, + max_distance = 128, + heads = 8 + ): + super().__init__() + self.scale = scale + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket( + relative_position, + num_buckets = 32, + max_distance = 128 + ): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).long() + + val_if_large = torch.min( + val_if_large, + torch.full_like(val_if_large, num_buckets - 1) + ) + + ret += torch.where(is_small, n, val_if_large) + return ret + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, n): + pos = torch.arange(n, device = self.device).long() + rel_pos = rearrange(pos, 'j -> 1 j') - rearrange(pos, 'i -> i 1') + + rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) + values = self.relative_attention_bias(rp_bucket) + + bias = rearrange(values, 'i j h -> h i j') + return bias * self.scale + # conformer class Swish(nn.Module): @@ -213,7 +271,8 @@ def forward( x, context = None, mask = None, - rotary_emb = None + rotary_emb = None, + attn_bias = None ): n, device, h, has_context = x.shape[-2], x.device, self.heads, exists(context) context = default(context, x) @@ -225,7 +284,7 @@ def forward( q = apply_rotary_pos_emb(rotary_emb, q) k = apply_rotary_pos_emb(rotary_emb, k) - out = self.attend(q, k, v, mask = mask) + out = self.attend(q, k, v, mask = mask, attn_bias = attn_bias) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) @@ -313,10 +372,11 @@ def forward( self, x, mask = None, - rotary_emb = None + rotary_emb = None, + attn_bias = None ): x = self.ff1(x) + x - x = self.attn(x, mask = mask, rotary_emb = rotary_emb) + x + x = self.attn(x, mask = mask, rotary_emb = rotary_emb, attn_bias = attn_bias) + x x = self.conv(x) + x x = self.ff2(x) + x x = self.post_norm(x) @@ -339,13 +399,18 @@ def __init__( ff_dropout = 0., conv_dropout = 0., conv_causal = False, - attn_flash = True + attn_flash = True, + t5_rel_pos_bias = False ): super().__init__() + + assert not (t5_rel_pos_bias and attn_flash), 'flash attention is not compatible with learned bias' + self.dim = dim self.layers = nn.ModuleList([]) - self.rotary_emb = RotaryEmbedding(dim_head) + self.rotary_emb = RotaryEmbedding(dim_head) if not t5_rel_pos_bias else None + self.rel_pos_bias = T5RelativePositionBias(dim_head ** 0.5, heads = heads) if t5_rel_pos_bias else None for _ in range(depth): self.layers.append(ConformerBlock( @@ -361,11 +426,13 @@ def __init__( )) def forward(self, x): + seq_len = x.shape[-2] - rotary_emb = self.rotary_emb(x.shape[-2]) + rotary_emb = self.rotary_emb(seq_len) if exists(self.rotary_emb) else None + attn_bias = self.rel_pos_bias(seq_len) if exists(self.rel_pos_bias) else None for block in self.layers: - x = block(x, rotary_emb = rotary_emb) + x = block(x, rotary_emb = rotary_emb, attn_bias = attn_bias) return x