diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index d9f2a7051..b13f9f00c 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -156,29 +156,29 @@ To use the Docker container as an interactive virtual environment, you can run a ### Using Singularity/Apptainer instead of Docker -Since many compute clusters don't allow the usage of Docker due to securtiy concerns and instead encourage the use of [Singularity/Apptainer](https://github.com/apptainer/apptainer) (formerly Singularity, now called Apptainer), we also provide instructions on how to build an Apptainer container based on the here provided Dockerfile. - -To convert the Dockerfile into an Apptainer definition file, we will use [spython](https://github.com/singularityhub/singularity-cli): +Since many compute clusters don't allow the usage of Docker due to securtiy concerns and instead encourage the use of [Singularity/Apptainer](https://github.com/apptainer/apptainer) (formerly Singularity, now called Apptainer), we also provide an Apptainer recipe (located at `docker/Singularity.def`) that can be used to build an image by running ```bash -pip3 install spython -cd algorithmic-efficiency/docker -spython recipe Dockerfile &> Singularity.def +singularity build --fakeroot .sif Singularity.def ``` -Now we can build the Apptainer image by running - +Note that this can take several minutes. Then, to start a shell session with GPU support (by using the `--nv` flag), we can run ```bash -singularity build --fakeroot .sif Singularity.def +singularity shell --bind $HOME/data:/data,$HOME/experiment_runs:/experiment_runs \ + --nv .sif ``` -To start a shell session with GPU support (by using the `--nv` flag), we can run +Note the `--bind` flag which, similarly to Docker, allows to bind specific paths on the host system and the container, as explained [here](https://docs.sylabs.io/guides/3.7/user-guide/bind_paths_and_mounts.html). + +Also note that we generated `Singularity.def` automatically from the `Dockerfile` using [spython](https://github.com/singularityhub/singularity-cli), as follows: ```bash -singularity shell --nv .sif +pip3 install spython +cd algorithmic-efficiency/docker +python scripts/singularity_converter.py -i Dockerfile -o Singularity.def ``` -Similarly to Docker, Apptainer allows you to bind specific paths on the host system and the container by specifying the `--bind` flag, as explained [here](https://docs.sylabs.io/guides/3.7/user-guide/bind_paths_and_mounts.html). +Users that wish to customize their images are invited to check and modify the `Singularity.def` recipe and the `singularity_converter.py` script. ## Download the Data diff --git a/algorithmic_efficiency/__init__.py b/algorithmic_efficiency/__init__.py index af0a6b8fc..a0e473e1d 100644 --- a/algorithmic_efficiency/__init__.py +++ b/algorithmic_efficiency/__init__.py @@ -1,3 +1,3 @@ """Algorithmic Efficiency.""" -__version__ = '0.0.1' +__version__ = '0.1.0' diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py index ab5d1839e..32e748ec7 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py @@ -34,6 +34,7 @@ def posemb_sincos_2d(h: int, class MlpBlock(nn.Module): """Transformer MLP / feed-forward block.""" mlp_dim: Optional[int] = None # Defaults to 4x input dim. + use_glu: bool = False dropout_rate: float = 0.0 @nn.compact @@ -47,6 +48,11 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: d = x.shape[2] x = nn.Dense(self.mlp_dim or 4 * d, **inits)(x) x = nn.gelu(x) + + if self.use_glu: + y = nn.Dense(self.mlp_dim, **inits)(x) + x = x * y + x = nn.Dropout(rate=self.dropout_rate)(x, train) x = nn.Dense(d, **inits)(x) return x @@ -56,26 +62,53 @@ class Encoder1DBlock(nn.Module): """Single transformer encoder block (MHSA + MLP).""" mlp_dim: Optional[int] = None # Defaults to 4x input dim. num_heads: int = 12 + use_glu: bool = False + use_post_layer_norm: bool = False dropout_rate: float = 0.0 @nn.compact def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: - y = nn.LayerNorm(name='LayerNorm_0')(x) - y = nn.SelfAttention( - num_heads=self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - deterministic=train, - name='MultiHeadDotProductAttention_1')( - y) - y = nn.Dropout(rate=self.dropout_rate)(y, train) - x = x + y - - y = nn.LayerNorm(name='LayerNorm_2')(x) - y = MlpBlock( - mlp_dim=self.mlp_dim, dropout_rate=self.dropout_rate, - name='MlpBlock_3')(y, train) - y = nn.Dropout(rate=self.dropout_rate)(y, train) - x = x + y + if not self.use_post_layer_norm: + y = nn.LayerNorm(name='LayerNorm_0')(x) + y = nn.SelfAttention( + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + deterministic=train, + name='MultiHeadDotProductAttention_1')( + y) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + + y = nn.LayerNorm(name='LayerNorm_2')(x) + y = MlpBlock( + mlp_dim=self.mlp_dim, + use_glu=self.use_glu, + dropout_rate=self.dropout_rate, + name='MlpBlock_3')(y, train) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + else: + y = x + y = nn.SelfAttention( + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + deterministic=train, + name='MultiHeadDotProductAttention_1')( + y) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + x = nn.LayerNorm(name='LayerNorm_0')(x) + + y = x + y = MlpBlock( + mlp_dim=self.mlp_dim, + use_glu=self.use_glu, + dropout_rate=self.dropout_rate, + name='MlpBlock_3')(y, train) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + x = nn.LayerNorm(name='LayerNorm_2')(x) + return x @@ -85,6 +118,8 @@ class Encoder(nn.Module): mlp_dim: Optional[int] = None # Defaults to 4x input dim. num_heads: int = 12 dropout_rate: float = 0.0 + use_glu: bool = False + use_post_layer_norm: bool = False @nn.compact def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: @@ -94,9 +129,36 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: name=f'encoderblock_{lyr}', mlp_dim=self.mlp_dim, num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, dropout_rate=self.dropout_rate) x = block(x, train) - return nn.LayerNorm(name='encoder_layernorm')(x) + if not self.use_post_layer_norm: + return nn.LayerNorm(name='encoder_layernorm')(x) + else: + return x + + +class MAPHead(nn.Module): + """Multihead Attention Pooling.""" + mlp_dim: Optional[int] = None # Defaults to 4x input dim + num_heads: int = 12 + + @nn.compact + def __call__(self, x): + n, _, d = x.shape + probe = self.param('probe', + nn.initializers.xavier_uniform(), (1, 1, d), + x.dtype) + probe = jnp.tile(probe, [n, 1, 1]) + + x = nn.MultiHeadDotProductAttention( + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform())(probe, x) + + y = nn.LayerNorm()(x) + x = x + MlpBlock(mlp_dim=self.mlp_dim)(y) + return x[:, 0] class ViT(nn.Module): @@ -112,6 +174,9 @@ class ViT(nn.Module): dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0. reinit: Optional[Sequence[str]] = None head_zeroinit: bool = True + use_glu: bool = False + use_post_layer_norm: bool = False + use_map: bool = False def get_posemb(self, seqshape: tuple, @@ -145,11 +210,16 @@ def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor: depth=self.depth, mlp_dim=self.mlp_dim, num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, dropout_rate=dropout_rate, name='Transformer')( x, train=not train) - x = jnp.mean(x, axis=1) + if self.use_map: + x = MAPHead(num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x) + else: + x = jnp.mean(x, axis=1) if self.rep_size: rep_size = self.width if self.rep_size is True else self.rep_size diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py index 3f3af0564..4b12247c2 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py @@ -37,6 +37,9 @@ def init_model_fn( self._model = models.ViT( dropout_rate=dropout_rate, num_classes=self._num_classes, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + use_map=self.use_map, **decode_variant('S/16')) params, model_state = self.initialized(rng, self._model) self._param_shapes = param_utils.jax_param_shapes(params) @@ -83,3 +86,24 @@ def _eval_model_on_split(self, rng, data_dir, global_step) + + +class ImagenetVitGluWorkload(ImagenetVitWorkload): + + @property + def use_glu(self) -> bool: + return True + + +class ImagenetVitPostLNWorkload(ImagenetVitWorkload): + + @property + def use_post_layer_norm(self) -> bool: + return True + + +class ImagenetVitMapWorkload(ImagenetVitWorkload): + + @property + def use_map(self) -> bool: + return True diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py index 55a8e370d..469716d59 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -1,8 +1,8 @@ """PyTorch implementation of refactored and simplified ViT. Adapted from: -https://github.com/huggingface/transformers/tree/main/src/transformers/models/vit. -https://github.com/lucidrains/vit-pytorch. +https://github.com/huggingface/transformers/tree/main/src/transformers/models/vit +and https://github.com/lucidrains/vit-pytorch. """ import math @@ -14,9 +14,12 @@ from algorithmic_efficiency import init_utils from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.wmt.wmt_pytorch.models import \ + MultiheadAttention def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor: + """Follows the MoCo v3 logic.""" _, width, h, w = patches.shape device = patches.device y, x = torch.meshgrid(torch.arange(h, device=device), @@ -39,18 +42,26 @@ def __init__( self, width: int, mlp_dim: Optional[int] = None, # Defaults to 4x input dim. + use_glu: bool = False, dropout_rate: float = 0.0) -> None: super().__init__() self.width = width self.mlp_dim = mlp_dim or 4 * width + self.use_glu = use_glu self.dropout_rate = dropout_rate - self.net = nn.Sequential( - nn.Linear(self.width, self.mlp_dim), - nn.GELU(), - nn.Dropout(self.dropout_rate), - nn.Linear(self.mlp_dim, self.width)) + self.linear1 = nn.Linear(self.width, self.mlp_dim) + self.act_fnc = nn.GELU(approximate='tanh') + self.dropout = nn.Dropout(self.dropout_rate) + + if self.use_glu: + self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim) + else: + self.glu_linear = None + + self.linear2 = nn.Linear(self.mlp_dim, self.width) + self.reset_parameters() def reset_parameters(self) -> None: @@ -61,7 +72,16 @@ def reset_parameters(self) -> None: module.bias.data.normal_(std=1e-6) def forward(self, x: spec.Tensor) -> spec.Tensor: - return self.net(x) + x = self.linear1(x) + x = self.act_fnc(x) + + if self.use_glu: + y = self.glu_linear(x) + x = x * y + + x = self.dropout(x) + x = self.linear2(x) + return x class SelfAttention(nn.Module): @@ -129,29 +149,50 @@ def __init__(self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 12, + use_glu: bool = False, + use_post_layer_norm: bool = False, dropout_rate: float = 0.0) -> None: super().__init__() self.width = width self.mlp_dim = mlp_dim self.num_heads = num_heads + self.use_glu = use_glu + self.use_post_layer_norm = use_post_layer_norm self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6) self.self_attention1 = SelfAttention(self.width, self.num_heads) self.dropout = nn.Dropout(dropout_rate) self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6) - self.mlp3 = MlpBlock(self.width, self.mlp_dim, dropout_rate) + self.mlp3 = MlpBlock( + width=self.width, + mlp_dim=self.mlp_dim, + use_glu=self.use_glu, + dropout_rate=dropout_rate) def forward(self, x: spec.Tensor) -> spec.Tensor: - y = self.layer_norm0(x) - y = self.self_attention1(y) - y = self.dropout(y) - x = x + y - - y = self.layer_norm2(x) - y = self.mlp3(y) - y = self.dropout(y) - x = x + y + if not self.use_post_layer_norm: + y = self.layer_norm0(x) + y = self.self_attention1(y) + y = self.dropout(y) + x = x + y + + y = self.layer_norm2(x) + y = self.mlp3(y) + y = self.dropout(y) + x = x + y + else: + y = x + y = self.self_attention1(y) + y = self.dropout(y) + x = x + y + x = self.layer_norm0(x) + + y = x + y = self.mlp3(y) + y = self.dropout(y) + x = x + y + x = self.layer_norm2(x) return x @@ -163,6 +204,8 @@ def __init__(self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 12, + use_glu: bool = False, + use_post_layer_norm: bool = False, dropout_rate: float = 0.0) -> None: super().__init__() @@ -170,18 +213,61 @@ def __init__(self, self.width = width self.mlp_dim = mlp_dim self.num_heads = num_heads + self.use_glu = use_glu + self.use_post_layer_norm = use_post_layer_norm self.net = nn.ModuleList([ - Encoder1DBlock(self.width, self.mlp_dim, self.num_heads, dropout_rate) - for _ in range(depth) + Encoder1DBlock(self.width, + self.mlp_dim, + self.num_heads, + self.use_glu, + self.use_post_layer_norm, + dropout_rate) for _ in range(depth) ]) - self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6) + + if not self.use_post_layer_norm: + self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6) + else: + self.encoder_norm = None def forward(self, x: spec.Tensor) -> spec.Tensor: # Input Encoder. for block in self.net: x = block(x) - return self.encoder_norm(x) + if not self.use_post_layer_norm: + return self.encoder_norm(x) + else: + return x + + +class MAPHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, + width: int, + mlp_dim: Optional[int] = None, + num_heads: int = 12): + super().__init__() + self.width = width + self.mlp_dim = mlp_dim + self.num_heads = num_heads + + self.probe = nn.Parameter(torch.zeros((1, 1, self.width))) + nn.init.xavier_uniform_(self.probe.data) + + self.mha = MultiheadAttention( + self.width, num_heads=self.num_heads, self_attn=False, bias=False) + self.layer_norm = nn.LayerNorm(self.width, eps=1e-6) + self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim) + + def forward(self, x: spec.Tensor) -> spec.Tensor: + n, _, _ = x.shape + probe = torch.tile(self.probe, [n, 1, 1]) + + x = self.mha(probe, x)[0] + y = self.layer_norm(x) + x = x + self.mlp(y) + return x[:, 0] class ViT(nn.Module): @@ -202,6 +288,9 @@ def __init__( rep_size: Union[int, bool] = True, dropout_rate: Optional[float] = 0.0, head_zeroinit: bool = True, + use_glu: bool = False, + use_post_layer_norm: bool = False, + use_map: bool = False, dtype: Any = torch.float32) -> None: super().__init__() if dropout_rate is None: @@ -215,6 +304,9 @@ def __init__( self.num_heads = num_heads self.rep_size = rep_size self.head_zeroinit = head_zeroinit + self.use_glu = use_glu + self.use_post_layer_norm = use_post_layer_norm + self.use_map = use_map self.dtype = dtype if self.rep_size: @@ -234,10 +326,18 @@ def __init__( width=self.width, mlp_dim=self.mlp_dim, num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, dropout_rate=dropout_rate) if self.num_classes: self.head = nn.Linear(self.width, self.num_classes) + + if self.use_map: + self.map = MAPHead(self.width, self.mlp_dim, self.num_heads) + else: + self.map = None + self.reset_parameters() def reset_parameters(self) -> None: @@ -270,7 +370,11 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: x = self.dropout(x) x = self.encoder(x) - x = torch.mean(x, dim=1) + + if self.use_map: + x = self.map(x) + else: + x = torch.mean(x, dim=1) if self.rep_size: x = torch.tanh(self.pre_logits(x)) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py index 08a62ede6..645b795ca 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -34,6 +34,9 @@ def init_model_fn( model = models.ViT( dropout_rate=dropout_rate, num_classes=self._num_classes, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + use_map=self.use_map, **decode_variant('S/16')) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) @@ -77,3 +80,24 @@ def model_fn( logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) return logits_batch, None + + +class ImagenetVitGluWorkload(ImagenetVitWorkload): + + @property + def use_glu(self) -> bool: + return True + + +class ImagenetVitPostLNWorkload(ImagenetVitWorkload): + + @property + def use_post_layer_norm(self) -> bool: + return True + + +class ImagenetVitMapWorkload(ImagenetVitWorkload): + + @property + def use_map(self) -> bool: + return True diff --git a/algorithmic_efficiency/workloads/imagenet_vit/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/workload.py index 61d3acfd3..ed0118ca0 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/workload.py @@ -60,6 +60,21 @@ def validation_target_value(self) -> float: def test_target_value(self) -> float: return 1 - 0.3481 # 0.6519 + @property + def use_post_layer_norm(self) -> bool: + """Whether to use layer normalization after the residual branch.""" + return False + + @property + def use_map(self) -> bool: + """Whether to use multihead attention pooling.""" + return False + + @property + def use_glu(self) -> bool: + """Whether to use GLU in the MLPBlock.""" + return False + @property def eval_batch_size(self) -> int: return 2048 diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/models.py index 358415587..0e66d2ab8 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/models.py @@ -15,7 +15,7 @@ def make_fn(inputs): return make_fn -def _make_mlp(hidden_dims, dropout): +def _make_mlp(hidden_dims, dropout, activation_fn): """Creates a MLP with specified dimensions.""" @jraph.concatenated_args @@ -24,7 +24,7 @@ def make_fn(inputs): for dim in hidden_dims: x = nn.Dense(features=dim)(x) x = nn.LayerNorm()(x) - x = nn.relu(x) + x = activation_fn(x) x = dropout(x) return x @@ -42,6 +42,7 @@ class GNN(nn.Module): # If None, defaults to 0.1. dropout_rate: Optional[float] = 0.1 num_message_passing_steps: int = 5 + activation_fn_name: str = 'relu' @nn.compact def __call__(self, graph, train): @@ -59,11 +60,24 @@ def __call__(self, graph, train): embed_edge_fn=_make_embed(self.latent_dim, name='edge_embedding')) graph = embedder(graph) + if self.activation_fn_name == 'relu': + activation_fn = nn.relu + elif self.activation_fn_name == 'gelu': + activation_fn = nn.gelu + elif self.activation_fn_name == 'silu': + activation_fn = nn.silu + else: + raise ValueError( + f'Invalid activation function name: {self.activation_fn_name}') + for _ in range(self.num_message_passing_steps): net = jraph.GraphNetwork( - update_edge_fn=_make_mlp(self.hidden_dims, dropout=dropout), - update_node_fn=_make_mlp(self.hidden_dims, dropout=dropout), - update_global_fn=_make_mlp(self.hidden_dims, dropout=dropout)) + update_edge_fn=_make_mlp( + self.hidden_dims, dropout=dropout, activation_fn=activation_fn), + update_node_fn=_make_mlp( + self.hidden_dims, dropout=dropout, activation_fn=activation_fn), + update_global_fn=_make_mlp( + self.hidden_dims, dropout=dropout, activation_fn=activation_fn)) graph = net(graph) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py index 009aab91a..9fc24552d 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py @@ -25,7 +25,13 @@ def init_model_fn( """aux_dropout_rate is unused.""" del aux_dropout_rate rng, params_rng, dropout_rng = jax.random.split(rng, 3) - self._model = models.GNN(self._num_outputs, dropout_rate=dropout_rate) + self._model = models.GNN( + self._num_outputs, + dropout_rate=dropout_rate, + activation_fn_name=self.activation_fn_name, + hidden_dims=self.hidden_dims, + latent_dim=self.latent_dim, + num_message_passing_steps=self.num_message_passing_steps) init_fn = jax.jit(functools.partial(self._model.init, train=False)) fake_batch = jraph.GraphsTuple( n_node=jnp.asarray([1]), @@ -115,3 +121,34 @@ def _normalize_eval_metrics( del num_examples total_metrics = total_metrics.reduce() return {k: float(v) for k, v in total_metrics.compute().items()} + + +class OgbgGeluWorkload(OgbgWorkload): + + @property + def activation_fn_name(self) -> str: + """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" + return 'gelu' + + +class OgbgSiluWorkload(OgbgWorkload): + + @property + def activation_fn_name(self) -> str: + """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" + return 'silu' + + +class OgbgModelSizeWorkload(OgbgWorkload): + + @property + def hidden_dims(self) -> Tuple[int]: + return (256, 256) + + @property + def latent_dim(self) -> int: + return 128 + + @property + def num_message_passing_steps(self) -> int: + return 3 diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 1b392753b..d93013b87 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -1,5 +1,6 @@ # Ported to PyTorch from # https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py. +from functools import partial from typing import Callable, Optional, Tuple import jax.tree_util as tree @@ -10,14 +11,16 @@ from algorithmic_efficiency import init_utils -def _make_mlp(in_dim, hidden_dims, dropout_rate): +def _make_mlp(in_dim, hidden_dims, dropout_rate, activation_fn): """Creates a MLP with specified dimensions.""" layers = nn.Sequential() - for dim in hidden_dims: - layers.add_module('dense', nn.Linear(in_features=in_dim, out_features=dim)) - layers.add_module('norm', nn.LayerNorm(dim, eps=1e-6)) - layers.add_module('relu', nn.ReLU()) - layers.add_module('dropout', nn.Dropout(dropout_rate)) + for i, dim in enumerate(hidden_dims): + layers.add_module(f'dense_{i}', + nn.Linear(in_features=in_dim, out_features=dim)) + layers.add_module(f'norm_{i}', nn.LayerNorm(dim, eps=1e-6)) + layers.add_module(f'activation_fn_{i}', activation_fn()) + layers.add_module(f'dropout_{i}', nn.Dropout(dropout_rate)) + in_dim = dim return layers @@ -27,14 +30,18 @@ class GNN(nn.Module): The model assumes the input data is a jraph.GraphsTuple without global variables. The final prediction will be encoded in the globals. """ - latent_dim: int = 256 - hidden_dims: Tuple[int] = (256,) - num_message_passing_steps: int = 5 def __init__(self, num_outputs: int = 128, - dropout_rate: Optional[float] = 0.1) -> None: + dropout_rate: Optional[float] = 0.1, + activation_fn_name: str = 'relu', + latent_dim: int = 256, + hidden_dims: Tuple[int] = (256,), + num_message_passing_steps: int = 5) -> None: super().__init__() + self.latent_dim = latent_dim + self.hidden_dims = hidden_dims + self.num_message_passing_steps = num_message_passing_steps self.num_outputs = num_outputs if dropout_rate is None: dropout_rate = 0.1 @@ -42,23 +49,44 @@ def __init__(self, self.node_embedder = nn.Linear(in_features=9, out_features=self.latent_dim) self.edge_embedder = nn.Linear(in_features=3, out_features=self.latent_dim) + if activation_fn_name == 'relu': + activation_fn = nn.ReLU + elif activation_fn_name == 'gelu': + activation_fn = partial(nn.GELU, approximate='tanh') + elif activation_fn_name == 'silu': + activation_fn = nn.SiLU + else: + raise ValueError( + f'Invalid activation function name: {self.activation_fn_name}') + graph_network_layers = [] for st in range(self.num_message_passing_steps): - # Constants in in_dims are based on the requirements of the GraphNetwork. + # Constants in in_dims are based on forward call of GraphNetwork: + # specifically update_edge_fn update_node_fn and update_global_fn. if st == 0: - in_dim = self.latent_dim * 3 + self.num_outputs - last_in_dim = self.latent_dim * 2 + self.num_outputs + in_dim_edge_fn = self.latent_dim * 3 + self.num_outputs + in_dim_node_fn = self.latent_dim + self.hidden_dims[ + -1] * 2 + self.num_outputs + last_in_dim = self.hidden_dims[-1] * 2 + self.num_outputs else: - in_dim = self.hidden_dims[-1] * 4 + in_dim_edge_fn = self.hidden_dims[-1] * 4 + in_dim_node_fn = self.hidden_dims[-1] * 4 last_in_dim = self.hidden_dims[-1] * 3 graph_network_layers.append( GraphNetwork( - update_edge_fn=_make_mlp(in_dim, self.hidden_dims, dropout_rate), - update_node_fn=_make_mlp(in_dim, self.hidden_dims, dropout_rate), + update_edge_fn=_make_mlp(in_dim_edge_fn, + self.hidden_dims, + dropout_rate, + activation_fn), + update_node_fn=_make_mlp(in_dim_node_fn, + self.hidden_dims, + dropout_rate, + activation_fn), update_global_fn=_make_mlp(last_in_dim, self.hidden_dims, - dropout_rate))) + dropout_rate, + activation_fn))) self.graph_network = nn.Sequential(*graph_network_layers) self.decoder = nn.Linear( @@ -147,7 +175,6 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple: # giving us tensors of shape [num_edges, global_feat]. global_edge_attributes = tree.tree_map( lambda g: torch.repeat_interleave(g, n_edge, dim=0), globals_) - if self.update_edge_fn: edge_fn_inputs = torch.cat( [edges, sent_attributes, received_attributes, global_edge_attributes], diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index a1fbf2e8a..beb518e0f 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -144,7 +144,13 @@ def init_model_fn( """aux_dropout_rate is unused.""" del aux_dropout_rate torch.random.manual_seed(rng[0]) - model = GNN(num_outputs=self._num_outputs, dropout_rate=dropout_rate) + model = GNN( + num_outputs=self._num_outputs, + dropout_rate=dropout_rate, + hidden_dims=self.hidden_dims, + latent_dim=self.latent_dim, + num_message_passing_steps=self.num_message_passing_steps, + activation_fn_name=self.activation_fn_name) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -235,3 +241,34 @@ def _normalize_eval_metrics( """Normalize eval metrics.""" del num_examples return {k: float(v) for k, v in total_metrics.compute().items()} + + +class OgbgGeluWorkload(OgbgWorkload): + + @property + def activation_fn_name(self) -> str: + """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" + return 'gelu' + + +class OgbgSiluWorkload(OgbgWorkload): + + @property + def activation_fn_name(self) -> str: + """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" + return 'silu' + + +class OgbgModelSizeWorkload(OgbgWorkload): + + @property + def hidden_dims(self) -> Tuple[int]: + return (256, 256) + + @property + def latent_dim(self) -> int: + return 128 + + @property + def num_message_passing_steps(self) -> int: + return 3 diff --git a/algorithmic_efficiency/workloads/ogbg/workload.py b/algorithmic_efficiency/workloads/ogbg/workload.py index 7ca6ebc1e..a32f385cb 100644 --- a/algorithmic_efficiency/workloads/ogbg/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/workload.py @@ -3,7 +3,7 @@ import abc import itertools import math -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple import jax @@ -22,6 +22,23 @@ def target_metric_name(self) -> str: """The name of the target metric (useful for scoring/processing code).""" return 'mean_average_precision' + @property + def activation_fn_name(self) -> str: + """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" + return 'relu' + + @property + def hidden_dims(self) -> Tuple[int]: + return (256,) + + @property + def latent_dim(self) -> int: + return 256 + + @property + def num_message_passing_steps(self) -> int: + return 5 + def has_reached_validation_target(self, eval_result: float) -> bool: return eval_result[ 'validation/mean_average_precision'] > self.validation_target_value diff --git a/algorithmic_efficiency/workloads/workloads.py b/algorithmic_efficiency/workloads/workloads.py index 6d0b08cef..f4efa5b19 100644 --- a/algorithmic_efficiency/workloads/workloads.py +++ b/algorithmic_efficiency/workloads/workloads.py @@ -68,6 +68,18 @@ 'workload_path': 'imagenet_vit/imagenet', 'workload_class_name': 'ImagenetVitWorkload', }, + 'imagenet_vit_glu': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetVitGluWorkload', + }, + 'imagenet_vit_post_ln': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetViTPostLNWorkload', + }, + 'imagenet_vit_map': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetViTMapLNWorkload', + }, 'librispeech_conformer': { 'workload_path': 'librispeech_conformer/librispeech', 'workload_class_name': 'LibriSpeechConformerWorkload', @@ -96,6 +108,16 @@ 'ogbg': { 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgWorkload' }, + 'ogbg_gelu': { + 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgGeluWorkload' + }, + 'ogbg_silu': { + 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgSiluWorkload' + }, + 'ogbg_model_size': { + 'workload_path': 'ogbg/ogbg', + 'workload_class_name': 'OgbgModelSizeWorkload' + }, 'wmt': {'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkload'}, 'wmt_post_ln': { 'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkloadPostLN' diff --git a/datasets/README.md b/datasets/README.md index 614344978..c68a5cc6b 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -1,74 +1,195 @@ -# Dataset Setup -TL;DR: -Use `dataset_setup.py` to download datasets. -Usage: +# MLCommons™ AlgoPerf: Dataset Setup + +## Table of Contents + +- [General Setup](#general-setup) + - [Set Data Directory (Docker Container)](#set-data-directory-docker-container) + - [Set Data Directory (on Host)](#set-data-directory-on-host) + - [Start tmux session (Recommended)](#start-tmux-session-recommended) + - [Clean up](#clean-up) +- [Individual Dataset Instructions](#individual-dataset-instructions) + - [OGBG](#ogbg) + - [WMT](#wmt) + - [FastMRI](#fastmri) + - [ImageNet](#imagenet) + - [Criteo1TB](#criteo1tb) + - [LibriSpeech](#librispeech) + - [Training SPM Tokenizer](#training-spm-tokenizer) + - [Preprocessing Script](#preprocessing-script) + +## General Setup + +This document provides instructions on downloading and preparing all datasets utilized in the AlgoPerf benchmark. You can prepare the individual datasets one-by-one as needed. If your setup, such as your cloud or cluster environment, already contains these datasets, you may skip the dataset setup for this particular data (and directly specify the dataset location in the `submission_runner.py`). Just verify that you are using the same dataset version (and possible preprocessing). + +*TL;DR to download and prepare a dataset, run `dataset_setup.py`:* + ```bash python3 datasets/dataset_setup.py \ --data_dir=~/data \ -- - -- + -- ``` -The complete benchmark uses 6 datasets: -- OGBG -- WMT -- FastMRI -- Imagenet -- Criteo 1TB -- Librispeech +The complete benchmark uses 6 different datasets: -Some dataset setups will require you to sign a third party agreement with the dataset owners in order to get the donwload URLs. +- [OGBG](#ogbg) +- [WMT](#wmt) +- [FastMRI](#fastmri) +- [Imagenet](#imagenet) +- [Criteo 1TB](#criteo1tb) +- [Librispeech](#librispeech) -# Per dataset instructions -## Environment +Some dataset setups will require you to sign a third-party agreement with the dataset owners in order to get the download URLs. -### Set data directory (Docker container) -If you are running the `dataset_setup.py` script from a Docker container, please +### Set Data Directory (Docker Container) + +If you are running the `dataset_setup.py` script from a Docker container, please make sure the data directory is mounted to a directory on your host with --v flag. If you are following instructions from the README you will have used +`-v` flag. If you are following instructions from the [Getting Started guide](/GETTING_STARTED.md) you will have used the `-v $HOME/data:/data` flag in the `docker run` command. This will mount -the `$HOME/data` directory to the `/data` directory in the container. -In this case set --data_dir to `/data`. +the `$HOME/data` directory to the `/data` directory in the container. +In this case set, `--data_dir` to `/data`. + ```bash DATA_DIR='/data' ``` -### Set data directory (on host) -Alternatively, if you are running the data download script directly on your host, feel free -to choose whatever directory you find suitable, further submission instructions -assume the data is stored in `~/data`. + +### Set Data Directory (on Host) + +Alternatively, if you are running the data download script directly on your host, feel free to choose whatever directory you find suitable, further submission instructions assume the data is stored in `~/data`. + ```bash DATA_DIR='~/data' ``` + #### Start tmux session (Recommended) -If running the dataset_setup.py on directly on host it is recommended to run -the dataset_setup.py script in a tmux session because some of the data downloads may -take several hours. To avoid your setup being interrupted start a tmux session: + +If running the `dataset_setup.py` on directly on host it is recommended to run +the `dataset_setup.py` script in a `tmux` session because some of the data downloads may take several hours. To avoid your setup being interrupted start a `tmux` session: + ```bash tmux new -s data_setup ``` +### Clean up -## Datasets +In order to avoid potential accidental deletion, this script does NOT +delete any intermediate temporary files (such as zip archives) without a user +confirmation. Deleting temp files is particularly important for Criteo 1TB, as +there can be multiple copies of the dataset on disk during preprocessing if +files are not cleaned up. + +By default, a user will be prompted before any files are deleted. If you do not want any temp files to be deleted, you can pass `--interactive_deletion=false` and then all files will be downloaded to the provided `--temp_dir`, and the user can manually delete these after downloading has finished. + +## Individual Dataset Instructions + +### OGBG -### OGBG From `algorithmic-efficiency` run: + ```bash python3 datasets/dataset_setup.py \ ---data_dir $DATA_DIR/ogbg \ +--data_dir $DATA_DIR \ --ogbg ``` -### WMT +
+The final directory structure should look like this: + +```bash +$DATA_DIR +├── ogbg +│ └── ogbg_molpcba +│ └── 0.1.3 +│ ├── dataset_info.json +│ ├── features.json +│ ├── metadata.json +│ ├── ogbg_molpcba-test.tfrecord-00000-of-00001 +│ ├── ogbg_molpcba-train.tfrecord-00000-of-00008 +│ ├── ogbg_molpcba-train.tfrecord-00001-of-00008 +│ ├── ogbg_molpcba-train.tfrecord-00002-of-00008 +│ ├── ogbg_molpcba-train.tfrecord-00003-of-00008 +│ ├── ogbg_molpcba-train.tfrecord-00004-of-00008 +│ ├── ogbg_molpcba-train.tfrecord-00005-of-00008 +│ ├── ogbg_molpcba-train.tfrecord-00006-of-00008 +│ ├── ogbg_molpcba-train.tfrecord-00007-of-00008 +│ └── ogbg_molpcba-validation.tfrecord-00000-of-00001 +``` + +In total, it should contain 13 files (via `find -type f | wc -l`) for a total of 777 MB (via `du -sch ogbg/`). +
+ +### WMT + From `algorithmic-efficiency` run: + ```bash python3 datasets/dataset_setup.py \ --data_dir $DATA_DIR \ --wmt ``` +
+The final directory structure should look like this: + +```bash +$DATA_DIR +├── wmt + ├── wmt14_translate + │ └── de-en + │ └── 1.0.0 + │ ├── dataset_info.json + │ ├── features.json + │ ├── wmt14_translate-test.tfrecord-00000-of-00001 + │ ├── wmt14_translate-train.tfrecord-00000-of-00016 + │ ├── wmt14_translate-train.tfrecord-00001-of-00016 + │ ├── wmt14_translate-train.tfrecord-00002-of-00016 + │ ├── wmt14_translate-train.tfrecord-00003-of-00016 + │ ├── wmt14_translate-train.tfrecord-00004-of-00016 + │ ├── wmt14_translate-train.tfrecord-00005-of-00016 + │ ├── wmt14_translate-train.tfrecord-00006-of-00016 + │ ├── wmt14_translate-train.tfrecord-00007-of-00016 + │ ├── wmt14_translate-train.tfrecord-00008-of-00016 + │ ├── wmt14_translate-train.tfrecord-00009-of-00016 + │ ├── wmt14_translate-train.tfrecord-00010-of-00016 + │ ├── wmt14_translate-train.tfrecord-00011-of-00016 + │ ├── wmt14_translate-train.tfrecord-00012-of-00016 + │ ├── wmt14_translate-train.tfrecord-00013-of-00016 + │ ├── wmt14_translate-train.tfrecord-00014-of-00016 + │ ├── wmt14_translate-train.tfrecord-00015-of-00016 + │ └── wmt14_translate-validation.tfrecord-00000-of-00001 + ├── wmt17_translate + │ └── de-en + │ └── 1.0.0 + │ ├── dataset_info.json + │ ├── features.json + │ ├── wmt17_translate-test.tfrecord-00000-of-00001 + │ ├── wmt17_translate-train.tfrecord-00000-of-00016 + │ ├── wmt17_translate-train.tfrecord-00001-of-00016 + │ ├── wmt17_translate-train.tfrecord-00002-of-00016 + │ ├── wmt17_translate-train.tfrecord-00003-of-00016 + │ ├── wmt17_translate-train.tfrecord-00004-of-00016 + │ ├── wmt17_translate-train.tfrecord-00005-of-00016 + │ ├── wmt17_translate-train.tfrecord-00006-of-00016 + │ ├── wmt17_translate-train.tfrecord-00007-of-00016 + │ ├── wmt17_translate-train.tfrecord-00008-of-00016 + │ ├── wmt17_translate-train.tfrecord-00009-of-00016 + │ ├── wmt17_translate-train.tfrecord-00010-of-00016 + │ ├── wmt17_translate-train.tfrecord-00011-of-00016 + │ ├── wmt17_translate-train.tfrecord-00012-of-00016 + │ ├── wmt17_translate-train.tfrecord-00013-of-00016 + │ ├── wmt17_translate-train.tfrecord-00014-of-00016 + │ ├── wmt17_translate-train.tfrecord-00015-of-00016 + │ └── wmt17_translate-validation.tfrecord-00000-of-00001 + └── wmt_sentencepiece_model +``` + +In total, it should contain 43 files (via `find -type f | wc -l`) for a total of 3.3 GB (via `du -sch wmt/`). +
+ +### FastMRI -## FastMRI -Fill out form on https://fastmri.med.nyu.edu/. After filling out the form +Fill out form on . After filling out the form you should get an email containing the URLS for "knee_singlecoil_train", "knee_singlecoil_val" and "knee_singlecoil_test". @@ -81,18 +202,37 @@ python3 datasets/dataset_setup.py \ --fastmri_knee_singlecoil_test_url '' ``` -## ImageNet -Register on https://image-net.org/ and follow directions to obtain the -URLS for the ILSVRC2012 train and validation images. +
+The final directory structure should look like this: -Imagenet dataset processsing is resource intensive. To avoid potential -ResourcExhausted errors increase the maximum number of open file descriptors: ```bash -ulimit -n 8192 +$DATA_DIR +├── fastmri +│ ├── knee_singlecoil_test +│ │ ├── file1000022.h5 +│ │ ├── [...] +│ │ └── file1002571.h5 +│ ├── knee_singlecoil_train +│ │ ├── file1000001.h5 +│ │ ├── [...] +│ │ └── file1002569.h5 +│ └── knee_singlecoil_val +│ ├── file1000000.h5 +│ ├── [...] +│ └── file1002570.h5 ``` -The imagenet data pipeline differs between the pytorch and jax workloads. -Therefore, you will have to specify the framework (pytorch or jax) through theframework flag. +In total, it should contain 1280 files (via `find -type f | wc -l`) for a total of 112 GB (via `du -sch fastmri/`). +
+ +### ImageNet + +Register on and follow directions to obtain the +URLS for the ILSVRC2012 train and validation images. +The script will additionally automatically download the `matched-frequency` version of [ImageNet v2](https://www.tensorflow.org/datasets/catalog/imagenet_v2#imagenet_v2matched-frequency_default_config), which is used as the test set of the ImageNet workloads. + +The ImageNet data pipeline differs between the PyTorch and JAX workloads. +Therefore, you will have to specify the framework (either `pytorch` or `jax`) through the framework flag. ```bash python3 datasets/dataset_setup.py \ @@ -102,15 +242,108 @@ python3 datasets/dataset_setup.py \ --imagenet_train_url \ --imagenet_val_url \ --framework jax +``` +Imagenet dataset processsing is resource intensive. To avoid potential +ResourcExhausted errors increase the maximum number of open file descriptors: + +```bash +ulimit -n 8192 ``` -Note that some functions use subprocess.Popen(..., shell=True), which can be -dangerous if the user injects code into the --data_dir or --temp_dir flags. We -do some basic sanitization in main(), but submitters should not let untrusted +Note that some functions use `subprocess.Popen(..., shell=True)`, which can be +dangerous if the user injects code into the `--data_dir` or `--temp_dir` flags. We +do some basic sanitization in `main()`, but submitters should not let untrusted users run this script on their systems. -## Criteo1tb +
+The final directory structure should look like this for ImageNet2012 (PyTorch): + +```bash +$DATA_DIR +├── imagenet +│ ├── train +│ ├── n01440764 +│ ├── n01440764_10026.JPEG +│ ├── n01440764_10027.JPEG +│ ├── n01440764_10029.JPEG +│ ├── [...] +│ ├── [...] +│ └── val +│ ├── n01440764 +│ ├── ILSVRC2012_val_00000293.JPEG +│ ├── ILSVRC2012_val_00002138.JPEG +│ ├── [...] +│ ├── [...] +``` + +In total, it should contain 1,281,167 `train` files and 50,000 `val` (via `find -type f | wc -l`) for a total of 177 GB and 7.8 GB, respectively (via `du -sch train/` and `du -sch val/`). +
+ +
+The final directory structure should look like this for ImageNet2012 (JAX): + +```bash +$DATA_DIR +├──imagenet +│ ├── jax +│ │ ├── downloads +│ │ │ ├── extracted +│ │ │ └── manual_ +│ │ ├── imagenet2012 +│ │ │ └── 5.1.0 +│ │ │ ├── dataset_info.json +│ │ │ ├── features.json +│ │ │ ├── imagenet2012-train.tfrecord-00000-of-01024 +│ │ │ ├── imagenet2012-train.tfrecord-00001-of-01024 +│ │ │ ├── [...] +│ │ └── imagenet_v2 +│ │ └── matched-frequency +│ │ └── 3.0.0 +│ │ ├── dataset_info.json +│ │ ├── features.json +│ │ ├── imagenet_v2-test.tfrecord-00000-of-00016 +│ │ ├── imagenet_v2-test.tfrecord-00001-of-00016 +│ │ ├── [...] +``` + +In total, it should contain 1,111 files (via `find -type f | wc -l`) for a total of 145 GB (via `du -sch imagenet/jax`). +
+ +
+The final directory structure should look like this for ImageNet v2 (separate): + +```bash +$DATA_DIR +├── imagenet_v2 +│ └── matched-frequency +│ └── 3.0.0 +│ ├── dataset_info.json +│ ├── features.json +│ ├── imagenet_v2-test.tfrecord-00000-of-00016 +│ ├── imagenet_v2-test.tfrecord-00001-of-00016 +│ ├── imagenet_v2-test.tfrecord-00002-of-00016 +│ ├── imagenet_v2-test.tfrecord-00003-of-00016 +│ ├── imagenet_v2-test.tfrecord-00004-of-00016 +│ ├── imagenet_v2-test.tfrecord-00005-of-00016 +│ ├── imagenet_v2-test.tfrecord-00006-of-00016 +│ ├── imagenet_v2-test.tfrecord-00007-of-00016 +│ ├── imagenet_v2-test.tfrecord-00008-of-00016 +│ ├── imagenet_v2-test.tfrecord-00009-of-00016 +│ ├── imagenet_v2-test.tfrecord-00010-of-00016 +│ ├── imagenet_v2-test.tfrecord-00011-of-00016 +│ ├── imagenet_v2-test.tfrecord-00012-of-00016 +│ ├── imagenet_v2-test.tfrecord-00013-of-00016 +│ ├── imagenet_v2-test.tfrecord-00014-of-00016 +│ ├── imagenet_v2-test.tfrecord-00015-of-00016 +│ └── label.labels.txt +``` + +In total, it should contain 20 files (via `find -type f | wc -l`) for a total of 1.2 GB (via `du -sch imagenet_v2/`). +
+ +### Criteo1TB + ```bash python3 datasets/dataset_setup.py \ --data_dir $DATA_DIR \ @@ -118,19 +351,28 @@ python3 datasets/dataset_setup.py \ --criteo1tb ``` -### Clean up -In order to avoid potential accidental deletion, this script does NOT -delete any intermediate temporary files (such as zip archives) without a user -confirmation. Deleting temp files is particularly important for Criteo 1TB, as -there can be multiple copies of the dataset on disk during preprocessing if -files are not cleaned up. If you do not want any temp files to be deleted, you -can pass --interactive_deletion=false and then all files will be downloaded to -the provided --temp_dir, and the user can manually delete these after -downloading has finished. +Note, that this requries the [`pigz` library](https://zlib.net/pigz/) to be installed. +
+The final directory structure should look like this: + +```bash +$DATA_DIR +├── criteo1tb +│ ├── day_0_00 +│ ├── day_0_01 +│ ├── day_0_02 +│ ├── day_0_03 +│ ├── [...] +``` + +In total, it should contain 885 files (via `find -type f | wc -l`) for a total of 1.1 TB (via `du -sch criteo1tb/`). +
+ +### LibriSpeech -## Librispeech To download, train a tokenizer and preprocess the librispeech dataset: + ```bash python3 datasets/dataset_setup.py \ --data_dir $DATA_DIR \ @@ -138,26 +380,72 @@ python3 datasets/dataset_setup.py \ --librispeech ``` -### Notes on librispeech preprocessing +Note, that this requries the [`ffmpeg` toolbox](https://ffmpeg.org/) to be installed. + +
+The final directory structure should look like this: + +```bash +$DATA_DIR +├──librispeech +│ ├── dev-clean.csv +│ ├── dev-other.csv +│ ├── spm_model.vocab +│ ├── test-clean.csv +│ ├── train-clean-100.csv +│ ├── train-clean-360.csv +│ ├── train-clean-500.csv +│ ├── dev-clean +│ │ ├── 1272-128104-0000_audio.npy +│ │ ├── 1272-128104-0000_targets.npy +│ │ ├── 1272-128104-0001_audio.npy +│ │ ├── 1272-128104-0001_targets.npy +│ │ ├── [...] +│ ├── dev-other +│ │ ├── 116-288045-0000_audio.npy +│ │ ├── 116-288045-0000_targets.npy +│ │ ├── [...] +│ ├── test-clean +│ │ ├── 1089-134686-0000_audio.npy +│ │ ├── 1089-134686-0000_targets.npy +│ │ ├── [...] +│ ├── train-clean-100 +│ │ ├── 103-1240-0000_audio.npy +│ │ ├── 103-1240-0000_targets.npy +│ │ ├── [...] +│ ├── train-clean-360 +│ │ ├── 100-121669-0000_audio.npy +│ │ ├── 100-121669-0000_targets.npy +│ │ ├── [...] +│ ├── train-other-500 +│ │ ├── 1006-135212-0000_audio.npy +│ │ ├── 1006-135212-0000_targets.npy +│ │ ├── [...] +``` + +In total, it should contain 543,323 files (via `find -type f | wc -l`) for a total of 388 GB (via `du -sch librispeech/`). +
+ #### Training SPM Tokenizer - A simple sentence piece tokenizer is trained over librispeech training - data. This tokenizer is then used in later preprocessing step to tokenize transcripts. + +During the above commands, a simple sentence piece tokenizer is trained over librispeech training data. +This tokenizer is then used in later preprocessing step to tokenize transcripts. This command generates `spm_model.vocab` file in `$DATA_DIR/librispeech`: + ```bash python3 librispeech_tokenizer.py --train --data_dir=$DATA_DIR/librispeech ``` The trained tokenizer can be loaded back to do sanity check by tokenizing + de-tokenizing a constant string: + ```bash librispeech_tokenizer.py --data_dir=$DATA_DIR/librispeech ``` #### Preprocessing Script + The preprocessing script will generate `.npy` files for audio data, `features.csv` which has paths to saved audio `.npy`, and `trans.csv` which has paths to `features.csv` and transcription data. ```bash python3 librispeech_preprocess.py --data_dir=$DATA_DIR/librispeech --tokenizer_vocab_path=$DATA_DIR/librispeech/spm_model.vocab ``` - - - diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index f9ee2f138..b22d352ce 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -26,17 +26,19 @@ Criteo 1TB download size: ~350GB Criteo 1TB final disk size: ~1TB -FastMRI download size: -FastMRI final disk size: -LibriSpeech download size: -LibriSpeech final disk size: -OGBG download size: -OGBG final disk size: -WMT download size: (1.58 GiB + ) = -WMT final disk size: +FastMRI download size: ~90GB +FastMRI final disk size: ~110GB +ImageNet download size: ~150GB +ImageNet final disk size: ~150GB +LibriSpeech download size: ~60GB +LibriSpeech final disk size: ~350GB +OGBG download size: ~37MB +OGBG final disk size: ~800MB +WMT download size: ~3GB +WMT final disk size: ~3GB _______________________ -Total download size: -Total disk size: +Total download size: ~650GB +Total disk size: ~1.1TB Some datasets require signing a form before downloading: @@ -49,8 +51,8 @@ Register on https://image-net.org/ and run this script with the links to the ILSVRC2012 train and validation images. -Note for tfds ImageNet, you may have to increase the max number of files allowed -open at once using `ulimit -n 8192`. +Note for tfds ImageNet, you may have to increase the max number of files +allowed open at once using `ulimit -n 8192`. Example command: @@ -77,7 +79,6 @@ import functools import os -import resource import shutil import subprocess import tarfile @@ -324,7 +325,7 @@ def download_criteo1tb(data_dir, unzipped_path = os.path.join(criteo_dir, f'day_{day}.csv') unzipped_paths.append(unzipped_path) split_path = os.path.join(criteo_dir, f'day_{day}_') - split_cmd = ('split -a 3 -d -l 5000000 --additional-suffix=.csv ' + split_cmd = ('split -a 2 -d -l 5000000 ' f'"{unzipped_path}" "{split_path}"') logging.info(f'Running Criteo 1TB split command:\n{split_cmd}') batch_processes.append(subprocess.Popen(split_cmd, shell=True)) @@ -334,6 +335,7 @@ def download_criteo1tb(data_dir, def download_cifar(data_dir, framework): + data_dir = os.path.join(data_dir, 'cifar10') if framework == 'jax': tfds.builder('cifar10:3.0.2', data_dir=data_dir).download_and_prepare() elif framework == 'pytorch': @@ -397,34 +399,39 @@ def extract(source, dest, mode='r:xz'): tar.close() -def setup_fastmri(data_dir, src_data_dir): - - train_tar_file_path = os.path.join(src_data_dir, FASTMRI_TRAIN_TAR_FILENAME) - val_tar_file_path = os.path.join(src_data_dir, FASTMRI_VAL_TAR_FILENAME) - test_tar_file_path = os.path.join(src_data_dir, FASTMRI_TEST_TAR_FILENAME) - - # Make train, val and test subdirectories - fastmri_data_dir = os.path.join(data_dir, 'fastmri') - train_data_dir = os.path.join(fastmri_data_dir, 'train') - os.makedirs(train_data_dir, exist_ok=True) - val_data_dir = os.path.join(fastmri_data_dir, 'val') - os.makedirs(val_data_dir, exist_ok=True) - test_data_dir = os.path.join(fastmri_data_dir, 'test') - os.makedirs(test_data_dir, exist_ok=True) +def setup_fastmri(data_dir): + train_tar_file_path = os.path.join(data_dir, FASTMRI_TRAIN_TAR_FILENAME) + val_tar_file_path = os.path.join(data_dir, FASTMRI_VAL_TAR_FILENAME) + test_tar_file_path = os.path.join(data_dir, FASTMRI_TEST_TAR_FILENAME) # Unzip tar file into subdirectories - logging.info('Unzipping {} to {}'.format(train_tar_file_path, train_data_dir)) - extract(train_tar_file_path, train_data_dir) - logging.info('Unzipping {} to {}'.format(val_tar_file_path, val_data_dir)) - extract(val_tar_file_path, val_data_dir) - logging.info('Unzipping {} to {}'.format(test_tar_file_path, test_data_dir)) - extract(test_tar_file_path, test_data_dir) - logging.info('Set up fastMRI dataset complete') + logging.info('Unzipping {} to {}'.format(train_tar_file_path, data_dir)) + extract(train_tar_file_path, data_dir) + logging.info('Unzipping {} to {}'.format(val_tar_file_path, data_dir)) + extract(val_tar_file_path, data_dir) + logging.info('Unzipping {} to {}'.format(test_tar_file_path, data_dir)) + extract(test_tar_file_path, data_dir) logging.info('Extraction completed!') + # Rename folders to match what the workload expects + os.rename( + os.path.join(data_dir, "singlecoil_train"), + os.path.join(data_dir, "knee_singlecoil_train"), + ) + os.rename( + os.path.join(data_dir, "singlecoil_val"), + os.path.join(data_dir, "knee_singlecoil_val"), + ) + os.rename( + os.path.join(data_dir, "singlecoil_test"), + os.path.join(data_dir, "knee_singlecoil_test"), + ) + logging.info("Set up fastMRI dataset complete") + def download_imagenet(data_dir, imagenet_train_url, imagenet_val_url): """Downloads and returns the download dir.""" + data_dir = os.path.join(data_dir, 'imagenet') imagenet_train_filepath = os.path.join(data_dir, IMAGENET_TRAIN_TAR_FILENAME) imagenet_val_filepath = os.path.join(data_dir, IMAGENET_VAL_TAR_FILENAME) @@ -456,6 +463,7 @@ def download_imagenet(data_dir, imagenet_train_url, imagenet_val_url): def setup_imagenet(data_dir, framework=None): + data_dir = os.path.join(data_dir, 'imagenet') if framework == 'jax': setup_imagenet_jax(data_dir) @@ -580,19 +588,21 @@ def download_imagenet_v2(data_dir): data_dir=data_dir).download_and_prepare() -def download_librispeech(dataset_dir, tmp_dir): +def download_librispeech(data_dir, tmp_dir): # After extraction the result is a folder named Librispeech containing audio # files in .flac format along with transcripts containing name of audio file # and corresponding transcription. tmp_librispeech_dir = os.path.join(tmp_dir, 'librispeech') extracted_data_dir = os.path.join(tmp_librispeech_dir, 'LibriSpeech') - final_data_dir = os.path.join(dataset_dir, 'librispeech') + final_data_dir = os.path.join(data_dir, 'librispeech') _maybe_mkdir(tmp_librispeech_dir) _maybe_mkdir(final_data_dir) for split in ['dev', 'test']: for version in ['clean', 'other']: + if split == 'test' and version == 'other': + continue wget_cmd = ( f'wget --directory-prefix={tmp_librispeech_dir} ' f'http://www.openslr.org/resources/12/{split}-{version}.tar.gz') @@ -617,10 +627,13 @@ def download_librispeech(dataset_dir, tmp_dir): f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}', shell=True).communicate() - tokenizer_vocab_path = os.path.join(extracted_data_dir, 'spm_model.vocab') + tokenizer_vocab_path = os.path.join(final_data_dir, 'spm_model.vocab') if not os.path.exists(tokenizer_vocab_path): - librispeech_tokenizer.run(train=True, data_dir=extracted_data_dir) + librispeech_tokenizer.run( + train=True, + input_dir=extracted_data_dir, + tokenizer_vocab_path=tokenizer_vocab_path) librispeech_preprocess.run( input_dir=extracted_data_dir, @@ -629,6 +642,7 @@ def download_librispeech(dataset_dir, tmp_dir): def download_mnist(data_dir): + data_dir = os.path.join(data_dir, 'MNIST') # Capitalization to match PyTorch tfds.builder('mnist', data_dir=data_dir).download_and_prepare() @@ -714,9 +728,8 @@ def main(_): raise ValueError( 'Please specify either jax or pytorch framework through framework ' 'flag.') - imagenet_data_dir = os.path.join(data_dir, 'imagenet') - download_imagenet(imagenet_data_dir, imagenet_train_url, imagenet_val_url) - setup_imagenet(imagenet_data_dir, framework=FLAGS.framework) + download_imagenet(data_dir, imagenet_train_url, imagenet_val_url) + setup_imagenet(data_dir, framework=FLAGS.framework) if FLAGS.all or FLAGS.librispeech: logging.info('Downloading Librispeech...') diff --git a/datasets/librispeech_preprocess.py b/datasets/librispeech_preprocess.py index acdaa8e98..a8c5cae1d 100644 --- a/datasets/librispeech_preprocess.py +++ b/datasets/librispeech_preprocess.py @@ -31,8 +31,7 @@ 'train-clean-100': 28539, 'train-clean-360': 104014, 'train-other-500': 148688, - 'test-clean': 2620, - 'test-other': 2939, + 'test-clean': 2620, # 'test-other': 2939, 'dev-clean': 2703, 'dev-other': 2864, } @@ -153,8 +152,7 @@ def run(input_dir, output_dir, tokenizer_vocab_path): 'train-other-500', 'dev-clean', 'dev-other', - 'test-clean', - 'test-other', + 'test-clean', # 'test-other', ] for subset in subset_list: logging.info('Processing split = %s...', subset) diff --git a/datasets/librispeech_tokenizer.py b/datasets/librispeech_tokenizer.py index e701d59d4..2f559752a 100644 --- a/datasets/librispeech_tokenizer.py +++ b/datasets/librispeech_tokenizer.py @@ -108,17 +108,16 @@ def load_tokenizer(model_filepath): return sp_tokenizer -def run(train, data_dir): - logging.info('Data dir: %s', data_dir) - vocab_path = os.path.join(data_dir, 'spm_model.vocab') - logging.info('vocab_path = ', vocab_path) +def run(train, input_dir, tokenizer_vocab_path): + logging.info('Data dir: %s', input_dir) + logging.info('vocab_path = %s', tokenizer_vocab_path) if train: logging.info('Training...') splits = ['train-clean-100'] - train_tokenizer(data_dir, splits, model_path=vocab_path) + train_tokenizer(input_dir, splits, model_path=tokenizer_vocab_path) else: - tokenizer = load_tokenizer(vocab_path) + tokenizer = load_tokenizer(tokenizer_vocab_path) test_input = 'OPEN SOURCE ROCKS' tokens = tokenizer.tokenize(test_input) detokenized = tokenizer.detokenize(tokens).numpy().decode('utf-8') diff --git a/docker/Singularity.def b/docker/Singularity.def new file mode 100644 index 000000000..5f5c31d60 --- /dev/null +++ b/docker/Singularity.def @@ -0,0 +1,76 @@ +Bootstrap: docker +From: nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 +Stage: spython-base + +%post +# Dockerfile for AlgoPerf environment. +# To build Docker image with only Jax GPU installed: +# docker build -t --build-arg framework=jax +# To build Docker image with Pytorch GPU installed: +# docker build -t --build-arg framework=pytorch + +# To build Docker image + +# Installing machine packages +echo "Setting up machine" +apt-get update +apt-get install -y curl tar +DEBIAN_FRONTEND=noninteractive apt-get install -y git python3 pip wget ffmpeg +apt-get install libtcmalloc-minimal4 +apt-get install unzip +apt-get install pigz +export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4 + +# Install GCP tools +echo "Setting up gsutil" +curl -O https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-413.0.0-linux-x86_64.tar.gz +tar -xf google-cloud-cli-413.0.0-linux-x86_64.tar.gz +yes | ./google-cloud-sdk/install.sh + +# Directory setup for input and output +echo "Setting up directories for data and experiment_runs" +mkdir -p data/ +mkdir -p experiment_runs/ + +# Install Algorithmic efficiency repo +echo "Setting up algorithmic_efficiency repo" +branch="main" +framework="both" +git_url=https://github.com/mlcommons/algorithmic-efficiency.git +git clone $git_url && cd /algorithmic-efficiency +cd /algorithmic-efficiency && git checkout $branch + +cd /algorithmic-efficiency && pip install -e '.[full]' + +if [ "$framework" = "jax" ] ; then \ +echo "Installing Jax GPU" \ +&& cd /algorithmic-efficiency \ +&& pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ +&& pip install -e '.[pytorch_cpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ +elif [ "$framework" = "pytorch" ] ; then \ +echo "Installing Pytorch GPU" \ +&& cd /algorithmic-efficiency \ +&& pip install -e '.[jax_cpu]' \ +&& pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ +elif [ "$framework" = "both" ] ; then \ +echo "Installing Jax GPU and Pytorch GPU" \ +&& cd /algorithmic-efficiency \ +&& pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ +&& pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ +else \ +echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \ +&& exit 1 ; \ +fi + +cd /algorithmic-efficiency && pip install -e '.[wandb]' + +cd /algorithmic-efficiency && git fetch origin +cd /algorithmic-efficiency && git pull + +# Todo: remove this, this is temporary for developing +chmod a+x /algorithmic-efficiency/docker/scripts/startup.sh + +%runscript +exec bash /algorithmic-efficiency/docker/scripts/startup.sh "$@" +%startscript +exec bash /algorithmic-efficiency/docker/scripts/startup.sh "$@" \ No newline at end of file diff --git a/docker/scripts/singularity_converter.py b/docker/scripts/singularity_converter.py new file mode 100644 index 000000000..48c521009 --- /dev/null +++ b/docker/scripts/singularity_converter.py @@ -0,0 +1,48 @@ +""" +This script is a modification of the +``spython recipe Dockerfile &> Singularity.def`` command, implemented here: +github.com/singularityhub/singularity-cli/blob/master/spython/client/recipe.py + +It converts the Docker recipe to Singularity, but suppressing any %files +command. Usage example: + +python singularity_converter.py -i Dockerfile -o Singularity.def +""" + +import argparse + +from spython.main.parse.parsers import get_parser +from spython.main.parse.writers import get_writer + +# globals +ENTRY_POINT = "/bin/bash" # seems to be a good default +FORCE = False # seems to be a good default +# +parser = argparse.ArgumentParser(description="Custom Singularity converter") +parser.add_argument( + "-i", "--input", type=str, help="Docker input path", default="Dockerfile") +parser.add_argument( + "-o", + "--output", + type=str, + help="Singularity output path", + default="Singularity.def", +) +args = parser.parse_args() +INPUT_DOCKERFILE_PATH = args.input +OUTPUT_SINGULARITY_PATH = args.output + +# create Docker parser and Singularity writer +parser = get_parser("docker") +writer = get_writer("singularity") + +# parse Dockerfile into Singularity and suppress %files commands +recipeParser = parser(INPUT_DOCKERFILE_PATH) +recipeWriter = writer(recipeParser.recipe) +(key,) = recipeParser.recipe.keys() +recipeWriter.recipe[key].files = [] + +# convert to string and save to output file +result = recipeWriter.convert(runscript=ENTRY_POINT, force=FORCE) +with open(OUTPUT_SINGULARITY_PATH, "w") as f: + f.write(result) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index be14ab498..30cb6b36b 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -113,13 +113,14 @@ done VALID_DATASETS=("criteo1tb" "imagenet" "fastmri" "ogbg" "librispeech" \ "wmt" "mnist") VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_resnet_gelu" \ - "imagenet_resnet_large_bn_init" "imagenet_vit" "fastmri" "ogbg" \ + "imagenet_resnet_large_bn_init" "imagenet_vit" "imagenet_vit_glu" \ + "imagenet_vit_post_ln" "imagenet_vit_map" "fastmri" "ogbg" \ + "criteo1tb_resnet" "criteo1tb_layernorm" "criteo1tb_embed_init" \ "wmt" "wmt_post_ln" "wmt_attention_temp" "wmt_glu_tanh" \ "librispeech_deepspeech" "librispeech_conformer" "mnist" \ - "criteo1tb_resnet" "criteo1tb_layernorm" "criteo1tb_embed_init" \ "conformer_layernorm" "conformer_attention_temperature" \ "conformer_gelu" "fastmri_model_size" "fastmri_tanh" \ - "fastmri_layernorm") + "fastmri_layernorm" "ogbg_gelu" "ogbg_silu" "ogbg_model_size") # Set data and experiment paths ROOT_DATA_BUCKET="gs://mlcommons-data" diff --git a/reference_algorithms/prize_qualification_baselines/README.md b/reference_algorithms/prize_qualification_baselines/README.md index 8276887da..100555964 100644 --- a/reference_algorithms/prize_qualification_baselines/README.md +++ b/reference_algorithms/prize_qualification_baselines/README.md @@ -50,8 +50,8 @@ torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc The prize qualification baseline submissionss for jax are: -- `reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py` -- `feference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py` +- `reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py` +- `feference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py` Example command: @@ -62,7 +62,7 @@ python3 submission_runner.py \ --experiment_dir= \ --experiment_name= \ --workload= \ - --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py \ + --submission_path=reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py \ --tuning_ruleset=self ``` @@ -70,8 +70,8 @@ python3 submission_runner.py \ The prize qualification baseline submissionss for PyTorch are: -- `reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py` -- `feference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py` +- `reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py` +- `feference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py` Example command: @@ -82,6 +82,6 @@ torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc --experiment_dir= \ --experiment_name=t \ --workload=\ - --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py \ + --submission_path=reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py \ --tuning_ruleset=self ``` diff --git a/setup.cfg b/setup.cfg index a00da91fc..20139d4c0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -67,7 +67,6 @@ full = %(ogbg)s %(librispeech_conformer)s %(wmt)s - pydub==0.25.1 # All workloads plus development dependencies full_dev = @@ -98,6 +97,7 @@ ogbg = librispeech_conformer = sentencepiece==0.1.99 tensorflow-text==2.12.1 + pydub==0.25.1 wmt = sentencepiece==0.1.99 diff --git a/tests/modeldiffs/imagenet_vit/compare.py b/tests/modeldiffs/imagenet_vit/compare.py index 1022b5b54..bf7d6dfa5 100644 --- a/tests/modeldiffs/imagenet_vit/compare.py +++ b/tests/modeldiffs/imagenet_vit/compare.py @@ -16,7 +16,7 @@ def key_transform(k): if 'Conv' in k[0]: - k = ('embedding', *k[1:]) + k = ('conv_patch_extract', *k[1:]) elif k[0] == 'Linear_0': k = ('pre_logits', *k[1:]) elif k[0] == 'Linear_1': diff --git a/tests/modeldiffs/imagenet_vit/glu_compare.py b/tests/modeldiffs/imagenet_vit/glu_compare.py new file mode 100644 index 000000000..444f1230a --- /dev/null +++ b/tests/modeldiffs/imagenet_vit/glu_compare.py @@ -0,0 +1,52 @@ +import os + +from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.imagenet_vit.compare import key_transform + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ + ImagenetVitGluWorkload as JaxWorkload +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ + ImagenetVitGluWorkload as PytWorkload + +sd_transform = None + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + # Test outputs for identical weights and inputs. + image = torch.randn(2, 3, 224, 224) + + jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} + pyt_batch = {'inputs': image} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=None, + ) diff --git a/tests/modeldiffs/imagenet_vit/post_ln_compare.py b/tests/modeldiffs/imagenet_vit/post_ln_compare.py new file mode 100644 index 000000000..8bf0bef7e --- /dev/null +++ b/tests/modeldiffs/imagenet_vit/post_ln_compare.py @@ -0,0 +1,52 @@ +import os + +from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.imagenet_vit.compare import key_transform + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ + ImagenetViTPostLNWorkload as JaxWorkload +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ + ImagenetViTPostLNWorkload as PytWorkload + +sd_transform = None + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + # Test outputs for identical weights and inputs. + image = torch.randn(2, 3, 224, 224) + + jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} + pyt_batch = {'inputs': image} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=None, + ) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index f091d3d4f..56316ba12 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -12,30 +12,50 @@ from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ OgbgWorkload as JaxWorkload from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ - OgbgWorkload as PytWorkload + OgbgWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff +# Todo: refactor tests to use workload properties in cleaner way +hidden_dims = len(JaxWorkload().hidden_dims) +num_graphs = JaxWorkload().num_message_passing_steps + def key_transform(k): new_key = [] bn = False ln = False + graph_network = False + graph_index = 0 + seq_index = 0 for i in k: bn = bn or 'BatchNorm' in i ln = ln or 'LayerNorm' in i - if 'ModuleList' in i: + graph_network = graph_network or 'GraphNetwork' in i + if 'Sequential' in i: + seq_index = int(i.split('_')[1]) continue - if 'CustomBatchNorm' in i: + if 'GraphNetwork' in i: + graph_index = int(i.split('_')[1]) continue if 'Linear' in i: - if 'NonDynamicallyQuantizableLinear' in i: - i = 'out' - else: - i = i.replace('Linear', 'Dense') - elif 'Conv1d' in i: - i = i.replace('Conv1d', 'Conv') - elif 'MHSAwithQS' in i: - i = i.replace('MHSAwithQS', 'SelfAttention') + layer_index = int(i.split('_')[1]) + if graph_network: + count = ( + graph_index * 3 * hidden_dims + seq_index * hidden_dims + + layer_index) + i = 'Dense_' + str(count) + elif layer_index == 0: + i = 'node_embedding' + elif layer_index == 1: + i = 'edge_embedding' + elif layer_index == 2: + count = num_graphs * 3 * hidden_dims + i = 'Dense_' + str(count) + elif 'LayerNorm' in i: + layer_index = int(i.split('_')[1]) + count = ( + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index) + i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: i = i.replace('weight', 'scale') @@ -47,22 +67,9 @@ def key_transform(k): def sd_transform(sd): # pylint: disable=locally-disabled, modified-iterating-dict, consider-using-dict-items - keys = list(sd.keys()) out = {} - for k in keys: - new_key = k - if len(k) == 5: - _, gn_id, seq_id = k[:3] - gn_id = int(gn_id.split('_')[1]) - seq_id = int(seq_id.split('_')[1]) - if 'LayerNorm' in k[3]: - new_key = (k[3].replace('0', f'{gn_id*3+seq_id}'), k[4]) - else: - new_key = (k[3].replace('0', f'{gn_id*3+seq_id+2}'), k[4]) - elif len(k) == 2 and k[0] == 'Dense_2': - new_key = ('Dense_17', k[1]) - out[new_key] = sd[k] - + for k in sd: + out[k] = sd[k] return out @@ -70,7 +77,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() pyt_batch = dict( n_node=torch.LongTensor([5]), diff --git a/tests/modeldiffs/ogbg_gelu/__init__.py b/tests/modeldiffs/ogbg_gelu/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py new file mode 100644 index 000000000..b58bcd461 --- /dev/null +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -0,0 +1,120 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import jraph +import numpy as np +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ + OgbgGeluWorkload as JaxWorkload +from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ + OgbgGeluWorkload as PyTorchWorkload +from tests.modeldiffs.diff import out_diff + +# Todo: refactor tests to use workload properties in cleaner way +hidden_dims = len(JaxWorkload().hidden_dims) +num_graphs = JaxWorkload().num_message_passing_steps + + +def key_transform(k): + new_key = [] + bn = False + ln = False + graph_network = False + graph_index = 0 + seq_index = 0 + for i in k: + bn = bn or 'BatchNorm' in i + ln = ln or 'LayerNorm' in i + graph_network = graph_network or 'GraphNetwork' in i + if 'Sequential' in i: + seq_index = int(i.split('_')[1]) + continue + if 'GraphNetwork' in i: + graph_index = int(i.split('_')[1]) + continue + if 'Linear' in i: + layer_index = int(i.split('_')[1]) + if graph_network: + count = ( + graph_index * 3 * hidden_dims + seq_index * hidden_dims + + layer_index) + i = 'Dense_' + str(count) + elif layer_index == 0: + i = 'node_embedding' + elif layer_index == 1: + i = 'edge_embedding' + elif layer_index == 2: + count = num_graphs * 3 * hidden_dims + i = 'Dense_' + str(count) + elif 'LayerNorm' in i: + layer_index = int(i.split('_')[1]) + count = ( + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index) + i = 'LayerNorm_' + str(count) + elif 'weight' in i: + if bn or ln: + i = i.replace('weight', 'scale') + else: + i = i.replace('weight', 'kernel') + new_key.append(i) + return tuple(new_key) + + +def sd_transform(sd): + # pylint: disable=locally-disabled, modified-iterating-dict, consider-using-dict-items + out = {} + for k in sd: + out[k] = sd[k] + return out + + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PyTorchWorkload() + + pyt_batch = dict( + n_node=torch.LongTensor([5]), + n_edge=torch.LongTensor([5]), + nodes=torch.randn(5, 9), + edges=torch.randn(5, 3), + globals=torch.randn(1, 128), + senders=torch.LongTensor(list(range(5))), + receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)])) + + jax_batch = {k: np.array(v) for k, v in pyt_batch.items()} + + # Test outputs for identical weights and inputs. + graph_j = jraph.GraphsTuple(**jax_batch) + graph_p = jraph.GraphsTuple(**pyt_batch) + + jax_batch = {'inputs': graph_j} + pyt_batch = {'inputs': graph_p} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None) diff --git a/tests/modeldiffs/ogbg_model_size/__init__.py b/tests/modeldiffs/ogbg_model_size/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py new file mode 100644 index 000000000..62443bbb5 --- /dev/null +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -0,0 +1,120 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import jraph +import numpy as np +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ + OgbgModelSizeWorkload as JaxWorkload +from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ + OgbgModelSizeWorkload as PyTorchWorkload +from tests.modeldiffs.diff import out_diff + +# Todo: refactor tests to use workload properties in cleaner way +hidden_dims = len(JaxWorkload().hidden_dims) +num_graphs = JaxWorkload().num_message_passing_steps + + +def key_transform(k): + new_key = [] + bn = False + ln = False + graph_network = False + graph_index = 0 + seq_index = 0 + for i in k: + bn = bn or 'BatchNorm' in i + ln = ln or 'LayerNorm' in i + graph_network = graph_network or 'GraphNetwork' in i + if 'Sequential' in i: + seq_index = int(i.split('_')[1]) + continue + if 'GraphNetwork' in i: + graph_index = int(i.split('_')[1]) + continue + if 'Linear' in i: + layer_index = int(i.split('_')[1]) + if graph_network: + count = ( + graph_index * 3 * hidden_dims + seq_index * hidden_dims + + layer_index) + i = 'Dense_' + str(count) + elif layer_index == 0: + i = 'node_embedding' + elif layer_index == 1: + i = 'edge_embedding' + elif layer_index == 2: + count = num_graphs * 3 * hidden_dims + i = 'Dense_' + str(count) + elif 'LayerNorm' in i: + layer_index = int(i.split('_')[1]) + count = ( + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index) + i = 'LayerNorm_' + str(count) + elif 'weight' in i: + if bn or ln: + i = i.replace('weight', 'scale') + else: + i = i.replace('weight', 'kernel') + new_key.append(i) + return tuple(new_key) + + +def sd_transform(sd): + # pylint: disable=locally-disabled, modified-iterating-dict, consider-using-dict-items + out = {} + for k in sd: + out[k] = sd[k] + return out + + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PyTorchWorkload() + + pyt_batch = dict( + n_node=torch.LongTensor([5]), + n_edge=torch.LongTensor([5]), + nodes=torch.randn(5, 9), + edges=torch.randn(5, 3), + globals=torch.randn(1, 128), + senders=torch.LongTensor(list(range(5))), + receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)])) + + jax_batch = {k: np.array(v) for k, v in pyt_batch.items()} + + # Test outputs for identical weights and inputs. + graph_j = jraph.GraphsTuple(**jax_batch) + graph_p = jraph.GraphsTuple(**pyt_batch) + + jax_batch = {'inputs': graph_j} + pyt_batch = {'inputs': graph_p} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None) diff --git a/tests/modeldiffs/ogbg_silu/__init__.py b/tests/modeldiffs/ogbg_silu/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py new file mode 100644 index 000000000..2922b7046 --- /dev/null +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -0,0 +1,120 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import jraph +import numpy as np +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ + OgbgSiluWorkload as JaxWorkload +from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ + OgbgSiluWorkload as PyTorchWorkload +from tests.modeldiffs.diff import out_diff + +# Todo: refactor tests to use workload properties in cleaner way +hidden_dims = len(JaxWorkload().hidden_dims) +num_graphs = JaxWorkload().num_message_passing_steps + + +def key_transform(k): + new_key = [] + bn = False + ln = False + graph_network = False + graph_index = 0 + seq_index = 0 + for i in k: + bn = bn or 'BatchNorm' in i + ln = ln or 'LayerNorm' in i + graph_network = graph_network or 'GraphNetwork' in i + if 'Sequential' in i: + seq_index = int(i.split('_')[1]) + continue + if 'GraphNetwork' in i: + graph_index = int(i.split('_')[1]) + continue + if 'Linear' in i: + layer_index = int(i.split('_')[1]) + if graph_network: + count = ( + graph_index * 3 * hidden_dims + seq_index * hidden_dims + + layer_index) + i = 'Dense_' + str(count) + elif layer_index == 0: + i = 'node_embedding' + elif layer_index == 1: + i = 'edge_embedding' + elif layer_index == 2: + count = num_graphs * 3 * hidden_dims + i = 'Dense_' + str(count) + elif 'LayerNorm' in i: + layer_index = int(i.split('_')[1]) + count = ( + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index) + i = 'LayerNorm_' + str(count) + elif 'weight' in i: + if bn or ln: + i = i.replace('weight', 'scale') + else: + i = i.replace('weight', 'kernel') + new_key.append(i) + return tuple(new_key) + + +def sd_transform(sd): + # pylint: disable=locally-disabled, modified-iterating-dict, consider-using-dict-items + out = {} + for k in sd: + out[k] = sd[k] + return out + + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PyTorchWorkload() + + pyt_batch = dict( + n_node=torch.LongTensor([5]), + n_edge=torch.LongTensor([5]), + nodes=torch.randn(5, 9), + edges=torch.randn(5, 3), + globals=torch.randn(1, 128), + senders=torch.LongTensor(list(range(5))), + receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)])) + + jax_batch = {k: np.array(v) for k, v in pyt_batch.items()} + + # Test outputs for identical weights and inputs. + graph_j = jraph.GraphsTuple(**jax_batch) + graph_p = jraph.GraphsTuple(**pyt_batch) + + jax_batch = {'inputs': graph_j} + pyt_batch = {'inputs': graph_p} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None)