Skip to content

Commit

Permalink
switch back to regular attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 16, 2022
1 parent 47571f2 commit def3a9e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-unet',
packages = find_packages(exclude=[]),
version = '0.0.22',
version = '0.1.0',
license='MIT',
description = 'X-Unet',
long_description_content_type = 'text/markdown',
Expand Down
10 changes: 3 additions & 7 deletions x_unet/x_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ def cast_tuple(val, length = None):

return output

def l2norm(t):
return F.normalize(t, dim = -1)

# helper classes

def Upsample(dim, dim_out):
Expand Down Expand Up @@ -189,7 +186,7 @@ def __init__(
scale = 8
):
super().__init__()
self.scale = scale
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = heads * dim_head
self.norm = LayerNorm(dim)
Expand All @@ -207,9 +204,8 @@ def forward(self, x):
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) ... -> b h (...) c', h = self.heads), (q, k, v))

q, k = map(l2norm, (q, k))

sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
attn = sim.softmax(dim = -1)

out = einsum('b h i j, b h j d -> b h i d', attn, v)
Expand Down

0 comments on commit def3a9e

Please sign in to comment.