Skip to content

Commit

Permalink
use t5 relative positional bias in prior network causal transformer, …
Browse files Browse the repository at this point in the history
…since it makes more sense than rotary embeddings
  • Loading branch information
lucidrains committed Apr 14, 2022
1 parent 9f55c24 commit 6e27f61
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 4 deletions.
60 changes: 57 additions & 3 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,44 @@ def __init__(
def forward(self, x):
return self.net(x.float())

# relative positional bias for causal transformer

class RelPosBias(nn.Module):
def __init__(
self,
heads = 8,
num_buckets = 32,
max_distance = 128,
):
super().__init__()
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)

@staticmethod
def _relative_position_bucket(
relative_position,
num_buckets = 32,
max_distance = 128
):
n = -relative_position
n = torch.max(n, torch.zeros_like(n))

max_exact = num_buckets // 2
is_small = n < max_exact

val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
return torch.where(is_small, n, val_if_large)

def forward(self, i, j, *, device):
q_pos = torch.arange(i, dtype = torch.long, device = device)
k_pos = torch.arange(j, dtype = torch.long, device = device)
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
values = self.relative_attention_bias(rp_bucket)
return rearrange(values, 'i j h -> h i j')

# feedforward

class SwiGLU(nn.Module):
Expand Down Expand Up @@ -208,7 +246,7 @@ def __init__(
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)

def forward(self, x, mask = None):
def forward(self, x, mask = None, attn_bias = None):
b, n, device = *x.shape[:2], x.device

x = self.norm(x)
Expand All @@ -225,6 +263,14 @@ def forward(self, x, mask = None):
q = q * self.scale

sim = einsum('b h i d, b j d -> b h i j', q, k)

# relative positional encoding (T5 style)

if exists(attn_bias):
sim = sim + attn_bias

# masking

max_neg_value = -torch.finfo(sim.dtype).max

if exists(mask):
Expand All @@ -237,10 +283,14 @@ def forward(self, x, mask = None):
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, max_neg_value)

# attention

sim = sim - sim.amax(dim = -1, keepdim = True)
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)

# aggregate values

out = einsum('b h i j, b j d -> b h i d', attn, v)

out = rearrange(out, 'b h n d -> b n (h d)')
Expand All @@ -260,7 +310,7 @@ def __init__(
ff_dropout = 0.
):
super().__init__()
# todo - bring in rotary embeddings or alibi
self.rel_pos_bias = RelPosBias(heads = heads)

self.layers = nn.ModuleList([])
for _ in range(depth):
Expand All @@ -276,8 +326,12 @@ def forward(
x,
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
):
n, device = x.shape[1], x.device

attn_bias = self.rel_pos_bias(n, n + 1, device = device)

for attn, ff in self.layers:
x = attn(x, mask = mask) + x
x = attn(x, mask = mask, attn_bias = attn_bias) + x
x = ff(x) + x

return self.norm(x)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.14',
version = '0.0.15',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
Expand Down

0 comments on commit 6e27f61

Please sign in to comment.