This project is built upon an existing transformer implementation from Aleksa Gordic. The README from the original repository follows below the line.
This repo contains PyTorch implementation of the original transformer paper (:link: Vaswani et al.).
It's aimed at making it easy to start playing and learning about transformers.
Important note: I'll be adding a jupyter notebook soon as well!
- What are transformers?
- Understanding transformers
- Machine translation
- Setup
- Usage
- Hardware requirements
Transformers were originally proposed by Vaswani et al. in a seminal paper called Attention Is All You Need.
You probably heard of transformers one way or another. GPT-3 and BERT to name a few well known ones 🦄. The main idea is that they showed that you don't have to use recurrent or convolutional layers and that simple architecture coupled with attention is super powerful. It gave the benefit of much better long-range dependencies modeling and the architecture itself is very parallelizable (:computer::computer::computer:) which leads to higher compute efficiency!
Here is how their beautifully simple architecture looks like:
This repo is supposed to be a learning resource for understanding transformers as the original transformer by itself is not a SOTA anymore.
For that purpose the code is (hopefully) well commented and I've included the playground.py
where I've visualized a couple
of concepts which are hard to explain using words but super simple once visualized. So here we go!
Can you parse this one in a glimpse of the eye?
Neither can I. Running the visualize_positional_encodings()
function from playground.py
we get this:
Depending on the position of your source/target token you "pick one row of this image" and you add it to it's embedding vector, that's it. They could also be learned, but it's just more fancy to do it like this, obviously! 🤓
Similarly can you parse this one in O(1)
?
Noup? So I thought, here it is visualized:
It's super easy to understand now. Now whether this part was crucial for the success of transformer? I doubt it.
But it's cool and makes things more complicated. 🤓 (.set_sarcasm(True)
)
Note: model dimension is basically the size of the embedding vector, baseline transformer used 512, the big one 1024
First time you hear of label smoothing it sounds tough but it's not. You usually set your target vocabulary distribution
to a one-hot
. Meaning 1 position out of 30k (or whatever your vocab size is) is set to 1. probability and everything else to 0.
In label smoothing instead of placing 1. on that particular position you place say 0.9 and you evenly distribute the rest of the "probability mass" over the other positions (that's visualized as a different shade of purple on the image above in a fictional vocab of size 4 - hence 4 columns)
Note: Pad token's distribution is set to all zeros as we don't want our model to predict those!
Aside from this repo (well duh) I would highly recommend you go ahead and read this amazing blog by Jay Alammar!
Transformer was originally trained for the NMT (neural machine translation) task on the WMT-14 dataset for:
- English to German translation task (achieved 28.4 BLEU score)
- English to French translation task (achieved 41.8 BLEU score)
What I did (for now) is I trained my models on the IWSLT dataset, which is much smaller, for the English-German language pair, as I speak those languages so it's easier to debug and play around.
I'll also train my models on WMT-14 soon, take a look at the todos section.
Anyways! Let's see what this repo can practically do for you! Well it can translate!
Some short translations from my German to English IWSLT model:
Input: Ich bin ein guter Mensch, denke ich.
("gold": I am a good person I think)
Output: ['<s>', 'I', 'think', 'I', "'m", 'a', 'good', 'person', '.', '</s>']
or in human-readable format: I think I'm a good person.
Which is actually pretty good! Maybe even better IMO than Google Translate's "gold" translation.
There are of course failure cases like this:
Input: Hey Alter, wie geht es dir?
(How is it going dude?)
Output: ['<s>', 'Hey', ',', 'age', 'how', 'are', 'you', '?', '</s>']
or in human-readable format: Hey, age, how are you?
Which is actually also not completely bad! Because:
- First of all the model was trained on IWSLT (TED like conversations)
- "Alter" is a colloquial expression for old buddy/dude/mate but it's literal meaning is indeed age.
Similarly for the English to German model.
So we talked about what transformers are, and what they can do for you (among other things).
Let's get this thing running! Follow the next steps:
git clone https://github.com/gordicaleksa/pytorch-original-transformer
- Open Anaconda console and navigate into project directory
cd path_to_repo
- Run
conda env create
from project directory (this will create a brand new conda environment). - Run
activate pytorch-transformer
(for running scripts from your console or set the interpreter in your IDE)
That's it! It should work out-of-the-box executing environment.yml file which deals with dependencies.
It may take a while as I'm automatically downloading SpaCy's statistical models for English and German.
PyTorch pip package will come bundled with some version of CUDA/cuDNN with it, but it is highly recommended that you install a system-wide CUDA beforehand, mostly because of the GPU drivers. I also recommend using Miniconda installer as a way to get conda on your system. Follow through points 1 and 2 of this setup and use the most up-to-date versions of Miniconda and CUDA/cuDNN for your system.
To run the training start the training_script.py
, there is a couple of settings you will want to specify:
--batch_size
- this is important to set to a maximum value that won't give you CUDA out of memory--dataset_name
- Pick betweenIWSLT
andWMT14
(WMT14 is not advisable until I add multi-GPU support)--language_direction
- Pick betweenE2G
andG2E
So an example run (from the console) would look like this:
python training_script.py --batch_size 1500 --dataset_name IWSLT --language_direction G2E
The code is well commented so you can (hopefully) understand how the training itself works.
The script will:
- Dump checkpoint *.pth models into
models/checkpoints/
- Dump the final *.pth model into
models/binaries/
- Download IWSLT/WMT-14 (the first time you run it and place it under
data/
) - Dump tensorboard data into
runs/
, just runtensorboard --logdir=runs
from your Anaconda - Periodically write some training metadata to the console
Note: data loading is slow in torch text, and so I've implemented a custom wrapper which adds the caching mechanisms and makes things ~30x faster! (it'll be slow the first time you run stuff)
The second part is all about playing with the models and seeing how they translate!
To get some translations start the translation_script.py
, there is a couple of settings you'll want to set:
--source_sentence
- depending on the model you specify this should either be English/German sentence--model_name
- one of the pretrained model names:iwslt_e2g
,iwslt_g2e
or your model(*)--dataset_name
- keep this in sync with the model,IWSLT
if the model was trained on IWSLT--language_direction
- keep in sync,E2G
if the model was trained to translate from English to German
(*) Note: after you train your model it'll get dumped into models/binaries
see what it's name is and specify it via
the --model_name
parameter if you want to play with it for translation purpose. If you specify some of the pretrained
models they'll automatically get downloaded the first time you run the translation script.
I'll link IWSLT pretrained model links here as well: English to German and German to English.
That's it you can also visualize the attention check out this section. for more info.
I tracked 3 curves while training:
- training loss (KL divergence, batchmean)
- validation loss (KL divergence, batchmean)
- BLEU-4
BLEU is an n-gram based metric for quantitatively evaluating the quality of machine translation models.
I used the BLEU-4 metric provided by the awesome nltk Python module.
Current results, models were trained for 20 epochs (DE stands for Deutch i.e. German in German 🤓):
Model | BLEU score | Dataset |
---|---|---|
Baseline transformer (EN-DE) | 27.8 | IWSLT val |
Baseline transformer (DE-EN) | 33.2 | IWSLT val |
Baseline transformer (EN-DE) | x | WMT-14 val |
Baseline transformer (DE-EN) | x | WMT-14 val |
I got these using greedy decoding so it's a pessimistic estimate, I'll add beam decoding soon.
Important note: Initialization matters a lot for the transformer! I initially thought that other implementations using Xavier initialization is again one of those arbitrary heuristics and that PyTorch default init will do - I was wrong:
You can see here 3 runs, the 2 lower ones used PyTorch default initialization (one used mean
for KL divergence
loss and the better one used batchmean
), whereas the upper one used Xavier uniform initialization!
Idea: you could potentially also periodically dump translations for a reference batch of source sentences.
That would give you some qualitative insight into how the transformer is doing, although I didn't do that.
A similar thing is done when you have hard time quantitatively evaluating your model like in GANs and NST fields.
The above plot is a snippet from my Azure ML run but when I run stuff locally I use Tensorboard.
Just run tensorboard --logdir=runs
from your Anaconda console and you can track your metrics during the training.
You can use the translation_script.py
and set the --visualize_attention
to True to additionally understand what your
model was "paying attention to" in the source and target sentences.
Here are the attentions I get for the input sentence Ich bin ein guter Mensch, denke ich.
These belong to layer 6 of the encoder. You can see all of the 8 multi-head attention heads.
And this one belongs to decoder layer 6 of the self-attention decoder MHA (multi-head attention) module.
You can notice an interesting triangular pattern which comes from the fact that target tokens can't look ahead!
The 3rd type of MHA module is the source attending one and it looks similar to the plot you saw for the encoder.
Feel free to play with it at your own pace!
Note: there are obviously some bias problems with this model but I won't get into that analysis here
You really need a decent hardware if you wish to train the transformer on the WMT-14 dataset.
The authors took:
- 12h on 8 P100 GPUs to train the baseline model and 3.5 days to train the big one.
If my calculations are right that amounts to ~19 epochs (100k steps, each step had ~25000 tokens and WMT-14 has ~130M src/trg tokens) for the baseline and 3x that for the big one (300k steps).
On the other hand it's much more feasible to train the model on the IWSLT dataset. It took me:
- 13.2 min/epoch (1500 token batch) on my RTX 2080 machine (8 GBs of VRAM)
- ~34 min/epoch (1500 token batch) on Azure ML's K80s (24 GBs of VRAM)
I could have pushed K80s to 3500+ tokens/batch but had some CUDA out of memory problems.
Finally there are a couple more todos which I'll hopefully add really soon:
- Multi-GPU/multi-node training support (so that you can train a model on WMT-14 for 19 epochs)
- Beam decoding (turns out it's not that easy to implement this one!)
- BPE and shared source-target vocab (I'm using SpaCy now)
The repo already has everything it needs, these are just the bonus points. I've tested everything from environment setup, to automatic model download, etc.
I also made a video covering how I approached learning transformers, you can check it out on my YouTube channel:
I found these resources useful (while developing this one):
I found some inspiration for the model design in the The Annotated Transformer but I found it hard to understand, and it had some bugs. It was mainly written with researchers in mind. Hopefully this repo opens up the understanding of transformers to the common folk as well! 🤓
If you find this code useful, please cite the following:
@misc{Gordić2020PyTorchOriginalTransformer,
author = {Gordić, Aleksa},
title = {pytorch-original-transformer},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/gordicaleksa/pytorch-original-transformer}},
}