Skip to content

Commit

Permalink
initial mixer commit
Browse files Browse the repository at this point in the history
  • Loading branch information
akolesnikoff committed May 5, 2021
1 parent b476038 commit fca1235
Show file tree
Hide file tree
Showing 10 changed files with 233 additions and 36 deletions.
136 changes: 103 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,35 @@ recently been updated and the results have not yet been fully replicated. We
will update the table below soon with new results from the updated code and then
merge this branch into `master`.

# Vision Transformer and MLP-Mixer Architectures for Vision

In this repository we release models from the papers
[An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929)
and
[MLP-Mixer: An all-MLP Architecture for Vision](https://arxiv.org/abs/2105.01601)
that were pre-trained on the [ImageNet](http://www.image-net.org/) (`imagenet`)
and [ImageNet-21k](http://www.image-net.org/) (`imagenet21k`) datasets. We
provide the code for fine-tuning the released models in
[Jax](https://jax.readthedocs.io)/[Flax](http://flax.readthedocs.io).

# Vision Transformer
by Alexey Dosovitskiy\*†, Lucas Beyer\*, Alexander Kolesnikov\*, Dirk
Weissenborn\*, Xiaohua Zhai\*, Thomas Unterthiner, Mostafa Dehghani, Matthias
Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit and Neil Houlsby\*†.

(\*) equal technical contribution, (†) equal advising.
First we describe the [Vision Transformer (ViT)](#vision-transformer) models.
Feel free to [jump to the section describing the MLP-Mixer models](#mlp-mixer)
if that's what you came for.

Open source release prepared by Andreas Steiner.

Note: This repository was forked and modified from
[google-research/big_transfer](https://github.com/google-research/big_transfer).

## Introduction
## Vision Transformer

In this repository we release models from the paper [An Image is Worth 16x16
Words: Transformers for Image Recognition at
Scale](https://arxiv.org/abs/2010.11929) that were pre-trained on the
[ImageNet-21k](http://www.image-net.org/) (`imagenet21k`) dataset. We provide
the code for fine-tuning the released models in
[Jax](https://jax.readthedocs.io)/[Flax](http://flax.readthedocs.io).
by Alexey Dosovitskiy\*†, Lucas Beyer\*, Alexander Kolesnikov\*, Dirk
Weissenborn\*, Xiaohua Zhai\*, Thomas Unterthiner, Mostafa Dehghani, Matthias
Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit and Neil Houlsby\*†.

(\*) equal technical contribution, (†) equal advising.

![Figure 1 from paper](figure1.png)
![Figure 1 from paper](vit_figure.png)

Overview of the model: we split an image into fixed-size patches, linearly embed
each of them, add position embeddings, and feed the resulting sequence of
Expand All @@ -35,9 +41,9 @@ to the sequence.

## Colab

Check out the Colab for loading the data, fine-tuning the model, evaluation,
and inference. The Colab loads the code from this repository and runs by
default on a TPU with 8 cores.
Check out the Colab for loading the data, fine-tuning the ViT model, its
evaluation, and inference. The Colab loads the code from this repository and
runs by default on a TPU with 8 cores.

https://colab.research.google.com/github/google-research/vision_transformer/blob/master/vit_jax.ipynb

Expand All @@ -58,7 +64,7 @@ Then, install python dependencies by running:
pip install -r vit_jax/requirements.txt
```

## Available models
## Available ViT models

We provide models pre-trained on imagenet21k for the following architectures:
ViT-B/16, ViT-B/32, ViT-L/16 and ViT-L/32. We provide the same models
Expand Down Expand Up @@ -98,7 +104,7 @@ You can run fine-tuning of the downloaded model on your dataset of interest. All
frameworks share the command line interface

```
python3 -m vit_jax.train --name ViT-B_16-cifar10_`date +%F_%H%M%S` --model ViT-B_16 --logdir /tmp/vit_logs --dataset cifar10
python -m vit_jax.main --workdir=/tmp/vit --config=$(pwd)/vit_jax/configs/vit.py:b16,cifar10 --config.pretrained_dir="gs://vit_models/imagenet21k/"
```

Currently, the code will automatically download CIFAR-10 and CIFAR-100 datasets.
Expand All @@ -114,23 +120,23 @@ To see a detailed list of all available flags, run `python3 -m vit_jax.train

Notes about some flags:

- `--accum_steps=16` : This works well with ViT-B_16 on a machine that has 8
GPUs of type V100 with 16G memory each attached. If you have fewer
- `--config.accum_steps=16` : This works well with ViT-B_16 on a machine that
has 8 GPUs of type V100 with 16G memory each attached. If you have fewer
accelerators or accelerators with less memory, you can use the same
configuration but increase the `--accum_steps`. For a small model like
ViT-B_32 you can even use `--accum_steps=1`. For a large model like ViT-L_16
you need to go in the other direction (e.g. `--accum_steps=32`). Note that
the largest model ViT-H_14 also needs adaptation of the batch size
(`--accum_steps=2 --batch=16` should work on a 8x V100).
tested `)
- `--batch=512` : Alternatively, you can decrease the batch size, but
configuration but increase the `--config.accum_steps`. For a small model
like ViT-B_32 you can even use `--config.accum_steps=1`. For a large model
like ViT-L_16 you need to go in the other direction (e.g.
`--config.accum_steps=32`). Note that the largest model ViT-H_14 also needs
adaptation of the batch size (`--config.accum_steps=2 --config.batch=16`
should work on a 8x V100). tested `)
- `--config.batch=512` : Alternatively, you can decrease the batch size, but
that usually involves some tuning of the learning rate parameters.

## Expected results

In this table we closely follow experiments from the paper and report results
that were achieved by running this code on Google Cloud machine with eight V100
GPUs.
In this table we closely follow experiments from the ViT paper and report
results that were achieved by running the code on Google Cloud machine with
eight V100 GPUs.

Note: Runs in table below before 2020-11-03 ([6fba202]) have
`config.transformer.dropout_rate=0.0`.
Expand Down Expand Up @@ -182,14 +188,78 @@ Some examples for CIFAR-10/100 datasets are presented in the table below.
| imagenet21k | ViT-B_16 | cifar100 | 500 / 50 | 0.8917 | 17m | [tensorboard.dev](https://tensorboard.dev/experiment/5hM4GrnAR0KEZg725Ewnqg/) |
| imagenet21k | ViT-B_16 | cifar100 | 1000 / 100 | 0.9115 | 39m | [tensorboard.dev](https://tensorboard.dev/experiment/QLQTaaIoT9uEcAjtA0eRwg/) |

## MLP-Mixer

by Ilya Tolstikhin\*, Neil Houlsby\*, Alexander Kolesnikov\*, Lucas Beyer\*,
Xiaohua Zhai, Thomas Unterthiner, Jessica Yung, Daniel Keysers, Jakob Uszkoreit,
Mario Lucic, Alexey Dosovitskiy.

(\*) equal contribution.

![Figure 1 from paper](mixer_figure.png)

MLP-Mixer (*Mixer* for short) consists of per-patch linear embeddings, Mixer
layers, and a classifier head. Mixer layers contain one token-mixing MLP and one
channel-mixing MLP, each consisting of two fully-connected layers and a GELU
nonlinearity. Other components include: skip-connections, dropout, and linear
classifier head.

For installation follow [the same steps](#installation) as above.

## Available Mixer models

We provide the Mixer-B/16 and Mixer-L/16 models pre-trained on the ImageNet and
ImageNet-21k datasets. Details can be found in Table 3 of the Mixer paper. All
the models can be found at:

https://console.cloud.google.com/storage/mixer_models/

## Colab

**Note**: We will soon extend the colab with Mixer examples.

## Fine-tuning Mixer models

The following command will load the Mixer-B/16 model pre-trained on ImageNet-21k
and fine-tune it on CIFAR-10 at resolution 224:

```
python -m vit_jax.main --workdir=/tmp/mixer --config=$(pwd)/vit_jax/configs/mixer_base16_cifar10.py --config.pretrained_dir="gs://mixer_models/imagenet21k/"
```

Specify `gs://mixer_models/imagenet1k/` to fine-tune the models pre-trained on
ImageNet. Change the `config.model` in the `mixer_base16_cifar10.py` config file
to use the Mixer-L/16 model. More details (including how to fine-tune on other
datasets) can be found in the
[section describing fine-tuning for ViT](#how-to-fine-tune-vit).

## Reproducing Mixer results on CIFAR-10

We ran the fine-tuning code on Google Cloud machine with four V100 GPUs with the
default adaption parameters from this repository. Here are the results:

upstream | model | dataset | accuracy | wall_clock_time | link
:----------- | :--------- | :------ | -------: | :-------------- | :---
ImageNet | Mixer-B/16 | cifar10 | 96.72 | 3.0h | [tensorboard.dev](https://tensorboard.dev/experiment/j9zCYt9yQVm93nqnsDZayA/)
ImageNet | Mixer-L/16 | cifar10 | 96.59 | 3.0h | [tensorboard.dev](https://tensorboard.dev/experiment/Q4feeErzRGGop5XzAvYj2g/)
ImageNet-21k | Mixer-B/16 | cifar10 | 96.82 | 9.6h | [tensorboard.dev](https://tensorboard.dev/experiment/mvP4McV2SEGFeIww20ie5Q/)
ImageNet-21k | Mixer-L/16 | cifar10 | 98.34 | 10.0h | [tensorboard.dev](https://tensorboard.dev/experiment/dolAJyQYTYmudytjalF6Jg/)

## Bibtex

```
@article{dosovitskiy2020,
title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and Uszkoreit, Jakob and Houlsby, Neil},
journal={arXiv preprint arXiv:2010.11929},
year={2020}
journal={ICLR},
year={2021}
}
@article{tolstikhin2021,
title={MLP-Mixer: An all-MLP Architecture for Vision},
author={Tolstikhin, Ilya and Houlsby, Neil and Kolesnikov, Alexander and Beyer, Lucas and Zhai, Xiaohua and Unterthiner, Thomas and Yung, Jessica and Keysers, Daniel and Uszkoreit, Jakob and Lucic, Mario and Dosovitskiy, Alexey},
journal={arXiv preprint arXiv:2105.01601},
year={2021}
}
```

Expand Down
Binary file added mixer_figure.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
File renamed without changes
29 changes: 29 additions & 0 deletions vit_jax/configs/mixer_base16_cifar10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import ml_collections
from vit_jax.configs import common
from vit_jax.configs import models


def get_config():
"""Returns config for training Mixer-B/16 on cifar10."""
config = common.get_config()
config.model_type = 'Mixer'
config.model = models.get_mixer_b16_config()
config.dataset = 'cifar10'
config.total_steps = 10_000
config.pp = ml_collections.ConfigDict(
{'train': 'train[:98%]', 'test': 'test', 'resize': 256, 'crop': 224})
return config
26 changes: 26 additions & 0 deletions vit_jax/configs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,29 @@ def get_h14_config():
config.classifier = 'token'
config.representation_size = None
return config


@_register
def get_mixer_b16_config():
"""Returns Mixer-B/16 configuration."""
config = ml_collections.ConfigDict()
config.name = 'Mixer-B_16'
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
config.hidden_dim = 768
config.num_blocks = 12
config.tokens_mlp_dim = 384
config.channels_mlp_dim = 3072
return config


@_register
def get_mixer_l16_config():
"""Returns Mixer-L/16 configuration."""
config = ml_collections.ConfigDict()
config.name = 'Mixer-L_16'
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
config.hidden_dim = 1024
config.num_blocks = 24
config.tokens_mlp_dim = 512
config.channels_mlp_dim = 4096
return config
2 changes: 2 additions & 0 deletions vit_jax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import flax.linen as nn
import jax.numpy as jnp

from vit_jax import models_mixer
from vit_jax import models_resnet

Array = Any
Expand Down Expand Up @@ -298,3 +299,4 @@ def __call__(self, inputs, *, train):
return x


MlpMixer = models_mixer.MlpMixer
68 changes: 68 additions & 0 deletions vit_jax/models_mixer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2021 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

import einops
import flax.linen as nn
import jax.numpy as jnp


class MlpBlock(nn.Module):
mlp_dim: int

@nn.compact
def __call__(self, x):
y = nn.Dense(self.mlp_dim)(x)
y = nn.gelu(y)
return nn.Dense(x.shape[-1])(y)


class MixerBlock(nn.Module):
"""Mixer block layer."""
tokens_mlp_dim: int
channels_mlp_dim: int

@nn.compact
def __call__(self, x):
y = nn.LayerNorm()(x)
y = jnp.swapaxes(y, 1, 2)
y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y)
y = jnp.swapaxes(y, 1, 2)
x = x + y
y = nn.LayerNorm()(x)
return x + MlpBlock(self.channels_mlp_dim, name='channel_mixing')(y)


class MlpMixer(nn.Module):
"""Mixer architecture."""
patches: Any
num_classes: int
num_blocks: int
hidden_dim: int
tokens_mlp_dim: int
channels_mlp_dim: int

@nn.compact
def __call__(self, inputs, *, train):
del train
x = nn.Conv(self.hidden_dim, self.patches.size,
strides=self.patches.size, name='stem')(inputs)
x = einops.rearrange(x, 'n h w c -> n (h w) c')
for _ in range(self.num_blocks):
x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x)
x = nn.LayerNorm(name='pre_head_layer_norm')(x)
x = jnp.mean(x, axis=1)
return nn.Dense(self.num_classes, kernel_init=nn.initializers.zeros,
name='head')(x)
4 changes: 3 additions & 1 deletion vit_jax/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
'ViT-L_32': 306_535_400,
'R50+ViT-L_32': 328_994_856,
'ViT-H_14': 632_045_800,
'Mixer-B_16': 59_880_472,
'Mixer-L_16': 208_196_168,
}


Expand All @@ -40,7 +42,7 @@ class ModelsTest(parameterized.TestCase):
def test_can_instantiate(self, name, size):
rng = jax.random.PRNGKey(0)
config = config_lib.MODEL_CONFIGS[name]
model_cls = models.VisionTransformer
model_cls = models.VisionTransformer if 'ViT' in name else models.MlpMixer
model = model_cls(num_classes=1_000, **config)
inputs = jnp.ones([2, 224, 224, 3], jnp.float32)
variables = model.init(rng, inputs, train=False)
Expand Down
3 changes: 2 additions & 1 deletion vit_jax/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
logging.info(ds_test)

# Build VisionTransformer architecture
model_cls = models.VisionTransformer
model_cls = {'ViT': models.VisionTransformer,
'Mixer': models.MlpMixer}[config.get('model_type', 'ViT')]
model = model_cls(num_classes=dataset_info['num_classes'], **config.model)

def init_model():
Expand Down
1 change: 0 additions & 1 deletion vit_jax/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def test_train_and_evaluate(self):

test_utils.create_checkpoint(config.model, f'{workdir}/testing.npz')
opt_pmap = train.train_and_evaluate(config, workdir)
opt_pmap = train.train_and_evaluate(config, workdir)
self.assertTrue(os.path.exists(f'{workdir}/model.npz'))


Expand Down

0 comments on commit fca1235

Please sign in to comment.