Skip to content

Commit

Permalink
complete funnel transformer like naive down and upsampling
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 9, 2021
1 parent c8d1ede commit 87fb4a1
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
44 changes: 41 additions & 3 deletions hourglass_transformer_pytorch/hourglass_transformer_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
from einops import rearrange, reduce, repeat

# helpers

Expand All @@ -11,6 +11,15 @@ def exists(val):
def default(val, d):
return val if exists(val) else d

def pad_to_multiple(tensor, multiple, dim = -1):
seq_len = tensor.shape[dim]
m = seq_len / multiple
if m.is_integer():
return tensor
remainder = math.ceil(m) * multiple - seq_len
pad_offset = (0,) * (-1 - dim) * 2
return F.pad(tensor, (*pad_offset, 0, remainder), value = 0)

def cast_tuple(val, depth = 1):
return val if isinstance(val, tuple) else ((val,) * depth)

Expand Down Expand Up @@ -133,18 +142,20 @@ def __init__(
dim,
*,
depth,
shorten_factor,
shorten_factor = 2,
heads = 8,
dim_head = 64,
causal = False,
norm_out = False
):
super().__init__()
assert len(depth) == 3, 'depth should be a tuple of length 3'
pre_layers_depth, valley_config, post_layers_depth = depth
pre_layers_depth, valley_depth, post_layers_depth = depth

if isinstance(shorten_factor, tuple):
shorten_factor, *rest_shorten_factor = shorten_factor
elif isinstance(valley_depth, int):
shorten_factor, rest_shorten_factor = shorten_factor, None
else:
shorten_factor, rest_shorten_factor = shorten_factor, shorten_factor

Expand All @@ -155,12 +166,39 @@ def __init__(
causal = causal
)

self.causal = causal
self.shorten_factor = shorten_factor

self.valley_transformer = get_hourglass_transformer(
shorten_factor = rest_shorten_factor,
depth = valley_depth,
**transformer_kwargs
)

self.pre_transformer = Transformer(depth = pre_layers_depth, **transformer_kwargs)
self.post_transformer = Transformer(depth = post_layers_depth, **transformer_kwargs)
self.norm_out = nn.LayerNorm(dim) if norm_out else nn.Identity()

def forward(self, x):
shorten_factor, n = self.shorten_factor, x.shape[-2]
x = self.pre_transformer(x)

x_residual = x
x = pad_to_multiple(x, shorten_factor, dim = -2)

if self.causal:
shift = shorten_factor - 1
x = F.pad(x, (0, 0, shift, -shift), value = 0.)

x = reduce(x, 'b (n r) d -> b n d', 'mean', r = shorten_factor)

x = self.valley_transformer(x)

x = repeat(x, 'b n d -> b (n r) d', r = shorten_factor)

x = x[:, :n]
x = x + x_residual

x = self.post_transformer(x)
return self.norm_out(x)

Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def decode_tokens(tokens):
num_tokens = 256,
dim = 512,
max_seq_len = SEQ_LEN,
depth = 8,
depth = (4, 2, 4),
shorten_factor = 2,
heads = 8
)

Expand Down

0 comments on commit 87fb4a1

Please sign in to comment.