Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…ciency into dev
  • Loading branch information
fsschneider committed Jan 12, 2024
2 parents 8453784 + 91a6169 commit 56ecd9c
Show file tree
Hide file tree
Showing 32 changed files with 1,516 additions and 231 deletions.
24 changes: 12 additions & 12 deletions GETTING_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,29 +156,29 @@ To use the Docker container as an interactive virtual environment, you can run a

### Using Singularity/Apptainer instead of Docker

Since many compute clusters don't allow the usage of Docker due to securtiy concerns and instead encourage the use of [Singularity/Apptainer](https://github.com/apptainer/apptainer) (formerly Singularity, now called Apptainer), we also provide instructions on how to build an Apptainer container based on the here provided Dockerfile.
To convert the Dockerfile into an Apptainer definition file, we will use [spython](https://github.com/singularityhub/singularity-cli):
Since many compute clusters don't allow the usage of Docker due to securtiy concerns and instead encourage the use of [Singularity/Apptainer](https://github.com/apptainer/apptainer) (formerly Singularity, now called Apptainer), we also provide an Apptainer recipe (located at `docker/Singularity.def`) that can be used to build an image by running
```bash
pip3 install spython
cd algorithmic-efficiency/docker
spython recipe Dockerfile &> Singularity.def
singularity build --fakeroot <singularity_image_name>.sif Singularity.def
```
Now we can build the Apptainer image by running
Note that this can take several minutes. Then, to start a shell session with GPU support (by using the `--nv` flag), we can run
```bash
singularity build --fakeroot <singularity_image_name>.sif Singularity.def
singularity shell --bind $HOME/data:/data,$HOME/experiment_runs:/experiment_runs \
--nv <singularity_image_name>.sif
```
To start a shell session with GPU support (by using the `--nv` flag), we can run
Note the `--bind` flag which, similarly to Docker, allows to bind specific paths on the host system and the container, as explained [here](https://docs.sylabs.io/guides/3.7/user-guide/bind_paths_and_mounts.html).
Also note that we generated `Singularity.def` automatically from the `Dockerfile` using [spython](https://github.com/singularityhub/singularity-cli), as follows:
```bash
singularity shell --nv <singularity_image_name>.sif
pip3 install spython
cd algorithmic-efficiency/docker
python scripts/singularity_converter.py -i Dockerfile -o Singularity.def
```
Similarly to Docker, Apptainer allows you to bind specific paths on the host system and the container by specifying the `--bind` flag, as explained [here](https://docs.sylabs.io/guides/3.7/user-guide/bind_paths_and_mounts.html).
Users that wish to customize their images are invited to check and modify the `Singularity.def` recipe and the `singularity_converter.py` script.
## Download the Data
Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Algorithmic Efficiency."""

__version__ = '0.0.1'
__version__ = '0.1.0'
106 changes: 88 additions & 18 deletions algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def posemb_sincos_2d(h: int,
class MlpBlock(nn.Module):
"""Transformer MLP / feed-forward block."""
mlp_dim: Optional[int] = None # Defaults to 4x input dim.
use_glu: bool = False
dropout_rate: float = 0.0

@nn.compact
Expand All @@ -47,6 +48,11 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
d = x.shape[2]
x = nn.Dense(self.mlp_dim or 4 * d, **inits)(x)
x = nn.gelu(x)

if self.use_glu:
y = nn.Dense(self.mlp_dim, **inits)(x)
x = x * y

x = nn.Dropout(rate=self.dropout_rate)(x, train)
x = nn.Dense(d, **inits)(x)
return x
Expand All @@ -56,26 +62,53 @@ class Encoder1DBlock(nn.Module):
"""Single transformer encoder block (MHSA + MLP)."""
mlp_dim: Optional[int] = None # Defaults to 4x input dim.
num_heads: int = 12
use_glu: bool = False
use_post_layer_norm: bool = False
dropout_rate: float = 0.0

@nn.compact
def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
y = nn.LayerNorm(name='LayerNorm_0')(x)
y = nn.SelfAttention(
num_heads=self.num_heads,
kernel_init=nn.initializers.xavier_uniform(),
deterministic=train,
name='MultiHeadDotProductAttention_1')(
y)
y = nn.Dropout(rate=self.dropout_rate)(y, train)
x = x + y

y = nn.LayerNorm(name='LayerNorm_2')(x)
y = MlpBlock(
mlp_dim=self.mlp_dim, dropout_rate=self.dropout_rate,
name='MlpBlock_3')(y, train)
y = nn.Dropout(rate=self.dropout_rate)(y, train)
x = x + y
if not self.use_post_layer_norm:
y = nn.LayerNorm(name='LayerNorm_0')(x)
y = nn.SelfAttention(
num_heads=self.num_heads,
kernel_init=nn.initializers.xavier_uniform(),
deterministic=train,
name='MultiHeadDotProductAttention_1')(
y)
y = nn.Dropout(rate=self.dropout_rate)(y, train)
x = x + y

y = nn.LayerNorm(name='LayerNorm_2')(x)
y = MlpBlock(
mlp_dim=self.mlp_dim,
use_glu=self.use_glu,
dropout_rate=self.dropout_rate,
name='MlpBlock_3')(y, train)
y = nn.Dropout(rate=self.dropout_rate)(y, train)
x = x + y
else:
y = x
y = nn.SelfAttention(
num_heads=self.num_heads,
kernel_init=nn.initializers.xavier_uniform(),
deterministic=train,
name='MultiHeadDotProductAttention_1')(
y)
y = nn.Dropout(rate=self.dropout_rate)(y, train)
x = x + y
x = nn.LayerNorm(name='LayerNorm_0')(x)

y = x
y = MlpBlock(
mlp_dim=self.mlp_dim,
use_glu=self.use_glu,
dropout_rate=self.dropout_rate,
name='MlpBlock_3')(y, train)
y = nn.Dropout(rate=self.dropout_rate)(y, train)
x = x + y
x = nn.LayerNorm(name='LayerNorm_2')(x)

return x


Expand All @@ -85,6 +118,8 @@ class Encoder(nn.Module):
mlp_dim: Optional[int] = None # Defaults to 4x input dim.
num_heads: int = 12
dropout_rate: float = 0.0
use_glu: bool = False
use_post_layer_norm: bool = False

@nn.compact
def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
Expand All @@ -94,9 +129,36 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
name=f'encoderblock_{lyr}',
mlp_dim=self.mlp_dim,
num_heads=self.num_heads,
use_glu=self.use_glu,
use_post_layer_norm=self.use_post_layer_norm,
dropout_rate=self.dropout_rate)
x = block(x, train)
return nn.LayerNorm(name='encoder_layernorm')(x)
if not self.use_post_layer_norm:
return nn.LayerNorm(name='encoder_layernorm')(x)
else:
return x


class MAPHead(nn.Module):
"""Multihead Attention Pooling."""
mlp_dim: Optional[int] = None # Defaults to 4x input dim
num_heads: int = 12

@nn.compact
def __call__(self, x):
n, _, d = x.shape
probe = self.param('probe',
nn.initializers.xavier_uniform(), (1, 1, d),
x.dtype)
probe = jnp.tile(probe, [n, 1, 1])

x = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads,
kernel_init=nn.initializers.xavier_uniform())(probe, x)

y = nn.LayerNorm()(x)
x = x + MlpBlock(mlp_dim=self.mlp_dim)(y)
return x[:, 0]


class ViT(nn.Module):
Expand All @@ -112,6 +174,9 @@ class ViT(nn.Module):
dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0.
reinit: Optional[Sequence[str]] = None
head_zeroinit: bool = True
use_glu: bool = False
use_post_layer_norm: bool = False
use_map: bool = False

def get_posemb(self,
seqshape: tuple,
Expand Down Expand Up @@ -145,11 +210,16 @@ def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor:
depth=self.depth,
mlp_dim=self.mlp_dim,
num_heads=self.num_heads,
use_glu=self.use_glu,
use_post_layer_norm=self.use_post_layer_norm,
dropout_rate=dropout_rate,
name='Transformer')(
x, train=not train)

x = jnp.mean(x, axis=1)
if self.use_map:
x = MAPHead(num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x)
else:
x = jnp.mean(x, axis=1)

if self.rep_size:
rep_size = self.width if self.rep_size is True else self.rep_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def init_model_fn(
self._model = models.ViT(
dropout_rate=dropout_rate,
num_classes=self._num_classes,
use_glu=self.use_glu,
use_post_layer_norm=self.use_post_layer_norm,
use_map=self.use_map,
**decode_variant('S/16'))
params, model_state = self.initialized(rng, self._model)
self._param_shapes = param_utils.jax_param_shapes(params)
Expand Down Expand Up @@ -83,3 +86,24 @@ def _eval_model_on_split(self,
rng,
data_dir,
global_step)


class ImagenetVitGluWorkload(ImagenetVitWorkload):

@property
def use_glu(self) -> bool:
return True


class ImagenetVitPostLNWorkload(ImagenetVitWorkload):

@property
def use_post_layer_norm(self) -> bool:
return True


class ImagenetVitMapWorkload(ImagenetVitWorkload):

@property
def use_map(self) -> bool:
return True
Loading

0 comments on commit 56ecd9c

Please sign in to comment.