Skip to content

Commit

Permalink
remove einops exts for better pytorch 2.0 compile compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 20, 2023
1 parent 580274b commit 0069857
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 23 deletions.
52 changes: 35 additions & 17 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
from torch import nn, einsum
import torchvision.transforms as T

from einops import rearrange, repeat, reduce
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom

from kornia.filters import gaussian_blur2d
import kornia.augmentation as K
Expand Down Expand Up @@ -669,6 +667,23 @@ def p2_reweigh_loss(self, loss, times):
return loss
return loss * extract(self.p2_loss_weight, times, loss.shape)

# rearrange image to sequence

class RearrangeToSequence(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn

def forward(self, x):
x = rearrange(x, 'b c ... -> b ... c')
x, ps = pack([x], 'b * c')

x = self.fn(x)

x, = unpack(x, ps, 'b * c')
x = rearrange(x, 'b ... c -> b c ...')
return x

# diffusion prior

class LayerNorm(nn.Module):
Expand Down Expand Up @@ -867,7 +882,7 @@ def forward(self, x, mask = None, attn_bias = None):

# add null key / value for classifier free guidance in prior net

nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b)
nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2))
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)

Expand Down Expand Up @@ -1629,14 +1644,10 @@ def __init__(
self.cross_attn = None

if exists(cond_dim):
self.cross_attn = EinopsToAndFrom(
'b c h w',
'b (h w) c',
CrossAttention(
dim = dim_out,
context_dim = cond_dim,
cosine_sim = cosine_sim_cross_attn
)
self.cross_attn = CrossAttention(
dim = dim_out,
context_dim = cond_dim,
cosine_sim = cosine_sim_cross_attn
)

self.block1 = Block(dim, dim_out, groups = groups, weight_standardization = weight_standardization)
Expand All @@ -1655,8 +1666,15 @@ def forward(self, x, time_emb = None, cond = None):

if exists(self.cross_attn):
assert exists(cond)

h = rearrange(h, 'b c ... -> b ... c')
h, ps = pack([h], 'b * c')

h = self.cross_attn(h, context = cond) + h

h, = unpack(h, ps, 'b * c')
h = rearrange(h, 'b ... c -> b c ...')

h = self.block2(h)
return h + self.res_conv(x)

Expand Down Expand Up @@ -1702,11 +1720,11 @@ def forward(self, x, context, mask = None):

q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))

q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))

# add null key / value for classifier free guidance in prior net

nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b)
nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))

k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
Expand Down Expand Up @@ -1759,7 +1777,7 @@ def forward(self, fmap):

fmap = self.norm(fmap)
q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)
q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v))

q = q.softmax(dim = -1)
k = k.softmax(dim = -2)
Expand Down Expand Up @@ -1993,7 +2011,7 @@ def __init__(

self_attn = cast_tuple(self_attn, num_stages)

create_self_attn = lambda dim: EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(dim, **attn_kwargs)))
create_self_attn = lambda dim: RearrangeToSequence(Residual(Attention(dim, **attn_kwargs)))

# resnet block klass

Expand Down Expand Up @@ -3230,7 +3248,7 @@ def forward(
learned_variance = self.learned_variance[unet_index]
b, c, h, w, device, = *image.shape, image.device

check_shape(image, 'b c h w', c = self.channels)
assert image.shape[1] == self.channels
assert h >= target_image_size and w >= target_image_size

times = torch.randint(0, noise_scheduler.num_timesteps, (b,), device = device, dtype = torch.long)
Expand Down
2 changes: 1 addition & 1 deletion dalle2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.12.4'
__version__ = '1.14.0'
5 changes: 2 additions & 3 deletions dalle2_pytorch/vqgan_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from torch.autograd import grad as torch_grad
import torchvision

from einops import rearrange, reduce, repeat
from einops_exts import rearrange_many
from einops import rearrange, reduce, repeat, pack, unpack
from einops.layers.torch import Rearrange

# constants
Expand Down Expand Up @@ -408,7 +407,7 @@ def forward(self, x):
x = self.norm(x)

q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
'clip-anytorch>=2.5.2',
'coca-pytorch>=0.0.5',
'ema-pytorch>=0.0.7',
'einops>=0.4',
'einops-exts>=0.0.3',
'einops>=0.6',
'embedding-reader',
'kornia>=0.5.4',
'numpy',
Expand Down

0 comments on commit 0069857

Please sign in to comment.