From cce02d849b808674dc92485bc666ace5a9c69e31 Mon Sep 17 00:00:00 2001 From: andres-fr Date: Thu, 26 Oct 2023 04:55:04 +0200 Subject: [PATCH 01/86] Fixed Issue #552 (Singularity build crashes) by updating Singularity.def and its creation process. Updated respective section in README --- README.md | 24 ++++---- docker/Singularity.def | 76 +++++++++++++++++++++++++ docker/scripts/singularity_converter.py | 45 +++++++++++++++ 3 files changed, 134 insertions(+), 11 deletions(-) create mode 100644 docker/Singularity.def create mode 100644 docker/scripts/singularity_converter.py diff --git a/README.md b/README.md index 1be096c2e..70450722f 100644 --- a/README.md +++ b/README.md @@ -135,23 +135,25 @@ To use the Docker container as an interactive virtual environment, you can run a To run a submission end-to-end in a containerized environment see [Getting Started Document](./getting_started.md#run-your-submission-in-a-docker-container). ### Using Singularity/Apptainer instead of Docker -Since many compute clusters don't allow the usage of Docker due to securtiy concerns and instead encourage the use of [Singularity/Apptainer](https://github.com/apptainer/apptainer) (formerly Singularity, now called Apptainer), we also provide instructions on how to build an Apptainer container based on the here provided Dockerfile. - -To convert the Dockerfile into an Apptainer definition file, we will use [spython](https://github.com/singularityhub/singularity-cli): +Since many compute clusters don't allow the usage of Docker due to securtiy concerns and instead encourage the use of [Singularity/Apptainer](https://github.com/apptainer/apptainer) (formerly Singularity, now called Apptainer), we also provide an Apptainer recipe that can be used to build an image by running ```bash -pip3 install spython -cd algorithmic-efficiency/docker -spython recipe Dockerfile &> Singularity.def +singularity build --fakeroot .sif Singularity.def ``` -Now we can build the Apptainer image by running +Then, to start a shell session with GPU support (by using the `--nv` flag), we can run ```bash -singularity build --fakeroot .sif Singularity.def +singularity shell --bind $HOME/data:/data,$HOME/experiment_runs:/experiment_runs \ + --nv .sif ``` -To start a shell session with GPU support (by using the `--nv` flag), we can run + +Note the `--bind` flag which, similarly to Docker, allows to bind specific paths on the host system and the container, as explained [here](https://docs.sylabs.io/guides/3.7/user-guide/bind_paths_and_mounts.html). + +Also note that `Singularity.def` was automatically generated from the `Dockerfile` using [spython](https://github.com/singularityhub/singularity-cli), as follows: ```bash -singularity shell --nv .sif +pip3 install spython +cd algorithmic-efficiency/docker +python scripts/singularity_converter.py -i Dockerfile -o Singularity.def ``` -Similarly to Docker, Apptainer allows you to bind specific paths on the host system and the container by specifying the `--bind` flag, as explained [here](https://docs.sylabs.io/guides/3.7/user-guide/bind_paths_and_mounts.html). +Users that wish to customize their images are invited to check the respective files. # Getting Started For instructions on developing and scoring your own algorithm in the benchmark see [Getting Started Document](./getting_started.md). diff --git a/docker/Singularity.def b/docker/Singularity.def new file mode 100644 index 000000000..5f5c31d60 --- /dev/null +++ b/docker/Singularity.def @@ -0,0 +1,76 @@ +Bootstrap: docker +From: nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 +Stage: spython-base + +%post +# Dockerfile for AlgoPerf environment. +# To build Docker image with only Jax GPU installed: +# docker build -t --build-arg framework=jax +# To build Docker image with Pytorch GPU installed: +# docker build -t --build-arg framework=pytorch + +# To build Docker image + +# Installing machine packages +echo "Setting up machine" +apt-get update +apt-get install -y curl tar +DEBIAN_FRONTEND=noninteractive apt-get install -y git python3 pip wget ffmpeg +apt-get install libtcmalloc-minimal4 +apt-get install unzip +apt-get install pigz +export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4 + +# Install GCP tools +echo "Setting up gsutil" +curl -O https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-413.0.0-linux-x86_64.tar.gz +tar -xf google-cloud-cli-413.0.0-linux-x86_64.tar.gz +yes | ./google-cloud-sdk/install.sh + +# Directory setup for input and output +echo "Setting up directories for data and experiment_runs" +mkdir -p data/ +mkdir -p experiment_runs/ + +# Install Algorithmic efficiency repo +echo "Setting up algorithmic_efficiency repo" +branch="main" +framework="both" +git_url=https://github.com/mlcommons/algorithmic-efficiency.git +git clone $git_url && cd /algorithmic-efficiency +cd /algorithmic-efficiency && git checkout $branch + +cd /algorithmic-efficiency && pip install -e '.[full]' + +if [ "$framework" = "jax" ] ; then \ +echo "Installing Jax GPU" \ +&& cd /algorithmic-efficiency \ +&& pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ +&& pip install -e '.[pytorch_cpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ +elif [ "$framework" = "pytorch" ] ; then \ +echo "Installing Pytorch GPU" \ +&& cd /algorithmic-efficiency \ +&& pip install -e '.[jax_cpu]' \ +&& pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ +elif [ "$framework" = "both" ] ; then \ +echo "Installing Jax GPU and Pytorch GPU" \ +&& cd /algorithmic-efficiency \ +&& pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ +&& pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ +else \ +echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \ +&& exit 1 ; \ +fi + +cd /algorithmic-efficiency && pip install -e '.[wandb]' + +cd /algorithmic-efficiency && git fetch origin +cd /algorithmic-efficiency && git pull + +# Todo: remove this, this is temporary for developing +chmod a+x /algorithmic-efficiency/docker/scripts/startup.sh + +%runscript +exec bash /algorithmic-efficiency/docker/scripts/startup.sh "$@" +%startscript +exec bash /algorithmic-efficiency/docker/scripts/startup.sh "$@" \ No newline at end of file diff --git a/docker/scripts/singularity_converter.py b/docker/scripts/singularity_converter.py new file mode 100644 index 000000000..5834d10db --- /dev/null +++ b/docker/scripts/singularity_converter.py @@ -0,0 +1,45 @@ +""" +This script is a modification of the +``spython recipe Dockerfile &> Singularity.def`` command, implemented here: +github.com/singularityhub/singularity-cli/blob/master/spython/client/recipe.py + +It converts the Docker recipy to Singularity, but suppressing any %files +command. Usage example: + +python singularity_converter.py -i Dockerfile -o Singularity.def +""" + + +import argparse +# +import spython +from spython.main.parse.parsers import get_parser +from spython.main.parse.writers import get_writer + +# globals +ENTRY_POINT = "/bin/bash" # seems to be a good default +FORCE = False # seems to be a good default +# +parser = argparse.ArgumentParser(description="Custom Singularity converter") +parser.add_argument('-i', '--input', type=str, + help="Docker input path", default="Dockerfile") +parser.add_argument('-o', '--output', type=str, + help="Singularity output path", default="Singularity.def") +args = parser.parse_args() +INPUT_DOCKERFILE_PATH = args.input +OUTPUT_SINGULARITY_PATH = args.output + +# create Docker parser and Singularity writer +parser = get_parser("docker") +writer = get_writer("singularity") + +# parse Dockerfile into Singularity and suppress %files commands +recipeParser = parser(INPUT_DOCKERFILE_PATH) +recipeWriter = writer(recipeParser.recipe) +key, = recipeParser.recipe.keys() +recipeWriter.recipe[key].files = [] + +# convert to string and save to output file +result = recipeWriter.convert(runscript=ENTRY_POINT, force=FORCE) +with open(OUTPUT_SINGULARITY_PATH, "w") as f: + f.write(result) From 95f252a5216d8e8472befcb97fac08325e5420d8 Mon Sep 17 00:00:00 2001 From: andres-fr Date: Thu, 2 Nov 2023 11:57:17 +0100 Subject: [PATCH 02/86] replaced double quotes with single quotes in script --- docker/scripts/singularity_converter.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docker/scripts/singularity_converter.py b/docker/scripts/singularity_converter.py index 5834d10db..fa40744ae 100644 --- a/docker/scripts/singularity_converter.py +++ b/docker/scripts/singularity_converter.py @@ -17,21 +17,21 @@ from spython.main.parse.writers import get_writer # globals -ENTRY_POINT = "/bin/bash" # seems to be a good default +ENTRY_POINT = '/bin/bash' # seems to be a good default FORCE = False # seems to be a good default # -parser = argparse.ArgumentParser(description="Custom Singularity converter") +parser = argparse.ArgumentParser(description='Custom Singularity converter') parser.add_argument('-i', '--input', type=str, - help="Docker input path", default="Dockerfile") + help='Docker input path', default='Dockerfile') parser.add_argument('-o', '--output', type=str, - help="Singularity output path", default="Singularity.def") + help='Singularity output path', default='Singularity.def') args = parser.parse_args() INPUT_DOCKERFILE_PATH = args.input OUTPUT_SINGULARITY_PATH = args.output # create Docker parser and Singularity writer -parser = get_parser("docker") -writer = get_writer("singularity") +parser = get_parser('docker') +writer = get_writer('singularity') # parse Dockerfile into Singularity and suppress %files commands recipeParser = parser(INPUT_DOCKERFILE_PATH) @@ -41,5 +41,5 @@ # convert to string and save to output file result = recipeWriter.convert(runscript=ENTRY_POINT, force=FORCE) -with open(OUTPUT_SINGULARITY_PATH, "w") as f: +with open(OUTPUT_SINGULARITY_PATH, 'w') as f: f.write(result) From ca5d8e37a29d16d55fa709565622c66f8e67f6d8 Mon Sep 17 00:00:00 2001 From: andres-fr Date: Mon, 13 Nov 2023 15:29:17 +0100 Subject: [PATCH 03/86] solved isort linting issue in singularity_converter.py --- docker/scripts/singularity_converter.py | 29 +++++++++++++++---------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/docker/scripts/singularity_converter.py b/docker/scripts/singularity_converter.py index fa40744ae..b101f7c2c 100644 --- a/docker/scripts/singularity_converter.py +++ b/docker/scripts/singularity_converter.py @@ -11,35 +11,40 @@ import argparse -# -import spython + from spython.main.parse.parsers import get_parser from spython.main.parse.writers import get_writer # globals -ENTRY_POINT = '/bin/bash' # seems to be a good default +ENTRY_POINT = "/bin/bash" # seems to be a good default FORCE = False # seems to be a good default # -parser = argparse.ArgumentParser(description='Custom Singularity converter') -parser.add_argument('-i', '--input', type=str, - help='Docker input path', default='Dockerfile') -parser.add_argument('-o', '--output', type=str, - help='Singularity output path', default='Singularity.def') +parser = argparse.ArgumentParser(description="Custom Singularity converter") +parser.add_argument( + "-i", "--input", type=str, help="Docker input path", default="Dockerfile" +) +parser.add_argument( + "-o", + "--output", + type=str, + help="Singularity output path", + default="Singularity.def", +) args = parser.parse_args() INPUT_DOCKERFILE_PATH = args.input OUTPUT_SINGULARITY_PATH = args.output # create Docker parser and Singularity writer -parser = get_parser('docker') -writer = get_writer('singularity') +parser = get_parser("docker") +writer = get_writer("singularity") # parse Dockerfile into Singularity and suppress %files commands recipeParser = parser(INPUT_DOCKERFILE_PATH) recipeWriter = writer(recipeParser.recipe) -key, = recipeParser.recipe.keys() +(key,) = recipeParser.recipe.keys() recipeWriter.recipe[key].files = [] # convert to string and save to output file result = recipeWriter.convert(runscript=ENTRY_POINT, force=FORCE) -with open(OUTPUT_SINGULARITY_PATH, 'w') as f: +with open(OUTPUT_SINGULARITY_PATH, "w") as f: f.write(result) From 41dce0b3f1be78dc71d075f9b66b70113122206a Mon Sep 17 00:00:00 2001 From: andres-fr Date: Tue, 14 Nov 2023 04:37:20 +0100 Subject: [PATCH 04/86] fix typo: recpiy -> recipe --- docker/scripts/singularity_converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/scripts/singularity_converter.py b/docker/scripts/singularity_converter.py index b101f7c2c..ce7566e1f 100644 --- a/docker/scripts/singularity_converter.py +++ b/docker/scripts/singularity_converter.py @@ -3,7 +3,7 @@ ``spython recipe Dockerfile &> Singularity.def`` command, implemented here: github.com/singularityhub/singularity-cli/blob/master/spython/client/recipe.py -It converts the Docker recipy to Singularity, but suppressing any %files +It converts the Docker recipe to Singularity, but suppressing any %files command. Usage example: python singularity_converter.py -i Dockerfile -o Singularity.def From 05cd0b426c3517649bcb09af04e7d9ef01459b01 Mon Sep 17 00:00:00 2001 From: andres-fr Date: Mon, 27 Nov 2023 16:32:38 +0100 Subject: [PATCH 05/86] fixed yapf formatting for singularity_converter.py --- docker/scripts/singularity_converter.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/docker/scripts/singularity_converter.py b/docker/scripts/singularity_converter.py index ce7566e1f..48c521009 100644 --- a/docker/scripts/singularity_converter.py +++ b/docker/scripts/singularity_converter.py @@ -9,7 +9,6 @@ python singularity_converter.py -i Dockerfile -o Singularity.def """ - import argparse from spython.main.parse.parsers import get_parser @@ -21,8 +20,7 @@ # parser = argparse.ArgumentParser(description="Custom Singularity converter") parser.add_argument( - "-i", "--input", type=str, help="Docker input path", default="Dockerfile" -) + "-i", "--input", type=str, help="Docker input path", default="Dockerfile") parser.add_argument( "-o", "--output", @@ -47,4 +45,4 @@ # convert to string and save to output file result = recipeWriter.convert(runscript=ENTRY_POINT, force=FORCE) with open(OUTPUT_SINGULARITY_PATH, "w") as f: - f.write(result) + f.write(result) From 711f6308fd02e23fae0d27514e08b06bb1fd5059 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Fri, 8 Dec 2023 11:33:40 -0500 Subject: [PATCH 06/86] minor --- .../imagenet_jax/randaugment.py | 2 +- .../imagenet_vit/imagenet_jax/models.py | 103 +++++++++-- .../imagenet_vit/imagenet_jax/workload.py | 28 ++- .../imagenet_vit/imagenet_pytorch/models.py | 121 ++++++++++--- .../imagenet_vit/imagenet_pytorch/workload.py | 28 ++- .../workloads/imagenet_vit/workload.py | 15 ++ tests/modeldiffs/imagenet_vit/compare.py | 67 ++++++- tests/modeldiffs/imagenet_vit/compare_glu.py | 163 ++++++++++++++++++ .../imagenet_vit/compare_post_ln.py | 163 ++++++++++++++++++ 9 files changed, 644 insertions(+), 46 deletions(-) create mode 100644 tests/modeldiffs/imagenet_vit/compare_glu.py create mode 100644 tests/modeldiffs/imagenet_vit/compare_post_ln.py diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index 5f92b1482..8fa1c0789 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -8,7 +8,7 @@ import math import tensorflow as tf -from tensorflow_addons import image as contrib_image +# from tensorflow_addons import image as contrib_image # This signifies the max integer that the controller RNN could predict for the # augmentation scheme. diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py index ab5d1839e..4a97ee661 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py @@ -34,6 +34,7 @@ def posemb_sincos_2d(h: int, class MlpBlock(nn.Module): """Transformer MLP / feed-forward block.""" mlp_dim: Optional[int] = None # Defaults to 4x input dim. + use_glu: bool = False dropout_rate: float = 0.0 @nn.compact @@ -47,6 +48,13 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: d = x.shape[2] x = nn.Dense(self.mlp_dim or 4 * d, **inits)(x) x = nn.gelu(x) + + if self.use_glu: + y = nn.Dense( + self.mlp_dim, + **inits)(x) + x = x * y + x = nn.Dropout(rate=self.dropout_rate)(x, train) x = nn.Dense(d, **inits)(x) return x @@ -56,26 +64,47 @@ class Encoder1DBlock(nn.Module): """Single transformer encoder block (MHSA + MLP).""" mlp_dim: Optional[int] = None # Defaults to 4x input dim. num_heads: int = 12 + use_glu: bool = False + use_post_layer_norm: bool = False dropout_rate: float = 0.0 @nn.compact def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: - y = nn.LayerNorm(name='LayerNorm_0')(x) - y = nn.SelfAttention( - num_heads=self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - deterministic=train, - name='MultiHeadDotProductAttention_1')( - y) - y = nn.Dropout(rate=self.dropout_rate)(y, train) - x = x + y - - y = nn.LayerNorm(name='LayerNorm_2')(x) - y = MlpBlock( - mlp_dim=self.mlp_dim, dropout_rate=self.dropout_rate, - name='MlpBlock_3')(y, train) - y = nn.Dropout(rate=self.dropout_rate)(y, train) - x = x + y + if not self.use_post_layer_norm: + y = nn.LayerNorm(name='LayerNorm_0')(x) + y = nn.SelfAttention( + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + deterministic=train, + name='MultiHeadDotProductAttention_1')( + y) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + + y = nn.LayerNorm(name='LayerNorm_2')(x) + y = MlpBlock( + mlp_dim=self.mlp_dim, use_glu=self.use_glu, dropout_rate=self.dropout_rate, + name='MlpBlock_3')(y, train) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + else: + y = nn.SelfAttention( + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + deterministic=train, + name='MultiHeadDotProductAttention_1')( + x) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + x = nn.LayerNorm(name='LayerNorm_0')(x) + + y = MlpBlock( + mlp_dim=self.mlp_dim, use_glu=self.use_glu, dropout_rate=self.dropout_rate, + name='MlpBlock_3')(x, train) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + x = nn.LayerNorm(name='LayerNorm_2')(x) + return x @@ -85,6 +114,8 @@ class Encoder(nn.Module): mlp_dim: Optional[int] = None # Defaults to 4x input dim. num_heads: int = 12 dropout_rate: float = 0.0 + use_glu: bool = False + use_post_layer_norm: bool = False @nn.compact def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: @@ -94,9 +125,35 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: name=f'encoderblock_{lyr}', mlp_dim=self.mlp_dim, num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, dropout_rate=self.dropout_rate) x = block(x, train) - return nn.LayerNorm(name='encoder_layernorm')(x) + if not self.use_post_layer_norm: + return nn.LayerNorm(name='encoder_layernorm')(x) + else: + return x + + +class MAPHead(nn.Module): + """Multihead Attention Pooling.""" + mlp_dim: Optional[int] = None # Defaults to 4x input dim + num_heads: int = 12 + @nn.compact + def __call__(self, x): + n, _, d = x.shape + probe = self.param('probe', + nn.initializers.xavier_uniform(), + (1, 1, d), x.dtype) + probe = jnp.tile(probe, [n, 1, 1]) + + x = nn.MultiHeadDotProductAttention( + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform())(probe, x) + + y = nn.LayerNorm()(x) + x = x + MlpBlock(mlp_dim=self.mlp_dim)(y) + return x[:, 0] class ViT(nn.Module): @@ -112,6 +169,9 @@ class ViT(nn.Module): dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0. reinit: Optional[Sequence[str]] = None head_zeroinit: bool = True + use_glu: bool = False, + use_post_layer_norm: bool = False, + use_map: bool = False, def get_posemb(self, seqshape: tuple, @@ -145,11 +205,18 @@ def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor: depth=self.depth, mlp_dim=self.mlp_dim, num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, dropout_rate=dropout_rate, name='Transformer')( x, train=not train) - x = jnp.mean(x, axis=1) + if self.use_map: + x = MAPHead(num_heads=self.num_heads, + mlp_dim=self.mlp_dim + )(x) + else: + x = jnp.mean(x, axis=1) if self.rep_size: rep_size = self.width if self.rep_size is True else self.rep_size diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py index 3f3af0564..22fcde66a 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py @@ -32,11 +32,16 @@ def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + aux_dropout_rate: Optional[float] = None, + head_zeroinit: bool = True) -> spec.ModelInitState: del aux_dropout_rate self._model = models.ViT( dropout_rate=dropout_rate, num_classes=self._num_classes, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + use_map=self.use_map, + head_zeroinit=head_zeroinit, **decode_variant('S/16')) params, model_state = self.initialized(rng, self._model) self._param_shapes = param_utils.jax_param_shapes(params) @@ -83,3 +88,24 @@ def _eval_model_on_split(self, rng, data_dir, global_step) + + +class ImagenetVitGluWorkload(ImagenetVitWorkload): + + @property + def use_glu(self) -> bool: + return True + + +class ImagenetViTPostLNWorkload(ImagenetVitWorkload): + + @property + def use_post_layer_norm(self) -> bool: + return True + + +class ImagenetViTMapLNWorkload(ImagenetVitWorkload): + + @property + def use_map(self) -> bool: + return True diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py index 55a8e370d..053b0ec76 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -39,18 +39,26 @@ def __init__( self, width: int, mlp_dim: Optional[int] = None, # Defaults to 4x input dim. + use_glu: bool = False, dropout_rate: float = 0.0) -> None: super().__init__() self.width = width self.mlp_dim = mlp_dim or 4 * width + self.use_glu = use_glu self.dropout_rate = dropout_rate - self.net = nn.Sequential( - nn.Linear(self.width, self.mlp_dim), - nn.GELU(), - nn.Dropout(self.dropout_rate), - nn.Linear(self.mlp_dim, self.width)) + self.linear1 = nn.Linear(self.width, self.mlp_dim) + self.act_fnc = nn.GELU(approximate='tanh') + self.dropout = nn.Dropout(self.dropout_rate) + + if self.use_glu: + self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim) + else: + self.glu_linear = None + + self.linear2 = nn.Linear(self.mlp_dim, self.width) + self.reset_parameters() def reset_parameters(self) -> None: @@ -61,7 +69,16 @@ def reset_parameters(self) -> None: module.bias.data.normal_(std=1e-6) def forward(self, x: spec.Tensor) -> spec.Tensor: - return self.net(x) + x = self.linear1(x) + x = self.act_fnc(x) + + if self.use_glu: + y = self.glu_linear(x) + x = x * y + + x = self.dropout(x) + x = self.linear2(x) + return x class SelfAttention(nn.Module): @@ -129,29 +146,44 @@ def __init__(self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 12, + use_glu: bool = False, + use_post_layer_norm: bool = False, dropout_rate: float = 0.0) -> None: super().__init__() self.width = width self.mlp_dim = mlp_dim self.num_heads = num_heads + self.use_glu = use_glu + self.use_post_layer_norm = use_post_layer_norm self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6) self.self_attention1 = SelfAttention(self.width, self.num_heads) self.dropout = nn.Dropout(dropout_rate) self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6) - self.mlp3 = MlpBlock(self.width, self.mlp_dim, dropout_rate) + self.mlp3 = MlpBlock(width=self.width, mlp_dim=self.mlp_dim, use_glu=self.use_glu, dropout_rate=dropout_rate) def forward(self, x: spec.Tensor) -> spec.Tensor: - y = self.layer_norm0(x) - y = self.self_attention1(y) - y = self.dropout(y) - x = x + y - - y = self.layer_norm2(x) - y = self.mlp3(y) - y = self.dropout(y) - x = x + y + if not self.use_post_layer_norm: + y = self.layer_norm0(x) + y = self.self_attention1(y) + y = self.dropout(y) + x = x + y + + y = self.layer_norm2(x) + y = self.mlp3(y) + y = self.dropout(y) + x = x + y + else: + y = self.self_attention1(x) + y = self.dropout(y) + x = x + y + x = self.layer_norm0(x) + + y = self.mlp3(x) + y = self.dropout(y) + x = x + y + x = self.layer_norm2(x) return x @@ -163,6 +195,8 @@ def __init__(self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 12, + use_glu: bool = False, + use_post_layer_norm: bool = False, dropout_rate: float = 0.0) -> None: super().__init__() @@ -170,18 +204,53 @@ def __init__(self, self.width = width self.mlp_dim = mlp_dim self.num_heads = num_heads + self.use_glu = use_glu + self.use_post_layer_norm = use_post_layer_norm self.net = nn.ModuleList([ - Encoder1DBlock(self.width, self.mlp_dim, self.num_heads, dropout_rate) + Encoder1DBlock(self.width, self.mlp_dim, self.num_heads, self.use_glu, self.use_post_layer_norm, dropout_rate) for _ in range(depth) ]) - self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6) + + if not self.use_post_layer_norm: + self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6) + else: + self.encoder_norm = None def forward(self, x: spec.Tensor) -> spec.Tensor: # Input Encoder. for block in self.net: x = block(x) - return self.encoder_norm(x) + if not self.use_post_layer_norm: + return self.encoder_norm(x) + else: + return x + + +class MAPHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 12): + super().__init__() + self.width = width + self.mlp_dim = mlp_dim + self.num_heads = num_heads + + self.probe = nn.Parameter(torch.zeros((1, 1, self.width))) + nn.init.xavier_uniform_(self.probe.data) + + self.mha = nn.MultiheadAttention(embed_dim=self.width, num_heads=self.num_heads) + self.layer_nrom = nn.LayerNorm(self.width, eps=1e-6) + self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim) + + def forward(self, x): + n, _, _ = x.shape + probe = torch.tile(self.probe, [n, 1, 1]) + + x = self.mha(probe, x) + y = self.layer_nrom(x) + x = x + self.mlp(y) + return x[:, 0] class ViT(nn.Module): @@ -202,6 +271,9 @@ def __init__( rep_size: Union[int, bool] = True, dropout_rate: Optional[float] = 0.0, head_zeroinit: bool = True, + use_glu: bool = False, + use_post_layer_norm: bool = False, + use_map: bool = False, dtype: Any = torch.float32) -> None: super().__init__() if dropout_rate is None: @@ -215,6 +287,9 @@ def __init__( self.num_heads = num_heads self.rep_size = rep_size self.head_zeroinit = head_zeroinit + self.use_glu = use_glu + self.use_post_layer_norm = use_post_layer_norm + self.use_map = use_map self.dtype = dtype if self.rep_size: @@ -234,6 +309,8 @@ def __init__( width=self.width, mlp_dim=self.mlp_dim, num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, dropout_rate=dropout_rate) if self.num_classes: @@ -270,7 +347,11 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: x = self.dropout(x) x = self.encoder(x) - x = torch.mean(x, dim=1) + + if self.use_map: + pass + else: + x = torch.mean(x, dim=1) if self.rep_size: x = torch.tanh(self.pre_logits(x)) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py index 08a62ede6..9e8af3a68 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -28,12 +28,17 @@ def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + aux_dropout_rate: Optional[float] = None, + head_zeroinit: bool = True) -> spec.ModelInitState: del aux_dropout_rate torch.random.manual_seed(rng[0]) model = models.ViT( dropout_rate=dropout_rate, num_classes=self._num_classes, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + use_map=self.use_map, + head_zeroinit=head_zeroinit, **decode_variant('S/16')) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) @@ -77,3 +82,24 @@ def model_fn( logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) return logits_batch, None + + +class ImagenetVitGluWorkload(ImagenetVitWorkload): + + @property + def use_glu(self) -> bool: + return True + + +class ImagenetViTPostLNWorkload(ImagenetVitWorkload): + + @property + def use_post_layer_norm(self) -> bool: + return True + + +class ImagenetViTMapLNWorkload(ImagenetVitWorkload): + + @property + def use_map(self) -> bool: + return True diff --git a/algorithmic_efficiency/workloads/imagenet_vit/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/workload.py index 61d3acfd3..ed0118ca0 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/workload.py @@ -60,6 +60,21 @@ def validation_target_value(self) -> float: def test_target_value(self) -> float: return 1 - 0.3481 # 0.6519 + @property + def use_post_layer_norm(self) -> bool: + """Whether to use layer normalization after the residual branch.""" + return False + + @property + def use_map(self) -> bool: + """Whether to use multihead attention pooling.""" + return False + + @property + def use_glu(self) -> bool: + """Whether to use GLU in the MLPBlock.""" + return False + @property def eval_batch_size(self) -> int: return 2048 diff --git a/tests/modeldiffs/imagenet_vit/compare.py b/tests/modeldiffs/imagenet_vit/compare.py index 1022b5b54..3e8b9dcb1 100644 --- a/tests/modeldiffs/imagenet_vit/compare.py +++ b/tests/modeldiffs/imagenet_vit/compare.py @@ -3,20 +3,75 @@ # Disable GPU access for both jax and pytorch. os.environ['CUDA_VISIBLE_DEVICES'] = '' -import jax -import torch - from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ ImagenetVitWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ ImagenetVitWorkload as PytWorkload -from tests.modeldiffs.diff import out_diff +from flax import jax_utils +import jax +import numpy as np +import torch + +from tests.modeldiffs.torch2jax_utils import Torch2Jax +from tests.modeldiffs.torch2jax_utils import value_transform + + +#pylint: disable=dangerous-default-value +def torch2jax_with_zeroinit(jax_workload, + pytorch_workload, + key_transform=None, + sd_transform=None, + init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0, head_zeroinit=False)): + jax_params, model_state = jax_workload.init_model_fn(jax.random.PRNGKey(0), + **init_kwargs) + pytorch_model, _ = pytorch_workload.init_model_fn([0], **init_kwargs) + jax_params = jax_utils.unreplicate(jax_params).unfreeze() + if model_state is not None: + model_state = jax_utils.unreplicate(model_state) + + if isinstance( + pytorch_model, + (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): + pytorch_model = pytorch_model.module + t2j = Torch2Jax(torch_model=pytorch_model, jax_model=jax_params) + if key_transform is not None: + t2j.key_transform(key_transform) + if sd_transform is not None: + t2j.sd_transform(sd_transform) + t2j.value_transform(value_transform) + t2j.diff() + t2j.update_jax_model() + return jax_params, model_state, pytorch_model + + +def out_diff(jax_workload, + pytorch_workload, + jax_model_kwargs, + pytorch_model_kwargs, + key_transform=None, + sd_transform=None, + out_transform=None): + jax_params, model_state, pytorch_model = torch2jax_with_zeroinit(jax_workload, + pytorch_workload, + key_transform, + sd_transform) + out_p, _ = pytorch_workload.model_fn(params=pytorch_model, + **pytorch_model_kwargs) + out_j, _ = jax_workload.model_fn(params=jax_params, + model_state=model_state, + **jax_model_kwargs) + if out_transform is not None: + out_p = out_transform(out_p) + out_j = out_transform(out_j) + + print(np.abs(out_p.detach().numpy() - np.array(out_j)).max()) + print(np.abs(out_p.detach().numpy() - np.array(out_j)).min()) def key_transform(k): if 'Conv' in k[0]: - k = ('embedding', *k[1:]) + k = ('conv_patch_extract', *k[1:]) elif k[0] == 'Linear_0': k = ('pre_logits', *k[1:]) elif k[0] == 'Linear_1': @@ -35,6 +90,8 @@ def key_transform(k): continue if 'CustomBatchNorm' in i: continue + if 'GLU' in i: + pass if 'Linear' in i: if attention: i = { diff --git a/tests/modeldiffs/imagenet_vit/compare_glu.py b/tests/modeldiffs/imagenet_vit/compare_glu.py new file mode 100644 index 000000000..a6f01f971 --- /dev/null +++ b/tests/modeldiffs/imagenet_vit/compare_glu.py @@ -0,0 +1,163 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ + ImagenetVitGluWorkload as JaxWorkload +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ + ImagenetVitGluWorkload as PytWorkload +from flax import jax_utils +import jax +import numpy as np +import torch + +from tests.modeldiffs.torch2jax_utils import Torch2Jax +from tests.modeldiffs.torch2jax_utils import value_transform + + +#pylint: disable=dangerous-default-value +def torch2jax_with_zeroinit(jax_workload, + pytorch_workload, + key_transform=None, + sd_transform=None, + init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0, head_zeroinit=False)): + jax_params, model_state = jax_workload.init_model_fn(jax.random.PRNGKey(0), + **init_kwargs) + pytorch_model, _ = pytorch_workload.init_model_fn([0], **init_kwargs) + jax_params = jax_utils.unreplicate(jax_params).unfreeze() + if model_state is not None: + model_state = jax_utils.unreplicate(model_state) + + if isinstance( + pytorch_model, + (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): + pytorch_model = pytorch_model.module + t2j = Torch2Jax(torch_model=pytorch_model, jax_model=jax_params) + if key_transform is not None: + t2j.key_transform(key_transform) + if sd_transform is not None: + t2j.sd_transform(sd_transform) + t2j.value_transform(value_transform) + t2j.diff() + t2j.update_jax_model() + return jax_params, model_state, pytorch_model + + +def out_diff(jax_workload, + pytorch_workload, + jax_model_kwargs, + pytorch_model_kwargs, + key_transform=None, + sd_transform=None, + out_transform=None): + jax_params, model_state, pytorch_model = torch2jax_with_zeroinit(jax_workload, + pytorch_workload, + key_transform, + sd_transform) + out_p, _ = pytorch_workload.model_fn(params=pytorch_model, + **pytorch_model_kwargs) + out_j, _ = jax_workload.model_fn(params=jax_params, + model_state=model_state, + **jax_model_kwargs) + if out_transform is not None: + out_p = out_transform(out_p) + out_j = out_transform(out_j) + + print(np.abs(out_p.detach().numpy() - np.array(out_j)).max()) + print(np.abs(out_p.detach().numpy() - np.array(out_j)).min()) + + +def key_transform(k): + if 'Conv' in k[0]: + k = ('conv_patch_extract', *k[1:]) + elif k[0] == 'Linear_0': + k = ('pre_logits', *k[1:]) + elif k[0] == 'Linear_1': + k = ('head', *k[1:]) + + new_key = [] + bn = False + attention = False + ln = False + enc_block = False + for idx, i in enumerate(k): + bn = bn or 'BatchNorm' in i + ln = ln or 'LayerNorm' in i + attention = attention or 'SelfAttention' in i + if 'ModuleList' in i or 'Sequential' in i: + continue + if 'CustomBatchNorm' in i: + continue + if 'GLU' in i: + pass + if 'Linear' in i: + if attention: + i = { + 'Linear_0': 'query', + 'Linear_1': 'key', + 'Linear_2': 'value', + 'Linear_3': 'out', + }[i] + else: + i = i.replace('Linear', 'Dense') + elif 'Conv2d' in i: + i = i.replace('Conv2d', 'Conv') + elif 'Encoder1DBlock' in i: + i = i.replace('Encoder1DBlock', 'encoderblock') + enc_block = True + elif 'Encoder' in i: + i = 'Transformer' + elif enc_block and 'SelfAttention' in i: + i = 'MultiHeadDotProductAttention_1' + elif enc_block and i == 'LayerNorm_1': + i = 'LayerNorm_2' + elif enc_block and 'MlpBlock' in i: + i = 'MlpBlock_3' + elif idx == 1 and i == 'LayerNorm_0': + i = 'encoder_layernorm' + elif 'weight' in i: + if bn or ln: + i = i.replace('weight', 'scale') + else: + i = i.replace('weight', 'kernel') + new_key.append(i) + return tuple(new_key) + + +sd_transform = None + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + # Test outputs for identical weights and inputs. + image = torch.randn(2, 3, 224, 224) + + jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} + pyt_batch = {'inputs': image} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=None, + ) diff --git a/tests/modeldiffs/imagenet_vit/compare_post_ln.py b/tests/modeldiffs/imagenet_vit/compare_post_ln.py new file mode 100644 index 000000000..e27d77482 --- /dev/null +++ b/tests/modeldiffs/imagenet_vit/compare_post_ln.py @@ -0,0 +1,163 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ + ImagenetViTPostLNWorkload as JaxWorkload +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ + ImagenetViTPostLNWorkload as PytWorkload +from flax import jax_utils +import jax +import numpy as np +import torch + +from tests.modeldiffs.torch2jax_utils import Torch2Jax +from tests.modeldiffs.torch2jax_utils import value_transform + + +#pylint: disable=dangerous-default-value +def torch2jax_with_zeroinit(jax_workload, + pytorch_workload, + key_transform=None, + sd_transform=None, + init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0, head_zeroinit=False)): + jax_params, model_state = jax_workload.init_model_fn(jax.random.PRNGKey(0), + **init_kwargs) + pytorch_model, _ = pytorch_workload.init_model_fn([0], **init_kwargs) + jax_params = jax_utils.unreplicate(jax_params).unfreeze() + if model_state is not None: + model_state = jax_utils.unreplicate(model_state) + + if isinstance( + pytorch_model, + (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): + pytorch_model = pytorch_model.module + t2j = Torch2Jax(torch_model=pytorch_model, jax_model=jax_params) + if key_transform is not None: + t2j.key_transform(key_transform) + if sd_transform is not None: + t2j.sd_transform(sd_transform) + t2j.value_transform(value_transform) + t2j.diff() + t2j.update_jax_model() + return jax_params, model_state, pytorch_model + + +def out_diff(jax_workload, + pytorch_workload, + jax_model_kwargs, + pytorch_model_kwargs, + key_transform=None, + sd_transform=None, + out_transform=None): + jax_params, model_state, pytorch_model = torch2jax_with_zeroinit(jax_workload, + pytorch_workload, + key_transform, + sd_transform) + out_p, _ = pytorch_workload.model_fn(params=pytorch_model, + **pytorch_model_kwargs) + out_j, _ = jax_workload.model_fn(params=jax_params, + model_state=model_state, + **jax_model_kwargs) + if out_transform is not None: + out_p = out_transform(out_p) + out_j = out_transform(out_j) + + print(np.abs(out_p.detach().numpy() - np.array(out_j)).max()) + print(np.abs(out_p.detach().numpy() - np.array(out_j)).min()) + + +def key_transform(k): + if 'Conv' in k[0]: + k = ('conv_patch_extract', *k[1:]) + elif k[0] == 'Linear_0': + k = ('pre_logits', *k[1:]) + elif k[0] == 'Linear_1': + k = ('head', *k[1:]) + + new_key = [] + bn = False + attention = False + ln = False + enc_block = False + for idx, i in enumerate(k): + bn = bn or 'BatchNorm' in i + ln = ln or 'LayerNorm' in i + attention = attention or 'SelfAttention' in i + if 'ModuleList' in i or 'Sequential' in i: + continue + if 'CustomBatchNorm' in i: + continue + if 'GLU' in i: + pass + if 'Linear' in i: + if attention: + i = { + 'Linear_0': 'query', + 'Linear_1': 'key', + 'Linear_2': 'value', + 'Linear_3': 'out', + }[i] + else: + i = i.replace('Linear', 'Dense') + elif 'Conv2d' in i: + i = i.replace('Conv2d', 'Conv') + elif 'Encoder1DBlock' in i: + i = i.replace('Encoder1DBlock', 'encoderblock') + enc_block = True + elif 'Encoder' in i: + i = 'Transformer' + elif enc_block and 'SelfAttention' in i: + i = 'MultiHeadDotProductAttention_1' + elif enc_block and i == 'LayerNorm_1': + i = 'LayerNorm_2' + elif enc_block and 'MlpBlock' in i: + i = 'MlpBlock_3' + elif idx == 1 and i == 'LayerNorm_0': + i = 'encoder_layernorm' + elif 'weight' in i: + if bn or ln: + i = i.replace('weight', 'scale') + else: + i = i.replace('weight', 'kernel') + new_key.append(i) + return tuple(new_key) + + +sd_transform = None + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + # Test outputs for identical weights and inputs. + image = torch.randn(2, 3, 224, 224) + + jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} + pyt_batch = {'inputs': image} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=None, + ) From c8a9e728486ea6faffd11ac8df4e90876377b0a7 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Fri, 8 Dec 2023 23:20:40 -0500 Subject: [PATCH 07/86] Clean up model diff --- .../imagenet_jax/randaugment.py | 1 + .../imagenet_vit/imagenet_jax/models.py | 6 +- .../imagenet_vit/imagenet_jax/workload.py | 4 +- .../imagenet_vit/imagenet_pytorch/models.py | 51 ++++-- .../imagenet_vit/imagenet_pytorch/workload.py | 6 +- tests/modeldiffs/imagenet_vit/compare.py | 66 +------ tests/modeldiffs/imagenet_vit/compare_glu.py | 163 ------------------ .../imagenet_vit/compare_post_ln.py | 163 ------------------ tests/modeldiffs/imagenet_vit/glu_compare.py | 52 ++++++ .../imagenet_vit/post_ln_compare.py | 52 ++++++ 10 files changed, 154 insertions(+), 410 deletions(-) delete mode 100644 tests/modeldiffs/imagenet_vit/compare_glu.py delete mode 100644 tests/modeldiffs/imagenet_vit/compare_post_ln.py create mode 100644 tests/modeldiffs/imagenet_vit/glu_compare.py create mode 100644 tests/modeldiffs/imagenet_vit/post_ln_compare.py diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index 8fa1c0789..caa77ae35 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -8,6 +8,7 @@ import math import tensorflow as tf + # from tensorflow_addons import image as contrib_image # This signifies the max integer that the controller RNN could predict for the diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py index 4a97ee661..c88132621 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py @@ -88,19 +88,21 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: y = nn.Dropout(rate=self.dropout_rate)(y, train) x = x + y else: + y = x y = nn.SelfAttention( num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform(), deterministic=train, name='MultiHeadDotProductAttention_1')( - x) + y) y = nn.Dropout(rate=self.dropout_rate)(y, train) x = x + y x = nn.LayerNorm(name='LayerNorm_0')(x) + y = x y = MlpBlock( mlp_dim=self.mlp_dim, use_glu=self.use_glu, dropout_rate=self.dropout_rate, - name='MlpBlock_3')(x, train) + name='MlpBlock_3')(y, train) y = nn.Dropout(rate=self.dropout_rate)(y, train) x = x + y x = nn.LayerNorm(name='LayerNorm_2')(x) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py index 22fcde66a..1acd58bcd 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py @@ -32,8 +32,7 @@ def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None, - head_zeroinit: bool = True) -> spec.ModelInitState: + aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: del aux_dropout_rate self._model = models.ViT( dropout_rate=dropout_rate, @@ -41,7 +40,6 @@ def init_model_fn( use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, use_map=self.use_map, - head_zeroinit=head_zeroinit, **decode_variant('S/16')) params, model_state = self.initialized(rng, self._model) self._param_shapes = param_utils.jax_param_shapes(params) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py index 053b0ec76..469716d59 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -1,8 +1,8 @@ """PyTorch implementation of refactored and simplified ViT. Adapted from: -https://github.com/huggingface/transformers/tree/main/src/transformers/models/vit. -https://github.com/lucidrains/vit-pytorch. +https://github.com/huggingface/transformers/tree/main/src/transformers/models/vit +and https://github.com/lucidrains/vit-pytorch. """ import math @@ -14,9 +14,12 @@ from algorithmic_efficiency import init_utils from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.wmt.wmt_pytorch.models import \ + MultiheadAttention def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor: + """Follows the MoCo v3 logic.""" _, width, h, w = patches.shape device = patches.device y, x = torch.meshgrid(torch.arange(h, device=device), @@ -161,7 +164,11 @@ def __init__(self, self.self_attention1 = SelfAttention(self.width, self.num_heads) self.dropout = nn.Dropout(dropout_rate) self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6) - self.mlp3 = MlpBlock(width=self.width, mlp_dim=self.mlp_dim, use_glu=self.use_glu, dropout_rate=dropout_rate) + self.mlp3 = MlpBlock( + width=self.width, + mlp_dim=self.mlp_dim, + use_glu=self.use_glu, + dropout_rate=dropout_rate) def forward(self, x: spec.Tensor) -> spec.Tensor: if not self.use_post_layer_norm: @@ -175,12 +182,14 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: y = self.dropout(y) x = x + y else: - y = self.self_attention1(x) + y = x + y = self.self_attention1(y) y = self.dropout(y) x = x + y x = self.layer_norm0(x) - y = self.mlp3(x) + y = x + y = self.mlp3(y) y = self.dropout(y) x = x + y x = self.layer_norm2(x) @@ -208,8 +217,12 @@ def __init__(self, self.use_post_layer_norm = use_post_layer_norm self.net = nn.ModuleList([ - Encoder1DBlock(self.width, self.mlp_dim, self.num_heads, self.use_glu, self.use_post_layer_norm, dropout_rate) - for _ in range(depth) + Encoder1DBlock(self.width, + self.mlp_dim, + self.num_heads, + self.use_glu, + self.use_post_layer_norm, + dropout_rate) for _ in range(depth) ]) if not self.use_post_layer_norm: @@ -230,7 +243,10 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: class MAPHead(nn.Module): """Multihead Attention Pooling.""" - def __init__(self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 12): + def __init__(self, + width: int, + mlp_dim: Optional[int] = None, + num_heads: int = 12): super().__init__() self.width = width self.mlp_dim = mlp_dim @@ -239,16 +255,17 @@ def __init__(self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 1 self.probe = nn.Parameter(torch.zeros((1, 1, self.width))) nn.init.xavier_uniform_(self.probe.data) - self.mha = nn.MultiheadAttention(embed_dim=self.width, num_heads=self.num_heads) - self.layer_nrom = nn.LayerNorm(self.width, eps=1e-6) + self.mha = MultiheadAttention( + self.width, num_heads=self.num_heads, self_attn=False, bias=False) + self.layer_norm = nn.LayerNorm(self.width, eps=1e-6) self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim) - def forward(self, x): + def forward(self, x: spec.Tensor) -> spec.Tensor: n, _, _ = x.shape probe = torch.tile(self.probe, [n, 1, 1]) - x = self.mha(probe, x) - y = self.layer_nrom(x) + x = self.mha(probe, x)[0] + y = self.layer_norm(x) x = x + self.mlp(y) return x[:, 0] @@ -315,6 +332,12 @@ def __init__( if self.num_classes: self.head = nn.Linear(self.width, self.num_classes) + + if self.use_map: + self.map = MAPHead(self.width, self.mlp_dim, self.num_heads) + else: + self.map = None + self.reset_parameters() def reset_parameters(self) -> None: @@ -349,7 +372,7 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: x = self.encoder(x) if self.use_map: - pass + x = self.map(x) else: x = torch.mean(x, dim=1) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py index 9e8af3a68..013bc643f 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -28,8 +28,7 @@ def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None, - head_zeroinit: bool = True) -> spec.ModelInitState: + aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: del aux_dropout_rate torch.random.manual_seed(rng[0]) model = models.ViT( @@ -38,7 +37,6 @@ def init_model_fn( use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, use_map=self.use_map, - head_zeroinit=head_zeroinit, **decode_variant('S/16')) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) @@ -98,7 +96,7 @@ def use_post_layer_norm(self) -> bool: return True -class ImagenetViTMapLNWorkload(ImagenetVitWorkload): +class ImagenetViTMapWorkload(ImagenetVitWorkload): @property def use_map(self) -> bool: diff --git a/tests/modeldiffs/imagenet_vit/compare.py b/tests/modeldiffs/imagenet_vit/compare.py index 3e8b9dcb1..39f2651a0 100644 --- a/tests/modeldiffs/imagenet_vit/compare.py +++ b/tests/modeldiffs/imagenet_vit/compare.py @@ -1,72 +1,18 @@ import os +from tests.modeldiffs.diff import out_diff + # Disable GPU access for both jax and pytorch. os.environ['CUDA_VISIBLE_DEVICES'] = '' +import jax +import torch + from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ ImagenetVitWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ ImagenetVitWorkload as PytWorkload -from flax import jax_utils -import jax -import numpy as np -import torch - -from tests.modeldiffs.torch2jax_utils import Torch2Jax -from tests.modeldiffs.torch2jax_utils import value_transform - - -#pylint: disable=dangerous-default-value -def torch2jax_with_zeroinit(jax_workload, - pytorch_workload, - key_transform=None, - sd_transform=None, - init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0, head_zeroinit=False)): - jax_params, model_state = jax_workload.init_model_fn(jax.random.PRNGKey(0), - **init_kwargs) - pytorch_model, _ = pytorch_workload.init_model_fn([0], **init_kwargs) - jax_params = jax_utils.unreplicate(jax_params).unfreeze() - if model_state is not None: - model_state = jax_utils.unreplicate(model_state) - - if isinstance( - pytorch_model, - (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): - pytorch_model = pytorch_model.module - t2j = Torch2Jax(torch_model=pytorch_model, jax_model=jax_params) - if key_transform is not None: - t2j.key_transform(key_transform) - if sd_transform is not None: - t2j.sd_transform(sd_transform) - t2j.value_transform(value_transform) - t2j.diff() - t2j.update_jax_model() - return jax_params, model_state, pytorch_model - - -def out_diff(jax_workload, - pytorch_workload, - jax_model_kwargs, - pytorch_model_kwargs, - key_transform=None, - sd_transform=None, - out_transform=None): - jax_params, model_state, pytorch_model = torch2jax_with_zeroinit(jax_workload, - pytorch_workload, - key_transform, - sd_transform) - out_p, _ = pytorch_workload.model_fn(params=pytorch_model, - **pytorch_model_kwargs) - out_j, _ = jax_workload.model_fn(params=jax_params, - model_state=model_state, - **jax_model_kwargs) - if out_transform is not None: - out_p = out_transform(out_p) - out_j = out_transform(out_j) - - print(np.abs(out_p.detach().numpy() - np.array(out_j)).max()) - print(np.abs(out_p.detach().numpy() - np.array(out_j)).min()) def key_transform(k): @@ -90,8 +36,6 @@ def key_transform(k): continue if 'CustomBatchNorm' in i: continue - if 'GLU' in i: - pass if 'Linear' in i: if attention: i = { diff --git a/tests/modeldiffs/imagenet_vit/compare_glu.py b/tests/modeldiffs/imagenet_vit/compare_glu.py deleted file mode 100644 index a6f01f971..000000000 --- a/tests/modeldiffs/imagenet_vit/compare_glu.py +++ /dev/null @@ -1,163 +0,0 @@ -import os - -# Disable GPU access for both jax and pytorch. -os.environ['CUDA_VISIBLE_DEVICES'] = '' - -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ - ImagenetVitGluWorkload as JaxWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ - ImagenetVitGluWorkload as PytWorkload -from flax import jax_utils -import jax -import numpy as np -import torch - -from tests.modeldiffs.torch2jax_utils import Torch2Jax -from tests.modeldiffs.torch2jax_utils import value_transform - - -#pylint: disable=dangerous-default-value -def torch2jax_with_zeroinit(jax_workload, - pytorch_workload, - key_transform=None, - sd_transform=None, - init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0, head_zeroinit=False)): - jax_params, model_state = jax_workload.init_model_fn(jax.random.PRNGKey(0), - **init_kwargs) - pytorch_model, _ = pytorch_workload.init_model_fn([0], **init_kwargs) - jax_params = jax_utils.unreplicate(jax_params).unfreeze() - if model_state is not None: - model_state = jax_utils.unreplicate(model_state) - - if isinstance( - pytorch_model, - (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): - pytorch_model = pytorch_model.module - t2j = Torch2Jax(torch_model=pytorch_model, jax_model=jax_params) - if key_transform is not None: - t2j.key_transform(key_transform) - if sd_transform is not None: - t2j.sd_transform(sd_transform) - t2j.value_transform(value_transform) - t2j.diff() - t2j.update_jax_model() - return jax_params, model_state, pytorch_model - - -def out_diff(jax_workload, - pytorch_workload, - jax_model_kwargs, - pytorch_model_kwargs, - key_transform=None, - sd_transform=None, - out_transform=None): - jax_params, model_state, pytorch_model = torch2jax_with_zeroinit(jax_workload, - pytorch_workload, - key_transform, - sd_transform) - out_p, _ = pytorch_workload.model_fn(params=pytorch_model, - **pytorch_model_kwargs) - out_j, _ = jax_workload.model_fn(params=jax_params, - model_state=model_state, - **jax_model_kwargs) - if out_transform is not None: - out_p = out_transform(out_p) - out_j = out_transform(out_j) - - print(np.abs(out_p.detach().numpy() - np.array(out_j)).max()) - print(np.abs(out_p.detach().numpy() - np.array(out_j)).min()) - - -def key_transform(k): - if 'Conv' in k[0]: - k = ('conv_patch_extract', *k[1:]) - elif k[0] == 'Linear_0': - k = ('pre_logits', *k[1:]) - elif k[0] == 'Linear_1': - k = ('head', *k[1:]) - - new_key = [] - bn = False - attention = False - ln = False - enc_block = False - for idx, i in enumerate(k): - bn = bn or 'BatchNorm' in i - ln = ln or 'LayerNorm' in i - attention = attention or 'SelfAttention' in i - if 'ModuleList' in i or 'Sequential' in i: - continue - if 'CustomBatchNorm' in i: - continue - if 'GLU' in i: - pass - if 'Linear' in i: - if attention: - i = { - 'Linear_0': 'query', - 'Linear_1': 'key', - 'Linear_2': 'value', - 'Linear_3': 'out', - }[i] - else: - i = i.replace('Linear', 'Dense') - elif 'Conv2d' in i: - i = i.replace('Conv2d', 'Conv') - elif 'Encoder1DBlock' in i: - i = i.replace('Encoder1DBlock', 'encoderblock') - enc_block = True - elif 'Encoder' in i: - i = 'Transformer' - elif enc_block and 'SelfAttention' in i: - i = 'MultiHeadDotProductAttention_1' - elif enc_block and i == 'LayerNorm_1': - i = 'LayerNorm_2' - elif enc_block and 'MlpBlock' in i: - i = 'MlpBlock_3' - elif idx == 1 and i == 'LayerNorm_0': - i = 'encoder_layernorm' - elif 'weight' in i: - if bn or ln: - i = i.replace('weight', 'scale') - else: - i = i.replace('weight', 'kernel') - new_key.append(i) - return tuple(new_key) - - -sd_transform = None - -if __name__ == '__main__': - # pylint: disable=locally-disabled, not-callable - - jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() - - # Test outputs for identical weights and inputs. - image = torch.randn(2, 3, 224, 224) - - jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} - pyt_batch = {'inputs': image} - - pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) - - jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) - - out_diff( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=None, - ) diff --git a/tests/modeldiffs/imagenet_vit/compare_post_ln.py b/tests/modeldiffs/imagenet_vit/compare_post_ln.py deleted file mode 100644 index e27d77482..000000000 --- a/tests/modeldiffs/imagenet_vit/compare_post_ln.py +++ /dev/null @@ -1,163 +0,0 @@ -import os - -# Disable GPU access for both jax and pytorch. -os.environ['CUDA_VISIBLE_DEVICES'] = '' - -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ - ImagenetViTPostLNWorkload as JaxWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ - ImagenetViTPostLNWorkload as PytWorkload -from flax import jax_utils -import jax -import numpy as np -import torch - -from tests.modeldiffs.torch2jax_utils import Torch2Jax -from tests.modeldiffs.torch2jax_utils import value_transform - - -#pylint: disable=dangerous-default-value -def torch2jax_with_zeroinit(jax_workload, - pytorch_workload, - key_transform=None, - sd_transform=None, - init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0, head_zeroinit=False)): - jax_params, model_state = jax_workload.init_model_fn(jax.random.PRNGKey(0), - **init_kwargs) - pytorch_model, _ = pytorch_workload.init_model_fn([0], **init_kwargs) - jax_params = jax_utils.unreplicate(jax_params).unfreeze() - if model_state is not None: - model_state = jax_utils.unreplicate(model_state) - - if isinstance( - pytorch_model, - (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): - pytorch_model = pytorch_model.module - t2j = Torch2Jax(torch_model=pytorch_model, jax_model=jax_params) - if key_transform is not None: - t2j.key_transform(key_transform) - if sd_transform is not None: - t2j.sd_transform(sd_transform) - t2j.value_transform(value_transform) - t2j.diff() - t2j.update_jax_model() - return jax_params, model_state, pytorch_model - - -def out_diff(jax_workload, - pytorch_workload, - jax_model_kwargs, - pytorch_model_kwargs, - key_transform=None, - sd_transform=None, - out_transform=None): - jax_params, model_state, pytorch_model = torch2jax_with_zeroinit(jax_workload, - pytorch_workload, - key_transform, - sd_transform) - out_p, _ = pytorch_workload.model_fn(params=pytorch_model, - **pytorch_model_kwargs) - out_j, _ = jax_workload.model_fn(params=jax_params, - model_state=model_state, - **jax_model_kwargs) - if out_transform is not None: - out_p = out_transform(out_p) - out_j = out_transform(out_j) - - print(np.abs(out_p.detach().numpy() - np.array(out_j)).max()) - print(np.abs(out_p.detach().numpy() - np.array(out_j)).min()) - - -def key_transform(k): - if 'Conv' in k[0]: - k = ('conv_patch_extract', *k[1:]) - elif k[0] == 'Linear_0': - k = ('pre_logits', *k[1:]) - elif k[0] == 'Linear_1': - k = ('head', *k[1:]) - - new_key = [] - bn = False - attention = False - ln = False - enc_block = False - for idx, i in enumerate(k): - bn = bn or 'BatchNorm' in i - ln = ln or 'LayerNorm' in i - attention = attention or 'SelfAttention' in i - if 'ModuleList' in i or 'Sequential' in i: - continue - if 'CustomBatchNorm' in i: - continue - if 'GLU' in i: - pass - if 'Linear' in i: - if attention: - i = { - 'Linear_0': 'query', - 'Linear_1': 'key', - 'Linear_2': 'value', - 'Linear_3': 'out', - }[i] - else: - i = i.replace('Linear', 'Dense') - elif 'Conv2d' in i: - i = i.replace('Conv2d', 'Conv') - elif 'Encoder1DBlock' in i: - i = i.replace('Encoder1DBlock', 'encoderblock') - enc_block = True - elif 'Encoder' in i: - i = 'Transformer' - elif enc_block and 'SelfAttention' in i: - i = 'MultiHeadDotProductAttention_1' - elif enc_block and i == 'LayerNorm_1': - i = 'LayerNorm_2' - elif enc_block and 'MlpBlock' in i: - i = 'MlpBlock_3' - elif idx == 1 and i == 'LayerNorm_0': - i = 'encoder_layernorm' - elif 'weight' in i: - if bn or ln: - i = i.replace('weight', 'scale') - else: - i = i.replace('weight', 'kernel') - new_key.append(i) - return tuple(new_key) - - -sd_transform = None - -if __name__ == '__main__': - # pylint: disable=locally-disabled, not-callable - - jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() - - # Test outputs for identical weights and inputs. - image = torch.randn(2, 3, 224, 224) - - jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} - pyt_batch = {'inputs': image} - - pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) - - jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) - - out_diff( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=None, - ) diff --git a/tests/modeldiffs/imagenet_vit/glu_compare.py b/tests/modeldiffs/imagenet_vit/glu_compare.py new file mode 100644 index 000000000..444f1230a --- /dev/null +++ b/tests/modeldiffs/imagenet_vit/glu_compare.py @@ -0,0 +1,52 @@ +import os + +from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.imagenet_vit.compare import key_transform + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ + ImagenetVitGluWorkload as JaxWorkload +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ + ImagenetVitGluWorkload as PytWorkload + +sd_transform = None + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + # Test outputs for identical weights and inputs. + image = torch.randn(2, 3, 224, 224) + + jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} + pyt_batch = {'inputs': image} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=None, + ) diff --git a/tests/modeldiffs/imagenet_vit/post_ln_compare.py b/tests/modeldiffs/imagenet_vit/post_ln_compare.py new file mode 100644 index 000000000..8bf0bef7e --- /dev/null +++ b/tests/modeldiffs/imagenet_vit/post_ln_compare.py @@ -0,0 +1,52 @@ +import os + +from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.imagenet_vit.compare import key_transform + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ + ImagenetViTPostLNWorkload as JaxWorkload +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ + ImagenetViTPostLNWorkload as PytWorkload + +sd_transform = None + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + # Test outputs for identical weights and inputs. + image = torch.randn(2, 3, 224, 224) + + jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} + pyt_batch = {'inputs': image} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=None, + ) From ecf8220edf11ecde32511f4dbe97888307b2cf86 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Fri, 8 Dec 2023 23:33:38 -0500 Subject: [PATCH 08/86] Add docker image --- .../imagenet_resnet/imagenet_jax/randaugment.py | 3 +-- algorithmic_efficiency/workloads/workloads.py | 12 ++++++++++++ docker/scripts/startup.sh | 3 ++- tests/modeldiffs/imagenet_vit/compare.py | 3 +-- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index caa77ae35..5f92b1482 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -8,8 +8,7 @@ import math import tensorflow as tf - -# from tensorflow_addons import image as contrib_image +from tensorflow_addons import image as contrib_image # This signifies the max integer that the controller RNN could predict for the # augmentation scheme. diff --git a/algorithmic_efficiency/workloads/workloads.py b/algorithmic_efficiency/workloads/workloads.py index 6cc53b7dd..bf444ea36 100644 --- a/algorithmic_efficiency/workloads/workloads.py +++ b/algorithmic_efficiency/workloads/workloads.py @@ -56,6 +56,18 @@ 'workload_path': 'imagenet_vit/imagenet', 'workload_class_name': 'ImagenetVitWorkload', }, + 'imagenet_vit_glu': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetVitGluWorkload', + }, + 'imagenet_vit_post_ln': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetViTPostLNWorkload', + }, + 'imagenet_vit_map': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetViTMapLNWorkload', + }, 'librispeech_conformer': { 'workload_path': 'librispeech_conformer/librispeech', 'workload_class_name': 'LibriSpeechConformerWorkload', diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 3f7458e4b..3b366b71c 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -113,7 +113,8 @@ done VALID_DATASETS=("criteo1tb" "imagenet" "fastmri" "ogbg" "librispeech" \ "wmt" "mnist") VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_resnet_gelu" \ - "imagenet_resnet_large_bn_init" "imagenet_vit" "fastmri" "ogbg" \ + "imagenet_resnet_large_bn_init" "imagenet_vit" "imagenet_vit_glu" \ + "imagenet_vit_post_ln" "imagenet_vit_map" "fastmri" "ogbg" \ "wmt" "librispeech_deepspeech" "librispeech_conformer" "mnist" \ "criteo1tb_resnet" "criteo1tb_layernorm" "criteo_embed_init" \ "conformer_layernorm" "conformer_attention_temperature" \ diff --git a/tests/modeldiffs/imagenet_vit/compare.py b/tests/modeldiffs/imagenet_vit/compare.py index 39f2651a0..bf7d6dfa5 100644 --- a/tests/modeldiffs/imagenet_vit/compare.py +++ b/tests/modeldiffs/imagenet_vit/compare.py @@ -1,7 +1,5 @@ import os -from tests.modeldiffs.diff import out_diff - # Disable GPU access for both jax and pytorch. os.environ['CUDA_VISIBLE_DEVICES'] = '' @@ -13,6 +11,7 @@ ImagenetVitWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ ImagenetVitWorkload as PytWorkload +from tests.modeldiffs.diff import out_diff def key_transform(k): From 290807795fc8a1cf392dd7e94823569d5b651e40 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Fri, 8 Dec 2023 23:40:36 -0500 Subject: [PATCH 09/86] Lint fix --- .../imagenet_vit/imagenet_jax/models.py | 91 ++++++++++--------- 1 file changed, 46 insertions(+), 45 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py index c88132621..32e748ec7 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py @@ -50,9 +50,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: x = nn.gelu(x) if self.use_glu: - y = nn.Dense( - self.mlp_dim, - **inits)(x) + y = nn.Dense(self.mlp_dim, **inits)(x) x = x * y x = nn.Dropout(rate=self.dropout_rate)(x, train) @@ -71,41 +69,45 @@ class Encoder1DBlock(nn.Module): @nn.compact def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: if not self.use_post_layer_norm: - y = nn.LayerNorm(name='LayerNorm_0')(x) - y = nn.SelfAttention( - num_heads=self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - deterministic=train, - name='MultiHeadDotProductAttention_1')( - y) - y = nn.Dropout(rate=self.dropout_rate)(y, train) - x = x + y - - y = nn.LayerNorm(name='LayerNorm_2')(x) - y = MlpBlock( - mlp_dim=self.mlp_dim, use_glu=self.use_glu, dropout_rate=self.dropout_rate, - name='MlpBlock_3')(y, train) - y = nn.Dropout(rate=self.dropout_rate)(y, train) - x = x + y + y = nn.LayerNorm(name='LayerNorm_0')(x) + y = nn.SelfAttention( + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + deterministic=train, + name='MultiHeadDotProductAttention_1')( + y) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + + y = nn.LayerNorm(name='LayerNorm_2')(x) + y = MlpBlock( + mlp_dim=self.mlp_dim, + use_glu=self.use_glu, + dropout_rate=self.dropout_rate, + name='MlpBlock_3')(y, train) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y else: - y = x - y = nn.SelfAttention( - num_heads=self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - deterministic=train, - name='MultiHeadDotProductAttention_1')( - y) - y = nn.Dropout(rate=self.dropout_rate)(y, train) - x = x + y - x = nn.LayerNorm(name='LayerNorm_0')(x) - - y = x - y = MlpBlock( - mlp_dim=self.mlp_dim, use_glu=self.use_glu, dropout_rate=self.dropout_rate, - name='MlpBlock_3')(y, train) - y = nn.Dropout(rate=self.dropout_rate)(y, train) - x = x + y - x = nn.LayerNorm(name='LayerNorm_2')(x) + y = x + y = nn.SelfAttention( + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + deterministic=train, + name='MultiHeadDotProductAttention_1')( + y) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + x = nn.LayerNorm(name='LayerNorm_0')(x) + + y = x + y = MlpBlock( + mlp_dim=self.mlp_dim, + use_glu=self.use_glu, + dropout_rate=self.dropout_rate, + name='MlpBlock_3')(y, train) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + x = nn.LayerNorm(name='LayerNorm_2')(x) return x @@ -141,12 +143,13 @@ class MAPHead(nn.Module): """Multihead Attention Pooling.""" mlp_dim: Optional[int] = None # Defaults to 4x input dim num_heads: int = 12 + @nn.compact def __call__(self, x): n, _, d = x.shape probe = self.param('probe', - nn.initializers.xavier_uniform(), - (1, 1, d), x.dtype) + nn.initializers.xavier_uniform(), (1, 1, d), + x.dtype) probe = jnp.tile(probe, [n, 1, 1]) x = nn.MultiHeadDotProductAttention( @@ -171,9 +174,9 @@ class ViT(nn.Module): dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0. reinit: Optional[Sequence[str]] = None head_zeroinit: bool = True - use_glu: bool = False, - use_post_layer_norm: bool = False, - use_map: bool = False, + use_glu: bool = False + use_post_layer_norm: bool = False + use_map: bool = False def get_posemb(self, seqshape: tuple, @@ -214,9 +217,7 @@ def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor: x, train=not train) if self.use_map: - x = MAPHead(num_heads=self.num_heads, - mlp_dim=self.mlp_dim - )(x) + x = MAPHead(num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x) else: x = jnp.mean(x, axis=1) From efa1120c49eaf7333cafcf2ee6a21b27d9044789 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 13 Dec 2023 12:50:03 +0100 Subject: [PATCH 10/86] Update version to match tag --- algorithmic_efficiency/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/__init__.py b/algorithmic_efficiency/__init__.py index af0a6b8fc..a0e473e1d 100644 --- a/algorithmic_efficiency/__init__.py +++ b/algorithmic_efficiency/__init__.py @@ -1,3 +1,3 @@ """Algorithmic Efficiency.""" -__version__ = '0.0.1' +__version__ = '0.1.0' From 386fabb704ad806e71ca2d40b3527fbf4880b291 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 13 Dec 2023 12:50:37 +0100 Subject: [PATCH 11/86] Restructure and fix ogbg data_dir --- datasets/README.md | 155 +++++++++++++++++++++++++++------------------ 1 file changed, 93 insertions(+), 62 deletions(-) diff --git a/datasets/README.md b/datasets/README.md index 614344978..c3feb5fed 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -1,74 +1,111 @@ -# Dataset Setup -TL;DR: -Use `dataset_setup.py` to download datasets. -Usage: +# MLCommons™ AlgoPerf: Dataset Setup + +## Table of Contents + +- [General Setup](#general-setup) + - [Set Data Directory (Docker Container)](#set-data-directory-docker-container) + - [Set Data Directory (on Host)](#set-data-directory-on-host) + - [Start tmux session (Recommended)](#start-tmux-session-recommended) + - [Clean up](#clean-up) +- [Individual Dataset Instructions](#individual-dataset-instructions) + - [OGBG](#ogbg) + - [WMT](#wmt) + - [FastMRI](#fastmri) + - [ImageNet](#imagenet) + - [Criteo1TB](#criteo1tb) + - [LibriSpeech](#librispeech) + - [Training SPM Tokenizer](#training-spm-tokenizer) + - [Preprocessing Script](#preprocessing-script) + +## General Setup + +This document provides instructions on downloading and preparing all datasets utilized in the AlgoPerf benchmark. You can prepare the individual datasets one-by-one as needed. If your setup, such as your cloud or cluster environment, already contains these datasets, you may skip the dataset setup for this particular data (and directly specify the dataset location in the `submission_runner.py`). Just verify that you are using the same dataset version (and possible preprocessing). + +*TL;DR to download and prepare a dataset, run `dataset_setup.py`:* + ```bash python3 datasets/dataset_setup.py \ --data_dir=~/data \ -- - -- + -- ``` -The complete benchmark uses 6 datasets: -- OGBG -- WMT -- FastMRI -- Imagenet -- Criteo 1TB -- Librispeech +The complete benchmark uses 6 different datasets: + +- [OGBG](#ogbg) +- [WMT](#wmt) +- [FastMRI](#fastmri) +- [Imagenet](#imagenet) +- [Criteo 1TB](#criteo1tb) +- [Librispeech](#librispeech) -Some dataset setups will require you to sign a third party agreement with the dataset owners in order to get the donwload URLs. +Some dataset setups will require you to sign a third-party agreement with the dataset owners in order to get the download URLs. -# Per dataset instructions -## Environment +### Set Data Directory (Docker Container) -### Set data directory (Docker container) -If you are running the `dataset_setup.py` script from a Docker container, please +If you are running the `dataset_setup.py` script from a Docker container, please make sure the data directory is mounted to a directory on your host with --v flag. If you are following instructions from the README you will have used +`-v` flag. If you are following instructions from the [Getting Started guide](/GETTING_STARTED.md) you will have used the `-v $HOME/data:/data` flag in the `docker run` command. This will mount -the `$HOME/data` directory to the `/data` directory in the container. -In this case set --data_dir to `/data`. +the `$HOME/data` directory to the `/data` directory in the container. +In this case set, `--data_dir` to `/data`. + ```bash DATA_DIR='/data' ``` -### Set data directory (on host) -Alternatively, if you are running the data download script directly on your host, feel free -to choose whatever directory you find suitable, further submission instructions -assume the data is stored in `~/data`. + +### Set Data Directory (on Host) + +Alternatively, if you are running the data download script directly on your host, feel free to choose whatever directory you find suitable, further submission instructions assume the data is stored in `~/data`. + ```bash DATA_DIR='~/data' ``` + #### Start tmux session (Recommended) -If running the dataset_setup.py on directly on host it is recommended to run -the dataset_setup.py script in a tmux session because some of the data downloads may -take several hours. To avoid your setup being interrupted start a tmux session: + +If running the `dataset_setup.py` on directly on host it is recommended to run +the `dataset_setup.py` script in a `tmux` session because some of the data downloads may take several hours. To avoid your setup being interrupted start a `tmux` session: + ```bash tmux new -s data_setup ``` +### Clean up + +In order to avoid potential accidental deletion, this script does NOT +delete any intermediate temporary files (such as zip archives) without a user +confirmation. Deleting temp files is particularly important for Criteo 1TB, as +there can be multiple copies of the dataset on disk during preprocessing if +files are not cleaned up. + +By default, a user will be prompted before any files are deleted. If you do not want any temp files to be deleted, you can pass `--interactive_deletion=false` and then all files will be downloaded to the provided `--temp_dir`, and the user can manually delete these after downloading has finished. -## Datasets +## Individual Dataset Instructions + +### OGBG -### OGBG From `algorithmic-efficiency` run: + ```bash python3 datasets/dataset_setup.py \ ---data_dir $DATA_DIR/ogbg \ +--data_dir $DATA_DIR \ --ogbg ``` -### WMT +### WMT + From `algorithmic-efficiency` run: + ```bash python3 datasets/dataset_setup.py \ --data_dir $DATA_DIR \ --wmt ``` +### FastMRI -## FastMRI -Fill out form on https://fastmri.med.nyu.edu/. After filling out the form +Fill out form on . After filling out the form you should get an email containing the URLS for "knee_singlecoil_train", "knee_singlecoil_val" and "knee_singlecoil_test". @@ -81,18 +118,14 @@ python3 datasets/dataset_setup.py \ --fastmri_knee_singlecoil_test_url '' ``` -## ImageNet -Register on https://image-net.org/ and follow directions to obtain the -URLS for the ILSVRC2012 train and validation images. +### ImageNet -Imagenet dataset processsing is resource intensive. To avoid potential -ResourcExhausted errors increase the maximum number of open file descriptors: -```bash -ulimit -n 8192 -``` +Register on and follow directions to obtain the +URLS for the ILSVRC2012 train and validation images. +The script will additionally automatically download the `matched-frequency` version of [ImageNet v2](https://www.tensorflow.org/datasets/catalog/imagenet_v2#imagenet_v2matched-frequency_default_config), which is used as the test set of the ImageNet workloads. -The imagenet data pipeline differs between the pytorch and jax workloads. -Therefore, you will have to specify the framework (pytorch or jax) through theframework flag. +The ImageNet data pipeline differs between the PyTorch and JAX workloads. +Therefore, you will have to specify the framework (either `pytorch` or `jax`) through the framework flag. ```bash python3 datasets/dataset_setup.py \ @@ -102,15 +135,22 @@ python3 datasets/dataset_setup.py \ --imagenet_train_url \ --imagenet_val_url \ --framework jax +``` +Imagenet dataset processsing is resource intensive. To avoid potential +ResourcExhausted errors increase the maximum number of open file descriptors: + +```bash +ulimit -n 8192 ``` -Note that some functions use subprocess.Popen(..., shell=True), which can be -dangerous if the user injects code into the --data_dir or --temp_dir flags. We -do some basic sanitization in main(), but submitters should not let untrusted +Note that some functions use `subprocess.Popen(..., shell=True)`, which can be +dangerous if the user injects code into the `--data_dir` or `--temp_dir` flags. We +do some basic sanitization in `main()`, but submitters should not let untrusted users run this script on their systems. -## Criteo1tb +### Criteo1TB + ```bash python3 datasets/dataset_setup.py \ --data_dir $DATA_DIR \ @@ -118,19 +158,10 @@ python3 datasets/dataset_setup.py \ --criteo1tb ``` -### Clean up -In order to avoid potential accidental deletion, this script does NOT -delete any intermediate temporary files (such as zip archives) without a user -confirmation. Deleting temp files is particularly important for Criteo 1TB, as -there can be multiple copies of the dataset on disk during preprocessing if -files are not cleaned up. If you do not want any temp files to be deleted, you -can pass --interactive_deletion=false and then all files will be downloaded to -the provided --temp_dir, and the user can manually delete these after -downloading has finished. +### LibriSpeech - -## Librispeech To download, train a tokenizer and preprocess the librispeech dataset: + ```bash python3 datasets/dataset_setup.py \ --data_dir $DATA_DIR \ @@ -138,26 +169,26 @@ python3 datasets/dataset_setup.py \ --librispeech ``` -### Notes on librispeech preprocessing #### Training SPM Tokenizer + A simple sentence piece tokenizer is trained over librispeech training data. This tokenizer is then used in later preprocessing step to tokenize transcripts. This command generates `spm_model.vocab` file in `$DATA_DIR/librispeech`: + ```bash python3 librispeech_tokenizer.py --train --data_dir=$DATA_DIR/librispeech ``` The trained tokenizer can be loaded back to do sanity check by tokenizing + de-tokenizing a constant string: + ```bash librispeech_tokenizer.py --data_dir=$DATA_DIR/librispeech ``` #### Preprocessing Script + The preprocessing script will generate `.npy` files for audio data, `features.csv` which has paths to saved audio `.npy`, and `trans.csv` which has paths to `features.csv` and transcription data. ```bash python3 librispeech_preprocess.py --data_dir=$DATA_DIR/librispeech --tokenizer_vocab_path=$DATA_DIR/librispeech/spm_model.vocab ``` - - - From 23320d2d1c5c140e93c1dfa59d28a073a80a3f4c Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 13 Dec 2023 12:58:13 +0100 Subject: [PATCH 12/86] Standardize how subfolders for datasets are implemented --- datasets/dataset_setup.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index f9ee2f138..f765e4a1a 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -334,6 +334,7 @@ def download_criteo1tb(data_dir, def download_cifar(data_dir, framework): + data_dir = os.path.join(data_dir, 'cifar10') if framework == 'jax': tfds.builder('cifar10:3.0.2', data_dir=data_dir).download_and_prepare() elif framework == 'pytorch': @@ -398,18 +399,18 @@ def extract(source, dest, mode='r:xz'): def setup_fastmri(data_dir, src_data_dir): + data_dir = os.path.join(data_dir, 'fastmri') train_tar_file_path = os.path.join(src_data_dir, FASTMRI_TRAIN_TAR_FILENAME) val_tar_file_path = os.path.join(src_data_dir, FASTMRI_VAL_TAR_FILENAME) test_tar_file_path = os.path.join(src_data_dir, FASTMRI_TEST_TAR_FILENAME) # Make train, val and test subdirectories - fastmri_data_dir = os.path.join(data_dir, 'fastmri') - train_data_dir = os.path.join(fastmri_data_dir, 'train') + train_data_dir = os.path.join(data_dir, 'train') os.makedirs(train_data_dir, exist_ok=True) - val_data_dir = os.path.join(fastmri_data_dir, 'val') + val_data_dir = os.path.join(data_dir, 'val') os.makedirs(val_data_dir, exist_ok=True) - test_data_dir = os.path.join(fastmri_data_dir, 'test') + test_data_dir = os.path.join(data_dir, 'test') os.makedirs(test_data_dir, exist_ok=True) # Unzip tar file into subdirectories @@ -425,6 +426,7 @@ def setup_fastmri(data_dir, src_data_dir): def download_imagenet(data_dir, imagenet_train_url, imagenet_val_url): """Downloads and returns the download dir.""" + data_dir = os.path.join(data_dir, 'imagenet') imagenet_train_filepath = os.path.join(data_dir, IMAGENET_TRAIN_TAR_FILENAME) imagenet_val_filepath = os.path.join(data_dir, IMAGENET_VAL_TAR_FILENAME) @@ -456,6 +458,7 @@ def download_imagenet(data_dir, imagenet_train_url, imagenet_val_url): def setup_imagenet(data_dir, framework=None): + data_dir = os.path.join(data_dir, 'imagenet') if framework == 'jax': setup_imagenet_jax(data_dir) @@ -629,6 +632,7 @@ def download_librispeech(dataset_dir, tmp_dir): def download_mnist(data_dir): + data_dir = os.path.join(data_dir, 'MNIST') # Capitalization to match PyTorch tfds.builder('mnist', data_dir=data_dir).download_and_prepare() @@ -714,9 +718,8 @@ def main(_): raise ValueError( 'Please specify either jax or pytorch framework through framework ' 'flag.') - imagenet_data_dir = os.path.join(data_dir, 'imagenet') - download_imagenet(imagenet_data_dir, imagenet_train_url, imagenet_val_url) - setup_imagenet(imagenet_data_dir, framework=FLAGS.framework) + download_imagenet(data_dir, imagenet_train_url, imagenet_val_url) + setup_imagenet(data_dir, framework=FLAGS.framework) if FLAGS.all or FLAGS.librispeech: logging.info('Downloading Librispeech...') From c46c548976c061aa032b742d320303c3dc24c235 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 13 Dec 2023 13:53:50 +0100 Subject: [PATCH 13/86] Add resulting directory structures and file numbers/sizes --- datasets/README.md | 196 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 196 insertions(+) diff --git a/datasets/README.md b/datasets/README.md index c3feb5fed..4f7b6b880 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -93,6 +93,32 @@ python3 datasets/dataset_setup.py \ --ogbg ``` +
+The final directory structure should look like this: + +```bash +$DATA_DIR +├── ogbg +│ └── ogbg_molpcba +│ └── 0.1.3 +│ ├── dataset_info.json +│ ├── features.json +│ ├── metadata.json +│ ├── ogbg_molpcba-test.tfrecord-00000-of-00001 +│ ├── ogbg_molpcba-train.tfrecord-00000-of-00008 +│ ├── ogbg_molpcba-train.tfrecord-00001-of-00008 +│ ├── ogbg_molpcba-train.tfrecord-00002-of-00008 +│ ├── ogbg_molpcba-train.tfrecord-00003-of-00008 +│ ├── ogbg_molpcba-train.tfrecord-00004-of-00008 +│ ├── ogbg_molpcba-train.tfrecord-00005-of-00008 +│ ├── ogbg_molpcba-train.tfrecord-00006-of-00008 +│ ├── ogbg_molpcba-train.tfrecord-00007-of-00008 +│ └── ogbg_molpcba-validation.tfrecord-00000-of-00001 +``` + +In total, it should contain 13 files (via `find -type f | wc -l`) for a total of 777 MB (via `du -sch ogbg/`). +
+ ### WMT From `algorithmic-efficiency` run: @@ -103,6 +129,64 @@ python3 datasets/dataset_setup.py \ --wmt ``` +
+The final directory structure should look like this: + +```bash +$DATA_DIR +├── wmt + ├── wmt14_translate + │ └── de-en + │ └── 1.0.0 + │ ├── dataset_info.json + │ ├── features.json + │ ├── wmt14_translate-test.tfrecord-00000-of-00001 + │ ├── wmt14_translate-train.tfrecord-00000-of-00016 + │ ├── wmt14_translate-train.tfrecord-00001-of-00016 + │ ├── wmt14_translate-train.tfrecord-00002-of-00016 + │ ├── wmt14_translate-train.tfrecord-00003-of-00016 + │ ├── wmt14_translate-train.tfrecord-00004-of-00016 + │ ├── wmt14_translate-train.tfrecord-00005-of-00016 + │ ├── wmt14_translate-train.tfrecord-00006-of-00016 + │ ├── wmt14_translate-train.tfrecord-00007-of-00016 + │ ├── wmt14_translate-train.tfrecord-00008-of-00016 + │ ├── wmt14_translate-train.tfrecord-00009-of-00016 + │ ├── wmt14_translate-train.tfrecord-00010-of-00016 + │ ├── wmt14_translate-train.tfrecord-00011-of-00016 + │ ├── wmt14_translate-train.tfrecord-00012-of-00016 + │ ├── wmt14_translate-train.tfrecord-00013-of-00016 + │ ├── wmt14_translate-train.tfrecord-00014-of-00016 + │ ├── wmt14_translate-train.tfrecord-00015-of-00016 + │ └── wmt14_translate-validation.tfrecord-00000-of-00001 + ├── wmt17_translate + │ └── de-en + │ └── 1.0.0 + │ ├── dataset_info.json + │ ├── features.json + │ ├── wmt17_translate-test.tfrecord-00000-of-00001 + │ ├── wmt17_translate-train.tfrecord-00000-of-00016 + │ ├── wmt17_translate-train.tfrecord-00001-of-00016 + │ ├── wmt17_translate-train.tfrecord-00002-of-00016 + │ ├── wmt17_translate-train.tfrecord-00003-of-00016 + │ ├── wmt17_translate-train.tfrecord-00004-of-00016 + │ ├── wmt17_translate-train.tfrecord-00005-of-00016 + │ ├── wmt17_translate-train.tfrecord-00006-of-00016 + │ ├── wmt17_translate-train.tfrecord-00007-of-00016 + │ ├── wmt17_translate-train.tfrecord-00008-of-00016 + │ ├── wmt17_translate-train.tfrecord-00009-of-00016 + │ ├── wmt17_translate-train.tfrecord-00010-of-00016 + │ ├── wmt17_translate-train.tfrecord-00011-of-00016 + │ ├── wmt17_translate-train.tfrecord-00012-of-00016 + │ ├── wmt17_translate-train.tfrecord-00013-of-00016 + │ ├── wmt17_translate-train.tfrecord-00014-of-00016 + │ ├── wmt17_translate-train.tfrecord-00015-of-00016 + │ └── wmt17_translate-validation.tfrecord-00000-of-00001 + └── wmt_sentencepiece_model +``` + +In total, it should contain 43 files (via `find -type f | wc -l`) for a total of 3.3 GB (via `du -sch wmt/`). +
+ ### FastMRI Fill out form on . After filling out the form @@ -118,6 +202,29 @@ python3 datasets/dataset_setup.py \ --fastmri_knee_singlecoil_test_url '' ``` +
+The final directory structure should look like this: + +```bash +$DATA_DIR +├── fastmri +│ ├── knee_singlecoil_test +│ │ ├── file1000022.h5 +│ │ ├── [...] +│ │ └── file1002571.h5 +│ ├── knee_singlecoil_train +│ │ ├── file1000001.h5 +│ │ ├── [...] +│ │ └── file1002569.h5 +│ └── knee_singlecoil_val +│ ├── file1000000.h5 +│ ├── [...] +│ └── file1002570.h5 +``` + +In total, it should contain 1280 files (via `find -type f | wc -l`) for a total of 112 GB (via `du -sch fastmri/`). +
+ ### ImageNet Register on and follow directions to obtain the @@ -149,6 +256,73 @@ dangerous if the user injects code into the `--data_dir` or `--temp_dir` flags. do some basic sanitization in `main()`, but submitters should not let untrusted users run this script on their systems. +
+The final directory structure should look like this for ImageNet2012 (PyTorch): + +```bash +$DATA_DIR +├── imagenet +│ ├── train +│ ├── n01440764 +│ ├── n01440764_10026.JPEG +│ ├── n01440764_10027.JPEG +│ ├── n01440764_10029.JPEG +│ ├── [...] +│ ├── [...] +│ └── val +│ ├── n01440764 +│ ├── ILSVRC2012_val_00000293.JPEG +│ ├── ILSVRC2012_val_00002138.JPEG +│ ├── [...] +│ ├── [...] +``` + +In total, it should contain 1,281,167 `train` files and 50,000 `val` (via `find -type f | wc -l`) for a total of 177 GB and 7.8 GB, respectively (via `du -sch train/` and `du -sch val/`). +
+ +**TODO** +
+The final directory structure should look like this for ImageNet2012 (JAX): + +```bash +$DATA_DIR +``` + +In total, it should contain ?? files (via `find -type f | wc -l`) for a total of ?? GB (via `du -sch imagenet/`). +
+ +
+The final directory structure should look like this for ImageNet v2: + +```bash +$DATA_DIR +├── imagenet_v2 +│ └── matched-frequency +│ └── 3.0.0 +│ ├── dataset_info.json +│ ├── features.json +│ ├── imagenet_v2-test.tfrecord-00000-of-00016 +│ ├── imagenet_v2-test.tfrecord-00001-of-00016 +│ ├── imagenet_v2-test.tfrecord-00002-of-00016 +│ ├── imagenet_v2-test.tfrecord-00003-of-00016 +│ ├── imagenet_v2-test.tfrecord-00004-of-00016 +│ ├── imagenet_v2-test.tfrecord-00005-of-00016 +│ ├── imagenet_v2-test.tfrecord-00006-of-00016 +│ ├── imagenet_v2-test.tfrecord-00007-of-00016 +│ ├── imagenet_v2-test.tfrecord-00008-of-00016 +│ ├── imagenet_v2-test.tfrecord-00009-of-00016 +│ ├── imagenet_v2-test.tfrecord-00010-of-00016 +│ ├── imagenet_v2-test.tfrecord-00011-of-00016 +│ ├── imagenet_v2-test.tfrecord-00012-of-00016 +│ ├── imagenet_v2-test.tfrecord-00013-of-00016 +│ ├── imagenet_v2-test.tfrecord-00014-of-00016 +│ ├── imagenet_v2-test.tfrecord-00015-of-00016 +│ └── label.labels.txt +``` + +In total, it should contain 20 files (via `find -type f | wc -l`) for a total of 1.2 GB (via `du -sch imagenet_v2/`). +
+ ### Criteo1TB ```bash @@ -158,6 +332,17 @@ python3 datasets/dataset_setup.py \ --criteo1tb ``` +**TODO** +
+The final directory structure should look like this: + +```bash +$DATA_DIR +``` + +In total, it should contain ?? files (via `find -type f | wc -l`) for a total of ?? GB (via `du -sch criteo1tb/`). +
+ ### LibriSpeech To download, train a tokenizer and preprocess the librispeech dataset: @@ -169,6 +354,17 @@ python3 datasets/dataset_setup.py \ --librispeech ``` +**TODO** +
+The final directory structure should look like this: + +```bash +$DATA_DIR +``` + +In total, it should contain ?? files (via `find -type f | wc -l`) for a total of ?? GB (via `du -sch librispeech/`). +
+ #### Training SPM Tokenizer A simple sentence piece tokenizer is trained over librispeech training From 801b9f1f3f525566515961d358ef5616d01abd2a Mon Sep 17 00:00:00 2001 From: andres-fr Date: Wed, 13 Dec 2023 16:33:26 +0100 Subject: [PATCH 14/86] updated Singularity docs in GETTING_STARTED.md with PR 553 --- GETTING_STARTED.md | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index d9f2a7051..b13f9f00c 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -156,29 +156,29 @@ To use the Docker container as an interactive virtual environment, you can run a ### Using Singularity/Apptainer instead of Docker -Since many compute clusters don't allow the usage of Docker due to securtiy concerns and instead encourage the use of [Singularity/Apptainer](https://github.com/apptainer/apptainer) (formerly Singularity, now called Apptainer), we also provide instructions on how to build an Apptainer container based on the here provided Dockerfile. - -To convert the Dockerfile into an Apptainer definition file, we will use [spython](https://github.com/singularityhub/singularity-cli): +Since many compute clusters don't allow the usage of Docker due to securtiy concerns and instead encourage the use of [Singularity/Apptainer](https://github.com/apptainer/apptainer) (formerly Singularity, now called Apptainer), we also provide an Apptainer recipe (located at `docker/Singularity.def`) that can be used to build an image by running ```bash -pip3 install spython -cd algorithmic-efficiency/docker -spython recipe Dockerfile &> Singularity.def +singularity build --fakeroot .sif Singularity.def ``` -Now we can build the Apptainer image by running - +Note that this can take several minutes. Then, to start a shell session with GPU support (by using the `--nv` flag), we can run ```bash -singularity build --fakeroot .sif Singularity.def +singularity shell --bind $HOME/data:/data,$HOME/experiment_runs:/experiment_runs \ + --nv .sif ``` -To start a shell session with GPU support (by using the `--nv` flag), we can run +Note the `--bind` flag which, similarly to Docker, allows to bind specific paths on the host system and the container, as explained [here](https://docs.sylabs.io/guides/3.7/user-guide/bind_paths_and_mounts.html). + +Also note that we generated `Singularity.def` automatically from the `Dockerfile` using [spython](https://github.com/singularityhub/singularity-cli), as follows: ```bash -singularity shell --nv .sif +pip3 install spython +cd algorithmic-efficiency/docker +python scripts/singularity_converter.py -i Dockerfile -o Singularity.def ``` -Similarly to Docker, Apptainer allows you to bind specific paths on the host system and the container by specifying the `--bind` flag, as explained [here](https://docs.sylabs.io/guides/3.7/user-guide/bind_paths_and_mounts.html). +Users that wish to customize their images are invited to check and modify the `Singularity.def` recipe and the `singularity_converter.py` script. ## Download the Data From 176460c09d469439d994c585f2c8ec733651ca54 Mon Sep 17 00:00:00 2001 From: andres-fr Date: Thu, 14 Dec 2023 00:47:06 +0100 Subject: [PATCH 15/86] added PyTorch install TLDR to readme, which was removed by accident --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index dba91eefc..3e888a0a9 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,14 @@ pip3 install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax pip3 install -e '.[full]' ``` +*TL;DR to install the PyTorch version for GPU run:* + +```bash +pip3 install -e '.[jax_cpu]' +pip3 install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/torch_stable.html' +pip3 install -e '.[full]' +``` + ## Getting Started For detailed instructions on developing and scoring your own algorithm in the benchmark see the [Getting Started](/GETTING_STARTED.md) document. From d3bd4eabe83b598ac45fa8bcc473745ebbd4c917 Mon Sep 17 00:00:00 2001 From: Frank Date: Fri, 15 Dec 2023 23:25:41 +0100 Subject: [PATCH 16/86] Highlight next deadline (#603) --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index 3e888a0a9..941344903 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,12 @@ > *AlgoPerf* is a suite of benchmarks and competitions to measure neural network training speedups due to algorithmic improvements in both training algorithms and models. This is the repository for the *AlgoPerf: Training Algorithms benchmark* and its associated competition. It is developed by the [MLCommons Algorithms Working Group](https://mlcommons.org/en/groups/research-algorithms/). This repository holds the [**competition rules**](/COMPETITION_RULES.md), the [**technical documentation**](/DOCUMENTATION.md) of the benchmark, [**getting started guides**](/GETTING_STARTED.md), and the benchmark code. For a detailed description of the benchmark design, see our [**paper**](https://arxiv.org/abs/2306.07179). +--- + +> [!IMPORTANT] +> Upcoming Deadline: +> Registration deadline to express non-binding intent to submit: **January 28th, 2024** + ## Table of Contents - [Installation](#installation) From 1fdd724cb91e9355336452266ffd3e9619b1d840 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Tue, 19 Dec 2023 11:30:09 +0100 Subject: [PATCH 17/86] Add missing directory structures --- datasets/README.md | 68 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 61 insertions(+), 7 deletions(-) diff --git a/datasets/README.md b/datasets/README.md index 4f7b6b880..ce2a6390e 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -280,19 +280,38 @@ $DATA_DIR In total, it should contain 1,281,167 `train` files and 50,000 `val` (via `find -type f | wc -l`) for a total of 177 GB and 7.8 GB, respectively (via `du -sch train/` and `du -sch val/`). -**TODO**
The final directory structure should look like this for ImageNet2012 (JAX): ```bash $DATA_DIR +├──imagenet +│ ├── jax +│ │ ├── downloads +│ │ │ ├── extracted +│ │ │ └── manual_ +│ │ ├── imagenet2012 +│ │ │ └── 5.1.0 +│ │ │ ├── dataset_info.json +│ │ │ ├── features.json +│ │ │ ├── imagenet2012-train.tfrecord-00000-of-01024 +│ │ │ ├── imagenet2012-train.tfrecord-00001-of-01024 +│ │ │ ├── [...] +│ │ └── imagenet_v2 +│ │ └── matched-frequency +│ │ └── 3.0.0 +│ │ ├── dataset_info.json +│ │ ├── features.json +│ │ ├── imagenet_v2-test.tfrecord-00000-of-00016 +│ │ ├── imagenet_v2-test.tfrecord-00001-of-00016 +│ │ ├── [...] ``` -In total, it should contain ?? files (via `find -type f | wc -l`) for a total of ?? GB (via `du -sch imagenet/`). +In total, it should contain 1,111 files (via `find -type f | wc -l`) for a total of 145 GB (via `du -sch imagenet/jax`).
-The final directory structure should look like this for ImageNet v2: +The final directory structure should look like this for ImageNet v2 (separate): ```bash $DATA_DIR @@ -332,15 +351,20 @@ python3 datasets/dataset_setup.py \ --criteo1tb ``` -**TODO**
The final directory structure should look like this: ```bash $DATA_DIR +├── criteo1tb +│ ├── day_0_000.csv +│ ├── day_0_001.csv +│ ├── day_0_002.csv +│ ├── day_0_003.csv +│ ├── [...] ``` -In total, it should contain ?? files (via `find -type f | wc -l`) for a total of ?? GB (via `du -sch criteo1tb/`). +In total, it should contain 885 files (via `find -type f | wc -l`) for a total of 1.1 TB (via `du -sch criteo1tb/`).
### LibriSpeech @@ -354,15 +378,45 @@ python3 datasets/dataset_setup.py \ --librispeech ``` -**TODO**
The final directory structure should look like this: ```bash $DATA_DIR +├──librispeech +│ ├── dev-clean +│ │ ├── 1272-128104-0000_audio.npy +│ │ ├── 1272-128104-0000_targets.npy +│ │ ├── [...] +│ ├── dev-clean.csv +│ ├── dev-other +│ │ ├── 116-288045-0000_audio.npy +│ │ ├── 116-288045-0000_targets.npy +│ │ ├── [...] +│ ├── dev-other.csv +│ ├── spm_model.vocab +│ ├── test-clean +│ │ ├── 1089-134686-0000_audio.npy +│ │ ├── 1089-134686-0000_targets.npy +│ │ ├── [...] +│ ├── test-clean.csv +│ ├── train-clean-100 +│ │ ├── 103-1240-0000_audio.npy +│ │ ├── 103-1240-0000_targets.npy +│ │ ├── [...] +│ ├── train-clean-100.csv +│ ├── train-clean-360 +│ │ ├── 100-121669-0000_audio.npy +│ │ ├── 100-121669-0000_targets.npy +│ │ ├── [...] +│ ├── train-clean-360.csv +│ │ ├── 985-126228-0050_audio.npy +│ │ └── 985-126228-0050_targets.npy +│ │ ├── [...] +│ └── train-other-500.csv ``` -In total, it should contain ?? files (via `find -type f | wc -l`) for a total of ?? GB (via `du -sch librispeech/`). +In total, it should contain 543,323 files (via `find -type f | wc -l`) for a total of 338 GB (via `du -sch librispeech/`).
#### Training SPM Tokenizer From b3b0785458ddb0ed38c25945a744530930637f06 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Tue, 19 Dec 2023 11:30:21 +0100 Subject: [PATCH 18/86] Add download and disk sizes --- datasets/dataset_setup.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index f765e4a1a..9140ed18a 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -26,17 +26,19 @@ Criteo 1TB download size: ~350GB Criteo 1TB final disk size: ~1TB -FastMRI download size: -FastMRI final disk size: -LibriSpeech download size: -LibriSpeech final disk size: -OGBG download size: -OGBG final disk size: -WMT download size: (1.58 GiB + ) = -WMT final disk size: +FastMRI download size: ~90GB +FastMRI final disk size: ~110GB +ImageNet download size: ~150GB +ImageNet final disk size: ~150GB +LibriSpeech download size: ~60GB +LibriSpeech final disk size: ~350GB +OGBG download size: ~37MB +OGBG final disk size: ~800MB +WMT download size: ~3GB +WMT final disk size: ~3GB _______________________ -Total download size: -Total disk size: +Total download size: ~650GB +Total disk size: ~1.1TB Some datasets require signing a form before downloading: @@ -49,8 +51,8 @@ Register on https://image-net.org/ and run this script with the links to the ILSVRC2012 train and validation images. -Note for tfds ImageNet, you may have to increase the max number of files allowed -open at once using `ulimit -n 8192`. +Note for tfds ImageNet, you may have to increase the max number of files +allowed open at once using `ulimit -n 8192`. Example command: From e1fd0f1e22c72f8d099cb52f3647ce04a9606f89 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Tue, 19 Dec 2023 11:30:35 +0100 Subject: [PATCH 19/86] Remove unused import --- datasets/dataset_setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 9140ed18a..2ddbf4438 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -79,7 +79,6 @@ import functools import os -import resource import shutil import subprocess import tarfile From efa518572531ecbe958158794b256063f6557939 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Tue, 19 Dec 2023 12:46:32 +0100 Subject: [PATCH 20/86] fix fastmri dir structure and simplify --- datasets/dataset_setup.py | 46 +++++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 2ddbf4438..7638cd1b6 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -399,31 +399,35 @@ def extract(source, dest, mode='r:xz'): tar.close() -def setup_fastmri(data_dir, src_data_dir): - data_dir = os.path.join(data_dir, 'fastmri') - - train_tar_file_path = os.path.join(src_data_dir, FASTMRI_TRAIN_TAR_FILENAME) - val_tar_file_path = os.path.join(src_data_dir, FASTMRI_VAL_TAR_FILENAME) - test_tar_file_path = os.path.join(src_data_dir, FASTMRI_TEST_TAR_FILENAME) - - # Make train, val and test subdirectories - train_data_dir = os.path.join(data_dir, 'train') - os.makedirs(train_data_dir, exist_ok=True) - val_data_dir = os.path.join(data_dir, 'val') - os.makedirs(val_data_dir, exist_ok=True) - test_data_dir = os.path.join(data_dir, 'test') - os.makedirs(test_data_dir, exist_ok=True) +def setup_fastmri(data_dir): + train_tar_file_path = os.path.join(data_dir, FASTMRI_TRAIN_TAR_FILENAME) + val_tar_file_path = os.path.join(data_dir, FASTMRI_VAL_TAR_FILENAME) + test_tar_file_path = os.path.join(data_dir, FASTMRI_TEST_TAR_FILENAME) # Unzip tar file into subdirectories - logging.info('Unzipping {} to {}'.format(train_tar_file_path, train_data_dir)) - extract(train_tar_file_path, train_data_dir) - logging.info('Unzipping {} to {}'.format(val_tar_file_path, val_data_dir)) - extract(val_tar_file_path, val_data_dir) - logging.info('Unzipping {} to {}'.format(test_tar_file_path, test_data_dir)) - extract(test_tar_file_path, test_data_dir) - logging.info('Set up fastMRI dataset complete') + logging.info('Unzipping {} to {}'.format(train_tar_file_path, data_dir)) + extract(train_tar_file_path, data_dir) + logging.info('Unzipping {} to {}'.format(val_tar_file_path, data_dir)) + extract(val_tar_file_path, data_dir) + logging.info('Unzipping {} to {}'.format(test_tar_file_path, data_dir)) + extract(test_tar_file_path, data_dir) logging.info('Extraction completed!') + # Rename folders to match what the workload expects + os.rename( + os.path.join(data_dir, "singlecoil_train"), + os.path.join(data_dir, "knee_singlecoil_train"), + ) + os.rename( + os.path.join(data_dir, "singlecoil_val"), + os.path.join(data_dir, "knee_singlecoil_val"), + ) + os.rename( + os.path.join(data_dir, "singlecoil_test"), + os.path.join(data_dir, "knee_singlecoil_test"), + ) + logging.info("Set up fastMRI dataset complete") + def download_imagenet(data_dir, imagenet_train_url, imagenet_val_url): """Downloads and returns the download dir.""" From c4d473393f08ff5cbd878989f7cbdb3537af5352 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Tue, 19 Dec 2023 12:50:47 +0100 Subject: [PATCH 21/86] Move pydub to librispeech dependency --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index a00da91fc..20139d4c0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -67,7 +67,6 @@ full = %(ogbg)s %(librispeech_conformer)s %(wmt)s - pydub==0.25.1 # All workloads plus development dependencies full_dev = @@ -98,6 +97,7 @@ ogbg = librispeech_conformer = sentencepiece==0.1.99 tensorflow-text==2.12.1 + pydub==0.25.1 wmt = sentencepiece==0.1.99 From 51e83409237168910661d875ea97b7812482ce79 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Tue, 19 Dec 2023 12:50:58 +0100 Subject: [PATCH 22/86] Note the requirement of pigz and ffmpeg --- datasets/README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/datasets/README.md b/datasets/README.md index ce2a6390e..685f56eed 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -351,6 +351,8 @@ python3 datasets/dataset_setup.py \ --criteo1tb ``` +Note, that this requries the [`pigz` library](https://zlib.net/pigz/) to be installed. +
The final directory structure should look like this: @@ -378,6 +380,8 @@ python3 datasets/dataset_setup.py \ --librispeech ``` +Note, that this requries the [`ffmpeg` toolbox](https://ffmpeg.org/) to be installed. +
The final directory structure should look like this: From aad8ec4b795e41f40216c0aca1ea7cb4c0815eb8 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 20 Dec 2023 21:33:57 +0000 Subject: [PATCH 23/86] fix criteo datasetting split --- datasets/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index f9ee2f138..755b4a93e 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -324,7 +324,7 @@ def download_criteo1tb(data_dir, unzipped_path = os.path.join(criteo_dir, f'day_{day}.csv') unzipped_paths.append(unzipped_path) split_path = os.path.join(criteo_dir, f'day_{day}_') - split_cmd = ('split -a 3 -d -l 5000000 --additional-suffix=.csv ' + split_cmd = ('split -a 2 -d -l 5000000 ' f'"{unzipped_path}" "{split_path}"') logging.info(f'Running Criteo 1TB split command:\n{split_cmd}') batch_processes.append(subprocess.Popen(split_cmd, shell=True)) From 1d6330c792fec0501bc62c24a23917cb553a78d1 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Thu, 21 Dec 2023 11:36:53 +0100 Subject: [PATCH 24/86] Specified the librispeech structure --- datasets/README.md | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/datasets/README.md b/datasets/README.md index 685f56eed..5f2ce7504 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -388,36 +388,39 @@ Note, that this requries the [`ffmpeg` toolbox](https://ffmpeg.org/) to be insta ```bash $DATA_DIR ├──librispeech +│ ├── dev-clean.csv +│ ├── dev-other.csv +│ ├── spm_model.vocab +│ ├── test-clean.csv +│ ├── train-clean-100.csv +│ ├── train-clean-360.csv +│ ├── train-clean-500.csv │ ├── dev-clean │ │ ├── 1272-128104-0000_audio.npy │ │ ├── 1272-128104-0000_targets.npy +│ │ ├── 1272-128104-0001_audio.npy +│ │ ├── 1272-128104-0001_targets.npy │ │ ├── [...] -│ ├── dev-clean.csv │ ├── dev-other │ │ ├── 116-288045-0000_audio.npy │ │ ├── 116-288045-0000_targets.npy │ │ ├── [...] -│ ├── dev-other.csv -│ ├── spm_model.vocab │ ├── test-clean │ │ ├── 1089-134686-0000_audio.npy │ │ ├── 1089-134686-0000_targets.npy │ │ ├── [...] -│ ├── test-clean.csv │ ├── train-clean-100 │ │ ├── 103-1240-0000_audio.npy │ │ ├── 103-1240-0000_targets.npy │ │ ├── [...] -│ ├── train-clean-100.csv │ ├── train-clean-360 │ │ ├── 100-121669-0000_audio.npy │ │ ├── 100-121669-0000_targets.npy │ │ ├── [...] -│ ├── train-clean-360.csv -│ │ ├── 985-126228-0050_audio.npy -│ │ └── 985-126228-0050_targets.npy +│ ├── train-other-500 +│ │ ├── 1006-135212-0000_audio.npy +│ │ ├── 1006-135212-0000_targets.npy │ │ ├── [...] -│ └── train-other-500.csv ``` In total, it should contain 543,323 files (via `find -type f | wc -l`) for a total of 338 GB (via `du -sch librispeech/`). From 0d911e0aeaabe558fb0df5d9eb2cf47b311d62a0 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Thu, 21 Dec 2023 11:41:17 +0100 Subject: [PATCH 25/86] Do not process `test-other` split --- datasets/dataset_setup.py | 2 ++ datasets/librispeech_preprocess.py | 6 ++---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 7638cd1b6..f52a9808e 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -601,6 +601,8 @@ def download_librispeech(dataset_dir, tmp_dir): for split in ['dev', 'test']: for version in ['clean', 'other']: + if split == 'test' and version == 'other': + continue wget_cmd = ( f'wget --directory-prefix={tmp_librispeech_dir} ' f'http://www.openslr.org/resources/12/{split}-{version}.tar.gz') diff --git a/datasets/librispeech_preprocess.py b/datasets/librispeech_preprocess.py index acdaa8e98..a8c5cae1d 100644 --- a/datasets/librispeech_preprocess.py +++ b/datasets/librispeech_preprocess.py @@ -31,8 +31,7 @@ 'train-clean-100': 28539, 'train-clean-360': 104014, 'train-other-500': 148688, - 'test-clean': 2620, - 'test-other': 2939, + 'test-clean': 2620, # 'test-other': 2939, 'dev-clean': 2703, 'dev-other': 2864, } @@ -153,8 +152,7 @@ def run(input_dir, output_dir, tokenizer_vocab_path): 'train-other-500', 'dev-clean', 'dev-other', - 'test-clean', - 'test-other', + 'test-clean', # 'test-other', ] for subset in subset_list: logging.info('Processing split = %s...', subset) From 6cf89e4e598f156da46c3cb401119582e8cb7d13 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Thu, 21 Dec 2023 11:55:38 +0100 Subject: [PATCH 26/86] Store tokenizer in the right directory --- datasets/dataset_setup.py | 11 +++++++---- datasets/librispeech_tokenizer.py | 11 +++++------ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index f52a9808e..ab9f31db5 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -588,13 +588,13 @@ def download_imagenet_v2(data_dir): data_dir=data_dir).download_and_prepare() -def download_librispeech(dataset_dir, tmp_dir): +def download_librispeech(data_dir, tmp_dir): # After extraction the result is a folder named Librispeech containing audio # files in .flac format along with transcripts containing name of audio file # and corresponding transcription. tmp_librispeech_dir = os.path.join(tmp_dir, 'librispeech') extracted_data_dir = os.path.join(tmp_librispeech_dir, 'LibriSpeech') - final_data_dir = os.path.join(dataset_dir, 'librispeech') + final_data_dir = os.path.join(data_dir, 'librispeech') _maybe_mkdir(tmp_librispeech_dir) _maybe_mkdir(final_data_dir) @@ -627,10 +627,13 @@ def download_librispeech(dataset_dir, tmp_dir): f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}', shell=True).communicate() - tokenizer_vocab_path = os.path.join(extracted_data_dir, 'spm_model.vocab') + tokenizer_vocab_path = os.path.join(data_dir, 'spm_model.vocab') if not os.path.exists(tokenizer_vocab_path): - librispeech_tokenizer.run(train=True, data_dir=extracted_data_dir) + librispeech_tokenizer.run( + train=True, + input_dir=extracted_data_dir, + tokenizer_vocab_path=tokenizer_vocab_path) librispeech_preprocess.run( input_dir=extracted_data_dir, diff --git a/datasets/librispeech_tokenizer.py b/datasets/librispeech_tokenizer.py index e701d59d4..2f559752a 100644 --- a/datasets/librispeech_tokenizer.py +++ b/datasets/librispeech_tokenizer.py @@ -108,17 +108,16 @@ def load_tokenizer(model_filepath): return sp_tokenizer -def run(train, data_dir): - logging.info('Data dir: %s', data_dir) - vocab_path = os.path.join(data_dir, 'spm_model.vocab') - logging.info('vocab_path = ', vocab_path) +def run(train, input_dir, tokenizer_vocab_path): + logging.info('Data dir: %s', input_dir) + logging.info('vocab_path = %s', tokenizer_vocab_path) if train: logging.info('Training...') splits = ['train-clean-100'] - train_tokenizer(data_dir, splits, model_path=vocab_path) + train_tokenizer(input_dir, splits, model_path=tokenizer_vocab_path) else: - tokenizer = load_tokenizer(vocab_path) + tokenizer = load_tokenizer(tokenizer_vocab_path) test_input = 'OPEN SOURCE ROCKS' tokens = tokenizer.tokenize(test_input) detokenized = tokenizer.detokenize(tokens).numpy().decode('utf-8') From 25739febbabba6e203243bb3678fd0d870aba4ab Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Thu, 21 Dec 2023 12:24:32 +0100 Subject: [PATCH 27/86] fix tokenizer folder --- datasets/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index ab9f31db5..ad373dd43 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -627,7 +627,7 @@ def download_librispeech(data_dir, tmp_dir): f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}', shell=True).communicate() - tokenizer_vocab_path = os.path.join(data_dir, 'spm_model.vocab') + tokenizer_vocab_path = os.path.join(final_data_dir, 'spm_model.vocab') if not os.path.exists(tokenizer_vocab_path): librispeech_tokenizer.run( From 87e26720da18980227d30f21f280dd4c8e1de357 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Thu, 21 Dec 2023 12:56:35 +0100 Subject: [PATCH 28/86] Fix the final directory structure of Criteo --- datasets/README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/datasets/README.md b/datasets/README.md index 5f2ce7504..37480f4f8 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -359,10 +359,10 @@ Note, that this requries the [`pigz` library](https://zlib.net/pigz/) to be inst ```bash $DATA_DIR ├── criteo1tb -│ ├── day_0_000.csv -│ ├── day_0_001.csv -│ ├── day_0_002.csv -│ ├── day_0_003.csv +│ ├── day_0_00 +│ ├── day_0_01 +│ ├── day_0_02 +│ ├── day_0_03 │ ├── [...] ``` @@ -428,8 +428,8 @@ In total, it should contain 543,323 files (via `find -type f | wc -l`) for a tot #### Training SPM Tokenizer - A simple sentence piece tokenizer is trained over librispeech training - data. This tokenizer is then used in later preprocessing step to tokenize transcripts. +During the above commands, a simple sentence piece tokenizer is trained over librispeech training data. +This tokenizer is then used in later preprocessing step to tokenize transcripts. This command generates `spm_model.vocab` file in `$DATA_DIR/librispeech`: ```bash From 7fb8124b4486274d81d500d6e2f64905034f1e2f Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 12:55:01 +0000 Subject: [PATCH 29/86] add model and workload variant code for ogbg --- .../workloads/ogbg/ogbg_jax/models.py | 24 +++++++++--- .../workloads/ogbg/ogbg_jax/workload.py | 38 ++++++++++++++++++- .../workloads/ogbg/ogbg_pytorch/models.py | 26 +++++++++---- .../workloads/ogbg/ogbg_pytorch/workload.py | 36 +++++++++++++++++- .../workloads/ogbg/workload.py | 19 +++++++++- 5 files changed, 128 insertions(+), 15 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/models.py index 358415587..0e66d2ab8 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/models.py @@ -15,7 +15,7 @@ def make_fn(inputs): return make_fn -def _make_mlp(hidden_dims, dropout): +def _make_mlp(hidden_dims, dropout, activation_fn): """Creates a MLP with specified dimensions.""" @jraph.concatenated_args @@ -24,7 +24,7 @@ def make_fn(inputs): for dim in hidden_dims: x = nn.Dense(features=dim)(x) x = nn.LayerNorm()(x) - x = nn.relu(x) + x = activation_fn(x) x = dropout(x) return x @@ -42,6 +42,7 @@ class GNN(nn.Module): # If None, defaults to 0.1. dropout_rate: Optional[float] = 0.1 num_message_passing_steps: int = 5 + activation_fn_name: str = 'relu' @nn.compact def __call__(self, graph, train): @@ -59,11 +60,24 @@ def __call__(self, graph, train): embed_edge_fn=_make_embed(self.latent_dim, name='edge_embedding')) graph = embedder(graph) + if self.activation_fn_name == 'relu': + activation_fn = nn.relu + elif self.activation_fn_name == 'gelu': + activation_fn = nn.gelu + elif self.activation_fn_name == 'silu': + activation_fn = nn.silu + else: + raise ValueError( + f'Invalid activation function name: {self.activation_fn_name}') + for _ in range(self.num_message_passing_steps): net = jraph.GraphNetwork( - update_edge_fn=_make_mlp(self.hidden_dims, dropout=dropout), - update_node_fn=_make_mlp(self.hidden_dims, dropout=dropout), - update_global_fn=_make_mlp(self.hidden_dims, dropout=dropout)) + update_edge_fn=_make_mlp( + self.hidden_dims, dropout=dropout, activation_fn=activation_fn), + update_node_fn=_make_mlp( + self.hidden_dims, dropout=dropout, activation_fn=activation_fn), + update_global_fn=_make_mlp( + self.hidden_dims, dropout=dropout, activation_fn=activation_fn)) graph = net(graph) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py index 009aab91a..809148631 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py @@ -25,7 +25,13 @@ def init_model_fn( """aux_dropout_rate is unused.""" del aux_dropout_rate rng, params_rng, dropout_rng = jax.random.split(rng, 3) - self._model = models.GNN(self._num_outputs, dropout_rate=dropout_rate) + self._model = models.GNN( + self._num_outputs, + dropout_rate=dropout_rate, + activation_fn_name=self.activation_fn_name, + hidden_dims=self.hidden_dims, + latent_dim=self.latent_dim, + num_message_passing_steps=self.num_message_passing_steps) init_fn = jax.jit(functools.partial(self._model.init, train=False)) fake_batch = jraph.GraphsTuple( n_node=jnp.asarray([1]), @@ -115,3 +121,33 @@ def _normalize_eval_metrics( del num_examples total_metrics = total_metrics.reduce() return {k: float(v) for k, v in total_metrics.compute().items()} + + +class OgbgGeluWorkload(OgbgWorkload): + + @property + def activation_fn_name(self) -> str: + """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" + return 'gelu' + + +class OgbgSiluWorkload(OgbgWorkload): + + @property + def activation_fn_name(self) -> str: + """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" + return 'silu' + +class OgbgModelSizeWorkload(OgbgWorkload): + + @property + def hidden_dims(self) -> Tuple[int]: + return (256, 256) + + @property + def latent_dim(self) -> int: + return 128 + + @property + def num_message_passing_steps(self) -> int: + return 5 \ No newline at end of file diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 1b392753b..04c503179 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -10,7 +10,7 @@ from algorithmic_efficiency import init_utils -def _make_mlp(in_dim, hidden_dims, dropout_rate): +def _make_mlp(in_dim, hidden_dims, dropout_rate, activation_fn): """Creates a MLP with specified dimensions.""" layers = nn.Sequential() for dim in hidden_dims: @@ -33,7 +33,8 @@ class GNN(nn.Module): def __init__(self, num_outputs: int = 128, - dropout_rate: Optional[float] = 0.1) -> None: + dropout_rate: Optional[float] = 0.1, + activation_fn_name: str = 'relu') -> None: super().__init__() self.num_outputs = num_outputs if dropout_rate is None: @@ -42,6 +43,16 @@ def __init__(self, self.node_embedder = nn.Linear(in_features=9, out_features=self.latent_dim) self.edge_embedder = nn.Linear(in_features=3, out_features=self.latent_dim) + if activation_fn_name == 'relu': + activation_fn = nn.ReLU + elif activation_fn_name == 'gelu': + activation_fn = nn.GeLU + elif activation_fn_name == 'silu': + activation_fn = nn.Silu + else: + raise ValueError( + f'Invalid activation function name: {self.activation_fn_name}') + graph_network_layers = [] for st in range(self.num_message_passing_steps): # Constants in in_dims are based on the requirements of the GraphNetwork. @@ -54,11 +65,12 @@ def __init__(self, graph_network_layers.append( GraphNetwork( - update_edge_fn=_make_mlp(in_dim, self.hidden_dims, dropout_rate), - update_node_fn=_make_mlp(in_dim, self.hidden_dims, dropout_rate), - update_global_fn=_make_mlp(last_in_dim, - self.hidden_dims, - dropout_rate))) + update_edge_fn=_make_mlp( + in_dim, self.hidden_dims, dropout_rate, activation_fn), + update_node_fn=_make_mlp( + in_dim, self.hidden_dims, dropout_rate, activation_fn), + update_global_fn=_make_mlp( + last_in_dim, self.hidden_dims, dropout_rate, activation_fn))) self.graph_network = nn.Sequential(*graph_network_layers) self.decoder = nn.Linear( diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index a1fbf2e8a..b2224bdec 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -144,7 +144,11 @@ def init_model_fn( """aux_dropout_rate is unused.""" del aux_dropout_rate torch.random.manual_seed(rng[0]) - model = GNN(num_outputs=self._num_outputs, dropout_rate=dropout_rate) + model = GNN(num_outputs=self._num_outputs, + dropout_rate=dropout_rate, + hidden_dims=self.hidden_dims, + latent_dim=self.latent_dim, + num_message_passing_steps=self.num_message_passing_steps) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -235,3 +239,33 @@ def _normalize_eval_metrics( """Normalize eval metrics.""" del num_examples return {k: float(v) for k, v in total_metrics.compute().items()} + + +class OgbgGeluWorkload(OgbgWorkload): + + @property + def activation_fn_name(self) -> str: + """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" + return 'gelu' + + +class OgbgSiluWorkload(OgbgWorkload): + + @property + def activation_fn_name(self) -> str: + """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" + return 'silu' + +class OgbgModelSizeWorkload(OgbgWorkload): + + @property + def hidden_dims(self) -> Tuple[int]: + return (256, 256) + + @property + def latent_dim(self) -> int: + return 128 + + @property + def num_message_passing_steps(self) -> int: + return 5 diff --git a/algorithmic_efficiency/workloads/ogbg/workload.py b/algorithmic_efficiency/workloads/ogbg/workload.py index 7ca6ebc1e..8f3e8c122 100644 --- a/algorithmic_efficiency/workloads/ogbg/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/workload.py @@ -3,7 +3,7 @@ import abc import itertools import math -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple import jax @@ -22,6 +22,23 @@ def target_metric_name(self) -> str: """The name of the target metric (useful for scoring/processing code).""" return 'mean_average_precision' + @property + def activation_fn_name(self) -> str: + """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" + return 'relu' + + @property + def hidden_dims(self) -> Tuple[int]: + return (256,) + + @property + def latent_dim(self) -> int: + return 128 + + @property + def num_message_passing_steps(self) -> int: + return 5 + def has_reached_validation_target(self, eval_result: float) -> bool: return eval_result[ 'validation/mean_average_precision'] > self.validation_target_value From d6048500a701b27ddaf3d98b5a257f1b552e2326 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 13:02:40 +0000 Subject: [PATCH 30/86] add ogbg workload variant definitions to registry --- .../workloads/ogbg/ogbg_jax/workload.py | 3 ++- algorithmic_efficiency/workloads/workloads.py | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py index 809148631..e77194643 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py @@ -150,4 +150,5 @@ def latent_dim(self) -> int: @property def num_message_passing_steps(self) -> int: - return 5 \ No newline at end of file + return 5 + \ No newline at end of file diff --git a/algorithmic_efficiency/workloads/workloads.py b/algorithmic_efficiency/workloads/workloads.py index 6d0b08cef..09ddfabfd 100644 --- a/algorithmic_efficiency/workloads/workloads.py +++ b/algorithmic_efficiency/workloads/workloads.py @@ -96,6 +96,15 @@ 'ogbg': { 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgWorkload' }, + 'ogbg_gelu': { + 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgGeluWorkload' + }, + 'ogbg_silu': { + 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgSiluWorkload' + }, + 'ogbg_model_size': { + 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgModelSizeWorkload' + }, 'wmt': {'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkload'}, 'wmt_post_ln': { 'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkloadPostLN' From fdcb5aa098cbc1d0e1584c28946360506125d027 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 13:04:36 +0000 Subject: [PATCH 31/86] add ogbg variants to docker startup.sh --- docker/scripts/startup.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index be14ab498..53ba3f6ba 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -119,7 +119,7 @@ VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_ "criteo1tb_resnet" "criteo1tb_layernorm" "criteo1tb_embed_init" \ "conformer_layernorm" "conformer_attention_temperature" \ "conformer_gelu" "fastmri_model_size" "fastmri_tanh" \ - "fastmri_layernorm") + "fastmri_layernorm" "ogbg_gelu" "ogbg_silu" "ogbg_model_size") # Set data and experiment paths ROOT_DATA_BUCKET="gs://mlcommons-data" From 94a942050108af7c5dabb515c1b4ad5a65c4a8a2 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 14:35:25 +0000 Subject: [PATCH 32/86] activation fn --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 04c503179..4a7d96c13 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -1,5 +1,6 @@ # Ported to PyTorch from # https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py. +from functools import partial from typing import Callable, Optional, Tuple import jax.tree_util as tree @@ -46,7 +47,7 @@ def __init__(self, if activation_fn_name == 'relu': activation_fn = nn.ReLU elif activation_fn_name == 'gelu': - activation_fn = nn.GeLU + activation_fn = partial(nn.GeLU, approximate='tanh') elif activation_fn_name == 'silu': activation_fn = nn.Silu else: From 2c8a3e1cf8b897a6ee66638fd826139f523fa4f5 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 14:43:36 +0000 Subject: [PATCH 33/86] add tests --- tests/modeldiffs/ogbg_gelu/__init__.py | 0 tests/modeldiffs/ogbg_gelu/compare.py | 113 +++++++++++++++++++ tests/modeldiffs/ogbg_model_size/__init__.py | 0 tests/modeldiffs/ogbg_model_size/compare.py | 113 +++++++++++++++++++ tests/modeldiffs/ogbg_silu/__init__.py | 0 tests/modeldiffs/ogbg_silu/compare.py | 113 +++++++++++++++++++ 6 files changed, 339 insertions(+) create mode 100644 tests/modeldiffs/ogbg_gelu/__init__.py create mode 100644 tests/modeldiffs/ogbg_gelu/compare.py create mode 100644 tests/modeldiffs/ogbg_model_size/__init__.py create mode 100644 tests/modeldiffs/ogbg_model_size/compare.py create mode 100644 tests/modeldiffs/ogbg_silu/__init__.py create mode 100644 tests/modeldiffs/ogbg_silu/compare.py diff --git a/tests/modeldiffs/ogbg_gelu/__init__.py b/tests/modeldiffs/ogbg_gelu/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py new file mode 100644 index 000000000..f6175e99d --- /dev/null +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -0,0 +1,113 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import jraph +import numpy as np +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ + OgbgGeluWorkload as JaxWorkload +from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ + OgbgGeluWorkload as PytWorkload +from tests.modeldiffs.diff import out_diff + + +def key_transform(k): + new_key = [] + bn = False + ln = False + for i in k: + bn = bn or 'BatchNorm' in i + ln = ln or 'LayerNorm' in i + if 'ModuleList' in i: + continue + if 'CustomBatchNorm' in i: + continue + if 'Linear' in i: + if 'NonDynamicallyQuantizableLinear' in i: + i = 'out' + else: + i = i.replace('Linear', 'Dense') + elif 'Conv1d' in i: + i = i.replace('Conv1d', 'Conv') + elif 'MHSAwithQS' in i: + i = i.replace('MHSAwithQS', 'SelfAttention') + elif 'weight' in i: + if bn or ln: + i = i.replace('weight', 'scale') + else: + i = i.replace('weight', 'kernel') + new_key.append(i) + return tuple(new_key) + + +def sd_transform(sd): + # pylint: disable=locally-disabled, modified-iterating-dict, consider-using-dict-items + keys = list(sd.keys()) + out = {} + for k in keys: + new_key = k + if len(k) == 5: + _, gn_id, seq_id = k[:3] + gn_id = int(gn_id.split('_')[1]) + seq_id = int(seq_id.split('_')[1]) + if 'LayerNorm' in k[3]: + new_key = (k[3].replace('0', f'{gn_id*3+seq_id}'), k[4]) + else: + new_key = (k[3].replace('0', f'{gn_id*3+seq_id+2}'), k[4]) + elif len(k) == 2 and k[0] == 'Dense_2': + new_key = ('Dense_17', k[1]) + out[new_key] = sd[k] + + return out + + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + pyt_batch = dict( + n_node=torch.LongTensor([5]), + n_edge=torch.LongTensor([5]), + nodes=torch.randn(5, 9), + edges=torch.randn(5, 3), + globals=torch.randn(1, 128), + senders=torch.LongTensor(list(range(5))), + receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)])) + + jax_batch = {k: np.array(v) for k, v in pyt_batch.items()} + + # Test outputs for identical weights and inputs. + graph_j = jraph.GraphsTuple(**jax_batch) + graph_p = jraph.GraphsTuple(**pyt_batch) + + jax_batch = {'inputs': graph_j} + pyt_batch = {'inputs': graph_p} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None) diff --git a/tests/modeldiffs/ogbg_model_size/__init__.py b/tests/modeldiffs/ogbg_model_size/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py new file mode 100644 index 000000000..3818598ed --- /dev/null +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -0,0 +1,113 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import jraph +import numpy as np +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ + OgbgModelSizeWorkload as JaxWorkload +from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ + OgbgModelSizeWorkload as PytWorkload +from tests.modeldiffs.diff import out_diff + + +def key_transform(k): + new_key = [] + bn = False + ln = False + for i in k: + bn = bn or 'BatchNorm' in i + ln = ln or 'LayerNorm' in i + if 'ModuleList' in i: + continue + if 'CustomBatchNorm' in i: + continue + if 'Linear' in i: + if 'NonDynamicallyQuantizableLinear' in i: + i = 'out' + else: + i = i.replace('Linear', 'Dense') + elif 'Conv1d' in i: + i = i.replace('Conv1d', 'Conv') + elif 'MHSAwithQS' in i: + i = i.replace('MHSAwithQS', 'SelfAttention') + elif 'weight' in i: + if bn or ln: + i = i.replace('weight', 'scale') + else: + i = i.replace('weight', 'kernel') + new_key.append(i) + return tuple(new_key) + + +def sd_transform(sd): + # pylint: disable=locally-disabled, modified-iterating-dict, consider-using-dict-items + keys = list(sd.keys()) + out = {} + for k in keys: + new_key = k + if len(k) == 5: + _, gn_id, seq_id = k[:3] + gn_id = int(gn_id.split('_')[1]) + seq_id = int(seq_id.split('_')[1]) + if 'LayerNorm' in k[3]: + new_key = (k[3].replace('0', f'{gn_id*3+seq_id}'), k[4]) + else: + new_key = (k[3].replace('0', f'{gn_id*3+seq_id+2}'), k[4]) + elif len(k) == 2 and k[0] == 'Dense_2': + new_key = ('Dense_17', k[1]) + out[new_key] = sd[k] + + return out + + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + pyt_batch = dict( + n_node=torch.LongTensor([5]), + n_edge=torch.LongTensor([5]), + nodes=torch.randn(5, 9), + edges=torch.randn(5, 3), + globals=torch.randn(1, 128), + senders=torch.LongTensor(list(range(5))), + receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)])) + + jax_batch = {k: np.array(v) for k, v in pyt_batch.items()} + + # Test outputs for identical weights and inputs. + graph_j = jraph.GraphsTuple(**jax_batch) + graph_p = jraph.GraphsTuple(**pyt_batch) + + jax_batch = {'inputs': graph_j} + pyt_batch = {'inputs': graph_p} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None) diff --git a/tests/modeldiffs/ogbg_silu/__init__.py b/tests/modeldiffs/ogbg_silu/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py new file mode 100644 index 000000000..420ee9020 --- /dev/null +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -0,0 +1,113 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import jraph +import numpy as np +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ + OgbgSiluWorkload as JaxWorkload +from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ + OgbgSiluWorkload as PytWorkload +from tests.modeldiffs.diff import out_diff + + +def key_transform(k): + new_key = [] + bn = False + ln = False + for i in k: + bn = bn or 'BatchNorm' in i + ln = ln or 'LayerNorm' in i + if 'ModuleList' in i: + continue + if 'CustomBatchNorm' in i: + continue + if 'Linear' in i: + if 'NonDynamicallyQuantizableLinear' in i: + i = 'out' + else: + i = i.replace('Linear', 'Dense') + elif 'Conv1d' in i: + i = i.replace('Conv1d', 'Conv') + elif 'MHSAwithQS' in i: + i = i.replace('MHSAwithQS', 'SelfAttention') + elif 'weight' in i: + if bn or ln: + i = i.replace('weight', 'scale') + else: + i = i.replace('weight', 'kernel') + new_key.append(i) + return tuple(new_key) + + +def sd_transform(sd): + # pylint: disable=locally-disabled, modified-iterating-dict, consider-using-dict-items + keys = list(sd.keys()) + out = {} + for k in keys: + new_key = k + if len(k) == 5: + _, gn_id, seq_id = k[:3] + gn_id = int(gn_id.split('_')[1]) + seq_id = int(seq_id.split('_')[1]) + if 'LayerNorm' in k[3]: + new_key = (k[3].replace('0', f'{gn_id*3+seq_id}'), k[4]) + else: + new_key = (k[3].replace('0', f'{gn_id*3+seq_id+2}'), k[4]) + elif len(k) == 2 and k[0] == 'Dense_2': + new_key = ('Dense_17', k[1]) + out[new_key] = sd[k] + + return out + + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + pyt_batch = dict( + n_node=torch.LongTensor([5]), + n_edge=torch.LongTensor([5]), + nodes=torch.randn(5, 9), + edges=torch.randn(5, 3), + globals=torch.randn(1, 128), + senders=torch.LongTensor(list(range(5))), + receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)])) + + jax_batch = {k: np.array(v) for k, v in pyt_batch.items()} + + # Test outputs for identical weights and inputs. + graph_j = jraph.GraphsTuple(**jax_batch) + graph_p = jraph.GraphsTuple(**pyt_batch) + + jax_batch = {'inputs': graph_j} + pyt_batch = {'inputs': graph_p} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None) From 609545f65246b99a290eb7622440082dcdfe0420 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 15:06:38 +0000 Subject: [PATCH 34/86] pytorch model ogbg fix --- .../workloads/ogbg/ogbg_pytorch/models.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 4a7d96c13..f616dac6e 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -28,15 +28,18 @@ class GNN(nn.Module): The model assumes the input data is a jraph.GraphsTuple without global variables. The final prediction will be encoded in the globals. """ - latent_dim: int = 256 - hidden_dims: Tuple[int] = (256,) - num_message_passing_steps: int = 5 def __init__(self, num_outputs: int = 128, dropout_rate: Optional[float] = 0.1, - activation_fn_name: str = 'relu') -> None: + activation_fn_name: str = 'relu', + latent_dim: int = 256, + hidden_dims: Tuple[int] = (256,), + num_message_passing_steps: int = 5) -> None: super().__init__() + self.latent_dim = latent_dim + self.hidden_dims = hidden_dims + self.num_message_passing_steps = num_message_passing_steps self.num_outputs = num_outputs if dropout_rate is None: dropout_rate = 0.1 From 54d6ea6939c6cf0e13b87f096cd72c6229343196 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 15:20:53 +0000 Subject: [PATCH 35/86] ogbg fix --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index f616dac6e..0ae2c901a 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -17,7 +17,7 @@ def _make_mlp(in_dim, hidden_dims, dropout_rate, activation_fn): for dim in hidden_dims: layers.add_module('dense', nn.Linear(in_features=in_dim, out_features=dim)) layers.add_module('norm', nn.LayerNorm(dim, eps=1e-6)) - layers.add_module('relu', nn.ReLU()) + layers.add_module('activation_fn', activation_fn) layers.add_module('dropout', nn.Dropout(dropout_rate)) return layers From 20b790ebfaaffffef4a4ab4bb6348c07822c105d Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 15:39:49 +0000 Subject: [PATCH 36/86] fix ogbg --- algorithmic_efficiency/workloads/ogbg/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/workload.py b/algorithmic_efficiency/workloads/ogbg/workload.py index 8f3e8c122..ade91b35d 100644 --- a/algorithmic_efficiency/workloads/ogbg/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/workload.py @@ -33,7 +33,7 @@ def hidden_dims(self) -> Tuple[int]: @property def latent_dim(self) -> int: - return 128 + return 256 @property def num_message_passing_steps(self) -> int: From 362819ea5583abec690524e606f63712e633093d Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 15:42:05 +0000 Subject: [PATCH 37/86] fix --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 0ae2c901a..978b62428 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -17,7 +17,7 @@ def _make_mlp(in_dim, hidden_dims, dropout_rate, activation_fn): for dim in hidden_dims: layers.add_module('dense', nn.Linear(in_features=in_dim, out_features=dim)) layers.add_module('norm', nn.LayerNorm(dim, eps=1e-6)) - layers.add_module('activation_fn', activation_fn) + layers.add_module('activation_fn', activation_fn()) layers.add_module('dropout', nn.Dropout(dropout_rate)) return layers From 4d63b9697e9b67f27604dcf6e8e0b11530f40a09 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 15:58:05 +0000 Subject: [PATCH 38/86] fix ogbg variant --- algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py | 2 +- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py index e77194643..65121ac7b 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py @@ -150,5 +150,5 @@ def latent_dim(self) -> int: @property def num_message_passing_steps(self) -> int: - return 5 + return 3 \ No newline at end of file diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index b2224bdec..c6e57b0f2 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -268,4 +268,4 @@ def latent_dim(self) -> int: @property def num_message_passing_steps(self) -> int: - return 5 + return 3 From 58b0edf24829ac2c8e98a2139704b27b5402d01f Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 16:14:10 +0000 Subject: [PATCH 39/86] ogbg debug --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index c6e57b0f2..102ef7b7c 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -264,7 +264,7 @@ def hidden_dims(self) -> Tuple[int]: @property def latent_dim(self) -> int: - return 128 + return 256 @property def num_message_passing_steps(self) -> int: From e0d7dbfdf28895eb1115a6d02fd97f0833f3703a Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 16:22:32 +0000 Subject: [PATCH 40/86] fix --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index 102ef7b7c..c6e57b0f2 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -264,7 +264,7 @@ def hidden_dims(self) -> Tuple[int]: @property def latent_dim(self) -> int: - return 256 + return 128 @property def num_message_passing_steps(self) -> int: From 710db241e6b751164c4eede35cc8eee947734673 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 17:52:30 +0000 Subject: [PATCH 41/86] debugging --- .../workloads/ogbg/ogbg_jax/workload.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py index 65121ac7b..cb5dda800 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py @@ -43,6 +43,14 @@ def init_model_fn( receivers=jnp.asarray([0])) params = init_fn({'params': params_rng, 'dropout': dropout_rng}, fake_batch) params = params['params'] + tabulate_fn = nn.tabulate( + self._model, + jax.random.PRNGKey(0), + console_kwargs={ + 'force_terminal': False, 'force_jupyter': False, 'width': 240 + }, + ) + print(tabulate_fn(fake_batch, train=False)) self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) return jax_utils.replicate(params), None From c2e1b8b6d959b7020b28cc99671fc948e48a21a0 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 17:53:44 +0000 Subject: [PATCH 42/86] debug --- algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py | 1 + algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py | 1 + 2 files changed, 2 insertions(+) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py index cb5dda800..e4ee57fb7 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Optional, Tuple from flax import jax_utils +import flax.linen as nn import jax import jax.numpy as jnp import jraph diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index c6e57b0f2..aa0e7ae5e 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -157,6 +157,7 @@ def init_model_fn( model = DDP(model, device_ids=[RANK], output_device=RANK) else: model = torch.nn.DataParallel(model) + print(model) return model, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: From 84e4606b9e2cf22c20b3a19233d306075de14e42 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 18:05:59 +0000 Subject: [PATCH 43/86] debug --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index aa0e7ae5e..9c852b38d 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -265,7 +265,7 @@ def hidden_dims(self) -> Tuple[int]: @property def latent_dim(self) -> int: - return 128 + return 256 @property def num_message_passing_steps(self) -> int: From 57f1d7cb75478fbb32e9bbc178ad3c473ec27f2c Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 18:16:58 +0000 Subject: [PATCH 44/86] debug --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index 9c852b38d..aa0e7ae5e 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -265,7 +265,7 @@ def hidden_dims(self) -> Tuple[int]: @property def latent_dim(self) -> int: - return 256 + return 128 @property def num_message_passing_steps(self) -> int: From 503e4f07c37b9c0981481547d14a62c22a05aa2a Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 3 Jan 2024 17:56:35 +0000 Subject: [PATCH 45/86] debug --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 978b62428..6c104d59e 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -163,7 +163,6 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple: # giving us tensors of shape [num_edges, global_feat]. global_edge_attributes = tree.tree_map( lambda g: torch.repeat_interleave(g, n_edge, dim=0), globals_) - if self.update_edge_fn: edge_fn_inputs = torch.cat( [edges, sent_attributes, received_attributes, global_edge_attributes], @@ -180,6 +179,8 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple: # giving us tensors of shape [num_nodes, global_feat]. global_attributes = tree.tree_map( lambda g: torch.repeat_interleave(g, n_node, dim=0), globals_) + print('SHAPES') + print(nodes.shape, sent_attributes.shape, received_attributes.shape, global_attributes.shape) node_fn_inputs = torch.cat( [nodes, sent_attributes, received_attributes, global_attributes], dim=-1) From d6716df3bff341dd0c6c23cd60b4a0264e837613 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 3 Jan 2024 17:59:38 +0000 Subject: [PATCH 46/86] debugging --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 6c104d59e..343fe9265 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -181,6 +181,8 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple: lambda g: torch.repeat_interleave(g, n_node, dim=0), globals_) print('SHAPES') print(nodes.shape, sent_attributes.shape, received_attributes.shape, global_attributes.shape) + print(senders.shape) + print(receivers.shape) node_fn_inputs = torch.cat( [nodes, sent_attributes, received_attributes, global_attributes], dim=-1) From 997e0e9984789f155427e64ffa7ac0edd32b99ac Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 3 Jan 2024 21:42:15 +0000 Subject: [PATCH 47/86] debugging --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 343fe9265..326ba3c06 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -183,6 +183,8 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple: print(nodes.shape, sent_attributes.shape, received_attributes.shape, global_attributes.shape) print(senders.shape) print(receivers.shape) + print(sum_n_node) + print(edges.shape) node_fn_inputs = torch.cat( [nodes, sent_attributes, received_attributes, global_attributes], dim=-1) From f9cabc5e205ba208adb7d675d75b39c15449895d Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 3 Jan 2024 23:14:37 +0000 Subject: [PATCH 48/86] fix --- .../workloads/ogbg/ogbg_pytorch/models.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 326ba3c06..7ec2f142a 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -59,20 +59,23 @@ def __init__(self, graph_network_layers = [] for st in range(self.num_message_passing_steps): - # Constants in in_dims are based on the requirements of the GraphNetwork. + # Constants in in_dims are based on forward call of GraphNetwork: + # specifically update_edge_fn update_node_fn and update_global_fn. if st == 0: - in_dim = self.latent_dim * 3 + self.num_outputs + in_dim_edge_fn = self.latent_dim * 3 + self.num_outputs + in_dim_node_fn = self.latent_dim + self.hidden_dims[-1] * 2 + self.num_outs last_in_dim = self.latent_dim * 2 + self.num_outputs else: - in_dim = self.hidden_dims[-1] * 4 + in_dim_edge_fn = self.hidden_dims[-1] * 4 + in_dim_node_fn = self.hidden_dims[-1] * 4 last_in_dim = self.hidden_dims[-1] * 3 graph_network_layers.append( GraphNetwork( update_edge_fn=_make_mlp( - in_dim, self.hidden_dims, dropout_rate, activation_fn), + in_dim_edge_fn, self.hidden_dims, dropout_rate, activation_fn), update_node_fn=_make_mlp( - in_dim, self.hidden_dims, dropout_rate, activation_fn), + in_dim_node_fn, self.hidden_dims, dropout_rate, activation_fn), update_global_fn=_make_mlp( last_in_dim, self.hidden_dims, dropout_rate, activation_fn))) self.graph_network = nn.Sequential(*graph_network_layers) From 176f0517dbe8e991f797541d69a3e27da38e4f9d Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 3 Jan 2024 23:15:55 +0000 Subject: [PATCH 49/86] fix --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 7ec2f142a..5d8aab46d 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -63,7 +63,7 @@ def __init__(self, # specifically update_edge_fn update_node_fn and update_global_fn. if st == 0: in_dim_edge_fn = self.latent_dim * 3 + self.num_outputs - in_dim_node_fn = self.latent_dim + self.hidden_dims[-1] * 2 + self.num_outs + in_dim_node_fn = self.latent_dim + self.hidden_dims[-1] * 2 + self.num_outputs last_in_dim = self.latent_dim * 2 + self.num_outputs else: in_dim_edge_fn = self.hidden_dims[-1] * 4 From 0e454ab2083a5c21081fe53756836f06f657bed4 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 3 Jan 2024 23:45:20 +0000 Subject: [PATCH 50/86] debugging --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 5d8aab46d..9a3b4190c 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -64,7 +64,7 @@ def __init__(self, if st == 0: in_dim_edge_fn = self.latent_dim * 3 + self.num_outputs in_dim_node_fn = self.latent_dim + self.hidden_dims[-1] * 2 + self.num_outputs - last_in_dim = self.latent_dim * 2 + self.num_outputs + last_in_dim = self.hidden_dims[-1] * 2 + self.num_outputs else: in_dim_edge_fn = self.hidden_dims[-1] * 4 in_dim_node_fn = self.hidden_dims[-1] * 4 From ea17ae6b46b739788c1facd72eec448f73710d4c Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 3 Jan 2024 23:48:24 +0000 Subject: [PATCH 51/86] fix --- .../workloads/ogbg/ogbg_pytorch/models.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 9a3b4190c..52cb8e053 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -182,12 +182,6 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple: # giving us tensors of shape [num_nodes, global_feat]. global_attributes = tree.tree_map( lambda g: torch.repeat_interleave(g, n_node, dim=0), globals_) - print('SHAPES') - print(nodes.shape, sent_attributes.shape, received_attributes.shape, global_attributes.shape) - print(senders.shape) - print(receivers.shape) - print(sum_n_node) - print(edges.shape) node_fn_inputs = torch.cat( [nodes, sent_attributes, received_attributes, global_attributes], dim=-1) From f6e1cb7989a68502d19a6cb192e907e5d069c901 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 00:07:51 +0000 Subject: [PATCH 52/86] fix --- tests/modeldiffs/ogbg/compare.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index f091d3d4f..1c552899b 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -20,6 +20,7 @@ def key_transform(k): new_key = [] bn = False ln = False + print("Key transform input ", k) for i in k: bn = bn or 'BatchNorm' in i ln = ln or 'LayerNorm' in i @@ -42,6 +43,7 @@ def key_transform(k): else: i = i.replace('weight', 'kernel') new_key.append(i) + print("New key output", new_key) return tuple(new_key) From 72573f4eea6a81062dd975b8eb83bb440979d7ce Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Wed, 3 Jan 2024 20:34:31 -0500 Subject: [PATCH 53/86] Fix names --- .../workloads/imagenet_vit/imagenet_jax/workload.py | 4 ++-- .../workloads/imagenet_vit/imagenet_pytorch/workload.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py index 1acd58bcd..4b12247c2 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py @@ -95,14 +95,14 @@ def use_glu(self) -> bool: return True -class ImagenetViTPostLNWorkload(ImagenetVitWorkload): +class ImagenetVitPostLNWorkload(ImagenetVitWorkload): @property def use_post_layer_norm(self) -> bool: return True -class ImagenetViTMapLNWorkload(ImagenetVitWorkload): +class ImagenetVitMapWorkload(ImagenetVitWorkload): @property def use_map(self) -> bool: diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py index 013bc643f..645b795ca 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -89,14 +89,14 @@ def use_glu(self) -> bool: return True -class ImagenetViTPostLNWorkload(ImagenetVitWorkload): +class ImagenetVitPostLNWorkload(ImagenetVitWorkload): @property def use_post_layer_norm(self) -> bool: return True -class ImagenetViTMapWorkload(ImagenetVitWorkload): +class ImagenetVitMapWorkload(ImagenetVitWorkload): @property def use_map(self) -> bool: From c02493b6c9eab977c445336f64afa72337f512a7 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 02:44:01 +0000 Subject: [PATCH 54/86] test fix --- tests/modeldiffs/ogbg/compare.py | 36 +++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 1c552899b..11badf91c 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -12,31 +12,43 @@ from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ OgbgWorkload as JaxWorkload from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ - OgbgWorkload as PytWorkload + OgbgWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff +MLP_HIDDEN_DIMS = len(PyTorchWorkload.hidden_dims) def key_transform(k): new_key = [] bn = False ln = False + graph_network = False + "Sequential_0', 'GraphNetwork_0', 'Sequential_0', 'Linear_0', 'weight'" print("Key transform input ", k) + graph_network_index = 0 + seq_index = 0 for i in k: bn = bn or 'BatchNorm' in i ln = ln or 'LayerNorm' in i - if 'ModuleList' in i: + graph_network = graph_network or 'GraphNetwork' in i + if 'Sequential' in i: + seq_index = i.split('_')[1] continue - if 'CustomBatchNorm' in i: + elif 'GraphNetwork' in i: + graph_network_index = i.split('_')[1] continue - if 'Linear' in i: - if 'NonDynamicallyQuantizableLinear' in i: - i = 'out' - else: - i = i.replace('Linear', 'Dense') - elif 'Conv1d' in i: - i = i.replace('Conv1d', 'Conv') - elif 'MHSAwithQS' in i: - i = i.replace('MHSAwithQS', 'SelfAttention') + elif 'Linear' in i: + layer_index = i.split('_')[1] + if graph_network: + count = graph_index * 3 * MLP_HIDDEN_DIMS + seq_index * MLP_HIDDEN_DIMS + layer_index + i = 'Dense_' + str(count) + elif layer_index == 0: + i = 'node_embedding' + elif layer_index == 1: + i = 'edge_embedding' + elif 'LayerNorm' in i: + layer_index = i.split('_')[1] + count = graph_index * 3 * MLP_HIDDEN_DIMS + seq_index * MLP_HIDDEN_DIMS + layer_index + i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: i = i.replace('weight', 'scale') From 8cf4714adcf1f7b51ebf1ba768ce3297482c160e Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 02:50:19 +0000 Subject: [PATCH 55/86] fix --- tests/modeldiffs/ogbg/compare.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 11badf91c..37d7094d2 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -15,9 +15,10 @@ OgbgWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff -MLP_HIDDEN_DIMS = len(PyTorchWorkload.hidden_dims) -def key_transform(k): +hidden_dims = JaxWorkload().hidden_dims + +def key_transform(k, hidden_dims): new_key = [] bn = False ln = False @@ -39,7 +40,7 @@ def key_transform(k): elif 'Linear' in i: layer_index = i.split('_')[1] if graph_network: - count = graph_index * 3 * MLP_HIDDEN_DIMS + seq_index * MLP_HIDDEN_DIMS + layer_index + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -47,7 +48,7 @@ def key_transform(k): i = 'edge_embedding' elif 'LayerNorm' in i: layer_index = i.split('_')[1] - count = graph_index * 3 * MLP_HIDDEN_DIMS + seq_index * MLP_HIDDEN_DIMS + layer_index + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: From d984e19f0c94a20a1e81117113b18863e14855c0 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 02:52:01 +0000 Subject: [PATCH 56/86] fix --- tests/modeldiffs/ogbg/compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 37d7094d2..5a07b8d91 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -85,7 +85,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() pyt_batch = dict( n_node=torch.LongTensor([5]), From d7543f084ffc21e8a509c8e9e57da2570cf0d116 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 02:53:14 +0000 Subject: [PATCH 57/86] fix --- tests/modeldiffs/ogbg/compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 5a07b8d91..f3356028e 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -18,7 +18,7 @@ hidden_dims = JaxWorkload().hidden_dims -def key_transform(k, hidden_dims): +def key_transform(k): new_key = [] bn = False ln = False From d0407027c654bf4ec3daa9360ec3ac8cdfa5f304 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 02:55:29 +0000 Subject: [PATCH 58/86] fix --- tests/modeldiffs/ogbg/compare.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index f3356028e..36500c88b 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -25,7 +25,7 @@ def key_transform(k): graph_network = False "Sequential_0', 'GraphNetwork_0', 'Sequential_0', 'Linear_0', 'weight'" print("Key transform input ", k) - graph_network_index = 0 + graph_index = 0 seq_index = 0 for i in k: bn = bn or 'BatchNorm' in i @@ -35,7 +35,7 @@ def key_transform(k): seq_index = i.split('_')[1] continue elif 'GraphNetwork' in i: - graph_network_index = i.split('_')[1] + graph_index = i.split('_')[1] continue elif 'Linear' in i: layer_index = i.split('_')[1] From cd15acf23fe544c047a4cd4c73a0cdf203a363d4 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 02:56:32 +0000 Subject: [PATCH 59/86] fix --- tests/modeldiffs/ogbg/compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 36500c88b..b760196d1 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -16,7 +16,7 @@ from tests.modeldiffs.diff import out_diff -hidden_dims = JaxWorkload().hidden_dims +hidden_dims = len(JaxWorkload().hidden_dims) def key_transform(k): new_key = [] From e4f5e0849f6cea1e917a84780046b72e0a0dbe8c Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 02:58:21 +0000 Subject: [PATCH 60/86] fix --- tests/modeldiffs/ogbg/compare.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index b760196d1..3fa4132bf 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -32,13 +32,13 @@ def key_transform(k): ln = ln or 'LayerNorm' in i graph_network = graph_network or 'GraphNetwork' in i if 'Sequential' in i: - seq_index = i.split('_')[1] + seq_index = int(i.split('_')[1]) continue elif 'GraphNetwork' in i: - graph_index = i.split('_')[1] + graph_index = int(i.split('_')[1]) continue elif 'Linear' in i: - layer_index = i.split('_')[1] + layer_index = int(i.split('_')[1]) if graph_network: count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index i = 'Dense_' + str(count) From 9376504f58ac7afcd8e482580e7231dfe776e3ff Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 02:59:30 +0000 Subject: [PATCH 61/86] fix --- tests/modeldiffs/ogbg/compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 3fa4132bf..e7adf25de 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -47,7 +47,7 @@ def key_transform(k): elif layer_index == 1: i = 'edge_embedding' elif 'LayerNorm' in i: - layer_index = i.split('_')[1] + layer_index = int(i.split('_')[1]) count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index i = 'LayerNorm_' + str(count) elif 'weight' in i: From e2d88101711c42b785e21acba155c682689bfd99 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 03:24:10 +0000 Subject: [PATCH 62/86] debugging --- tests/modeldiffs/torch2jax_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/modeldiffs/torch2jax_utils.py b/tests/modeldiffs/torch2jax_utils.py index d9264b400..c1d4ad48a 100644 --- a/tests/modeldiffs/torch2jax_utils.py +++ b/tests/modeldiffs/torch2jax_utils.py @@ -77,6 +77,10 @@ def key_transform(self, k_transform_fn): } def value_transform(self, v_transform_fn): + print('pytorch sd') + print(pytorch_sd.keys()) + print('jax sd') + print(jax_sd.key()) self.pytorch_sd = { k: v_transform_fn(k, self.pytorch_sd[k], self.flattened_jax_model[k]) for k in self.pytorch_sd From 6de5f2f1205651c2b64345e133a30de50e374a3e Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 03:26:37 +0000 Subject: [PATCH 63/86] fix --- tests/modeldiffs/torch2jax_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/modeldiffs/torch2jax_utils.py b/tests/modeldiffs/torch2jax_utils.py index c1d4ad48a..333cda758 100644 --- a/tests/modeldiffs/torch2jax_utils.py +++ b/tests/modeldiffs/torch2jax_utils.py @@ -78,9 +78,9 @@ def key_transform(self, k_transform_fn): def value_transform(self, v_transform_fn): print('pytorch sd') - print(pytorch_sd.keys()) + print(self.pytorch_sd.keys()) print('jax sd') - print(jax_sd.key()) + print(self.jax_sd.key()) self.pytorch_sd = { k: v_transform_fn(k, self.pytorch_sd[k], self.flattened_jax_model[k]) for k in self.pytorch_sd From 0bb7a6decba1f3d2df0dc3379dbdf64fda49e147 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 03:28:17 +0000 Subject: [PATCH 64/86] fix --- tests/modeldiffs/torch2jax_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/modeldiffs/torch2jax_utils.py b/tests/modeldiffs/torch2jax_utils.py index 333cda758..a1d6503dc 100644 --- a/tests/modeldiffs/torch2jax_utils.py +++ b/tests/modeldiffs/torch2jax_utils.py @@ -80,7 +80,7 @@ def value_transform(self, v_transform_fn): print('pytorch sd') print(self.pytorch_sd.keys()) print('jax sd') - print(self.jax_sd.key()) + print(self.flattened_jax_model.key()) self.pytorch_sd = { k: v_transform_fn(k, self.pytorch_sd[k], self.flattened_jax_model[k]) for k in self.pytorch_sd From 4821cdfa8afc4fd7d991724a06d5460eafd239f3 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 03:31:34 +0000 Subject: [PATCH 65/86] fix --- tests/modeldiffs/torch2jax_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/modeldiffs/torch2jax_utils.py b/tests/modeldiffs/torch2jax_utils.py index a1d6503dc..560a071d6 100644 --- a/tests/modeldiffs/torch2jax_utils.py +++ b/tests/modeldiffs/torch2jax_utils.py @@ -80,7 +80,7 @@ def value_transform(self, v_transform_fn): print('pytorch sd') print(self.pytorch_sd.keys()) print('jax sd') - print(self.flattened_jax_model.key()) + print(self.flattened_jax_model.keys()) self.pytorch_sd = { k: v_transform_fn(k, self.pytorch_sd[k], self.flattened_jax_model[k]) for k in self.pytorch_sd From 5f19a5afd0649513dd347d5140b422b4daa943ce Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 03:49:39 +0000 Subject: [PATCH 66/86] fix --- tests/modeldiffs/ogbg/compare.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index e7adf25de..c9d58f658 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -17,6 +17,7 @@ hidden_dims = len(JaxWorkload().hidden_dims) +num_graphs= JaxWorkload().num_message_passing_steps def key_transform(k): new_key = [] @@ -46,6 +47,9 @@ def key_transform(k): i = 'node_embedding' elif layer_index == 1: i = 'edge_embedding' + elif layer_index == 2: + count = num_graphs * 3 * hidden_dims + i = elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index @@ -64,20 +68,6 @@ def sd_transform(sd): # pylint: disable=locally-disabled, modified-iterating-dict, consider-using-dict-items keys = list(sd.keys()) out = {} - for k in keys: - new_key = k - if len(k) == 5: - _, gn_id, seq_id = k[:3] - gn_id = int(gn_id.split('_')[1]) - seq_id = int(seq_id.split('_')[1]) - if 'LayerNorm' in k[3]: - new_key = (k[3].replace('0', f'{gn_id*3+seq_id}'), k[4]) - else: - new_key = (k[3].replace('0', f'{gn_id*3+seq_id+2}'), k[4]) - elif len(k) == 2 and k[0] == 'Dense_2': - new_key = ('Dense_17', k[1]) - out[new_key] = sd[k] - return out From 36b65a34cb6338bc09344db724ca2174d8d22ef1 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 03:51:30 +0000 Subject: [PATCH 67/86] fix --- tests/modeldiffs/ogbg/compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index c9d58f658..95f3e7df4 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -49,7 +49,7 @@ def key_transform(k): i = 'edge_embedding' elif layer_index == 2: count = num_graphs * 3 * hidden_dims - i = + i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index From 50d479e0902ad1be2208a5415316a289d749dd41 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 03:55:45 +0000 Subject: [PATCH 68/86] fix --- tests/modeldiffs/ogbg/compare.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 95f3e7df4..7537362ff 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -66,8 +66,9 @@ def key_transform(k): def sd_transform(sd): # pylint: disable=locally-disabled, modified-iterating-dict, consider-using-dict-items - keys = list(sd.keys()) out = {} + for k in sd: + out[k] = sd[k] return out From 5403f5bb76c97f413f7acd05dbad533f4149b1d9 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 04:13:47 +0000 Subject: [PATCH 69/86] fix tests --- tests/modeldiffs/ogbg_gelu/compare.py | 62 +++++++++++---------- tests/modeldiffs/ogbg_model_size/compare.py | 61 +++++++++++--------- tests/modeldiffs/ogbg_silu/compare.py | 62 +++++++++++---------- 3 files changed, 102 insertions(+), 83 deletions(-) diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py index f6175e99d..f58c58bde 100644 --- a/tests/modeldiffs/ogbg_gelu/compare.py +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -16,53 +16,59 @@ from tests.modeldiffs.diff import out_diff +hidden_dims = len(JaxWorkload().hidden_dims) +num_graphs= JaxWorkload().num_message_passing_steps + def key_transform(k): new_key = [] bn = False ln = False + graph_network = False + "Sequential_0', 'GraphNetwork_0', 'Sequential_0', 'Linear_0', 'weight'" + print("Key transform input ", k) + graph_index = 0 + seq_index = 0 for i in k: bn = bn or 'BatchNorm' in i ln = ln or 'LayerNorm' in i - if 'ModuleList' in i: + graph_network = graph_network or 'GraphNetwork' in i + if 'Sequential' in i: + seq_index = int(i.split('_')[1]) continue - if 'CustomBatchNorm' in i: + elif 'GraphNetwork' in i: + graph_index = int(i.split('_')[1]) continue - if 'Linear' in i: - if 'NonDynamicallyQuantizableLinear' in i: - i = 'out' - else: - i = i.replace('Linear', 'Dense') - elif 'Conv1d' in i: - i = i.replace('Conv1d', 'Conv') - elif 'MHSAwithQS' in i: - i = i.replace('MHSAwithQS', 'SelfAttention') + elif 'Linear' in i: + layer_index = int(i.split('_')[1]) + if graph_network: + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + i = 'Dense_' + str(count) + elif layer_index == 0: + i = 'node_embedding' + elif layer_index == 1: + i = 'edge_embedding' + elif layer_index == 2: + count = num_graphs * 3 * hidden_dims + i = 'Dense_' + str(count) + elif 'LayerNorm' in i: + layer_index = int(i.split('_')[1]) + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: i = i.replace('weight', 'scale') else: i = i.replace('weight', 'kernel') new_key.append(i) + print("New key output", new_key) return tuple(new_key) def sd_transform(sd): # pylint: disable=locally-disabled, modified-iterating-dict, consider-using-dict-items - keys = list(sd.keys()) out = {} - for k in keys: - new_key = k - if len(k) == 5: - _, gn_id, seq_id = k[:3] - gn_id = int(gn_id.split('_')[1]) - seq_id = int(seq_id.split('_')[1]) - if 'LayerNorm' in k[3]: - new_key = (k[3].replace('0', f'{gn_id*3+seq_id}'), k[4]) - else: - new_key = (k[3].replace('0', f'{gn_id*3+seq_id+2}'), k[4]) - elif len(k) == 2 and k[0] == 'Dense_2': - new_key = ('Dense_17', k[1]) - out[new_key] = sd[k] - + for k in sd: + out[k] = sd[k] return out @@ -70,7 +76,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() pyt_batch = dict( n_node=torch.LongTensor([5]), @@ -110,4 +116,4 @@ def sd_transform(sd): pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, sd_transform=sd_transform, - out_transform=None) + out_transform=None) \ No newline at end of file diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py index 3818598ed..4df4d67aa 100644 --- a/tests/modeldiffs/ogbg_model_size/compare.py +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -16,53 +16,59 @@ from tests.modeldiffs.diff import out_diff +hidden_dims = len(JaxWorkload().hidden_dims) +num_graphs= JaxWorkload().num_message_passing_steps + def key_transform(k): new_key = [] bn = False ln = False + graph_network = False + "Sequential_0', 'GraphNetwork_0', 'Sequential_0', 'Linear_0', 'weight'" + print("Key transform input ", k) + graph_index = 0 + seq_index = 0 for i in k: bn = bn or 'BatchNorm' in i ln = ln or 'LayerNorm' in i - if 'ModuleList' in i: + graph_network = graph_network or 'GraphNetwork' in i + if 'Sequential' in i: + seq_index = int(i.split('_')[1]) continue - if 'CustomBatchNorm' in i: + elif 'GraphNetwork' in i: + graph_index = int(i.split('_')[1]) continue - if 'Linear' in i: - if 'NonDynamicallyQuantizableLinear' in i: - i = 'out' - else: - i = i.replace('Linear', 'Dense') - elif 'Conv1d' in i: - i = i.replace('Conv1d', 'Conv') - elif 'MHSAwithQS' in i: - i = i.replace('MHSAwithQS', 'SelfAttention') + elif 'Linear' in i: + layer_index = int(i.split('_')[1]) + if graph_network: + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + i = 'Dense_' + str(count) + elif layer_index == 0: + i = 'node_embedding' + elif layer_index == 1: + i = 'edge_embedding' + elif layer_index == 2: + count = num_graphs * 3 * hidden_dims + i = 'Dense_' + str(count) + elif 'LayerNorm' in i: + layer_index = int(i.split('_')[1]) + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: i = i.replace('weight', 'scale') else: i = i.replace('weight', 'kernel') new_key.append(i) + print("New key output", new_key) return tuple(new_key) def sd_transform(sd): # pylint: disable=locally-disabled, modified-iterating-dict, consider-using-dict-items - keys = list(sd.keys()) out = {} - for k in keys: - new_key = k - if len(k) == 5: - _, gn_id, seq_id = k[:3] - gn_id = int(gn_id.split('_')[1]) - seq_id = int(seq_id.split('_')[1]) - if 'LayerNorm' in k[3]: - new_key = (k[3].replace('0', f'{gn_id*3+seq_id}'), k[4]) - else: - new_key = (k[3].replace('0', f'{gn_id*3+seq_id+2}'), k[4]) - elif len(k) == 2 and k[0] == 'Dense_2': - new_key = ('Dense_17', k[1]) - out[new_key] = sd[k] - + for k in sd: + out[k] = sd[k] return out @@ -70,7 +76,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() pyt_batch = dict( n_node=torch.LongTensor([5]), @@ -111,3 +117,4 @@ def sd_transform(sd): key_transform=key_transform, sd_transform=sd_transform, out_transform=None) + \ No newline at end of file diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py index 420ee9020..5fa9eabc9 100644 --- a/tests/modeldiffs/ogbg_silu/compare.py +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -16,53 +16,59 @@ from tests.modeldiffs.diff import out_diff +hidden_dims = len(JaxWorkload().hidden_dims) +num_graphs= JaxWorkload().num_message_passing_steps + def key_transform(k): new_key = [] bn = False ln = False + graph_network = False + "Sequential_0', 'GraphNetwork_0', 'Sequential_0', 'Linear_0', 'weight'" + print("Key transform input ", k) + graph_index = 0 + seq_index = 0 for i in k: bn = bn or 'BatchNorm' in i ln = ln or 'LayerNorm' in i - if 'ModuleList' in i: + graph_network = graph_network or 'GraphNetwork' in i + if 'Sequential' in i: + seq_index = int(i.split('_')[1]) continue - if 'CustomBatchNorm' in i: + elif 'GraphNetwork' in i: + graph_index = int(i.split('_')[1]) continue - if 'Linear' in i: - if 'NonDynamicallyQuantizableLinear' in i: - i = 'out' - else: - i = i.replace('Linear', 'Dense') - elif 'Conv1d' in i: - i = i.replace('Conv1d', 'Conv') - elif 'MHSAwithQS' in i: - i = i.replace('MHSAwithQS', 'SelfAttention') + elif 'Linear' in i: + layer_index = int(i.split('_')[1]) + if graph_network: + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + i = 'Dense_' + str(count) + elif layer_index == 0: + i = 'node_embedding' + elif layer_index == 1: + i = 'edge_embedding' + elif layer_index == 2: + count = num_graphs * 3 * hidden_dims + i = 'Dense_' + str(count) + elif 'LayerNorm' in i: + layer_index = int(i.split('_')[1]) + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: i = i.replace('weight', 'scale') else: i = i.replace('weight', 'kernel') new_key.append(i) + print("New key output", new_key) return tuple(new_key) def sd_transform(sd): # pylint: disable=locally-disabled, modified-iterating-dict, consider-using-dict-items - keys = list(sd.keys()) out = {} - for k in keys: - new_key = k - if len(k) == 5: - _, gn_id, seq_id = k[:3] - gn_id = int(gn_id.split('_')[1]) - seq_id = int(seq_id.split('_')[1]) - if 'LayerNorm' in k[3]: - new_key = (k[3].replace('0', f'{gn_id*3+seq_id}'), k[4]) - else: - new_key = (k[3].replace('0', f'{gn_id*3+seq_id+2}'), k[4]) - elif len(k) == 2 and k[0] == 'Dense_2': - new_key = ('Dense_17', k[1]) - out[new_key] = sd[k] - + for k in sd: + out[k] = sd[k] return out @@ -70,7 +76,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() pyt_batch = dict( n_node=torch.LongTensor([5]), @@ -110,4 +116,4 @@ def sd_transform(sd): pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, sd_transform=sd_transform, - out_transform=None) + out_transform=None) \ No newline at end of file From 20b8cfce440f1fc3dc743b7ed7b43114f9a7907b Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 04:14:51 +0000 Subject: [PATCH 70/86] fix --- tests/modeldiffs/ogbg_gelu/compare.py | 2 +- tests/modeldiffs/ogbg_model_size/compare.py | 3 +-- tests/modeldiffs/ogbg_silu/compare.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py index f58c58bde..027e772d5 100644 --- a/tests/modeldiffs/ogbg_gelu/compare.py +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -12,7 +12,7 @@ from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ OgbgGeluWorkload as JaxWorkload from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ - OgbgGeluWorkload as PytWorkload + OgbgGeluWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py index 4df4d67aa..4734a0d0d 100644 --- a/tests/modeldiffs/ogbg_model_size/compare.py +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -12,7 +12,7 @@ from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ OgbgModelSizeWorkload as JaxWorkload from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ - OgbgModelSizeWorkload as PytWorkload + OgbgModelSizeWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -117,4 +117,3 @@ def sd_transform(sd): key_transform=key_transform, sd_transform=sd_transform, out_transform=None) - \ No newline at end of file diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py index 5fa9eabc9..52eee4aa8 100644 --- a/tests/modeldiffs/ogbg_silu/compare.py +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -12,7 +12,7 @@ from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ OgbgSiluWorkload as JaxWorkload from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ - OgbgSiluWorkload as PytWorkload + OgbgSiluWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff From 528d44f992e970d8285a3d75a9d910b3b186fc51 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 04:28:07 +0000 Subject: [PATCH 71/86] fix --- tests/modeldiffs/ogbg/compare.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 7537362ff..03f7451dc 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -41,7 +41,7 @@ def key_transform(k): elif 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = graph_index * 3 * hidden_dims + seq_index + 1 i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -52,7 +52,7 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = graph_index * 3 * hidden_dims + seq_index + 1 i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: From 3046a6c12dc09f205a3c58971e6a7d15a131b7d8 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 04:32:33 +0000 Subject: [PATCH 72/86] fix --- tests/modeldiffs/ogbg/compare.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 03f7451dc..d22499636 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -41,7 +41,7 @@ def key_transform(k): elif 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index + 1 + count = graph_index * 3 * hidden_dims + seq_index i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -52,7 +52,7 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index + 1 + count = graph_index * 3 * hidden_dims + seq_index i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: From 44cb147b3ca11e6591ee2d16d7d296722d72b66a Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 04:34:33 +0000 Subject: [PATCH 73/86] fix --- tests/modeldiffs/ogbg_gelu/compare.py | 4 ++-- tests/modeldiffs/ogbg_model_size/compare.py | 4 ++-- tests/modeldiffs/ogbg_silu/compare.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py index 027e772d5..b3158d6c4 100644 --- a/tests/modeldiffs/ogbg_gelu/compare.py +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -41,7 +41,7 @@ def key_transform(k): elif 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = graph_index * 3 * hidden_dims + seq_index i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -52,7 +52,7 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = graph_index * 3 * hidden_dims + seq_index i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py index 4734a0d0d..b7799411c 100644 --- a/tests/modeldiffs/ogbg_model_size/compare.py +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -41,7 +41,7 @@ def key_transform(k): elif 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = graph_index * 3 * hidden_dims + seq_index i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -52,7 +52,7 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = graph_index * 3 * hidden_dims + seq_index i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py index 52eee4aa8..675eb4215 100644 --- a/tests/modeldiffs/ogbg_silu/compare.py +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -41,7 +41,7 @@ def key_transform(k): elif 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = graph_index * 3 * hidden_dims + seq_index i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -52,7 +52,7 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = graph_index * 3 * hidden_dims + seq_index i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: From 303bf1ae6b7e096876b552ea5a200e3ed8f42f40 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 04:46:19 +0000 Subject: [PATCH 74/86] fix --- .../workloads/ogbg/ogbg_pytorch/models.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 52cb8e053..e6015196a 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -14,11 +14,11 @@ def _make_mlp(in_dim, hidden_dims, dropout_rate, activation_fn): """Creates a MLP with specified dimensions.""" layers = nn.Sequential() - for dim in hidden_dims: - layers.add_module('dense', nn.Linear(in_features=in_dim, out_features=dim)) - layers.add_module('norm', nn.LayerNorm(dim, eps=1e-6)) - layers.add_module('activation_fn', activation_fn()) - layers.add_module('dropout', nn.Dropout(dropout_rate)) + for i, dim in enumerate(hidden_dims): + layers.add_module(f'dense_{i}', nn.Linear(in_features=in_dim, out_features=dim)) + layers.add_module(f'norm_{i}', nn.LayerNorm(dim, eps=1e-6)) + layers.add_module(f'activation_fn_{i}', activation_fn()) + layers.add_module(f'dropout_{i}', nn.Dropout(dropout_rate)) return layers From 8fa2b44b1c78a5880daac52f0460cb2a9356dfa6 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 04:49:17 +0000 Subject: [PATCH 75/86] fix --- tests/modeldiffs/ogbg/compare.py | 4 ++-- tests/modeldiffs/ogbg_gelu/compare.py | 4 ++-- tests/modeldiffs/ogbg_model_size/compare.py | 4 ++-- tests/modeldiffs/ogbg_silu/compare.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index d22499636..7537362ff 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -41,7 +41,7 @@ def key_transform(k): elif 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -52,7 +52,7 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py index b3158d6c4..027e772d5 100644 --- a/tests/modeldiffs/ogbg_gelu/compare.py +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -41,7 +41,7 @@ def key_transform(k): elif 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -52,7 +52,7 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py index b7799411c..4734a0d0d 100644 --- a/tests/modeldiffs/ogbg_model_size/compare.py +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -41,7 +41,7 @@ def key_transform(k): elif 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -52,7 +52,7 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py index 675eb4215..52eee4aa8 100644 --- a/tests/modeldiffs/ogbg_silu/compare.py +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -41,7 +41,7 @@ def key_transform(k): elif 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -52,7 +52,7 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: From c3650de0610b5294ae901ca7bf8dffd7289a3539 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 04:54:17 +0000 Subject: [PATCH 76/86] fix --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index e6015196a..31a025732 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -19,6 +19,7 @@ def _make_mlp(in_dim, hidden_dims, dropout_rate, activation_fn): layers.add_module(f'norm_{i}', nn.LayerNorm(dim, eps=1e-6)) layers.add_module(f'activation_fn_{i}', activation_fn()) layers.add_module(f'dropout_{i}', nn.Dropout(dropout_rate)) + in_dim = dim return layers From 759fc179d35f0166f3ba5caf2e26ee714b46e53b Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 16:41:00 +0000 Subject: [PATCH 77/86] ogbg variant fix --- .../workloads/ogbg/ogbg_pytorch/workload.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index aa0e7ae5e..cd4b3e0a0 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -148,7 +148,8 @@ def init_model_fn( dropout_rate=dropout_rate, hidden_dims=self.hidden_dims, latent_dim=self.latent_dim, - num_message_passing_steps=self.num_message_passing_steps) + num_message_passing_steps=self.num_message_passing_steps, + activation_fn_name=self.activation_fn_name) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -257,6 +258,7 @@ def activation_fn_name(self) -> str: """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" return 'silu' + class OgbgModelSizeWorkload(OgbgWorkload): @property From 2b758c4a69485f24647cfcdb2e55aace6cb26918 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 16:45:50 +0000 Subject: [PATCH 78/86] fix ogbg activation fn pytorch models --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 31a025732..458ceff48 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -51,9 +51,9 @@ def __init__(self, if activation_fn_name == 'relu': activation_fn = nn.ReLU elif activation_fn_name == 'gelu': - activation_fn = partial(nn.GeLU, approximate='tanh') + activation_fn = partial(nn.GELU, approximate='tanh') elif activation_fn_name == 'silu': - activation_fn = nn.Silu + activation_fn = nn.SiLU else: raise ValueError( f'Invalid activation function name: {self.activation_fn_name}') From 16d6ae46121c5f241f1be08ca77ba7775a1b746c Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 16:56:31 +0000 Subject: [PATCH 79/86] clean up debugging statements --- .../workloads/ogbg/ogbg_jax/workload.py | 8 -------- .../workloads/ogbg/ogbg_pytorch/workload.py | 1 - tests/modeldiffs/ogbg/compare.py | 5 ++--- tests/modeldiffs/ogbg_gelu/compare.py | 5 ++--- tests/modeldiffs/ogbg_model_size/compare.py | 5 ++--- tests/modeldiffs/ogbg_silu/compare.py | 8 ++++---- tests/modeldiffs/torch2jax_utils.py | 4 ---- 7 files changed, 10 insertions(+), 26 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py index e4ee57fb7..0ff7f158a 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py @@ -44,14 +44,6 @@ def init_model_fn( receivers=jnp.asarray([0])) params = init_fn({'params': params_rng, 'dropout': dropout_rng}, fake_batch) params = params['params'] - tabulate_fn = nn.tabulate( - self._model, - jax.random.PRNGKey(0), - console_kwargs={ - 'force_terminal': False, 'force_jupyter': False, 'width': 240 - }, - ) - print(tabulate_fn(fake_batch, train=False)) self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) return jax_utils.replicate(params), None diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index cd4b3e0a0..513d6a269 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -158,7 +158,6 @@ def init_model_fn( model = DDP(model, device_ids=[RANK], output_device=RANK) else: model = torch.nn.DataParallel(model) - print(model) return model, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 7537362ff..18980d9f4 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -16,16 +16,16 @@ from tests.modeldiffs.diff import out_diff +# Todo: refactor tests to use workload properties in cleaner way hidden_dims = len(JaxWorkload().hidden_dims) num_graphs= JaxWorkload().num_message_passing_steps + def key_transform(k): new_key = [] bn = False ln = False graph_network = False - "Sequential_0', 'GraphNetwork_0', 'Sequential_0', 'Linear_0', 'weight'" - print("Key transform input ", k) graph_index = 0 seq_index = 0 for i in k: @@ -60,7 +60,6 @@ def key_transform(k): else: i = i.replace('weight', 'kernel') new_key.append(i) - print("New key output", new_key) return tuple(new_key) diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py index 027e772d5..0c7b1e0d4 100644 --- a/tests/modeldiffs/ogbg_gelu/compare.py +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -16,16 +16,16 @@ from tests.modeldiffs.diff import out_diff +# Todo: refactor tests to use workload properties in cleaner way hidden_dims = len(JaxWorkload().hidden_dims) num_graphs= JaxWorkload().num_message_passing_steps + def key_transform(k): new_key = [] bn = False ln = False graph_network = False - "Sequential_0', 'GraphNetwork_0', 'Sequential_0', 'Linear_0', 'weight'" - print("Key transform input ", k) graph_index = 0 seq_index = 0 for i in k: @@ -60,7 +60,6 @@ def key_transform(k): else: i = i.replace('weight', 'kernel') new_key.append(i) - print("New key output", new_key) return tuple(new_key) diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py index 4734a0d0d..022e05b94 100644 --- a/tests/modeldiffs/ogbg_model_size/compare.py +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -16,16 +16,16 @@ from tests.modeldiffs.diff import out_diff +# Todo: refactor tests to use workload properties in cleaner way hidden_dims = len(JaxWorkload().hidden_dims) num_graphs= JaxWorkload().num_message_passing_steps + def key_transform(k): new_key = [] bn = False ln = False graph_network = False - "Sequential_0', 'GraphNetwork_0', 'Sequential_0', 'Linear_0', 'weight'" - print("Key transform input ", k) graph_index = 0 seq_index = 0 for i in k: @@ -60,7 +60,6 @@ def key_transform(k): else: i = i.replace('weight', 'kernel') new_key.append(i) - print("New key output", new_key) return tuple(new_key) diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py index 52eee4aa8..feb141057 100644 --- a/tests/modeldiffs/ogbg_silu/compare.py +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -16,16 +16,16 @@ from tests.modeldiffs.diff import out_diff +# Todo: refactor tests to use workload properties in cleaner way hidden_dims = len(JaxWorkload().hidden_dims) num_graphs= JaxWorkload().num_message_passing_steps + def key_transform(k): new_key = [] bn = False ln = False graph_network = False - "Sequential_0', 'GraphNetwork_0', 'Sequential_0', 'Linear_0', 'weight'" - print("Key transform input ", k) graph_index = 0 seq_index = 0 for i in k: @@ -60,7 +60,6 @@ def key_transform(k): else: i = i.replace('weight', 'kernel') new_key.append(i) - print("New key output", new_key) return tuple(new_key) @@ -116,4 +115,5 @@ def sd_transform(sd): pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, sd_transform=sd_transform, - out_transform=None) \ No newline at end of file + out_transform=None) + \ No newline at end of file diff --git a/tests/modeldiffs/torch2jax_utils.py b/tests/modeldiffs/torch2jax_utils.py index 560a071d6..d9264b400 100644 --- a/tests/modeldiffs/torch2jax_utils.py +++ b/tests/modeldiffs/torch2jax_utils.py @@ -77,10 +77,6 @@ def key_transform(self, k_transform_fn): } def value_transform(self, v_transform_fn): - print('pytorch sd') - print(self.pytorch_sd.keys()) - print('jax sd') - print(self.flattened_jax_model.keys()) self.pytorch_sd = { k: v_transform_fn(k, self.pytorch_sd[k], self.flattened_jax_model[k]) for k in self.pytorch_sd From 15fd5b113a9676410943fb5f20143577c4d84564 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 16:57:52 +0000 Subject: [PATCH 80/86] formatting --- .../workloads/ogbg/ogbg_jax/workload.py | 2 +- .../workloads/ogbg/ogbg_pytorch/models.py | 24 ++++++++++++------- .../workloads/ogbg/ogbg_pytorch/workload.py | 15 ++++++------ .../workloads/ogbg/workload.py | 2 +- algorithmic_efficiency/workloads/workloads.py | 3 ++- tests/modeldiffs/ogbg/compare.py | 5 ++-- tests/modeldiffs/ogbg_gelu/compare.py | 7 +++--- tests/modeldiffs/ogbg_model_size/compare.py | 5 ++-- tests/modeldiffs/ogbg_silu/compare.py | 6 ++--- 9 files changed, 37 insertions(+), 32 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py index 0ff7f158a..7201a2d90 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py @@ -139,6 +139,7 @@ def activation_fn_name(self) -> str: """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" return 'silu' + class OgbgModelSizeWorkload(OgbgWorkload): @property @@ -152,4 +153,3 @@ def latent_dim(self) -> int: @property def num_message_passing_steps(self) -> int: return 3 - \ No newline at end of file diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 458ceff48..d93013b87 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -15,7 +15,8 @@ def _make_mlp(in_dim, hidden_dims, dropout_rate, activation_fn): """Creates a MLP with specified dimensions.""" layers = nn.Sequential() for i, dim in enumerate(hidden_dims): - layers.add_module(f'dense_{i}', nn.Linear(in_features=in_dim, out_features=dim)) + layers.add_module(f'dense_{i}', + nn.Linear(in_features=in_dim, out_features=dim)) layers.add_module(f'norm_{i}', nn.LayerNorm(dim, eps=1e-6)) layers.add_module(f'activation_fn_{i}', activation_fn()) layers.add_module(f'dropout_{i}', nn.Dropout(dropout_rate)) @@ -64,7 +65,8 @@ def __init__(self, # specifically update_edge_fn update_node_fn and update_global_fn. if st == 0: in_dim_edge_fn = self.latent_dim * 3 + self.num_outputs - in_dim_node_fn = self.latent_dim + self.hidden_dims[-1] * 2 + self.num_outputs + in_dim_node_fn = self.latent_dim + self.hidden_dims[ + -1] * 2 + self.num_outputs last_in_dim = self.hidden_dims[-1] * 2 + self.num_outputs else: in_dim_edge_fn = self.hidden_dims[-1] * 4 @@ -73,12 +75,18 @@ def __init__(self, graph_network_layers.append( GraphNetwork( - update_edge_fn=_make_mlp( - in_dim_edge_fn, self.hidden_dims, dropout_rate, activation_fn), - update_node_fn=_make_mlp( - in_dim_node_fn, self.hidden_dims, dropout_rate, activation_fn), - update_global_fn=_make_mlp( - last_in_dim, self.hidden_dims, dropout_rate, activation_fn))) + update_edge_fn=_make_mlp(in_dim_edge_fn, + self.hidden_dims, + dropout_rate, + activation_fn), + update_node_fn=_make_mlp(in_dim_node_fn, + self.hidden_dims, + dropout_rate, + activation_fn), + update_global_fn=_make_mlp(last_in_dim, + self.hidden_dims, + dropout_rate, + activation_fn))) self.graph_network = nn.Sequential(*graph_network_layers) self.decoder = nn.Linear( diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index 513d6a269..beb518e0f 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -144,12 +144,13 @@ def init_model_fn( """aux_dropout_rate is unused.""" del aux_dropout_rate torch.random.manual_seed(rng[0]) - model = GNN(num_outputs=self._num_outputs, - dropout_rate=dropout_rate, - hidden_dims=self.hidden_dims, - latent_dim=self.latent_dim, - num_message_passing_steps=self.num_message_passing_steps, - activation_fn_name=self.activation_fn_name) + model = GNN( + num_outputs=self._num_outputs, + dropout_rate=dropout_rate, + hidden_dims=self.hidden_dims, + latent_dim=self.latent_dim, + num_message_passing_steps=self.num_message_passing_steps, + activation_fn_name=self.activation_fn_name) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -259,7 +260,7 @@ def activation_fn_name(self) -> str: class OgbgModelSizeWorkload(OgbgWorkload): - + @property def hidden_dims(self) -> Tuple[int]: return (256, 256) diff --git a/algorithmic_efficiency/workloads/ogbg/workload.py b/algorithmic_efficiency/workloads/ogbg/workload.py index ade91b35d..a32f385cb 100644 --- a/algorithmic_efficiency/workloads/ogbg/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/workload.py @@ -22,7 +22,7 @@ def target_metric_name(self) -> str: """The name of the target metric (useful for scoring/processing code).""" return 'mean_average_precision' - @property + @property def activation_fn_name(self) -> str: """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" return 'relu' diff --git a/algorithmic_efficiency/workloads/workloads.py b/algorithmic_efficiency/workloads/workloads.py index 09ddfabfd..a9cbec1e8 100644 --- a/algorithmic_efficiency/workloads/workloads.py +++ b/algorithmic_efficiency/workloads/workloads.py @@ -103,7 +103,8 @@ 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgSiluWorkload' }, 'ogbg_model_size': { - 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgModelSizeWorkload' + 'workload_path': 'ogbg/ogbg', + 'workload_class_name': 'OgbgModelSizeWorkload' }, 'wmt': {'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkload'}, 'wmt_post_ln': { diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 18980d9f4..53a500085 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -15,10 +15,9 @@ OgbgWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff - # Todo: refactor tests to use workload properties in cleaner way hidden_dims = len(JaxWorkload().hidden_dims) -num_graphs= JaxWorkload().num_message_passing_steps +num_graphs = JaxWorkload().num_message_passing_steps def key_transform(k): @@ -31,7 +30,7 @@ def key_transform(k): for i in k: bn = bn or 'BatchNorm' in i ln = ln or 'LayerNorm' in i - graph_network = graph_network or 'GraphNetwork' in i + graph_network = graph_network or 'GraphNetwork' in i if 'Sequential' in i: seq_index = int(i.split('_')[1]) continue diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py index 0c7b1e0d4..964032da7 100644 --- a/tests/modeldiffs/ogbg_gelu/compare.py +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -15,10 +15,9 @@ OgbgGeluWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff - # Todo: refactor tests to use workload properties in cleaner way hidden_dims = len(JaxWorkload().hidden_dims) -num_graphs= JaxWorkload().num_message_passing_steps +num_graphs = JaxWorkload().num_message_passing_steps def key_transform(k): @@ -31,7 +30,7 @@ def key_transform(k): for i in k: bn = bn or 'BatchNorm' in i ln = ln or 'LayerNorm' in i - graph_network = graph_network or 'GraphNetwork' in i + graph_network = graph_network or 'GraphNetwork' in i if 'Sequential' in i: seq_index = int(i.split('_')[1]) continue @@ -115,4 +114,4 @@ def sd_transform(sd): pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, sd_transform=sd_transform, - out_transform=None) \ No newline at end of file + out_transform=None) diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py index 022e05b94..b90e3d8a8 100644 --- a/tests/modeldiffs/ogbg_model_size/compare.py +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -15,10 +15,9 @@ OgbgModelSizeWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff - # Todo: refactor tests to use workload properties in cleaner way hidden_dims = len(JaxWorkload().hidden_dims) -num_graphs= JaxWorkload().num_message_passing_steps +num_graphs = JaxWorkload().num_message_passing_steps def key_transform(k): @@ -31,7 +30,7 @@ def key_transform(k): for i in k: bn = bn or 'BatchNorm' in i ln = ln or 'LayerNorm' in i - graph_network = graph_network or 'GraphNetwork' in i + graph_network = graph_network or 'GraphNetwork' in i if 'Sequential' in i: seq_index = int(i.split('_')[1]) continue diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py index feb141057..10bc79f57 100644 --- a/tests/modeldiffs/ogbg_silu/compare.py +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -15,10 +15,9 @@ OgbgSiluWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff - # Todo: refactor tests to use workload properties in cleaner way hidden_dims = len(JaxWorkload().hidden_dims) -num_graphs= JaxWorkload().num_message_passing_steps +num_graphs = JaxWorkload().num_message_passing_steps def key_transform(k): @@ -31,7 +30,7 @@ def key_transform(k): for i in k: bn = bn or 'BatchNorm' in i ln = ln or 'LayerNorm' in i - graph_network = graph_network or 'GraphNetwork' in i + graph_network = graph_network or 'GraphNetwork' in i if 'Sequential' in i: seq_index = int(i.split('_')[1]) continue @@ -116,4 +115,3 @@ def sd_transform(sd): key_transform=key_transform, sd_transform=sd_transform, out_transform=None) - \ No newline at end of file From a40e0851531d3736c02cba36e69c9cc7d6e2777f Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 17:11:22 +0000 Subject: [PATCH 81/86] remove unused import --- algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py | 1 - 1 file changed, 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py index 7201a2d90..9fc24552d 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py @@ -3,7 +3,6 @@ from typing import Any, Dict, Optional, Tuple from flax import jax_utils -import flax.linen as nn import jax import jax.numpy as jnp import jraph From ea5bd5ef37aebcd56a499f14b307dadc317b2746 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 17:18:12 +0000 Subject: [PATCH 82/86] formatting --- tests/modeldiffs/ogbg/compare.py | 9 ++++++--- tests/modeldiffs/ogbg_gelu/compare.py | 9 ++++++--- tests/modeldiffs/ogbg_model_size/compare.py | 7 +++++-- tests/modeldiffs/ogbg_silu/compare.py | 9 ++++++--- 4 files changed, 23 insertions(+), 11 deletions(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 53a500085..40b92ce4f 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -37,10 +37,12 @@ def key_transform(k): elif 'GraphNetwork' in i: graph_index = int(i.split('_')[1]) continue - elif 'Linear' in i: + if 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = ( + graph_index * 3 * hidden_dims + seq_index * hidden_dims + + layer_index) i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -51,7 +53,8 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = ( + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index) i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py index 964032da7..4c87366f3 100644 --- a/tests/modeldiffs/ogbg_gelu/compare.py +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -37,10 +37,12 @@ def key_transform(k): elif 'GraphNetwork' in i: graph_index = int(i.split('_')[1]) continue - elif 'Linear' in i: + if 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = ( + graph_index * 3 * hidden_dims + seq_index * hidden_dims + + layer_index) i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -51,7 +53,8 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = ( + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index) i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py index b90e3d8a8..11a74b26a 100644 --- a/tests/modeldiffs/ogbg_model_size/compare.py +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -40,7 +40,9 @@ def key_transform(k): elif 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = ( + graph_index * 3 * hidden_dims + seq_index * hidden_dims + + layer_index) i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -51,7 +53,8 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = ( + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index) i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py index 10bc79f57..bb47cb4be 100644 --- a/tests/modeldiffs/ogbg_silu/compare.py +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -37,10 +37,12 @@ def key_transform(k): elif 'GraphNetwork' in i: graph_index = int(i.split('_')[1]) continue - elif 'Linear' in i: + if 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = ( + graph_index * 3 * hidden_dims + seq_index * hidden_dims + + layer_index) i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -51,7 +53,8 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = ( + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index) i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: From 59156c700294383484352d1d22a65fa1d7b83e8d Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 17:44:17 +0000 Subject: [PATCH 83/86] lint fix --- tests/modeldiffs/ogbg/compare.py | 2 +- tests/modeldiffs/ogbg_gelu/compare.py | 2 +- tests/modeldiffs/ogbg_model_size/compare.py | 2 +- tests/modeldiffs/ogbg_silu/compare.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 40b92ce4f..56316ba12 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -34,7 +34,7 @@ def key_transform(k): if 'Sequential' in i: seq_index = int(i.split('_')[1]) continue - elif 'GraphNetwork' in i: + if 'GraphNetwork' in i: graph_index = int(i.split('_')[1]) continue if 'Linear' in i: diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py index 4c87366f3..b58bcd461 100644 --- a/tests/modeldiffs/ogbg_gelu/compare.py +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -34,7 +34,7 @@ def key_transform(k): if 'Sequential' in i: seq_index = int(i.split('_')[1]) continue - elif 'GraphNetwork' in i: + if 'GraphNetwork' in i: graph_index = int(i.split('_')[1]) continue if 'Linear' in i: diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py index 11a74b26a..f32d53171 100644 --- a/tests/modeldiffs/ogbg_model_size/compare.py +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -34,7 +34,7 @@ def key_transform(k): if 'Sequential' in i: seq_index = int(i.split('_')[1]) continue - elif 'GraphNetwork' in i: + if 'GraphNetwork' in i: graph_index = int(i.split('_')[1]) continue elif 'Linear' in i: diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py index bb47cb4be..2922b7046 100644 --- a/tests/modeldiffs/ogbg_silu/compare.py +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -34,7 +34,7 @@ def key_transform(k): if 'Sequential' in i: seq_index = int(i.split('_')[1]) continue - elif 'GraphNetwork' in i: + if 'GraphNetwork' in i: graph_index = int(i.split('_')[1]) continue if 'Linear' in i: From 28b8dfb1e5f419c37b62e7281ce2b17060a6b358 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 17:52:53 +0000 Subject: [PATCH 84/86] lint fix --- tests/modeldiffs/ogbg_model_size/compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py index f32d53171..62443bbb5 100644 --- a/tests/modeldiffs/ogbg_model_size/compare.py +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -37,7 +37,7 @@ def key_transform(k): if 'GraphNetwork' in i: graph_index = int(i.split('_')[1]) continue - elif 'Linear' in i: + if 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: count = ( From c92075e8c02e105d6a04456badd95ad0da1e10de Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 5 Jan 2024 11:34:29 -0800 Subject: [PATCH 85/86] Update README.md Fix typo for self_tuning commands. --- .../prize_qualification_baselines/README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/reference_algorithms/prize_qualification_baselines/README.md b/reference_algorithms/prize_qualification_baselines/README.md index 8276887da..100555964 100644 --- a/reference_algorithms/prize_qualification_baselines/README.md +++ b/reference_algorithms/prize_qualification_baselines/README.md @@ -50,8 +50,8 @@ torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc The prize qualification baseline submissionss for jax are: -- `reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py` -- `feference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py` +- `reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py` +- `feference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py` Example command: @@ -62,7 +62,7 @@ python3 submission_runner.py \ --experiment_dir= \ --experiment_name= \ --workload= \ - --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py \ + --submission_path=reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py \ --tuning_ruleset=self ``` @@ -70,8 +70,8 @@ python3 submission_runner.py \ The prize qualification baseline submissionss for PyTorch are: -- `reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py` -- `feference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py` +- `reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py` +- `feference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py` Example command: @@ -82,6 +82,6 @@ torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc --experiment_dir= \ --experiment_name=t \ --workload=\ - --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py \ + --submission_path=reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py \ --tuning_ruleset=self ``` From 45402699b8bc4f9f9673569fbf9f882b7e7cdd97 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Thu, 11 Jan 2024 10:52:53 +0100 Subject: [PATCH 86/86] Fix typo (size of LibriSpeech) --- datasets/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/README.md b/datasets/README.md index 37480f4f8..c68a5cc6b 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -423,7 +423,7 @@ $DATA_DIR │ │ ├── [...] ``` -In total, it should contain 543,323 files (via `find -type f | wc -l`) for a total of 338 GB (via `du -sch librispeech/`). +In total, it should contain 543,323 files (via `find -type f | wc -l`) for a total of 388 GB (via `du -sch librispeech/`).
#### Training SPM Tokenizer