diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml new file mode 100644 index 0000000..3bfabfc --- /dev/null +++ b/.github/workflows/python-publish.yml @@ -0,0 +1,36 @@ +# This workflow will upload a Python Package using Twine when a release is created +# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries + +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +name: Upload Python Package + +on: + release: + types: [published] + +jobs: + deploy: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.x' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install build + - name: Build package + run: python -m build + - name: Publish package + uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/data/README.md b/data/README.md new file mode 100644 index 0000000..ae48167 --- /dev/null +++ b/data/README.md @@ -0,0 +1,3 @@ +# Data source + +The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/ \ No newline at end of file diff --git a/data/enwik8.gz b/data/enwik8.gz new file mode 100644 index 0000000..7a8ec66 Binary files /dev/null and b/data/enwik8.gz differ diff --git a/hourglass_transformer_pytorch/__init__.py b/hourglass_transformer_pytorch/__init__.py index e69de29..ff15ae0 100644 --- a/hourglass_transformer_pytorch/__init__.py +++ b/hourglass_transformer_pytorch/__init__.py @@ -0,0 +1 @@ +from hourglass_transformer_pytorch.hourglass_transformer_pytorch import HourglassTransformer diff --git a/hourglass_transformer_pytorch/autoregressive_wrapper.py b/hourglass_transformer_pytorch/autoregressive_wrapper.py new file mode 100644 index 0000000..07ed788 --- /dev/null +++ b/hourglass_transformer_pytorch/autoregressive_wrapper.py @@ -0,0 +1,70 @@ +import torch +from torch import nn +import torch.nn.functional as F + +# helper function + +def exists(val): + return val is not None + +def eval_decorator(fn): + def inner(model, *args, **kwargs): + was_training = model.training + model.eval() + out = fn(model, *args, **kwargs) + model.train(was_training) + return out + return inner + +# top k filtering + +def top_k(logits, thres = 0.9): + k = int((1 - thres) * logits.shape[-1]) + val, ind = torch.topk(logits, k) + probs = torch.full_like(logits, float('-inf')) + probs.scatter_(1, ind, val) + return probs + +class AutoregressiveWrapper(nn.Module): + def __init__(self, net, pad_value = 0): + super().__init__() + self.pad_value = pad_value + self.net = net + self.max_seq_len = net.max_seq_len + + @torch.no_grad() + @eval_decorator + def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_thres = 0.9, **kwargs): + b, t, device = *start_tokens.shape, start_tokens.device + + out = start_tokens + + for _ in range(seq_len): + x = out[:, -self.max_seq_len:] + + logits = self.net(x, **kwargs)[:, -1, :] + + filtered_logits = top_k(logits, thres = filter_thres) + probs = F.softmax(filtered_logits / temperature, dim=-1) + + sample = torch.multinomial(probs, 1) + + out = torch.cat((out, sample), dim=-1) + + if exists(eos_token): + is_eos_token = (out == eos_token) + + if is_eos_token.any(dim = -1).all(): + # mask out everything after the eos tokens + shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) + mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1 + out = out.masked_fill(mask, self.pad_value) + break + + out = out[:, t:] + return out + + def forward(self, x, **kwargs): + x_inp, x_labels = x[:, :-1], x[:, 1:] + logits = self.net(x_inp, **kwargs) + return F.cross_entropy(logits.transpose(1, 2), x_labels, ignore_index = self.pad_value) diff --git a/hourglass_transformer_pytorch/hourglass_transformer_pytorch.py b/hourglass_transformer_pytorch/hourglass_transformer_pytorch.py index e69de29..4f6e69f 100644 --- a/hourglass_transformer_pytorch/hourglass_transformer_pytorch.py +++ b/hourglass_transformer_pytorch/hourglass_transformer_pytorch.py @@ -0,0 +1,47 @@ +import torch +from torch import nn, einsum +import torch.nn.functional as F +from einops import rearrange + +# helpers + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +# classes + +# main class + +class HourglassTransformer(nn.Module): + def __init__( + self, + *, + num_tokens, + dim, + max_seq_len, + depth, + heads = 8, + dim_head = 64, + causal = True + ): + super().__init__() + self.max_seq_len = max_seq_len + + self.token_emb = nn.Embedding(num_tokens, dim) + self.pos_emb = nn.Embedding(max_seq_len, dim) + + self.to_logits = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, num_tokens) + ) + + def forward(self, x): + device = x.device + x = self.token_emb(x) + pos_emb = self.pos_emb(torch.arange(x.shape[-2], device = device)) + x = x + rearrange(pos_emb, 'n d -> () n d') + + return self.to_logits(x) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..85f8ed8 --- /dev/null +++ b/setup.py @@ -0,0 +1,28 @@ +from setuptools import setup, find_packages + +setup( + name = 'hourglass-transformer-pytorch', + packages = find_packages(), + version = '0.0.1', + license='MIT', + description = 'Hourglass Transformer', + author = 'Phil Wang', + author_email = 'lucidrains@gmail.com', + url = 'https://github.com/lucidrains/hourglass-transformer-pytorch', + keywords = [ + 'artificial intelligence', + 'attention mechanism', + 'transformers' + ], + install_requires=[ + 'einops', + 'torch>=1.6' + ], + classifiers=[ + 'Development Status :: 4 - Beta', + 'Intended Audience :: Developers', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'License :: OSI Approved :: MIT License', + 'Programming Language :: Python :: 3.6', + ], +) diff --git a/train.py b/train.py new file mode 100644 index 0000000..790d2d7 --- /dev/null +++ b/train.py @@ -0,0 +1,109 @@ +from hourglass_transformer_pytorch import HourglassTransformer +from hourglass_transformer_pytorch.autoregressive_wrapper import AutoregressiveWrapper + +import random +import tqdm +import gzip +import numpy as np +import torch +import torch.optim as optim +from torch.nn import functional as F +from torch.utils.data import DataLoader, Dataset + +# constants + +NUM_BATCHES = int(1e5) +BATCH_SIZE = 4 +GRADIENT_ACCUMULATE_EVERY = 4 +LEARNING_RATE = 2e-4 +VALIDATE_EVERY = 100 +GENERATE_EVERY = 500 +GENERATE_LENGTH = 512 +SEQ_LEN = 512 + +# helpers + +def cycle(loader): + while True: + for data in loader: + yield data + +def decode_token(token): + return str(chr(max(32, token))) + +def decode_tokens(tokens): + return ''.join(list(map(decode_token, tokens))) + +# instantiate GPT-like decoder model + +model = HourglassTransformer( + num_tokens = 256, + dim = 512, + max_seq_len = SEQ_LEN, + depth = 8, + heads = 8, + causal = True +) + +model = AutoregressiveWrapper(model) +model.cuda() + +# prepare enwik8 data + +with gzip.open('./data/enwik8.gz') as file: + X = np.fromstring(file.read(int(95e6)), dtype=np.uint8) + trX, vaX = np.split(X, [int(90e6)]) + data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX) + +class TextSamplerDataset(Dataset): + def __init__(self, data, seq_len): + super().__init__() + self.data = data + self.seq_len = seq_len + + def __getitem__(self, index): + rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,)) + full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long() + return full_seq.cuda() + + def __len__(self): + return self.data.size(0) // self.seq_len + +train_dataset = TextSamplerDataset(data_train, SEQ_LEN) +val_dataset = TextSamplerDataset(data_val, SEQ_LEN) +train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE)) +val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE)) + +# optimizer + +optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) + +# training + +for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'): + model.train() + + for __ in range(GRADIENT_ACCUMULATE_EVERY): + loss = model(next(train_loader)) + loss.backward() + + print(f'training loss: {loss.item()}') + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) + optim.step() + optim.zero_grad() + + if i % VALIDATE_EVERY == 0: + model.eval() + with torch.no_grad(): + loss = model(next(val_loader)) + print(f'validation loss: {loss.item()}') + + if i % GENERATE_EVERY == 0: + model.eval() + inp = random.choice(val_dataset)[:-1] + prime = decode_tokens(inp) + print(f'%s \n\n %s', (prime, '*' * 100)) + + sample = model.generate(inp[None, ...], GENERATE_LENGTH) + output_str = decode_tokens(sample[0]) + print(output_str)