From ce99901e2ea31528c19161da35798f09c424c8e5 Mon Sep 17 00:00:00 2001 From: init-22 Date: Thu, 14 Nov 2024 23:31:40 +0530 Subject: [PATCH 01/63] feat: package updates with python311 --- algorithmic_efficiency/random_utils.py | 8 +-- docker/Dockerfile | 29 ++++++++- setup.cfg | 86 +++++++++++++------------- 3 files changed, 75 insertions(+), 48 deletions(-) diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index cf1ea6c32..31317047e 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -18,8 +18,8 @@ # Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an # unsigned int), while RandomState.randint only accepts and returns signed ints. -MAX_INT32 = 2**31 -MIN_INT32 = -MAX_INT32 +MAX_UINT32 = 2**31 +MIN_UINT32 = 0 SeedType = Union[int, list, np.ndarray] @@ -35,13 +35,13 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32) + new_seed = rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32) return [new_seed, data] def _split(seed: SeedType, num: int = 2) -> SeedType: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2]) + return rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32, size=[num, 2]) def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name diff --git a/docker/Dockerfile b/docker/Dockerfile index 9b72aea86..24d05b495 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -11,7 +11,34 @@ FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04 RUN echo "Setting up machine" RUN apt-get update RUN apt-get install -y curl tar -RUN DEBIAN_FRONTEND=noninteractive apt-get install -y git python3 pip wget ffmpeg +RUN DEBIAN_FRONTEND=noninteractive apt-get install -y git ffmpeg + +# Install prerequisites +RUN apt-get update && apt-get install -y \ + wget \ + build-essential \ + zlib1g-dev \ + libncurses5-dev \ + libssl-dev \ + libreadline-dev \ + libffi-dev \ + curl \ + libbz2-dev \ + liblzma-dev + +# Download and install Python 3.11 +RUN cd /tmp \ + && wget https://www.python.org/ftp/python/3.11.0/Python-3.11.0.tgz \ + && tar -xvzf Python-3.11.0.tgz \ + && cd Python-3.11.0 \ + && ./configure --enable-optimizations \ + && make -j$(nproc) \ + && make altinstall + +# Create symlinks for python and pip (use 'pip' instead of 'pip3') +RUN ln -s /usr/local/bin/python3.11 /usr/bin/python \ + && ln -s /usr/local/bin/pip3.11 /usr/bin/pip + RUN apt-get install libtcmalloc-minimal4 RUN apt-get install unzip RUN apt-get install pigz diff --git a/setup.cfg b/setup.cfg index 4afefd164..deeb1c6c4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,6 +24,7 @@ classifiers = Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 Topic :: Scientific/Engineering :: Artificial Intelligence [options] @@ -34,22 +35,22 @@ setup_requires = setuptools_scm # Dependencies of the project: install_requires = - absl-py==1.4.0 + absl-py==2.1.1 # Pin to avoid unpinned install in dependencies that requires Python>=3.9. - networkx==3.1 - docker==7.0.0 - numpy>=1.23 - pandas>=2.0.1 - tensorflow==2.12.0 - tensorflow-datasets==4.9.2 - tensorflow-probability==0.20.0 - tensorflow-addons==0.20.0 + networkx==3.2.1 + docker==7.1.0 + numpy>=1.26.4 + pandas==2.2.3 + tensorflow==2.18.0 + tensorflow-datasets==4.9.7 + tensorflow-addons==0.23.0 gputil==1.4.0 - psutil==5.9.5 - clu==0.0.7 - matplotlib>=3.7.2 + psutil==6.1.0 + clu==0.0.12 + matplotlib>=3.9.2 tabulate==0.9.0 -python_requires = >=3.8 + wandb==0.18.7 +python_requires = >=3.11 ############################################################################### @@ -79,78 +80,77 @@ full_dev = # Dependencies for developing the package dev = - isort==5.12.0 - pylint==2.17.4 - pytest==7.3.1 - yapf==0.33.0 - pre-commit==3.3.1 + isort==5.13.2 + pylint==3.3.1 + pytest==8.3.3 + yapf==0.43.0 + pre-commit==4.0.1 # Workloads # criteo1tb = - scikit-learn==1.2.2 + scikit-learn==1.5.2 fastmri = - h5py==3.8.0 - scikit-image==0.20.0 + h5py==3.12.1 + scikit-image==0.24.0 ogbg = jraph==0.0.6.dev0 - scikit-learn==1.2.2 + scikit-learn==1.5.2 librispeech_conformer = - sentencepiece==0.1.99 - tensorflow-text==2.12.1 + sentencepiece==0.2.0 + tensorflow-text==2.18.0 pydub==0.25.1 wmt = - sentencepiece==0.1.99 - tensorflow-text==2.12.1 - sacrebleu==1.3.1 + sentencepiece==0.2.0 + tensorflow-text==2.18.0 + sacrebleu==2.4.3 # Frameworks # # JAX Core jax_core_deps = - flax==0.6.10 - optax==0.1.5 + flax==0.10.1 + optax==0.2.4 # Fix chex (optax dependency) version. # Not fixing it can raise dependency issues with our # jax version. # Todo(kasimbeg): verify if this is necessary after we # upgrade jax. - chex==0.1.7 - ml_dtypes==0.2.0 - protobuf==4.25.3 + chex==0.1.87 + ml_dtypes==0.4.1 + protobuf==4.25.5 # JAX CPU jax_cpu = - jax==0.4.10 - jaxlib==0.4.10 + jax==0.4.35 + jaxlib==0.4.35 %(jax_core_deps)s # JAX GPU # Note this installs both jax and jaxlib. jax_gpu = - jax==0.4.10 - jaxlib==0.4.10+cuda12.cudnn88 + jax==0.4.35 + jaxlib==0.4.35 + jax-cuda12-plugin[with_cuda]==0.4.35 + jax-cuda12-pjrt==0.4.35 %(jax_core_deps)s # PyTorch CPU pytorch_cpu = - torch==2.1.0 - torchvision==0.16.0 + torch==2.5.0 + torchvision==0.20.0 # PyTorch GPU # Note: omit the cuda suffix and installing from the appropriate # wheel will result in using locally installed CUDA. pytorch_gpu = - torch==2.1.0 - torchvision==0.16.0 + torch==2.5.0 + torchvision==0.20.0 -# wandb -wandb = - wandb==0.16.5 ############################################################################### # Linting Configurations # From 21fb3f902d5744c8331be89f896c2376977f7f12 Mon Sep 17 00:00:00 2001 From: init-22 Date: Thu, 14 Nov 2024 23:46:17 +0530 Subject: [PATCH 02/63] fix: absl package version change --- docker/Dockerfile | 12 +++++++----- setup.cfg | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 24d05b495..497ffb2c1 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -28,9 +28,9 @@ RUN apt-get update && apt-get install -y \ # Download and install Python 3.11 RUN cd /tmp \ - && wget https://www.python.org/ftp/python/3.11.0/Python-3.11.0.tgz \ - && tar -xvzf Python-3.11.0.tgz \ - && cd Python-3.11.0 \ + && wget https://www.python.org/ftp/python/3.11.10/Python-3.11.10.tgz \ + && tar -xvzf Python-3.11.10.tgz \ + && cd Python-3.11.10 \ && ./configure --enable-optimizations \ && make -j$(nproc) \ && make altinstall @@ -55,11 +55,13 @@ RUN echo "Setting up directories for data and experiment_runs" RUN mkdir -p data/ RUN mkdir -p experiment_runs/ +RUN pip install --upgrade pip + # Install Algorithmic efficiency repo RUN echo "Setting up algorithmic_efficiency repo" -ARG branch="main" +ARG branch="python311" ARG framework="both" -ARG git_url=https://github.com/mlcommons/algorithmic-efficiency.git +ARG git_url=https://github.com/init-22/algorithmic-efficiency.git RUN git clone $git_url && cd /algorithmic-efficiency RUN cd /algorithmic-efficiency && git checkout $branch diff --git a/setup.cfg b/setup.cfg index deeb1c6c4..e952513df 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,7 +35,7 @@ setup_requires = setuptools_scm # Dependencies of the project: install_requires = - absl-py==2.1.1 + absl-py==2.1.0 # Pin to avoid unpinned install in dependencies that requires Python>=3.9. networkx==3.2.1 docker==7.1.0 From 67b9f15108486a1a29b348031e1b50a82fa55b40 Mon Sep 17 00:00:00 2001 From: init-22 Date: Fri, 15 Nov 2024 00:09:04 +0530 Subject: [PATCH 03/63] fix: pytorch version change --- setup.cfg | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.cfg b/setup.cfg index e952513df..a74faa197 100644 --- a/setup.cfg +++ b/setup.cfg @@ -141,15 +141,15 @@ jax_gpu = # PyTorch CPU pytorch_cpu = - torch==2.5.0 - torchvision==0.20.0 + torch==2.5.1 + torchvision==0.20.1 # PyTorch GPU # Note: omit the cuda suffix and installing from the appropriate # wheel will result in using locally installed CUDA. pytorch_gpu = - torch==2.5.0 - torchvision==0.20.0 + torch==2.5.1 + torchvision==0.20.1 ############################################################################### From 78df36f2f0f173ad651b81527cda8d55f85028b0 Mon Sep 17 00:00:00 2001 From: init-22 Date: Fri, 15 Nov 2024 00:42:26 +0530 Subject: [PATCH 04/63] fix: tf version to use numpy < 2 --- docker/Dockerfile | 2 -- setup.cfg | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 497ffb2c1..88fc55243 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -87,8 +87,6 @@ RUN if [ "$framework" = "jax" ] ; then \ RUN cd /algorithmic-efficiency && pip install -e '.[full]' -RUN cd /algorithmic-efficiency && pip install -e '.[wandb]' - RUN cd /algorithmic-efficiency && git fetch origin RUN cd /algorithmic-efficiency && git pull diff --git a/setup.cfg b/setup.cfg index a74faa197..2a300469a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -41,7 +41,7 @@ install_requires = docker==7.1.0 numpy>=1.26.4 pandas==2.2.3 - tensorflow==2.18.0 + tensorflow==2.17.0 tensorflow-datasets==4.9.7 tensorflow-addons==0.23.0 gputil==1.4.0 @@ -105,7 +105,7 @@ librispeech_conformer = wmt = sentencepiece==0.2.0 - tensorflow-text==2.18.0 + tensorflow-text==2.17.0 sacrebleu==2.4.3 # Frameworks # From 2584416e8cc82bb61ef7a1d2a395a25da919f93f Mon Sep 17 00:00:00 2001 From: init-22 Date: Fri, 15 Nov 2024 19:09:40 +0530 Subject: [PATCH 05/63] fix: librispeech requirement of tf-text rolled back to v2.17 --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 2a300469a..078b694b8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -100,7 +100,7 @@ ogbg = librispeech_conformer = sentencepiece==0.2.0 - tensorflow-text==2.18.0 + tensorflow-text==2.17.0 pydub==0.25.1 wmt = From d603ce921b211918ce0e3d27742032f5e7ece674 Mon Sep 17 00:00:00 2001 From: init-22 Date: Fri, 15 Nov 2024 19:11:38 +0530 Subject: [PATCH 06/63] fix: using the main repo and branch for testing --- docker/Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 88fc55243..ee9136cbf 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -59,9 +59,9 @@ RUN pip install --upgrade pip # Install Algorithmic efficiency repo RUN echo "Setting up algorithmic_efficiency repo" -ARG branch="python311" +ARG branch="main" ARG framework="both" -ARG git_url=https://github.com/init-22/algorithmic-efficiency.git +ARG git_url=https://github.com/mlcommons/algorithmic-efficiency.git RUN git clone $git_url && cd /algorithmic-efficiency RUN cd /algorithmic-efficiency && git checkout $branch From be68f8cbf4a528804c78eff886ffd7e36e04fca8 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sat, 16 Nov 2024 13:57:11 +0530 Subject: [PATCH 07/63] fix: overflow error resolved and PRNGKey to key --- algorithmic_efficiency/checkpoint_utils.py | 2 +- algorithmic_efficiency/random_utils.py | 10 +++++----- setup.cfg | 8 ++++---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/algorithmic_efficiency/checkpoint_utils.py b/algorithmic_efficiency/checkpoint_utils.py index 29c1a821e..04dad0eb7 100644 --- a/algorithmic_efficiency/checkpoint_utils.py +++ b/algorithmic_efficiency/checkpoint_utils.py @@ -231,7 +231,7 @@ def save_checkpoint(framework: str, target=checkpoint_state, step=global_step, overwrite=True, - keep=np.Inf if save_intermediate_checkpoints else 1) + keep=np.inf if save_intermediate_checkpoints else 1) else: if not save_intermediate_checkpoints: checkpoint_files = gfile.glob( diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index 31317047e..93dc263bd 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -18,7 +18,7 @@ # Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an # unsigned int), while RandomState.randint only accepts and returns signed ints. -MAX_UINT32 = 2**31 +MAX_UINT32 = 2**32-1 MIN_UINT32 = 0 SeedType = Union[int, list, np.ndarray] @@ -26,11 +26,11 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: if isinstance(seed, int): - return seed % 2**32 + return seed % MAX_UINT32 if isinstance(seed, list): - return [s % 2**32 for s in seed] + return [s % MAX_UINT32 for s in seed] if isinstance(seed, np.ndarray): - return np.array([s % 2**32 for s in seed.tolist()]) + return np.array([s % MAX_UINT32 for s in seed.tolist()]) def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: @@ -75,5 +75,5 @@ def split(seed: SeedType, num: int = 2) -> SeedType: def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name if FLAGS.framework == 'jax': _check_jax_install() - return jax_rng.PRNGKey(seed) + return jax_rng.key(seed) return _PRNGKey(seed) diff --git a/setup.cfg b/setup.cfg index 078b694b8..6e6a1c957 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,9 +39,9 @@ install_requires = # Pin to avoid unpinned install in dependencies that requires Python>=3.9. networkx==3.2.1 docker==7.1.0 - numpy>=1.26.4 + numpy>=2.1.3 pandas==2.2.3 - tensorflow==2.17.0 + tensorflow==2.18.0 tensorflow-datasets==4.9.7 tensorflow-addons==0.23.0 gputil==1.4.0 @@ -100,12 +100,12 @@ ogbg = librispeech_conformer = sentencepiece==0.2.0 - tensorflow-text==2.17.0 + tensorflow-text==2.18.0 pydub==0.25.1 wmt = sentencepiece==0.2.0 - tensorflow-text==2.17.0 + tensorflow-text==2.18.0 sacrebleu==2.4.3 # Frameworks # From e890c893297a6e64cbfdc6d63f87ee7f7b4d385a Mon Sep 17 00:00:00 2001 From: init-22 Date: Wed, 20 Nov 2024 19:13:50 +0530 Subject: [PATCH 08/63] fix: minor changes in docs --- GETTING_STARTED.md | 2 +- algorithmic_efficiency/logger_utils.py | 2 +- setup.cfg | 3 --- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index 006b972ec..aa493bc9f 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -35,7 +35,7 @@ The specs on the benchmarking machines are: > **Prerequisites:** > -> - Python minimum requirement >= 3.8 +> - Python minimum requirement >= 3.11 > - CUDA 12.1 > - NVIDIA Driver version 535.104.05 diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index 609d996e6..155e55356 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -211,7 +211,7 @@ def _get_system_software_info() -> Dict: system_software_info['os_platform'] = \ platform.platform() # Ex. 'Linux-5.4.48-x86_64-with-glibc2.29' system_software_info['python_version'] = platform.python_version( - ) # Ex. '3.8.10' + ) # Ex. '3.11.10' system_software_info['python_compiler'] = platform.python_compiler( ) # Ex. 'GCC 9.3.0' # Note: do not store hostname as that may be sensitive diff --git a/setup.cfg b/setup.cfg index 6e6a1c957..5023f1ba6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -21,9 +21,6 @@ classifiers = Intended Audience :: Science/Research License :: OSI Approved :: Apache Software License Operating System :: OS Independent - Programming Language :: Python :: 3.8 - Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 Programming Language :: Python :: 3.11 Topic :: Scientific/Engineering :: Artificial Intelligence From 1bc2a7b2d5de45309bbcab035bff587c9f19ef27 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sat, 30 Nov 2024 13:07:46 +0530 Subject: [PATCH 09/63] fix: changing the python versions in workflow to pass the tests --- .github/workflows/CI.yml | 48 +++++++++++++------------- .github/workflows/linting.yml | 12 +++---- .github/workflows/traindiffs_tests.yml | 2 +- 3 files changed, 31 insertions(+), 31 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 05d94e896..fe2441bfe 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -7,10 +7,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -25,10 +25,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -42,10 +42,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -59,10 +59,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -77,10 +77,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -96,10 +96,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -113,10 +113,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -130,10 +130,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -148,10 +148,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -166,10 +166,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -184,10 +184,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install pytest @@ -208,10 +208,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install pytest diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 89b5ef288..628fc012b 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -7,10 +7,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: 3.11.10 - name: Install pylint run: | python -m pip install --upgrade pip @@ -27,10 +27,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: 3.11.10 - name: Install isort run: | python -m pip install --upgrade pip @@ -43,10 +43,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: 3.11.10 - name: Install yapf run: | python -m pip install --upgrade pip diff --git a/.github/workflows/traindiffs_tests.yml b/.github/workflows/traindiffs_tests.yml index 382f0dfe1..a2fdcb453 100644 --- a/.github/workflows/traindiffs_tests.yml +++ b/.github/workflows/traindiffs_tests.yml @@ -3,7 +3,7 @@ name: Containerized Training Differences Tests Jax vs PyTorch on: pull_request: branches: - - 'main' + - 'python311' jobs: build_and_push_docker_image: From 7a0fee3224e3d4e8602a2aca2819358bf97acf00 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sat, 30 Nov 2024 22:59:52 +0530 Subject: [PATCH 10/63] fix: changing numpy compatible version --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 5023f1ba6..0aa4dce49 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,7 +36,7 @@ install_requires = # Pin to avoid unpinned install in dependencies that requires Python>=3.9. networkx==3.2.1 docker==7.1.0 - numpy>=2.1.3 + numpy>=2.0.2 pandas==2.2.3 tensorflow==2.18.0 tensorflow-datasets==4.9.7 From 7cdea1638ceb2a3c0019e95c0a63f0c36605064a Mon Sep 17 00:00:00 2001 From: init-22 Date: Sat, 30 Nov 2024 23:07:52 +0530 Subject: [PATCH 11/63] adding key_data to check the CI tests --- submission_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 551173bf5..0024c35d4 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -210,7 +210,7 @@ def train_once( ) -> Tuple[spec.Timing, Dict[str, Any]]: _reset_cuda_mem() data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) - + data_rng = jax.random.key_data(data_rng) # Workload setup. logging.info('Initializing dataset.') if hasattr(workload, '_eval_num_workers'): @@ -336,7 +336,7 @@ def train_once( step_rng = prng.fold_in(rng, global_step) data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3) - + eval_rng = jax.random.key_data(eval_rng) with profiler.profile('Data selection'): batch = data_selection(workload, input_queue, From 7264c3f80d0bd38a1c50f107d715765a7c76dcdc Mon Sep 17 00:00:00 2001 From: init-22 Date: Sun, 1 Dec 2024 14:41:05 +0530 Subject: [PATCH 12/63] fix: updated packge of sacrebleu changed the way it used to work, hence using the corpus_bleu from the main package --- algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py | 3 ++- setup.cfg | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py index 0ba49c2f6..327ca34ad 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py @@ -5,6 +5,7 @@ from absl import logging import jax +import sacrebleu import tensorflow as tf import torch import torch.distributed as dist @@ -162,7 +163,7 @@ def translate_and_calculate_bleu(self, predictions.append(self._decode_tokens(predicted[idx])) # Calculate BLEU score for translated eval corpus against reference. - bleu_score = bleu.corpus_bleu(predictions, [references]).score + bleu_score = sacrebleu.corpus_bleu(predictions, [references]).score return bleu_score def init_model_fn( diff --git a/setup.cfg b/setup.cfg index 0aa4dce49..23e86a13b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -104,7 +104,6 @@ wmt = sentencepiece==0.2.0 tensorflow-text==2.18.0 sacrebleu==2.4.3 - # Frameworks # # JAX Core From abbdc8262917fd8e38ba954f8cdaf478a5d8d1c7 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sun, 1 Dec 2024 16:11:01 +0530 Subject: [PATCH 13/63] fix: temporarily commenting tfa --- .../workloads/imagenet_resnet/imagenet_jax/randaugment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index 5f92b1482..d0bbecb8f 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -8,7 +8,7 @@ import math import tensorflow as tf -from tensorflow_addons import image as contrib_image +#from tensorflow_addons import image as contrib_image # This signifies the max integer that the controller RNN could predict for the # augmentation scheme. From 86029a742094a653e5bf9a6f17f0d42c0990671d Mon Sep 17 00:00:00 2001 From: init-22 Date: Mon, 2 Dec 2024 22:10:24 +0530 Subject: [PATCH 14/63] fix: explicitly using mask kwarg to use MultiHeadDotProductAttention and also using sacrebleu --- .../workloads/imagenet_resnet/imagenet_jax/randaugment.py | 1 + algorithmic_efficiency/workloads/wmt/wmt_jax/models.py | 6 +++--- algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index d0bbecb8f..af1b763c1 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -8,6 +8,7 @@ import math import tensorflow as tf + #from tensorflow_addons import image as contrib_image # This signifies the max integer that the controller RNN could predict for the diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py index e4b5cd014..7bbc0b168 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py @@ -224,7 +224,7 @@ def __call__(self, inputs, encoder_mask=None): dropout_rate=attention_dropout_rate, deterministic=cfg.deterministic)(cfg.attention_temp * x, x, - encoder_mask) + mask=encoder_mask) if cfg.dropout_rate is None: dropout_rate = 0.1 @@ -288,7 +288,7 @@ def __call__(self, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=cfg.deterministic, - decode=cfg.decode)(cfg.attention_temp * x, x, decoder_mask) + decode=cfg.decode)(cfg.attention_temp * x, x, mask=decoder_mask) if cfg.dropout_rate is None: dropout_rate = 0.1 else: @@ -311,7 +311,7 @@ def __call__(self, dropout_rate=attention_dropout_rate, deterministic=cfg.deterministic)(cfg.attention_temp * y, encoded, - encoder_decoder_mask) + mask=encoder_decoder_mask) y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) y = y + x diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index 046d5e469..442c85899 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -12,6 +12,7 @@ import jax.numpy as jnp import numpy as np import optax +import sacrebleu from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec @@ -203,7 +204,7 @@ def translate_and_calculate_bleu(self, predictions.append(self._decode_tokens(predicted[idx])) # Calculate BLEU score for translated eval corpus against reference. - bleu_score = bleu.corpus_bleu(predictions, [references]).score + bleu_score = sacrebleu.corpus_bleu(predictions, [references]).score return bleu_score def init_model_fn( From aca45a2b1e1df7e42a5108df8e30d49baf6ef6e2 Mon Sep 17 00:00:00 2001 From: init-22 Date: Mon, 2 Dec 2024 22:42:21 +0530 Subject: [PATCH 15/63] fix: using flax.core.pop instead of variables.pop, better way to update batch_stats --- .../workloads/imagenet_resnet/imagenet_jax/workload.py | 7 ++++--- .../workloads/imagenet_vit/imagenet_jax/workload.py | 3 ++- .../librispeech_conformer/librispeech_jax/workload.py | 7 ++++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py index d8de214f5..8ab4adbb9 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -11,6 +11,7 @@ from flax import jax_utils from flax import linen as nn +from flax.core import pop import jax from jax import lax import jax.numpy as jnp @@ -79,8 +80,8 @@ def sync_batch_stats( # In this case each device has its own version of the batch statistics and # we average them. avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy( - {'batch_stats': avg_fn(model_state['batch_stats'])}) + new_model_state = model_state.copy() # Create a shallow copy + new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state def init_model_fn( @@ -111,7 +112,7 @@ def init_model_fn( input_shape = (1, 224, 224, 3) variables = jax.jit(model.init)({'params': rng}, jnp.ones(input_shape, model.dtype)) - model_state, params = variables.pop('params') + model_state, params = pop(variables, "params") self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) model_state = jax_utils.replicate(model_state) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py index 2ad71ffd0..5f826d035 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py @@ -4,6 +4,7 @@ from flax import jax_utils from flax import linen as nn +from flax.core import pop import jax import jax.numpy as jnp @@ -28,7 +29,7 @@ def initialized(self, key: spec.RandomState, variables = jax.jit( model.init)({'params': params_rng, 'dropout': dropout_rng}, jnp.ones(input_shape)) - model_state, params = variables.pop('params') + model_state, params = pop(variables, "params") return params, model_state def init_model_fn( diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index f4d1ab0f3..d805e8b17 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -3,6 +3,7 @@ from typing import Dict, Iterator, Optional, Tuple from flax import jax_utils +from flax.core import pop import flax.linen as nn import jax from jax import lax @@ -89,7 +90,7 @@ def init_model_fn( variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, *fake_input_batch) - model_state, params = variables.pop('params') + model_state, params = pop(variables, "params") self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -374,8 +375,8 @@ def sync_batch_stats( # In this case each device has its own version of the batch statistics and # we average them. avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy( - {'batch_stats': avg_fn(model_state['batch_stats'])}) + new_model_state = model_state.copy() + new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state From 2618c5e6b1dcbdf48c2625f4cfbdca93fdc53993 Mon Sep 17 00:00:00 2001 From: init-22 Date: Mon, 2 Dec 2024 22:50:27 +0530 Subject: [PATCH 16/63] fix: changing the traindiffs_tests branch to main again --- .github/workflows/traindiffs_tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/traindiffs_tests.yml b/.github/workflows/traindiffs_tests.yml index a2fdcb453..382f0dfe1 100644 --- a/.github/workflows/traindiffs_tests.yml +++ b/.github/workflows/traindiffs_tests.yml @@ -3,7 +3,7 @@ name: Containerized Training Differences Tests Jax vs PyTorch on: pull_request: branches: - - 'python311' + - 'main' jobs: build_and_push_docker_image: From 8c9062564c920e7fea8c3ee6abc8fce51d663c82 Mon Sep 17 00:00:00 2001 From: init-22 Date: Mon, 2 Dec 2024 23:23:09 +0530 Subject: [PATCH 17/63] fix: unfreeze() in test_param_shapes expect FrozenDict also added flax.core.pop instead of variables.pop --- .../workloads/cifar/cifar_jax/workload.py | 7 ++++--- tests/test_param_shapes.py | 6 +++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py index b019d1cee..6ec90b99a 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py @@ -5,6 +5,7 @@ from flax import jax_utils from flax import linen as nn +from flax.core import pop import jax from jax import lax import jax.numpy as jnp @@ -75,8 +76,8 @@ def sync_batch_stats( # In this case each device has its own version of the batch statistics # and we average them. avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy( - {'batch_stats': avg_fn(model_state['batch_stats'])}) + new_model_state = model_state.copy() + new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state def init_model_fn( @@ -93,7 +94,7 @@ def init_model_fn( input_shape = (1, 32, 32, 3) variables = jax.jit(model.init)({'params': rng}, jnp.ones(input_shape, model.dtype)) - model_state, params = variables.pop('params') + model_state, params = pop(variables, 'params') self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) model_state = jax_utils.replicate(model_state) diff --git a/tests/test_param_shapes.py b/tests/test_param_shapes.py index b67625213..4ad56c873 100644 --- a/tests/test_param_shapes.py +++ b/tests/test_param_shapes.py @@ -3,6 +3,7 @@ import jax import numpy as np import pytest +from flax.core import FrozenDict # isort: skip_file # pylint:disable=line-too-long @@ -51,8 +52,11 @@ def test_param_shapes(workload): jax_workload, pytorch_workload = get_workload(workload) # Compare number of parameter tensors of both models. + jax_workload_param_shapes = jax_workload.param_shapes + if isinstance(jax_workload_param_shapes, dict): + jax_workload_param_shapes = FrozenDict(jax_workload_param_shapes) jax_param_shapes = jax.tree_util.tree_leaves( - jax_workload.param_shapes.unfreeze()) + jax_workload_param_shapes.unfreeze()) pytorch_param_shapes = jax.tree_util.tree_leaves( pytorch_workload.param_shapes) if workload == 'wmt': From 1b587b75890c39c3b3ebf5359b7f82b260e06bc6 Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 3 Dec 2024 20:30:41 +0530 Subject: [PATCH 18/63] fix: formatting changes with yapf --- algorithmic_efficiency/profiler.py | 4 +-- algorithmic_efficiency/random_utils.py | 2 +- .../workloads/cifar/cifar_jax/workload.py | 2 +- .../fastmri/fastmri_pytorch/workload.py | 4 +-- .../imagenet_jax/randaugment.py | 8 ++---- .../imagenet_pytorch/workload.py | 4 +-- .../librispeech_jax/models.py | 10 +++---- .../librispeech_jax/spectrum_augmenter.py | 4 +-- .../librispeech_jax/workload.py | 2 +- .../librispeech_pytorch/workload.py | 9 +++--- .../librispeech_jax/models.py | 10 +++---- .../workloads/mnist/workload.py | 7 ++--- .../workloads/wmt/wmt_jax/models.py | 13 ++++----- .../workloads/wmt/wmt_pytorch/models.py | 4 +-- .../external_tuning/jax_nadamw_full_budget.py | 10 ++++--- .../jax_nadamw_target_setting.py | 10 ++++--- .../self_tuning/jax_nadamw_full_budget.py | 10 ++++--- .../self_tuning/jax_nadamw_target_setting.py | 10 ++++--- .../paper_baselines/nadamw/jax/submission.py | 10 ++++--- .../paper_baselines/sam/jax/submission.py | 8 +++--- .../shampoo/jax/distributed_shampoo.py | 28 +++++++------------ .../target_setting_algorithms/jax_nadamw.py | 10 ++++--- submission_runner.py | 4 +-- tests/modeldiffs/wmt/compare.py | 6 ++-- .../modeldiffs/wmt_attention_temp/compare.py | 6 ++-- tests/modeldiffs/wmt_glu_tanh/compare.py | 6 ++-- tests/modeldiffs/wmt_post_ln/compare.py | 6 ++-- 27 files changed, 98 insertions(+), 109 deletions(-) diff --git a/algorithmic_efficiency/profiler.py b/algorithmic_efficiency/profiler.py index fa2a1bee2..d73efd964 100644 --- a/algorithmic_efficiency/profiler.py +++ b/algorithmic_efficiency/profiler.py @@ -72,8 +72,8 @@ def _make_report( float(np.std(d)), len(d), float(np.sum(d)), - 100.0 * float(np.sum(d)) / total_duration) for a, - d in self.recorded_durations.items()] + 100.0 * float(np.sum(d)) / total_duration) + for a, d in self.recorded_durations.items()] report.sort(key=lambda x: x[5], reverse=True) total_calls = sum(x[3] for x in report) return report, total_calls, total_duration diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index 93dc263bd..b5b30ce22 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -18,7 +18,7 @@ # Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an # unsigned int), while RandomState.randint only accepts and returns signed ints. -MAX_UINT32 = 2**32-1 +MAX_UINT32 = 2**32 - 1 MIN_UINT32 = 0 SeedType = Union[int, list, np.ndarray] diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py index 6ec90b99a..dd4643a60 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py @@ -76,7 +76,7 @@ def sync_batch_stats( # In this case each device has its own version of the batch statistics # and we average them. avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy() + new_model_state = model_state.copy() new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py index 74f6aa13d..a2f0828e3 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py @@ -252,9 +252,7 @@ def _eval_model_on_split(self, for _ in range(num_batches): batch = next(self._eval_iters[split]) batch_metrics = self._eval_model(params, batch, model_rng) - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index af1b763c1..94c66033a 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -313,8 +313,7 @@ def build_lut(histo, step): # If step is zero, return the original image. Otherwise, build # lut from the full histogram and step and then index from it. result = tf.cond( - tf.equal(step, 0), - lambda: im, + tf.equal(step, 0), lambda: im, lambda: tf.gather(build_lut(histo, step), im)) return tf.cast(result, tf.uint8) @@ -549,7 +548,6 @@ def distort_image_with_randaugment(image, num_layers, magnitude, key): translate_const=100) image = tf.cond( tf.equal(i, op_to_select), - lambda selected_func=func, - selected_args=args: selected_func(image, *selected_args), - lambda: image) + lambda selected_func=func, selected_args=args: selected_func( + image, *selected_args), lambda: image) return image diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 3549911fa..0ed944191 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -309,9 +309,7 @@ def _eval_model_on_split(self, update_batch_norm=False) weights = batch.get('weights') batch_metrics = self._compute_metrics(logits, batch['targets'], weights) - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index ed05f4335..db8cbc70a 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -153,8 +153,8 @@ def setup(self): self.kernel = self.param('kernel', nn.initializers.xavier_uniform(), self.filter_shape) - self.bias = self.param( - 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) + self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32), + self.output_channels) @nn.compact def __call__(self, inputs, paddings): @@ -442,12 +442,10 @@ def setup(self): dtype = self.config.dtype self.ra_mean = self.variable('batch_stats', - 'mean', - lambda s: jnp.zeros(s, dtype), + 'mean', lambda s: jnp.zeros(s, dtype), dim) self.ra_var = self.variable('batch_stats', - 'var', - lambda s: jnp.ones(s, dtype), + 'var', lambda s: jnp.ones(s, dtype), dim) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py index 2a6f73d4d..c16740629 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py @@ -81,8 +81,8 @@ def _get_mask(self, jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0), [batch_size, 1]) multiplicity_tensor = masks_per_frame * choose_range - multiplicity_weights = (multiplicity_weights < - multiplicity_tensor).astype(jnp.int32) + multiplicity_weights = (multiplicity_weights + < multiplicity_tensor).astype(jnp.int32) pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights) else: pre_mask = jnp.einsum('bmt->bt', pre_mask) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index d805e8b17..64e41989f 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -375,7 +375,7 @@ def sync_batch_stats( # In this case each device has its own version of the batch statistics and # we average them. avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy() + new_model_state = model_state.copy() new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 155b30920..31d069e88 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -260,8 +260,9 @@ def greedy_decode( idxs = torch.arange( fin_result.numel(), device=result.device).view(*fin_result.shape) mask = torch.arange( - fin_result.shape[1], device=result.device).view( - 1, -1) < result.count_nonzero(dim=1).view(-1, 1) + fin_result.shape[1], + device=result.device).view(1, -1) < result.count_nonzero(dim=1).view( + -1, 1) fin_result.view(-1)[idxs[mask != 0]] = result[result != blank_id] padding = fin_result == 0 return fin_result, padding @@ -329,9 +330,7 @@ def _eval_model_on_split(self, 'word_errors': word_errors, 'num_words': num_words, } - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py index f9eb732e9..c2fe540a6 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -139,8 +139,8 @@ def setup(self): self.kernel = self.param('kernel', nn.initializers.xavier_uniform(), self.filter_shape) - self.bias = self.param( - 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) + self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32), + self.output_channels) @nn.compact def __call__(self, inputs, paddings, train): @@ -273,12 +273,10 @@ def setup(self): dtype = self.dtype self.ra_mean = self.variable('batch_stats', - 'mean', - lambda s: jnp.zeros(s, dtype), + 'mean', lambda s: jnp.zeros(s, dtype), dim) self.ra_var = self.variable('batch_stats', - 'var', - lambda s: jnp.ones(s, dtype), + 'var', lambda s: jnp.ones(s, dtype), dim) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index dcc195170..ad950b869 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -46,8 +46,7 @@ def _build_mnist_dataset( ds = ds.map( lambda x: { 'inputs': _normalize(x['image'], train_mean, train_stddev), - 'targets': x['label'], - }) + 'targets': x['label'],}) is_train = split == 'train' if cache: @@ -214,8 +213,6 @@ def _eval_model_on_split(self, batch, model_state, per_device_model_rngs) - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} return self._normalize_eval_metrics(num_examples, total_metrics) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py index 7bbc0b168..97fee032f 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py @@ -222,9 +222,8 @@ def __call__(self, inputs, encoder_mask=None): use_bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, - deterministic=cfg.deterministic)(cfg.attention_temp * x, - x, - mask=encoder_mask) + deterministic=cfg.deterministic)( + cfg.attention_temp * x, x, mask=encoder_mask) if cfg.dropout_rate is None: dropout_rate = 0.1 @@ -288,7 +287,8 @@ def __call__(self, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=cfg.deterministic, - decode=cfg.decode)(cfg.attention_temp * x, x, mask=decoder_mask) + decode=cfg.decode)( + cfg.attention_temp * x, x, mask=decoder_mask) if cfg.dropout_rate is None: dropout_rate = 0.1 else: @@ -309,9 +309,8 @@ def __call__(self, use_bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, - deterministic=cfg.deterministic)(cfg.attention_temp * y, - encoded, - mask=encoder_decoder_mask) + deterministic=cfg.deterministic)( + cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) y = y + x diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py index a1c7ce15e..089f1bfbb 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py @@ -942,8 +942,8 @@ def forward(self, # not the remaining zero elements. if attn_mask is not None: raise ValueError('Attention mask has to be None for decode == True.') - attn_mask = (torch.arange(max_len, device=k.device) >= - cache_index).reshape(1, max_len) + attn_mask = (torch.arange(max_len, device=k.device) + >= cache_index).reshape(1, max_len) # Update sequence length to account for complete sequence. seq_len = k.size(1) diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index 98193f01f..ad4d8e6f5 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -123,8 +123,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -139,8 +140,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index 66fdc4ebb..bde851468 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -123,8 +123,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -139,8 +140,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 4f53afb56..4122be181 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -132,8 +132,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -148,8 +149,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 60a1f784d..6b5faa6b8 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -132,8 +132,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -148,8 +149,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index 98193f01f..ad4d8e6f5 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -123,8 +123,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -139,8 +140,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index 85b3d7441..d33daadb8 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -67,8 +67,9 @@ def update_fn(updates, state, grad_fn_params_tuple): # the noised parameters in the same order as on the original gradients and # with the same 1e-6 epsilon that is used when clipping the gradients. updates = dual_vector(updates) - noised_params = jax.tree_util.tree_map( - lambda p, u: p + rho * u, params, updates) + noised_params = jax.tree_util.tree_map(lambda p, u: p + rho * u, + params, + updates) (_, (n_valid_examples, _)), updates = grad_fn(noised_params) # Get correct global mean grad. (n_valid_examples, updates) = lax.psum((n_valid_examples, updates), @@ -80,8 +81,7 @@ def update_fn(updates, state, grad_fn_params_tuple): sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates))) scaled_updates = jax.tree_map( lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates) - updates = jax.lax.cond(updates_norm > grad_clip, - lambda _: scaled_updates, + updates = jax.lax.cond(updates_norm > grad_clip, lambda _: scaled_updates, lambda _: updates, None) updates, state = base_opt_update_fn(updates, state, params) diff --git a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py index 725529cae..722dab06b 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py @@ -595,8 +595,8 @@ def matrix_inverse_pth_root( if padding_start is not None: # Zero out padding in identity as well for convergence checks. - ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( - matrix.dtype) + ix = (jnp.arange(matrix_size, dtype=jnp.int32) + < padding_start).astype(matrix.dtype) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix @@ -815,8 +815,8 @@ def matrix_inverse_pth_root_eigh( alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE) identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE) if padding_start is not None: - ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( - matrix.dtype) + ix = (jnp.arange(matrix_size, dtype=jnp.int32) + < padding_start).astype(matrix.dtype) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix @@ -1809,17 +1809,13 @@ def sharded_update_fn(grads, state, params): )) new_stats_flat = jax.tree_map( - lambda g, - s, - p: _compute_stats(g, s, p, state.count), + lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat, stats_flat, params_flat) outputs = jax.tree_map( - lambda g, - s, - p: _transform_grad(g, s, p, state.count), + lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat) @@ -1923,8 +1919,8 @@ def _internal_inverse_pth_root_all(): errors = metrics.inverse_pth_root_errors errors = errors.reshape((-1, 1, 1)) predicate = jnp.logical_or( - jnp.isnan(errors), - errors >= inverse_failure_threshold).astype(new_preconditioners.dtype) + jnp.isnan(errors), errors + >= inverse_failure_threshold).astype(new_preconditioners.dtype) # TODO(rohananil): Check for numerical instabilities. new_conditional_preconditioners = ( predicate * global_stats.preconditioners + @@ -2442,9 +2438,7 @@ def update_fn(grads, state, params): stats_grads = treedef.flatten_up_to(grads_custom) new_stats_flat = jax.tree_map( - lambda g, - s, - p: _compute_stats(g, s, p, state.count), + lambda g, s, p: _compute_stats(g, s, p, state.count), stats_grads, stats_flat, params_flat) @@ -2453,9 +2447,7 @@ def update_fn(grads, state, params): params_flat, state.count) outputs = jax.tree_map( - lambda g, - s, - p: _transform_grad(g, s, p, state.count), + lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat) diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index 21f2a7b2b..fc866f80a 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -108,8 +108,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -124,8 +125,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/submission_runner.py b/submission_runner.py index 0024c35d4..a6bea1aa8 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -377,8 +377,8 @@ def train_once( train_state['is_time_remaining'] = ( train_state['accumulated_submission_time'] < max_allowed_runtime_sec) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) >= - workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) + >= workload.eval_period_time_sec or train_state['training_complete']): with profiler.profile('Evaluation'): del batch _reset_cuda_mem() diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 41fc5ee17..8f9154f53 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + for k in key): + value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_attention_temp/compare.py b/tests/modeldiffs/wmt_attention_temp/compare.py index 92ce4eb44..ff7103d43 100644 --- a/tests/modeldiffs/wmt_attention_temp/compare.py +++ b/tests/modeldiffs/wmt_attention_temp/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + for k in key): + value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_glu_tanh/compare.py b/tests/modeldiffs/wmt_glu_tanh/compare.py index b8d860479..d24d818a2 100644 --- a/tests/modeldiffs/wmt_glu_tanh/compare.py +++ b/tests/modeldiffs/wmt_glu_tanh/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + for k in key): + value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_post_ln/compare.py b/tests/modeldiffs/wmt_post_ln/compare.py index 3f5469d8d..7d0556345 100644 --- a/tests/modeldiffs/wmt_post_ln/compare.py +++ b/tests/modeldiffs/wmt_post_ln/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + for k in key): + value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) From c65d93e5b4adfa6e493e6101048738afd8dc15d9 Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 3 Dec 2024 20:37:32 +0530 Subject: [PATCH 19/63] fix: running yapf again with 0.32, earlier using 0.43 --- algorithmic_efficiency/profiler.py | 4 ++-- .../workloads/fastmri/fastmri_pytorch/workload.py | 4 +++- .../imagenet_resnet/imagenet_jax/randaugment.py | 8 +++++--- .../imagenet_resnet/imagenet_pytorch/workload.py | 4 +++- .../librispeech_conformer/librispeech_jax/models.py | 10 ++++++---- .../librispeech_jax/spectrum_augmenter.py | 4 ++-- .../librispeech_pytorch/workload.py | 9 +++++---- .../librispeech_deepspeech/librispeech_jax/models.py | 10 ++++++---- algorithmic_efficiency/workloads/mnist/workload.py | 7 +++++-- .../workloads/wmt/wmt_pytorch/models.py | 4 ++-- setup.cfg | 2 +- submission_runner.py | 4 ++-- tests/modeldiffs/wmt/compare.py | 6 +++--- tests/modeldiffs/wmt_attention_temp/compare.py | 6 +++--- tests/modeldiffs/wmt_glu_tanh/compare.py | 6 +++--- tests/modeldiffs/wmt_post_ln/compare.py | 6 +++--- 16 files changed, 54 insertions(+), 40 deletions(-) diff --git a/algorithmic_efficiency/profiler.py b/algorithmic_efficiency/profiler.py index d73efd964..fa2a1bee2 100644 --- a/algorithmic_efficiency/profiler.py +++ b/algorithmic_efficiency/profiler.py @@ -72,8 +72,8 @@ def _make_report( float(np.std(d)), len(d), float(np.sum(d)), - 100.0 * float(np.sum(d)) / total_duration) - for a, d in self.recorded_durations.items()] + 100.0 * float(np.sum(d)) / total_duration) for a, + d in self.recorded_durations.items()] report.sort(key=lambda x: x[5], reverse=True) total_calls = sum(x[3] for x in report) return report, total_calls, total_duration diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py index a2f0828e3..74f6aa13d 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py @@ -252,7 +252,9 @@ def _eval_model_on_split(self, for _ in range(num_batches): batch = next(self._eval_iters[split]) batch_metrics = self._eval_model(params, batch, model_rng) - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index 94c66033a..af1b763c1 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -313,7 +313,8 @@ def build_lut(histo, step): # If step is zero, return the original image. Otherwise, build # lut from the full histogram and step and then index from it. result = tf.cond( - tf.equal(step, 0), lambda: im, + tf.equal(step, 0), + lambda: im, lambda: tf.gather(build_lut(histo, step), im)) return tf.cast(result, tf.uint8) @@ -548,6 +549,7 @@ def distort_image_with_randaugment(image, num_layers, magnitude, key): translate_const=100) image = tf.cond( tf.equal(i, op_to_select), - lambda selected_func=func, selected_args=args: selected_func( - image, *selected_args), lambda: image) + lambda selected_func=func, + selected_args=args: selected_func(image, *selected_args), + lambda: image) return image diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 0ed944191..3549911fa 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -309,7 +309,9 @@ def _eval_model_on_split(self, update_batch_norm=False) weights = batch.get('weights') batch_metrics = self._compute_metrics(logits, batch['targets'], weights) - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index db8cbc70a..ed05f4335 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -153,8 +153,8 @@ def setup(self): self.kernel = self.param('kernel', nn.initializers.xavier_uniform(), self.filter_shape) - self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32), - self.output_channels) + self.bias = self.param( + 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) @nn.compact def __call__(self, inputs, paddings): @@ -442,10 +442,12 @@ def setup(self): dtype = self.config.dtype self.ra_mean = self.variable('batch_stats', - 'mean', lambda s: jnp.zeros(s, dtype), + 'mean', + lambda s: jnp.zeros(s, dtype), dim) self.ra_var = self.variable('batch_stats', - 'var', lambda s: jnp.ones(s, dtype), + 'var', + lambda s: jnp.ones(s, dtype), dim) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py index c16740629..2a6f73d4d 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py @@ -81,8 +81,8 @@ def _get_mask(self, jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0), [batch_size, 1]) multiplicity_tensor = masks_per_frame * choose_range - multiplicity_weights = (multiplicity_weights - < multiplicity_tensor).astype(jnp.int32) + multiplicity_weights = (multiplicity_weights < + multiplicity_tensor).astype(jnp.int32) pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights) else: pre_mask = jnp.einsum('bmt->bt', pre_mask) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 31d069e88..155b30920 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -260,9 +260,8 @@ def greedy_decode( idxs = torch.arange( fin_result.numel(), device=result.device).view(*fin_result.shape) mask = torch.arange( - fin_result.shape[1], - device=result.device).view(1, -1) < result.count_nonzero(dim=1).view( - -1, 1) + fin_result.shape[1], device=result.device).view( + 1, -1) < result.count_nonzero(dim=1).view(-1, 1) fin_result.view(-1)[idxs[mask != 0]] = result[result != blank_id] padding = fin_result == 0 return fin_result, padding @@ -330,7 +329,9 @@ def _eval_model_on_split(self, 'word_errors': word_errors, 'num_words': num_words, } - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py index c2fe540a6..f9eb732e9 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -139,8 +139,8 @@ def setup(self): self.kernel = self.param('kernel', nn.initializers.xavier_uniform(), self.filter_shape) - self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32), - self.output_channels) + self.bias = self.param( + 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) @nn.compact def __call__(self, inputs, paddings, train): @@ -273,10 +273,12 @@ def setup(self): dtype = self.dtype self.ra_mean = self.variable('batch_stats', - 'mean', lambda s: jnp.zeros(s, dtype), + 'mean', + lambda s: jnp.zeros(s, dtype), dim) self.ra_var = self.variable('batch_stats', - 'var', lambda s: jnp.ones(s, dtype), + 'var', + lambda s: jnp.ones(s, dtype), dim) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index ad950b869..dcc195170 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -46,7 +46,8 @@ def _build_mnist_dataset( ds = ds.map( lambda x: { 'inputs': _normalize(x['image'], train_mean, train_stddev), - 'targets': x['label'],}) + 'targets': x['label'], + }) is_train = split == 'train' if cache: @@ -213,6 +214,8 @@ def _eval_model_on_split(self, batch, model_state, per_device_model_rngs) - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } return self._normalize_eval_metrics(num_examples, total_metrics) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py index 089f1bfbb..a1c7ce15e 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py @@ -942,8 +942,8 @@ def forward(self, # not the remaining zero elements. if attn_mask is not None: raise ValueError('Attention mask has to be None for decode == True.') - attn_mask = (torch.arange(max_len, device=k.device) - >= cache_index).reshape(1, max_len) + attn_mask = (torch.arange(max_len, device=k.device) >= + cache_index).reshape(1, max_len) # Update sequence length to account for complete sequence. seq_len = k.size(1) diff --git a/setup.cfg b/setup.cfg index 23e86a13b..e8044fe02 100644 --- a/setup.cfg +++ b/setup.cfg @@ -80,7 +80,7 @@ dev = isort==5.13.2 pylint==3.3.1 pytest==8.3.3 - yapf==0.43.0 + yapf==0.32.0 pre-commit==4.0.1 # Workloads # diff --git a/submission_runner.py b/submission_runner.py index a6bea1aa8..0024c35d4 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -377,8 +377,8 @@ def train_once( train_state['is_time_remaining'] = ( train_state['accumulated_submission_time'] < max_allowed_runtime_sec) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) - >= workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) >= + workload.eval_period_time_sec or train_state['training_complete']): with profiler.profile('Evaluation'): del batch _reset_cuda_mem() diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 8f9154f53..41fc5ee17 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): - value - for key, value in out.items() + for k in key): value + for key, + value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_attention_temp/compare.py b/tests/modeldiffs/wmt_attention_temp/compare.py index ff7103d43..92ce4eb44 100644 --- a/tests/modeldiffs/wmt_attention_temp/compare.py +++ b/tests/modeldiffs/wmt_attention_temp/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): - value - for key, value in out.items() + for k in key): value + for key, + value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_glu_tanh/compare.py b/tests/modeldiffs/wmt_glu_tanh/compare.py index d24d818a2..b8d860479 100644 --- a/tests/modeldiffs/wmt_glu_tanh/compare.py +++ b/tests/modeldiffs/wmt_glu_tanh/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): - value - for key, value in out.items() + for k in key): value + for key, + value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_post_ln/compare.py b/tests/modeldiffs/wmt_post_ln/compare.py index 7d0556345..3f5469d8d 100644 --- a/tests/modeldiffs/wmt_post_ln/compare.py +++ b/tests/modeldiffs/wmt_post_ln/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): - value - for key, value in out.items() + for k in key): value + for key, + value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) From 3afd1dff5e6bf0780c5ff77e2e7daedba74928cb Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 3 Dec 2024 20:39:03 +0530 Subject: [PATCH 20/63] fix: running yapf again with 0.32, earlier using 0.43 --- .../external_tuning/jax_nadamw_full_budget.py | 10 +++---- .../jax_nadamw_target_setting.py | 10 +++---- .../self_tuning/jax_nadamw_full_budget.py | 10 +++---- .../self_tuning/jax_nadamw_target_setting.py | 10 +++---- .../paper_baselines/nadamw/jax/submission.py | 10 +++---- .../paper_baselines/sam/jax/submission.py | 8 +++--- .../shampoo/jax/distributed_shampoo.py | 28 ++++++++++++------- .../target_setting_algorithms/jax_nadamw.py | 10 +++---- 8 files changed, 46 insertions(+), 50 deletions(-) diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index ad4d8e6f5..98193f01f 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -123,9 +123,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -140,9 +139,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index bde851468..66fdc4ebb 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -123,9 +123,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -140,9 +139,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 4122be181..4f53afb56 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -132,9 +132,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -149,9 +148,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 6b5faa6b8..60a1f784d 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -132,9 +132,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -149,9 +148,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index ad4d8e6f5..98193f01f 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -123,9 +123,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -140,9 +139,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index d33daadb8..85b3d7441 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -67,9 +67,8 @@ def update_fn(updates, state, grad_fn_params_tuple): # the noised parameters in the same order as on the original gradients and # with the same 1e-6 epsilon that is used when clipping the gradients. updates = dual_vector(updates) - noised_params = jax.tree_util.tree_map(lambda p, u: p + rho * u, - params, - updates) + noised_params = jax.tree_util.tree_map( + lambda p, u: p + rho * u, params, updates) (_, (n_valid_examples, _)), updates = grad_fn(noised_params) # Get correct global mean grad. (n_valid_examples, updates) = lax.psum((n_valid_examples, updates), @@ -81,7 +80,8 @@ def update_fn(updates, state, grad_fn_params_tuple): sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates))) scaled_updates = jax.tree_map( lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates) - updates = jax.lax.cond(updates_norm > grad_clip, lambda _: scaled_updates, + updates = jax.lax.cond(updates_norm > grad_clip, + lambda _: scaled_updates, lambda _: updates, None) updates, state = base_opt_update_fn(updates, state, params) diff --git a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py index 722dab06b..725529cae 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py @@ -595,8 +595,8 @@ def matrix_inverse_pth_root( if padding_start is not None: # Zero out padding in identity as well for convergence checks. - ix = (jnp.arange(matrix_size, dtype=jnp.int32) - < padding_start).astype(matrix.dtype) + ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( + matrix.dtype) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix @@ -815,8 +815,8 @@ def matrix_inverse_pth_root_eigh( alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE) identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE) if padding_start is not None: - ix = (jnp.arange(matrix_size, dtype=jnp.int32) - < padding_start).astype(matrix.dtype) + ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( + matrix.dtype) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix @@ -1809,13 +1809,17 @@ def sharded_update_fn(grads, state, params): )) new_stats_flat = jax.tree_map( - lambda g, s, p: _compute_stats(g, s, p, state.count), + lambda g, + s, + p: _compute_stats(g, s, p, state.count), grads_flat, stats_flat, params_flat) outputs = jax.tree_map( - lambda g, s, p: _transform_grad(g, s, p, state.count), + lambda g, + s, + p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat) @@ -1919,8 +1923,8 @@ def _internal_inverse_pth_root_all(): errors = metrics.inverse_pth_root_errors errors = errors.reshape((-1, 1, 1)) predicate = jnp.logical_or( - jnp.isnan(errors), errors - >= inverse_failure_threshold).astype(new_preconditioners.dtype) + jnp.isnan(errors), + errors >= inverse_failure_threshold).astype(new_preconditioners.dtype) # TODO(rohananil): Check for numerical instabilities. new_conditional_preconditioners = ( predicate * global_stats.preconditioners + @@ -2438,7 +2442,9 @@ def update_fn(grads, state, params): stats_grads = treedef.flatten_up_to(grads_custom) new_stats_flat = jax.tree_map( - lambda g, s, p: _compute_stats(g, s, p, state.count), + lambda g, + s, + p: _compute_stats(g, s, p, state.count), stats_grads, stats_flat, params_flat) @@ -2447,7 +2453,9 @@ def update_fn(grads, state, params): params_flat, state.count) outputs = jax.tree_map( - lambda g, s, p: _transform_grad(g, s, p, state.count), + lambda g, + s, + p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat) diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index fc866f80a..21f2a7b2b 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -108,9 +108,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -125,9 +124,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): From 6ff2010d884e9d14911beab6dbce1a546a0a6213 Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 3 Dec 2024 21:21:03 +0530 Subject: [PATCH 21/63] fix: latest versions of typing dont support Text instead str is recommended --- algorithmic_efficiency/halton.py | 14 +++++++------- .../workloads/wmt/wmt_jax/workload.py | 2 +- .../workloads/wmt/wmt_pytorch/workload.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/algorithmic_efficiency/halton.py b/algorithmic_efficiency/halton.py index 9eb30861d..d710e3fce 100644 --- a/algorithmic_efficiency/halton.py +++ b/algorithmic_efficiency/halton.py @@ -10,13 +10,13 @@ import functools import itertools import math -from typing import Any, Callable, Dict, List, Sequence, Text, Tuple, Union +from typing import Any, Callable, Dict, List, Sequence, Tuple, Union from absl import logging from numpy import random -_SweepSequence = List[Dict[Text, Any]] -_GeneratorFn = Callable[[float], Tuple[Text, float]] +_SweepSequence = List[Dict[str, Any]] +_GeneratorFn = Callable[[float], Tuple[str, float]] def generate_primes(n: int) -> List[int]: @@ -195,10 +195,10 @@ def generate_sequence(num_samples: int, return halton_sequence -def _generate_double_point(name: Text, +def _generate_double_point(name: str, min_val: float, max_val: float, - scaling: Text, + scaling: str, halton_point: float) -> Tuple[str, float]: """Generate a float hyperparameter value from a Halton sequence point.""" if scaling not in ['linear', 'log']: @@ -234,7 +234,7 @@ def interval(start: int, end: int) -> Tuple[int, int]: return start, end -def loguniform(name: Text, range_endpoints: Tuple[int, int]) -> _GeneratorFn: +def loguniform(name: str, range_endpoints: Tuple[int, int]) -> _GeneratorFn: min_val, max_val = range_endpoints return functools.partial(_generate_double_point, name, @@ -244,7 +244,7 @@ def loguniform(name: Text, range_endpoints: Tuple[int, int]) -> _GeneratorFn: def uniform( - name: Text, search_points: Union[_DiscretePoints, + name: str, search_points: Union[_DiscretePoints, Tuple[int, int]]) -> _GeneratorFn: if isinstance(search_points, _DiscretePoints): return functools.partial(_generate_discrete_point, diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index 442c85899..72108c9d9 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -16,7 +16,7 @@ from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.wmt import bleu +#from algorithmic_efficiency.workloads.wmt import bleu from algorithmic_efficiency.workloads.wmt.wmt_jax import decode from algorithmic_efficiency.workloads.wmt.wmt_jax import models from algorithmic_efficiency.workloads.wmt.workload import BaseWmtWorkload diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py index 327ca34ad..b554b2ab3 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py @@ -16,7 +16,7 @@ from algorithmic_efficiency import param_utils from algorithmic_efficiency import pytorch_utils from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.wmt import bleu +#from algorithmic_efficiency.workloads.wmt import bleu from algorithmic_efficiency.workloads.wmt.wmt_pytorch import decode from algorithmic_efficiency.workloads.wmt.wmt_pytorch.models import Transformer from algorithmic_efficiency.workloads.wmt.workload import BaseWmtWorkload From 55bacbd493c425fda147bc59aa97341f73b1ef17 Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 3 Dec 2024 21:24:18 +0530 Subject: [PATCH 22/63] fix: minor yapf --- algorithmic_efficiency/halton.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/halton.py b/algorithmic_efficiency/halton.py index d710e3fce..1f36b07bf 100644 --- a/algorithmic_efficiency/halton.py +++ b/algorithmic_efficiency/halton.py @@ -245,7 +245,7 @@ def loguniform(name: str, range_endpoints: Tuple[int, int]) -> _GeneratorFn: def uniform( name: str, search_points: Union[_DiscretePoints, - Tuple[int, int]]) -> _GeneratorFn: + Tuple[int, int]]) -> _GeneratorFn: if isinstance(search_points, _DiscretePoints): return functools.partial(_generate_discrete_point, name, From 5eac985fcefc7fa0f93c2e4f28e0d71ca6db7d3d Mon Sep 17 00:00:00 2001 From: init-22 Date: Sat, 7 Dec 2024 21:07:21 +0530 Subject: [PATCH 23/63] fix: going back to sacrebleu v1.3.1 --- algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py | 5 ++--- algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py | 5 ++--- setup.cfg | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index 72108c9d9..046d5e469 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -12,11 +12,10 @@ import jax.numpy as jnp import numpy as np import optax -import sacrebleu from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec -#from algorithmic_efficiency.workloads.wmt import bleu +from algorithmic_efficiency.workloads.wmt import bleu from algorithmic_efficiency.workloads.wmt.wmt_jax import decode from algorithmic_efficiency.workloads.wmt.wmt_jax import models from algorithmic_efficiency.workloads.wmt.workload import BaseWmtWorkload @@ -204,7 +203,7 @@ def translate_and_calculate_bleu(self, predictions.append(self._decode_tokens(predicted[idx])) # Calculate BLEU score for translated eval corpus against reference. - bleu_score = sacrebleu.corpus_bleu(predictions, [references]).score + bleu_score = bleu.corpus_bleu(predictions, [references]).score return bleu_score def init_model_fn( diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py index b554b2ab3..0ba49c2f6 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py @@ -5,7 +5,6 @@ from absl import logging import jax -import sacrebleu import tensorflow as tf import torch import torch.distributed as dist @@ -16,7 +15,7 @@ from algorithmic_efficiency import param_utils from algorithmic_efficiency import pytorch_utils from algorithmic_efficiency import spec -#from algorithmic_efficiency.workloads.wmt import bleu +from algorithmic_efficiency.workloads.wmt import bleu from algorithmic_efficiency.workloads.wmt.wmt_pytorch import decode from algorithmic_efficiency.workloads.wmt.wmt_pytorch.models import Transformer from algorithmic_efficiency.workloads.wmt.workload import BaseWmtWorkload @@ -163,7 +162,7 @@ def translate_and_calculate_bleu(self, predictions.append(self._decode_tokens(predicted[idx])) # Calculate BLEU score for translated eval corpus against reference. - bleu_score = sacrebleu.corpus_bleu(predictions, [references]).score + bleu_score = bleu.corpus_bleu(predictions, [references]).score return bleu_score def init_model_fn( diff --git a/setup.cfg b/setup.cfg index e8044fe02..a7c224407 100644 --- a/setup.cfg +++ b/setup.cfg @@ -103,7 +103,7 @@ librispeech_conformer = wmt = sentencepiece==0.2.0 tensorflow-text==2.18.0 - sacrebleu==2.4.3 + sacrebleu==1.3.1 # Frameworks # # JAX Core From 786771169b0f9bafe241692ac9411d30fccce62d Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 17 Dec 2024 21:13:16 +0530 Subject: [PATCH 24/63] feat: custom tf_addons support in TF2.18 --- .../imagenet_jax/custom_tf_addons.py | 433 ++++++++++++++++++ .../imagenet_jax/randaugment.py | 16 +- 2 files changed, 441 insertions(+), 8 deletions(-) create mode 100644 algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py new file mode 100644 index 000000000..eda67d226 --- /dev/null +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py @@ -0,0 +1,433 @@ +""" +Note: +The following code is adapted from: +https://github.com/tensorflow/addons/tree/master/tensorflow_addons/image + + +""" + +import math +from typing import Callable, List, Optional, Union + +import numpy as np +import tensorflow as tf + +_IMAGE_DTYPES = { + tf.dtypes.uint8, + tf.dtypes.int32, + tf.dtypes.int64, + tf.dtypes.float16, + tf.dtypes.float32, + tf.dtypes.float64, +} + +Number = Union[float, + int, + np.float16, + np.float32, + np.float64, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64,] + +TensorLike = Union[List[Union[Number, list]], + tuple, + Number, + np.ndarray, + tf.Tensor, + tf.SparseTensor, + tf.Variable,] + + +def get_ndims(image): + return image.get_shape().ndims or tf.rank(image) + + +def to_4D_image(image): + """Convert 2/3/4D image to 4D image. + + Args: + image: 2/3/4D `Tensor`. + + Returns: + 4D `Tensor` with the same type. + """ + with tf.control_dependencies([ + tf.debugging.assert_rank_in( + image, [2, 3, 4], message="`image` must be 2/3/4D tensor") + ]): + ndims = image.get_shape().ndims + if ndims is None: + return _dynamic_to_4D_image(image) + elif ndims == 2: + return image[None, :, :, None] + elif ndims == 3: + return image[None, :, :, :] + else: + return image + + +def _dynamic_to_4D_image(image): + shape = tf.shape(image) + original_rank = tf.rank(image) + # 4D image => [N, H, W, C] or [N, C, H, W] + # 3D image => [1, H, W, C] or [1, C, H, W] + # 2D image => [1, H, W, 1] + left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) + right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) + new_shape = tf.concat( + [ + tf.ones(shape=left_pad, dtype=tf.int32), + shape, + tf.ones(shape=right_pad, dtype=tf.int32), + ], + axis=0, + ) + return tf.reshape(image, new_shape) + + +def from_4D_image(image, ndims): + """Convert back to an image with `ndims` rank. + + Args: + image: 4D `Tensor`. + ndims: The original rank of the image. + + Returns: + `ndims`-D `Tensor` with the same type. + """ + with tf.control_dependencies( + [tf.debugging.assert_rank(image, 4, + message="`image` must be 4D tensor")]): + if isinstance(ndims, tf.Tensor): + return _dynamic_from_4D_image(image, ndims) + elif ndims == 2: + return tf.squeeze(image, [0, 3]) + elif ndims == 3: + return tf.squeeze(image, [0]) + else: + return image + + +def _dynamic_from_4D_image(image, original_rank): + shape = tf.shape(image) + # 4D image <= [N, H, W, C] or [N, C, H, W] + # 3D image <= [1, H, W, C] or [1, C, H, W] + # 2D image <= [1, H, W, 1] + begin = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) + end = 4 - tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) + new_shape = shape[begin:end] + return tf.reshape(image, new_shape) + + +def transform( + images: TensorLike, + transforms: TensorLike, + interpolation: str = "nearest", + fill_mode: str = "constant", + output_shape: Optional[list] = None, + name: Optional[str] = None, + fill_value: TensorLike = 0.0, +) -> tf.Tensor: + """Applies the given transform(s) to the image(s). + + Args: + images: A tensor of shape (num_images, num_rows, num_columns, + num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or + (num_rows, num_columns) (HW). + transforms: Projective transform matrix/matrices. A vector of length 8 or + tensor of size N x 8. If one row of transforms is + [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point + `(x, y)` to a transformed *input* point + `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, + where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to + the transform mapping input points to output points. Note that + gradients are not backpropagated into transformation parameters. + interpolation: Interpolation mode. + Supported values: "nearest", "bilinear". + fill_mode: Points outside the boundaries of the input are filled according + to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode` is "constant". + output_shape: Output dimesion after the transform, [height, width]. + If None, output is the same size as input image. + + name: The name of the op. + + Returns: + Image(s) with the same type and shape as `images`, with the given + transform(s) applied. Transformed coordinates outside of the input image + will be filled with zeros. + + Raises: + TypeError: If `image` is an invalid type. + ValueError: If output shape is not 1-D int32 Tensor. + """ + with tf.name_scope(name or "transform"): + image_or_images = tf.convert_to_tensor(images, name="images") + transform_or_transforms = tf.convert_to_tensor( + transforms, name="transforms", dtype=tf.dtypes.float32) + if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: + raise TypeError("Invalid dtype %s." % image_or_images.dtype) + images = to_4D_image(image_or_images) + original_ndims = get_ndims(image_or_images) + + if output_shape is None: + output_shape = tf.shape(images)[1:3] + + output_shape = tf.convert_to_tensor( + output_shape, tf.dtypes.int32, name="output_shape") + + if not output_shape.get_shape().is_compatible_with([2]): + raise ValueError("output_shape must be a 1-D Tensor of 2 elements: " + "new_height, new_width") + + if len(transform_or_transforms.get_shape()) == 1: + transforms = transform_or_transforms[None] + elif transform_or_transforms.get_shape().ndims is None: + raise ValueError("transforms rank must be statically known") + elif len(transform_or_transforms.get_shape()) == 2: + transforms = transform_or_transforms + else: + transforms = transform_or_transforms + raise ValueError("transforms should have rank 1 or 2, but got rank %d" % + len(transforms.get_shape())) + + fill_value = tf.convert_to_tensor( + fill_value, dtype=tf.float32, name="fill_value") + output = tf.raw_ops.ImageProjectiveTransformV3( + images=images, + transforms=transforms, + output_shape=output_shape, + interpolation=interpolation.upper(), + fill_mode=fill_mode.upper(), + fill_value=fill_value, + ) + return from_4D_image(output, original_ndims) + + +def angles_to_projective_transforms( + angles: TensorLike, + image_height: TensorLike, + image_width: TensorLike, + name: Optional[str] = None, +) -> tf.Tensor: + """Returns projective transform(s) for the given angle(s). + + Args: + angles: A scalar angle to rotate all images by, or (for batches of + images) a vector with an angle to rotate each image in the batch. The + rank must be statically known (the shape is not `TensorShape(None)`. + image_height: Height of the image(s) to be transformed. + image_width: Width of the image(s) to be transformed. + + Returns: + A tensor of shape (num_images, 8). Projective transforms which can be + given to `transform` op. + """ + with tf.name_scope(name or "angles_to_projective_transforms"): + angle_or_angles = tf.convert_to_tensor( + angles, name="angles", dtype=tf.dtypes.float32) + if len(angle_or_angles.get_shape()) == 0: + angles = angle_or_angles[None] + elif len(angle_or_angles.get_shape()) == 1: + angles = angle_or_angles + else: + raise ValueError("angles should have rank 0 or 1.") + cos_angles = tf.math.cos(angles) + sin_angles = tf.math.sin(angles) + x_offset = ((image_width - 1) - + (cos_angles * (image_width - 1) - sin_angles * + (image_height - 1))) / 2.0 + y_offset = ((image_height - 1) - + (sin_angles * (image_width - 1) + cos_angles * + (image_height - 1))) / 2.0 + num_angles = tf.shape(angles)[0] + return tf.concat( + values=[ + cos_angles[:, None], + -sin_angles[:, None], + x_offset[:, None], + sin_angles[:, None], + cos_angles[:, None], + y_offset[:, None], + tf.zeros((num_angles, 2), tf.dtypes.float32), + ], + axis=1, + ) + + +def rotate( + images: TensorLike, + angles: TensorLike, + interpolation: str = "nearest", + fill_mode: str = "constant", + name: Optional[str] = None, + fill_value: TensorLike = 0.0, +) -> tf.Tensor: + """Rotate image(s) counterclockwise by the passed angle(s) in radians. + + Args: + images: A tensor of shape + `(num_images, num_rows, num_columns, num_channels)` + (NHWC), `(num_rows, num_columns, num_channels)` (HWC), or + `(num_rows, num_columns)` (HW). + angles: A scalar angle to rotate all images by, or (if `images` has rank 4) + a vector of length num_images, with an angle for each image in the + batch. + interpolation: Interpolation mode. Supported values: "nearest", + "bilinear". + fill_mode: Points outside the boundaries of the input are filled according + to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode` is "constant". + name: The name of the op. + + Returns: + Image(s) with the same type and shape as `images`, rotated by the given + angle(s). Empty space due to the rotation will be filled with zeros. + + Raises: + TypeError: If `images` is an invalid type. + """ + with tf.name_scope(name or "rotate"): + image_or_images = tf.convert_to_tensor(images) + if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: + raise TypeError("Invalid dtype %s." % image_or_images.dtype) + images = to_4D_image(image_or_images) + original_ndims = get_ndims(image_or_images) + + image_height = tf.cast(tf.shape(images)[1], tf.dtypes.float32)[None] + image_width = tf.cast(tf.shape(images)[2], tf.dtypes.float32)[None] + output = transform( + images, + angles_to_projective_transforms(angles, image_height, image_width), + interpolation=interpolation, + fill_mode=fill_mode, + fill_value=fill_value, + ) + return from_4D_image(output, original_ndims) + + +def translations_to_projective_transforms(translations: TensorLike, + name: Optional[str] = None + ) -> tf.Tensor: + """Returns projective transform(s) for the given translation(s). + + Args: + translations: A 2-element list representing `[dx, dy]` or a matrix of + 2-element lists representing `[dx, dy]` to translate for each image + (for a batch of images). The rank must be statically known + (the shape is not `TensorShape(None)`). + name: The name of the op. + Returns: + A tensor of shape `(num_images, 8)` projective transforms which can be + given to `tfa.image.transform`. + """ + with tf.name_scope(name or "translations_to_projective_transforms"): + translation_or_translations = tf.convert_to_tensor( + translations, name="translations", dtype=tf.dtypes.float32) + if translation_or_translations.get_shape().ndims is None: + raise TypeError( + "translation_or_translations rank must be statically known") + elif len(translation_or_translations.get_shape()) == 1: + translations = translation_or_translations[None] + elif len(translation_or_translations.get_shape()) == 2: + translations = translation_or_translations + else: + raise TypeError("Translations should have rank 1 or 2.") + num_translations = tf.shape(translations)[0] + # The translation matrix looks like: + # [[1 0 -dx] + # [0 1 -dy] + # [0 0 1]] + # where the last entry is implicit. + # Translation matrices are always float32. + return tf.concat( + values=[ + tf.ones((num_translations, 1), tf.dtypes.float32), + tf.zeros((num_translations, 1), tf.dtypes.float32), + -translations[:, 0, None], + tf.zeros((num_translations, 1), tf.dtypes.float32), + tf.ones((num_translations, 1), tf.dtypes.float32), + -translations[:, 1, None], + tf.zeros((num_translations, 2), tf.dtypes.float32), + ], + axis=1, + ) + + +@tf.function +def translate( + images: TensorLike, + translations: TensorLike, + interpolation: str = "nearest", + fill_mode: str = "constant", + name: Optional[str] = None, + fill_value: TensorLike = 0.0, +) -> tf.Tensor: + """Translate image(s) by the passed vectors(s). + + Args: + images: A tensor of shape + `(num_images, num_rows, num_columns, num_channels)` (NHWC), + `(num_rows, num_columns, num_channels)` (HWC), or + `(num_rows, num_columns)` (HW). The rank must be statically known (the + shape is not `TensorShape(None)`). + translations: A vector representing `[dx, dy]` or (if `images` has rank 4) + a matrix of length num_images, with a `[dx, dy]` vector for each image + in the batch. + interpolation: Interpolation mode. Supported values: "nearest", + "bilinear". + fill_mode: Points outside the boundaries of the input are filled according + to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode` is "constant". + name: The name of the op. + Returns: + Image(s) with the same type and shape as `images`, translated by the + given vector(s). Empty space due to the translation will be filled with + zeros. + Raises: + TypeError: If `images` is an invalid type. + """ + with tf.name_scope(name or "translate"): + return transform( + images, + translations_to_projective_transforms(translations), + interpolation=interpolation, + fill_mode=fill_mode, + fill_value=fill_value, + ) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index af1b763c1..f3a946245 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -9,7 +9,9 @@ import tensorflow as tf -#from tensorflow_addons import image as contrib_image +from .custom_tf_addons import rotate +from .custom_tf_addons import transform +from .custom_tf_addons import translate # This signifies the max integer that the controller RNN could predict for the # augmentation scheme. @@ -177,19 +179,19 @@ def rotate(image, degrees, replace): # In practice, we should randomize the rotation degrees by flipping # it negatively half the time, but that's done on 'degrees' outside # of the function. - image = contrib_image.rotate(wrap(image), radians) + image = rotate(wrap(image), radians) return unwrap(image, replace) def translate_x(image, pixels, replace): """Equivalent of PIL Translate in X dimension.""" - image = contrib_image.translate(wrap(image), [-pixels, 0]) + image = translate(wrap(image), [-pixels, 0]) return unwrap(image, replace) def translate_y(image, pixels, replace): """Equivalent of PIL Translate in Y dimension.""" - image = contrib_image.translate(wrap(image), [0, -pixels]) + image = translate(wrap(image), [0, -pixels]) return unwrap(image, replace) @@ -199,8 +201,7 @@ def shear_x(image, level, replace): # with a matrix form of: # [1 level # 0 1]. - image = contrib_image.transform( - wrap(image), [1., level, 0., 0., 1., 0., 0., 0.]) + image = transform(wrap(image), [1., level, 0., 0., 1., 0., 0., 0.]) return unwrap(image, replace) @@ -210,8 +211,7 @@ def shear_y(image, level, replace): # with a matrix form of: # [1 0 # level 1]. - image = contrib_image.transform( - wrap(image), [1., 0., 0., level, 1., 0., 0., 0.]) + image = transform(wrap(image), [1., 0., 0., level, 1., 0., 0., 0.]) return unwrap(image, replace) From d6dd2e8e16145e73f69664bc81690ac06857319b Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 17 Dec 2024 21:50:11 +0530 Subject: [PATCH 25/63] fix: resolving pylint issues in custom_tf_addons --- .../imagenet_jax/custom_tf_addons.py | 27 +++++++++---------- .../imagenet_jax/randaugment.py | 4 +-- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py index eda67d226..79aef6791 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py @@ -6,8 +6,7 @@ """ -import math -from typing import Callable, List, Optional, Union +from typing import List, Optional, Union import numpy as np import tensorflow as tf @@ -48,7 +47,7 @@ def get_ndims(image): return image.get_shape().ndims or tf.rank(image) -def to_4D_image(image): +def to_4d_image(image): """Convert 2/3/4D image to 4D image. Args: @@ -63,7 +62,7 @@ def to_4D_image(image): ]): ndims = image.get_shape().ndims if ndims is None: - return _dynamic_to_4D_image(image) + return _dynamic_to_4d_image(image) elif ndims == 2: return image[None, :, :, None] elif ndims == 3: @@ -72,7 +71,7 @@ def to_4D_image(image): return image -def _dynamic_to_4D_image(image): +def _dynamic_to_4d_image(image): shape = tf.shape(image) original_rank = tf.rank(image) # 4D image => [N, H, W, C] or [N, C, H, W] @@ -91,7 +90,7 @@ def _dynamic_to_4D_image(image): return tf.reshape(image, new_shape) -def from_4D_image(image, ndims): +def from_4d_image(image, ndims): """Convert back to an image with `ndims` rank. Args: @@ -105,7 +104,7 @@ def from_4D_image(image, ndims): [tf.debugging.assert_rank(image, 4, message="`image` must be 4D tensor")]): if isinstance(ndims, tf.Tensor): - return _dynamic_from_4D_image(image, ndims) + return _dynamic_from_4d_image(image, ndims) elif ndims == 2: return tf.squeeze(image, [0, 3]) elif ndims == 3: @@ -114,7 +113,7 @@ def from_4D_image(image, ndims): return image -def _dynamic_from_4D_image(image, original_rank): +def _dynamic_from_4d_image(image, original_rank): shape = tf.shape(image) # 4D image <= [N, H, W, C] or [N, C, H, W] # 3D image <= [1, H, W, C] or [1, C, H, W] @@ -183,7 +182,7 @@ def transform( transforms, name="transforms", dtype=tf.dtypes.float32) if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: raise TypeError("Invalid dtype %s." % image_or_images.dtype) - images = to_4D_image(image_or_images) + images = to_4d_image(image_or_images) original_ndims = get_ndims(image_or_images) if output_shape is None: @@ -217,7 +216,7 @@ def transform( fill_mode=fill_mode.upper(), fill_value=fill_value, ) - return from_4D_image(output, original_ndims) + return from_4d_image(output, original_ndims) def angles_to_projective_transforms( @@ -271,7 +270,7 @@ def angles_to_projective_transforms( ) -def rotate( +def rotate_img( images: TensorLike, angles: TensorLike, interpolation: str = "nearest", @@ -286,7 +285,7 @@ def rotate( `(num_images, num_rows, num_columns, num_channels)` (NHWC), `(num_rows, num_columns, num_channels)` (HWC), or `(num_rows, num_columns)` (HW). - angles: A scalar angle to rotate all images by, or (if `images` has rank 4) + angles: A scalar angle to rotate all images by (if `images` has rank 4) a vector of length num_images, with an angle for each image in the batch. interpolation: Interpolation mode. Supported values: "nearest", @@ -317,7 +316,7 @@ def rotate( image_or_images = tf.convert_to_tensor(images) if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: raise TypeError("Invalid dtype %s." % image_or_images.dtype) - images = to_4D_image(image_or_images) + images = to_4d_image(image_or_images) original_ndims = get_ndims(image_or_images) image_height = tf.cast(tf.shape(images)[1], tf.dtypes.float32)[None] @@ -329,7 +328,7 @@ def rotate( fill_mode=fill_mode, fill_value=fill_value, ) - return from_4D_image(output, original_ndims) + return from_4d_image(output, original_ndims) def translations_to_projective_transforms(translations: TensorLike, diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index f3a946245..dd00146cd 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -9,7 +9,7 @@ import tensorflow as tf -from .custom_tf_addons import rotate +from .custom_tf_addons import rotate_img from .custom_tf_addons import transform from .custom_tf_addons import translate @@ -179,7 +179,7 @@ def rotate(image, degrees, replace): # In practice, we should randomize the rotation degrees by flipping # it negatively half the time, but that's done on 'degrees' outside # of the function. - image = rotate(wrap(image), radians) + image = rotate_img(wrap(image), radians) return unwrap(image, replace) From a0b587aed0ccecb794a46e2ba99713c56ed69f93 Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 17 Dec 2024 22:04:59 +0530 Subject: [PATCH 26/63] resolved pyline and changed the pylint version to current version of main --- .../imagenet_jax/custom_tf_addons.py | 20 ++++++++++++------- setup.cfg | 2 +- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py index 79aef6791..3d6939218 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py @@ -241,12 +241,15 @@ def angles_to_projective_transforms( with tf.name_scope(name or "angles_to_projective_transforms"): angle_or_angles = tf.convert_to_tensor( angles, name="angles", dtype=tf.dtypes.float32) + + if len(angle_or_angles.get_shape()) not in (0, 1): + raise ValueError("angles should have rank 0 or 1.") + if len(angle_or_angles.get_shape()) == 0: angles = angle_or_angles[None] - elif len(angle_or_angles.get_shape()) == 1: - angles = angle_or_angles else: - raise ValueError("angles should have rank 0 or 1.") + angles = angle_or_angles + cos_angles = tf.math.cos(angles) sin_angles = tf.math.sin(angles) x_offset = ((image_width - 1) - @@ -352,12 +355,15 @@ def translations_to_projective_transforms(translations: TensorLike, if translation_or_translations.get_shape().ndims is None: raise TypeError( "translation_or_translations rank must be statically known") - elif len(translation_or_translations.get_shape()) == 1: + + if len(translation_or_translations.get_shape()) not in (1, 2): + raise TypeError("Translations should have rank 1 or 2.") + + if len(translation_or_translations.get_shape()) == 1: translations = translation_or_translations[None] - elif len(translation_or_translations.get_shape()) == 2: - translations = translation_or_translations else: - raise TypeError("Translations should have rank 1 or 2.") + translations = translation_or_translations + num_translations = tf.shape(translations)[0] # The translation matrix looks like: # [[1 0 -dx] diff --git a/setup.cfg b/setup.cfg index a7c224407..7977267bd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -78,7 +78,7 @@ full_dev = # Dependencies for developing the package dev = isort==5.13.2 - pylint==3.3.1 + pylint==2.16.1 pytest==8.3.3 yapf==0.32.0 pre-commit==4.0.1 From 9393145ba91b9432c1732f5bd9d8865c2cb232f8 Mon Sep 17 00:00:00 2001 From: init-22 Date: Wed, 18 Dec 2024 20:58:42 +0530 Subject: [PATCH 27/63] fix: removing tensorflow addons from setup cfg --- setup.cfg | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 7977267bd..2d246b48b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -40,7 +40,6 @@ install_requires = pandas==2.2.3 tensorflow==2.18.0 tensorflow-datasets==4.9.7 - tensorflow-addons==0.23.0 gputil==1.4.0 psutil==6.1.0 clu==0.0.12 From 53eff1d469635408aff5d80a28f3248c4bd79464 Mon Sep 17 00:00:00 2001 From: init-22 Date: Fri, 20 Dec 2024 00:41:47 +0530 Subject: [PATCH 28/63] fix: adding absolute paths for custom_tf_addons in randaugment --- .../imagenet_resnet/imagenet_jax/randaugment.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index dd00146cd..e920331bc 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -9,9 +9,12 @@ import tensorflow as tf -from .custom_tf_addons import rotate_img -from .custom_tf_addons import transform -from .custom_tf_addons import translate +from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ + rotate_img +from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ + transform +from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ + translate # This signifies the max integer that the controller RNN could predict for the # augmentation scheme. From b29397a6b2c53972edbfd73110d76893cbe5d85b Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Sat, 21 Dec 2024 01:11:05 +0000 Subject: [PATCH 29/63] update_docker --- docker/Dockerfile | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 9b72aea86..47277d440 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -11,7 +11,35 @@ FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04 RUN echo "Setting up machine" RUN apt-get update RUN apt-get install -y curl tar -RUN DEBIAN_FRONTEND=noninteractive apt-get install -y git python3 pip wget ffmpeg + +RUN DEBIAN_FRONTEND=noninteractive apt-get install -y git ffmpeg + +# Install prerequisites +RUN apt-get update && apt-get install -y \ + wget \ + build-essential \ + zlib1g-dev \ + libncurses5-dev \ + libssl-dev \ + libreadline-dev \ + libffi-dev \ + curl \ + libbz2-dev \ + liblzma-dev + +# Download and install Python 3.11 +RUN cd /tmp \ + && wget https://www.python.org/ftp/python/3.11.10/Python-3.11.10.tgz \ + && tar -xvzf Python-3.11.10.tgz \ + && cd Python-3.11.10 \ + && ./configure --enable-optimizations \ + && make -j$(nproc) \ + && make altinstall + +# Create symlinks for python and pip (use 'pip' instead of 'pip3') +RUN ln -s /usr/local/bin/python3.11 /usr/bin/python \ + && ln -s /usr/local/bin/pip3.11 /usr/bin/pip + RUN apt-get install libtcmalloc-minimal4 RUN apt-get install unzip RUN apt-get install pigz @@ -29,6 +57,8 @@ RUN mkdir -p data/ RUN mkdir -p experiment_runs/ # Install Algorithmic efficiency repo +RUN pip install --upgrade pip + RUN echo "Setting up algorithmic_efficiency repo" ARG branch="main" ARG framework="both" @@ -58,8 +88,6 @@ RUN if [ "$framework" = "jax" ] ; then \ RUN cd /algorithmic-efficiency && pip install -e '.[full]' -RUN cd /algorithmic-efficiency && pip install -e '.[wandb]' - RUN cd /algorithmic-efficiency && git fetch origin RUN cd /algorithmic-efficiency && git pull From 3d58d5bb97d03bfd4cecdb12fa71dfb1529eb44d Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Sat, 21 Dec 2024 01:13:10 +0000 Subject: [PATCH 30/63] add regression tests for target branch python_test_env_upgrade --- .../regression_tests_python_upgrade.yml | 183 ++++++++++++++++++ 1 file changed, 183 insertions(+) create mode 100644 .github/workflows/regression_tests_python_upgrade.yml diff --git a/.github/workflows/regression_tests_python_upgrade.yml b/.github/workflows/regression_tests_python_upgrade.yml new file mode 100644 index 000000000..abbb5b806 --- /dev/null +++ b/.github/workflows/regression_tests_python_upgrade.yml @@ -0,0 +1,183 @@ +name: Containerized Regression Tests + +on: + pull_request: + branches: + - 'python_test_env_upgrade' + +jobs: + build_and_push_jax_docker_image: + runs-on: self-hosted + steps: + - uses: actions/checkout@v2 + - name: Build and push docker images + run: | + GIT_BRANCH=${{ github.head_ref || github.ref_name }} + FRAMEWORK=jax + IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" + cd $HOME/algorithmic-efficiency/docker + docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH + BUILD_RETURN=$? + if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi + docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + build_and_push_pytorch_docker_image: + runs-on: self-hosted + steps: + - uses: actions/checkout@v2 + - name: Build and push docker images + run: | + GIT_BRANCH=${{ github.head_ref || github.ref_name }} + FRAMEWORK=pytorch + IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" + cd $HOME/algorithmic-efficiency/docker + docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH + BUILD_RETURN=$? + if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi + docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + fastmri_jax: + runs-on: self-hosted + needs: build_and_push_jax_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d fastmri -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w fastmri -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + imagenet_resnet_jax: + runs-on: self-hosted + needs: build_and_push_jax_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d imagenet -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w imagenet_resnet -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + imagenet_vit_jax: + runs-on: self-hosted + needs: build_and_push_jax_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d imagenet -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w imagenet_vit -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + ogbg_jax: + runs-on: self-hosted + needs: build_and_push_jax_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d ogbg -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w ogbg -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + criteo_jax: + runs-on: self-hosted + needs: build_and_push_jax_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w criteo1tb -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + librispeech_conformer_jax: + runs-on: self-hosted + needs: build_and_push_jax_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d librispeech -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w librispeech_conformer -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + librispeech_deepspeech_jax: + runs-on: self-hosted + needs: build_and_push_jax_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d librispeech -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w librispeech_deepspeech -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + wmt_jax: + runs-on: self-hosted + needs: build_and_push_jax_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d wmt -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w wmt -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + fastmri_pytorch: + runs-on: self-hosted + needs: build_and_push_pytorch_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d fastmri -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w fastmri -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + imagenet_resnet_pytorch: + runs-on: self-hosted + needs: build_and_push_pytorch_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d imagenet -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w imagenet_resnet -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + imagenet_vit_pytorch: + runs-on: self-hosted + needs: build_and_push_pytorch_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d imagenet -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w imagenet_vit -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + ogbg_pytorch: + runs-on: self-hosted + needs: build_and_push_pytorch_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d ogbg -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w ogbg -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + criteo_pytorch: + runs-on: self-hosted + needs: build_and_push_pytorch_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w criteo1tb -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + exit $? + librispeech_conformer_pytorch: + runs-on: self-hosted + needs: build_and_push_pytorch_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d librispeech -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w librispeech_conformer -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + librispeech_deepspeech_pytorch: + runs-on: self-hosted + needs: build_and_push_pytorch_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d librispeech -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w librispeech_deepspeech -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + wmt_pytorch: + runs-on: self-hosted + needs: build_and_push_pytorch_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d wmt -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w wmt -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false From 8d966fe8c55e2606353502d835705973bbd376f4 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Sat, 21 Dec 2024 01:19:00 +0000 Subject: [PATCH 31/63] add regression tests for target branch python_test_env_upgrade --- .github/workflows/regression_tests_python_upgrade.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/regression_tests_python_upgrade.yml b/.github/workflows/regression_tests_python_upgrade.yml index abbb5b806..783395353 100644 --- a/.github/workflows/regression_tests_python_upgrade.yml +++ b/.github/workflows/regression_tests_python_upgrade.yml @@ -1,4 +1,4 @@ -name: Containerized Regression Tests +name: Containerized Regression Tests Python Upgrades on: pull_request: From d21d8205d565c94d82b312709491deac0b31de31 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sun, 22 Dec 2024 16:11:59 +0530 Subject: [PATCH 32/63] fix: changes jax.tree_map to jax.tree.map --- algorithmic_efficiency/data_utils.py | 2 +- algorithmic_efficiency/param_utils.py | 2 +- .../workloads/cifar/cifar_jax/workload.py | 2 +- .../imagenet_resnet/imagenet_jax/workload.py | 2 +- .../workloads/mnist/mnist_jax/workload.py | 2 +- .../workloads/ogbg/input_pipeline.py | 4 +-- .../workloads/ogbg/ogbg_pytorch/workload.py | 6 ++-- .../workloads/wmt/wmt_jax/decode.py | 8 ++--- .../workloads/wmt/wmt_jax/workload.py | 4 +-- .../workloads/wmt/wmt_pytorch/decode.py | 8 ++--- .../workloads/wmt/wmt_pytorch/workload.py | 2 +- .../external_tuning/jax_nadamw_full_budget.py | 16 +++++----- .../jax_nadamw_target_setting.py | 16 +++++----- .../self_tuning/jax_nadamw_full_budget.py | 16 +++++----- .../self_tuning/jax_nadamw_target_setting.py | 16 +++++----- .../cifar/cifar_jax/submission.py | 2 +- .../mnist/mnist_jax/submission.py | 2 +- .../adafactor/jax/sharded_adafactor.py | 16 +++++----- .../adafactor/jax/submission.py | 6 ++-- .../paper_baselines/adamw/jax/submission.py | 6 ++-- .../paper_baselines/lamb/jax/submission.py | 6 ++-- .../momentum/jax/submission.py | 6 ++-- .../paper_baselines/nadamw/jax/submission.py | 16 +++++----- .../nesterov/jax/submission.py | 6 ++-- .../paper_baselines/sam/jax/submission.py | 10 +++---- .../shampoo/jax/distributed_shampoo.py | 30 +++++++++---------- .../paper_baselines/shampoo/jax/submission.py | 6 ++-- .../target_setting_algorithms/jax_adamw.py | 2 +- .../target_setting_algorithms/jax_momentum.py | 2 +- .../target_setting_algorithms/jax_nadamw.py | 12 ++++---- .../target_setting_algorithms/jax_nesterov.py | 2 +- .../jax_submission_base.py | 4 +-- tests/modeldiffs/vanilla_sgd_jax.py | 2 +- tests/reference_algorithm_tests.py | 4 +-- .../imagenet_jax/workload_test.py | 2 +- 35 files changed, 124 insertions(+), 124 deletions(-) diff --git a/algorithmic_efficiency/data_utils.py b/algorithmic_efficiency/data_utils.py index 901f0b582..38a76381f 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algorithmic_efficiency/data_utils.py @@ -65,7 +65,7 @@ def _prepare(x): # Assumes that `global_batch_size % local_device_count == 0`. return x.reshape((local_device_count, -1, *x.shape[1:])) - return jax.tree_map(_prepare, batch) + return jax.tree.map(_prepare, batch) def pad(tensor: np.ndarray, diff --git a/algorithmic_efficiency/param_utils.py b/algorithmic_efficiency/param_utils.py index b430366b1..916eb8728 100644 --- a/algorithmic_efficiency/param_utils.py +++ b/algorithmic_efficiency/param_utils.py @@ -66,7 +66,7 @@ def pytorch_param_types( def jax_param_shapes( params: spec.ParameterContainer) -> spec.ParameterShapeTree: - return jax.tree_map(lambda x: spec.ShapeTuple(x.shape), params) + return jax.tree.map(lambda x: spec.ShapeTuple(x.shape), params) def jax_param_types(param_shapes: spec.ParameterShapeTree, diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py index 6bbf9c64b..60f15c2f0 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py @@ -207,4 +207,4 @@ def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - return jax.tree_map(lambda x: float(x[0] / num_examples), total_metrics) + return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py index 91cdec60a..4366fcf25 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -264,7 +264,7 @@ def _eval_model_on_split(self, eval_metrics[metric_name] = 0.0 eval_metrics[metric_name] += metric_value - eval_metrics = jax.tree_map(lambda x: float(x[0] / num_examples), + eval_metrics = jax.tree.map(lambda x: float(x[0] / num_examples), eval_metrics) return eval_metrics diff --git a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py b/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py index efbd73e33..dcb0b6f36 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py +++ b/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py @@ -132,4 +132,4 @@ def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - return jax.tree_map(lambda x: float(x[0] / num_examples), total_metrics) + return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics) diff --git a/algorithmic_efficiency/workloads/ogbg/input_pipeline.py b/algorithmic_efficiency/workloads/ogbg/input_pipeline.py index a301d677a..3cb6f51de 100644 --- a/algorithmic_efficiency/workloads/ogbg/input_pipeline.py +++ b/algorithmic_efficiency/workloads/ogbg/input_pipeline.py @@ -51,7 +51,7 @@ def _load_dataset(split, should_shuffle, data_rng, data_dir): def _to_jraph(example): """Converts an example graph to jraph.GraphsTuple.""" - example = jax.tree_map(lambda x: x._numpy(), example) # pylint: disable=protected-access + example = jax.tree.map(lambda x: x._numpy(), example) # pylint: disable=protected-access edge_feat = example['edge_feat'] node_feat = example['node_feat'] edge_index = example['edge_index'] @@ -150,7 +150,7 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None): if count == num_shards: def f(x): - return jax.tree_map(lambda *vals: np.stack(vals, axis=0), x[0], *x[1:]) + return jax.tree.map(lambda *vals: np.stack(vals, axis=0), x[0], *x[1:]) graphs_shards = f(graphs_shards) labels_shards = f(labels_shards) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index d4817226d..e66a7a151 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -20,8 +20,8 @@ def _pytorch_map(inputs: Any) -> Any: if USE_PYTORCH_DDP: - return jax.tree_map(lambda a: torch.as_tensor(a, device=DEVICE), inputs) - return jax.tree_map( + return jax.tree.map(lambda a: torch.as_tensor(a, device=DEVICE), inputs) + return jax.tree.map( lambda a: torch.as_tensor(a, device=DEVICE).view(-1, a.shape[-1]) if len(a.shape) == 3 else torch.as_tensor(a, device=DEVICE).view(-1), inputs) @@ -30,7 +30,7 @@ def _pytorch_map(inputs: Any) -> Any: def _shard(inputs: Any) -> Any: if not USE_PYTORCH_DDP: return inputs - return jax.tree_map(lambda tensor: tensor[RANK], inputs) + return jax.tree.map(lambda tensor: tensor[RANK], inputs) def _graph_map(function: Callable, graph: GraphsTuple) -> GraphsTuple: diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/decode.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/decode.py index 85d0eaac4..dfead5918 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/decode.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/decode.py @@ -86,7 +86,7 @@ def gather_fn(x): return x return x[batch_indices, beam_indices] - return jax.tree_map(gather_fn, nested) + return jax.tree.map(gather_fn, nested) def gather_topk_beams(nested, score_or_log_prob, batch_size, new_beam_size): @@ -139,7 +139,7 @@ def beam_init(batch_size, beam_size, max_decode_len, cache): finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_) # add beam dimension to attention cache pytree elements - beam_cache0 = jax.tree_map(lambda x: add_beam_dim(x, beam_size), cache) + beam_cache0 = jax.tree.map(lambda x: add_beam_dim(x, beam_size), cache) return BeamState( cur_index=cur_index0, live_logprobs=live_logprobs0, @@ -225,7 +225,7 @@ def beam_search_loop_body_fn(state): (batch_size, beam_size, 1))) # Flatten beam dimension into batch to be compatible with model. # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...} - flat_cache = jax.tree_map(flatten_beam_dim, state.cache) + flat_cache = jax.tree.map(flatten_beam_dim, state.cache) # Call fast-decoder model on current tokens to get next-position logits. # --> [batch * beam, vocab] @@ -236,7 +236,7 @@ def beam_search_loop_body_fn(state): logits = unflatten_beam_dim(flat_logits, batch_size, beam_size) # Unflatten beam dimension in attention cache arrays # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...} - new_cache = jax.tree_map( + new_cache = jax.tree.map( lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache) # Gather log probabilities from logits diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index 046d5e469..dd6728450 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -94,7 +94,7 @@ def eval_step(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: replicated_eval_metrics = self.eval_step_pmapped(params, batch) - return jax.tree_map(lambda x: jnp.sum(x, axis=0), replicated_eval_metrics) + return jax.tree.map(lambda x: jnp.sum(x, axis=0), replicated_eval_metrics) @functools.partial( jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,)) @@ -291,7 +291,7 @@ def _normalize_eval_metrics( """Normalize eval metrics.""" del num_examples eval_denominator = total_metrics.pop('denominator') - return jax.tree_map(lambda x: float(x / eval_denominator), total_metrics) + return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) class WmtWorkloadPostLN(WmtWorkload): diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/decode.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/decode.py index 0488a144f..078560c36 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/decode.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/decode.py @@ -98,7 +98,7 @@ def gather_fn(x): return x return x[batch_indices, beam_indices] - return jax.tree_map(gather_fn, nested) + return jax.tree.map(gather_fn, nested) def gather_topk_beams(nested: Dict[str, Any], @@ -164,7 +164,7 @@ def beam_init(batch_size: int, dtype=torch.bool, device=DEVICE) # add beam dimension to attention cache pytree elements - beam_cache0 = jax.tree_map(lambda x: add_beam_dim(x, beam_size), cache) + beam_cache0 = jax.tree.map(lambda x: add_beam_dim(x, beam_size), cache) return BeamState( cur_index=cur_index0, live_logprobs=live_logprobs0, @@ -251,7 +251,7 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: state.live_seqs[:batch_size, :beam_size, cur_index:cur_index + 1]) # Flatten beam dimension into batch to be compatible with model. # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...} - flat_cache = jax.tree_map(flatten_beam_dim, state.cache) + flat_cache = jax.tree.map(flatten_beam_dim, state.cache) # Call fast-decoder model on current tokens to get next-position logits. # --> [batch * beam, vocab] @@ -262,7 +262,7 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: logits = unflatten_beam_dim(flat_logits, batch_size, beam_size) # Unflatten beam dimension in attention cache arrays # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...} - new_cache = jax.tree_map( + new_cache = jax.tree.map( lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache) # Gather log probabilities from logits diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py index 0ba49c2f6..9c1c21e93 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py @@ -347,7 +347,7 @@ def _normalize_eval_metrics( dist.all_reduce(metric) total_metrics = {k: v.item() for k, v in total_metrics.items()} eval_denominator = total_metrics.pop('denominator') - return jax.tree_map(lambda x: float(x / eval_denominator), total_metrics) + return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) class WmtWorkloadPostLN(WmtWorkload): diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index 36e7e5607..30f9068d1 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -111,8 +111,8 @@ def scale_by_nadam(b1: float = 0.9, raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): @@ -123,7 +123,7 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( + updates = jax.tree.map( lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) @@ -139,14 +139,14 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( + return jax.tree.map( lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): """Perform bias correction. This becomes a no-op as count goes to infinity.""" beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) def scale_by_learning_rate(learning_rate, flip_sign=True): @@ -188,7 +188,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters.beta2, eps=1e-8, weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -236,7 +236,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -244,7 +244,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index 07281f540..71b1c5e1e 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -111,8 +111,8 @@ def scale_by_nadam(b1: float = 0.9, raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): @@ -123,7 +123,7 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( + updates = jax.tree.map( lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) @@ -139,14 +139,14 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( + return jax.tree.map( lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): """Perform bias correction. This becomes a no-op as count goes to infinity.""" beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) def scale_by_learning_rate(learning_rate, flip_sign=True): @@ -188,7 +188,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters.beta2, eps=1e-8, weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -236,7 +236,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -244,7 +244,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 0d194ef7a..127e660d0 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -120,8 +120,8 @@ def scale_by_nadam(b1: float = 0.9, raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): @@ -132,7 +132,7 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( + updates = jax.tree.map( lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) @@ -148,14 +148,14 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( + return jax.tree.map( lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): """Perform bias correction. This becomes a no-op as count goes to infinity.""" beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) def scale_by_learning_rate(learning_rate, flip_sign=True): @@ -200,7 +200,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters['beta2'], eps=1e-8, weight_decay=hyperparameters['weight_decay']) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -248,7 +248,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -256,7 +256,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 60fc25ec4..92c0f599c 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -120,8 +120,8 @@ def scale_by_nadam(b1: float = 0.9, raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): @@ -132,7 +132,7 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( + updates = jax.tree.map( lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) @@ -148,14 +148,14 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( + return jax.tree.map( lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): """Perform bias correction. This becomes a no-op as count goes to infinity.""" beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) def scale_by_learning_rate(learning_rate, flip_sign=True): @@ -200,7 +200,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters['beta2'], eps=1e-8, weight_decay=hyperparameters['weight_decay']) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -248,7 +248,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -256,7 +256,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py index e8e0bf4ac..055de8569 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py @@ -60,7 +60,7 @@ def init_optimizer_state(workload: spec.Workload, del model_params del model_state del rng - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = optimizer(hyperparameters, workload.num_train_examples) diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py index b33c0285b..b7c4dd2f2 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py @@ -26,7 +26,7 @@ def init_optimizer_state(workload: spec.Workload, del model_params del model_state del rng - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = optax.chain( optax.scale_by_adam( diff --git a/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py b/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py index 9f4da9132..ff98464ae 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py @@ -316,11 +316,11 @@ def to_state(self, count, result_tree): """Maps from a tree of (factored) values to separate trees of values.""" return ShardedAdafactorState( count=count, - m=jax.tree_map(lambda o: o.m, result_tree), - m_scale=jax.tree_map(lambda o: o.m_scale, result_tree), - vr=jax.tree_map(lambda o: o.vr, result_tree), - vc=jax.tree_map(lambda o: o.vc, result_tree), - v=jax.tree_map(lambda o: o.v, result_tree)) + m=jax.tree.map(lambda o: o.m, result_tree), + m_scale=jax.tree.map(lambda o: o.m_scale, result_tree), + vr=jax.tree.map(lambda o: o.vr, result_tree), + vc=jax.tree.map(lambda o: o.vc, result_tree), + v=jax.tree.map(lambda o: o.v, result_tree)) def init(self, param): """Initializes the optimizer state for a given param.""" @@ -667,7 +667,7 @@ def init_fn(params): """Initializes the optimizer's state.""" return sharded_adafactor_helper.to_state( jnp.zeros([], jnp.int32), - jax.tree_map(sharded_adafactor_helper.init, params)) + jax.tree.map(sharded_adafactor_helper.init, params)) def update_fn(updates, state, params=None): if params is None: @@ -677,7 +677,7 @@ def update_fn(updates, state, params=None): compute_var_and_slot_update_fn = functools.partial( sharded_adafactor_helper.compute_var_and_slot_update, state.count) - output = jax.tree_map(compute_var_and_slot_update_fn, + output = jax.tree.map(compute_var_and_slot_update_fn, updates, state.m, state.m_scale, @@ -685,7 +685,7 @@ def update_fn(updates, state, params=None): state.vc, state.v, params) - updates = jax.tree_map(lambda o: o.update, output) + updates = jax.tree.map(lambda o: o.update, output) count_plus_one = state.count + jnp.array(1, jnp.int32) updated_states = sharded_adafactor_helper.to_state(count_plus_one, output) return updates, updated_states diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py index 0fcb9da0f..133468aea 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py @@ -46,7 +46,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): learning_rate=lr_schedule_fn, beta1=1.0 - hyperparameters.one_minus_beta1, weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -94,7 +94,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -102,7 +102,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index e80a29693..60a336250 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -46,7 +46,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters.beta2, eps=1e-8, weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -94,7 +94,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -102,7 +102,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py index ebcdc9914..7a3e1289c 100644 --- a/reference_algorithms/paper_baselines/lamb/jax/submission.py +++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py @@ -53,7 +53,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters.beta2, eps=1e-8, weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -102,7 +102,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -110,7 +110,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index 271ef860b..182fbe644 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -28,7 +28,7 @@ def init_optimizer_state(workload: spec.Workload, lr_schedule_fn = create_lr_schedule_fn(workload.step_hint, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = sgd( learning_rate=lr_schedule_fn, @@ -128,7 +128,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -136,7 +136,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index 36e7e5607..30f9068d1 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -111,8 +111,8 @@ def scale_by_nadam(b1: float = 0.9, raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): @@ -123,7 +123,7 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( + updates = jax.tree.map( lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) @@ -139,14 +139,14 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( + return jax.tree.map( lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): """Perform bias correction. This becomes a no-op as count goes to infinity.""" beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) def scale_by_learning_rate(learning_rate, flip_sign=True): @@ -188,7 +188,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters.beta2, eps=1e-8, weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -236,7 +236,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -244,7 +244,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index a435643e4..e45d8a854 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -28,7 +28,7 @@ def init_optimizer_state(workload: spec.Workload, lr_schedule_fn = create_lr_schedule_fn(workload.step_hint, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = sgd( learning_rate=lr_schedule_fn, @@ -128,7 +128,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -136,7 +136,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index 5f45901dd..3f029fbfd 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -24,7 +24,7 @@ def dual_vector(y: jnp.ndarray) -> jnp.ndarray: """ gradient_norm = jnp.sqrt( sum(jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y))) - normalized_gradient = jax.tree_map(lambda x: x / gradient_norm, y) + normalized_gradient = jax.tree.map(lambda x: x / gradient_norm, y) return normalized_gradient @@ -73,12 +73,12 @@ def update_fn(updates, state, grad_fn_params_tuple): # Get correct global mean grad. (n_valid_examples, updates) = lax.psum((n_valid_examples, updates), axis_name=batch_axis_name) - updates = jax.tree_map(lambda x: x / n_valid_examples, updates) + updates = jax.tree.map(lambda x: x / n_valid_examples, updates) if grad_clip: updates_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates))) - scaled_updates = jax.tree_map( + scaled_updates = jax.tree.map( lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates) updates = jax.lax.cond(updates_norm > grad_clip, lambda _: scaled_updates, @@ -136,7 +136,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): base_opt_update_fn=opt_update_fn) # Initialize optimizer state. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -186,7 +186,7 @@ def _loss_fn(params, update_batch_norm=True): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) diff --git a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py index 725529cae..a5c2732ac 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py @@ -342,7 +342,7 @@ def init_training_metrics( """Initialize TrainingMetrics, masked if disabled.""" if not generate_training_metrics: return optax.MaskedNode() - return jax.tree_map( + return jax.tree.map( functools.partial(jnp.repeat, repeats=num_statistics), default_training_metrics()) @@ -356,14 +356,14 @@ def init_training_metrics_shapes( num_statistics, generate_training_metrics, ) - return jax.tree_map(lambda arr: [list(arr.shape), arr.dtype], seed) + return jax.tree.map(lambda arr: [list(arr.shape), arr.dtype], seed) def init_training_metrics_pspec(generate_training_metrics,): """Initialize training metrics partition specification.""" if not generate_training_metrics: return optax.MaskedNode() - return jax.tree_map(lambda _: jax.sharding.PartitionSpec(), + return jax.tree.map(lambda _: jax.sharding.PartitionSpec(), default_training_metrics()) @@ -1253,7 +1253,7 @@ def _add_metrics_into_local_stats(local_stats, metrics, keep_old): index_start = int(local_stat.index_start) index_end = int(len(local_stat.sizes)) + index_start # pylint:disable=cell-var-from-loop Used immediately. - per_stat_metrics = jax.tree_map(lambda x: x[index_start:index_end], metrics) + per_stat_metrics = jax.tree.map(lambda x: x[index_start:index_end], metrics) # We don't want to update the metrics if we didn't do a new inverse p-th # root calculation to find a new preconditioner, so that TensorBoard curves # look consistent (otherwise they'd oscillate between NaN and measured @@ -1808,7 +1808,7 @@ def sharded_update_fn(grads, state, params): local_stat, )) - new_stats_flat = jax.tree_map( + new_stats_flat = jax.tree.map( lambda g, s, p: _compute_stats(g, s, p, state.count), @@ -1816,7 +1816,7 @@ def sharded_update_fn(grads, state, params): stats_flat, params_flat) - outputs = jax.tree_map( + outputs = jax.tree.map( lambda g, s, p: _transform_grad(g, s, p, state.count), @@ -1981,7 +1981,7 @@ def _init(param): )) return ShampooState( - count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params)) + count=jnp.zeros([], jnp.int32), stats=jax.tree.map(_init, params)) def _skip_preconditioning(param): return len(param.shape) < skip_preconditioning_rank_lt or any( @@ -2140,7 +2140,7 @@ def _internal_inverse_pth_root_all(): preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name) metrics = jax.lax.all_gather(metrics, batch_axis_name) preconditioners_flat = unbatch(preconditioners) - metrics_flat = jax.tree_map(unbatch, metrics) + metrics_flat = jax.tree.map(unbatch, metrics) else: preconditioners, metrics = _matrix_inverse_pth_root_vmap( all_statistics[0], @@ -2149,9 +2149,9 @@ def _internal_inverse_pth_root_all(): _maybe_ix(all_preconditioners, 0), ) preconditioners_flat = unbatch(jnp.stack([preconditioners])) - metrics = jax.tree_map( + metrics = jax.tree.map( functools.partial(jnp.expand_dims, axis=0), metrics) - metrics_flat = jax.tree_map(unbatch, metrics) + metrics_flat = jax.tree.map(unbatch, metrics) return preconditioners_flat, metrics_flat @@ -2166,7 +2166,7 @@ def _internal_inverse_pth_root_all(): s[:, :precond_dim(s.shape[0])] for s in packed_statistics ] n = len(packed_statistics) - metrics_init = jax.tree_map( + metrics_init = jax.tree.map( lambda x: [x] * n, default_training_metrics().replace( inverse_pth_root_errors=inverse_failure_threshold)) @@ -2215,12 +2215,12 @@ def _select_preconditioner(error, new_p, old_p): if generate_training_metrics: # pylint:disable=cell-var-from-loop Used immediately. - metrics_for_state = jax.tree_map( + metrics_for_state = jax.tree.map( lambda x: jnp.stack(x[idx:idx + num_statistics]), metrics_flat, is_leaf=lambda x: isinstance(x, list)) assert jax.tree_util.tree_all( - jax.tree_map(lambda x: len(state.statistics) == len(x), + jax.tree.map(lambda x: len(state.statistics) == len(x), metrics_for_state)) # If we skipped preconditioner computation, record old metrics. metrics_for_state = efficient_cond(perform_step, @@ -2441,7 +2441,7 @@ def update_fn(grads, state, params): if custom_preconditioner and grads_custom is not None: stats_grads = treedef.flatten_up_to(grads_custom) - new_stats_flat = jax.tree_map( + new_stats_flat = jax.tree.map( lambda g, s, p: _compute_stats(g, s, p, state.count), @@ -2452,7 +2452,7 @@ def update_fn(grads, state, params): new_stats_flat = _compute_preconditioners(new_stats_flat, params_flat, state.count) - outputs = jax.tree_map( + outputs = jax.tree.map( lambda g, s, p: _transform_grad(g, s, p, state.count), diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py index 294ad2706..4a257d17b 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py @@ -49,7 +49,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): weight_decay=hyperparameters.weight_decay, batch_axis_name='batch', eigh=False) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -97,7 +97,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -105,7 +105,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/reference_algorithms/target_setting_algorithms/jax_adamw.py b/reference_algorithms/target_setting_algorithms/jax_adamw.py index 6d2cfe245..bb85ecf05 100644 --- a/reference_algorithms/target_setting_algorithms/jax_adamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_adamw.py @@ -29,7 +29,7 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) epsilon = ( hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) diff --git a/reference_algorithms/target_setting_algorithms/jax_momentum.py b/reference_algorithms/target_setting_algorithms/jax_momentum.py index 08a0f7e9d..c5fc2a0c6 100644 --- a/reference_algorithms/target_setting_algorithms/jax_momentum.py +++ b/reference_algorithms/target_setting_algorithms/jax_momentum.py @@ -32,7 +32,7 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = sgd( learning_rate=lr_schedule_fn, diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index 21f2a7b2b..1e6b691fc 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -96,8 +96,8 @@ def scale_by_nadam(b1: float = 0.9, raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): @@ -108,7 +108,7 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( + updates = jax.tree.map( lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) @@ -124,14 +124,14 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( + return jax.tree.map( lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): """Perform bias correction. This becomes a no-op as count goes to infinity.""" beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) def scale_by_learning_rate(learning_rate, flip_sign=True): @@ -156,7 +156,7 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) epsilon = ( hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) diff --git a/reference_algorithms/target_setting_algorithms/jax_nesterov.py b/reference_algorithms/target_setting_algorithms/jax_nesterov.py index 6b27e0e2a..e5abde50b 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nesterov.py +++ b/reference_algorithms/target_setting_algorithms/jax_nesterov.py @@ -32,7 +32,7 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = sgd( learning_rate=lr_schedule_fn, diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 7a16c07cb..703310df4 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -53,7 +53,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -61,7 +61,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/tests/modeldiffs/vanilla_sgd_jax.py b/tests/modeldiffs/vanilla_sgd_jax.py index d45694bcb..18dce968a 100644 --- a/tests/modeldiffs/vanilla_sgd_jax.py +++ b/tests/modeldiffs/vanilla_sgd_jax.py @@ -21,7 +21,7 @@ def init_optimizer_state(workload: spec.Workload, del rng # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = optax.sgd(learning_rate=0.001) diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index f107be8d7..6afea8a8e 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -97,9 +97,9 @@ def _make_fake_image_batch(batch_shape, data_shape, num_classes): def _pytorch_map(inputs): if USE_PYTORCH_DDP: - return jax.tree_map( + return jax.tree.map( lambda a: torch.as_tensor(a[RANK], device=PYTORCH_DEVICE), inputs) - return jax.tree_map( + return jax.tree.map( lambda a: torch.as_tensor(a, device=PYTORCH_DEVICE).view(-1, a.shape[-1]) if len(a.shape) == 3 else torch.as_tensor(a, device=PYTORCH_DEVICE).view( -1), diff --git a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py index 6a85c2196..49fd85fef 100644 --- a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py +++ b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py @@ -10,7 +10,7 @@ def _pytree_total_diff(pytree_a, pytree_b): - pytree_diff = jax.tree_map(lambda a, b: jnp.sum(a - b), pytree_a, pytree_b) + pytree_diff = jax.tree.map(lambda a, b: jnp.sum(a - b), pytree_a, pytree_b) pytree_diff = jax.tree_util.tree_leaves(pytree_diff) return jnp.sum(jnp.array(pytree_diff)) From 785d82bff29454a1053cd0bf3e0fdd0354851bd1 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sun, 22 Dec 2024 17:06:16 +0530 Subject: [PATCH 33/63] fix: MultiHeadDotProductAttention and optax ctc_loss changes --- .../workloads/imagenet_vit/imagenet_jax/models.py | 4 ++-- .../librispeech_conformer/librispeech_jax/models.py | 6 +++--- .../librispeech_conformer/librispeech_jax/workload.py | 11 ++++++----- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py index 639800b44..79ad54097 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py @@ -70,7 +70,7 @@ class Encoder1DBlock(nn.Module): def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: if not self.use_post_layer_norm: y = nn.LayerNorm(name='LayerNorm_0')(x) - y = nn.SelfAttention( + y = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform(), deterministic=train, @@ -89,7 +89,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: x = x + y else: y = x - y = nn.SelfAttention( + y = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform(), deterministic=train, diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index cb6287c5e..85a8d1bb7 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -396,10 +396,9 @@ def __call__(self, inputs, paddings, train): mask_paddings > 0, mask_paddings > 0, dtype=jnp.float32) inputs = LayerNorm(dim=config.encoder_dim)(inputs) - attention_fn = functools.partial( dot_product_attention, temperature=config.attention_temperature) - result = nn.SelfAttention( + result = nn.MultiHeadDotProductAttention( num_heads=config.num_attention_heads, qkv_features=config.encoder_dim, decode=False, @@ -410,7 +409,8 @@ def __call__(self, inputs, paddings, train): broadcast_dropout=False, attention_fn=attention_fn, dropout_rate=config.attention_dropout_rate, - deterministic=not train)(inputs, attention_mask) + deterministic=not train)( + inputs_q=inputs, mask=attention_mask) if config.attention_residual_dropout_rate is None: attention_residual_dropout_rate = 0.1 diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index 05faf1135..f546ef785 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -227,11 +227,12 @@ def ctc_loss(self, labels: spec.Tensor, label_paddings: spec.Tensor, blank_id: int = 0) -> spec.Tensor: - return optax.ctc_loss(logits, - logit_paddings, - labels, - label_paddings, - blank_id) + return optax.ctc_loss( + logits=logits, + logit_paddings=logit_paddings, + labels=labels, + label_paddings=label_paddings, + blank_id=blank_id) # Adapted from lingvo's greedy decoding logic here: # https://github.com/tensorflow/lingvo/blob/2ee26814c57b7dcead3f0382170f2f3da006f810/lingvo/jax/layers/ctc_objectives.py#L138. From d4aa90a8e8de930deb7981a931f6ff672ca1c9e1 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sun, 22 Dec 2024 19:18:10 +0530 Subject: [PATCH 34/63] fix: removed the sacrebleu dependency --- algorithmic_efficiency/workloads/wmt/bleu.py | 366 ++++++++++++++++++- setup.cfg | 2 +- 2 files changed, 355 insertions(+), 13 deletions(-) diff --git a/algorithmic_efficiency/workloads/wmt/bleu.py b/algorithmic_efficiency/workloads/wmt/bleu.py index 1efc87381..dda6d102a 100644 --- a/algorithmic_efficiency/workloads/wmt/bleu.py +++ b/algorithmic_efficiency/workloads/wmt/bleu.py @@ -1,8 +1,20 @@ +""" +Removing the dependency on sacrebleu, we reimplement the BLEU score computation in this file. +Reference: +https://github.com/mjpost/sacrebleu/blob/v1.3.1/sacrebleu.py. +""" + +from collections import Counter +from collections import namedtuple from itertools import zip_longest -from typing import Sequence +import logging +import math +import re +import sys +from typing import List, Sequence +import unicodedata from absl import logging -import sacrebleu import torch import torch.distributed as dist @@ -10,10 +22,340 @@ USE_PYTORCH_DDP, _, DEVICE, N_GPUS = pytorch_setup() +NGRAM_ORDER = 4 +# The default floor value to use with `--smooth floor` +SMOOTH_VALUE_DEFAULT = 0.0 + + +def my_log(num): + """ + Floors the log function + + :param num: the number + :return: log(num) floored to a very low number + """ + + if num == 0.0: + return -9999999999 + return math.log(num) + + +def tokenize_13a(line): + """ + Tokenizes an input line using a relatively minimal tokenization that is however equivalent to mteval-v13a, used by WMT. + + :param line: a segment to tokenize + :return: the tokenized line + """ + + norm = line + + # language-independent part: + norm = norm.replace('', '') + norm = norm.replace('-\n', '') + norm = norm.replace('\n', ' ') + norm = norm.replace('"', '"') + norm = norm.replace('&', '&') + norm = norm.replace('<', '<') + norm = norm.replace('>', '>') + + # language-dependent part (assuming Western languages): + norm = " {} ".format(norm) + norm = re.sub(r'([\{-\~\[-\` -\&\(-\+\:-\@\/])', ' \\1 ', norm) + norm = re.sub(r'([^0-9])([\.,])', '\\1 \\2 ', + norm) # tokenize period and comma unless preceded by a digit + norm = re.sub(r'([\.,])([^0-9])', ' \\1 \\2', + norm) # tokenize period and comma unless followed by a digit + norm = re.sub(r'([0-9])(-)', '\\1 \\2 ', + norm) # tokenize dash when preceded by a digit + norm = re.sub(r'\s+', ' ', norm) # one space only between words + norm = re.sub(r'^\s+', '', norm) # no leading space + norm = re.sub(r'\s+$', '', norm) # no trailing space + + return norm + + +class UnicodeRegex: + """Ad-hoc hack to recognize all punctuation and symbols. + + without depending on https://pypi.python.org/pypi/regex/.""" + + def _property_chars(prefix): + return ''.join( + chr(x) + for x in range(sys.maxunicode) + if unicodedata.category(chr(x)).startswith(prefix)) + + punctuation = _property_chars('P') + nondigit_punct_re = re.compile(r'([^\d])([' + punctuation + r'])') + punct_nondigit_re = re.compile(r'([' + punctuation + r'])([^\d])') + symbol_re = re.compile('([' + _property_chars('S') + '])') + + +def tokenize_v14_international(string): + r"""Tokenize a string following the official BLEU implementation. + + See https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983 + In our case, the input string is expected to be just one line + and no HTML entities de-escaping is needed. + So we just tokenize on punctuation and symbols, + except when a punctuation is preceded and followed by a digit + (e.g. a comma/dot as a thousand/decimal separator). + + Note that a number (e.g., a year) followed by a dot at the end of sentence is NOT tokenized, + i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g` + does not match this case (unless we add a space after each sentence). + However, this error is already in the original mteval-v14.pl + and we want to be consistent with it. + The error is not present in the non-international version, + which uses `$norm_text = " $norm_text "` (or `norm = " {} ".format(norm)` in Python). + + :param string: the input string + :return: a list of tokens + """ + string = UnicodeRegex.nondigit_punct_re.sub(r'\1 \2 ', string) + string = UnicodeRegex.punct_nondigit_re.sub(r' \1 \2', string) + string = UnicodeRegex.symbol_re.sub(r' \1 ', string) + return string.strip() + + +def tokenize_zh(sentence): + """MIT License + Copyright (c) 2017 - Shujian Huang + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + The tokenization of Chinese text in this script contains two steps: separate each Chinese + characters (by utf-8 encoding); tokenize the non Chinese part (following the mteval script). + Author: Shujian Huang huangsj@nju.edu.cn + + :param sentence: input sentence + :return: tokenized sentence + """ + + def is_chinese_char(uchar): + """ + :param uchar: input char in unicode + :return: whether the input char is a Chinese character. + """ + if uchar >= u'\u3400' and uchar <= u'\u4db5': # CJK Unified Ideographs Extension A, release 3.0 + return True + elif uchar >= u'\u4e00' and uchar <= u'\u9fa5': # CJK Unified Ideographs, release 1.1 + return True + elif uchar >= u'\u9fa6' and uchar <= u'\u9fbb': # CJK Unified Ideographs, release 4.1 + return True + elif uchar >= u'\uf900' and uchar <= u'\ufa2d': # CJK Compatibility Ideographs, release 1.1 + return True + elif uchar >= u'\ufa30' and uchar <= u'\ufa6a': # CJK Compatibility Ideographs, release 3.2 + return True + elif uchar >= u'\ufa70' and uchar <= u'\ufad9': # CJK Compatibility Ideographs, release 4.1 + return True + elif uchar >= u'\u20000' and uchar <= u'\u2a6d6': # CJK Unified Ideographs Extension B, release 3.1 + return True + elif uchar >= u'\u2f800' and uchar <= u'\u2fa1d': # CJK Compatibility Supplement, release 3.1 + return True + elif uchar >= u'\uff00' and uchar <= u'\uffef': # Full width ASCII, full width of English punctuation, half width Katakana, half wide half width kana, Korean alphabet + return True + elif uchar >= u'\u2e80' and uchar <= u'\u2eff': # CJK Radicals Supplement + return True + elif uchar >= u'\u3000' and uchar <= u'\u303f': # CJK punctuation mark + return True + elif uchar >= u'\u31c0' and uchar <= u'\u31ef': # CJK stroke + return True + elif uchar >= u'\u2f00' and uchar <= u'\u2fdf': # Kangxi Radicals + return True + elif uchar >= u'\u2ff0' and uchar <= u'\u2fff': # Chinese character structure + return True + elif uchar >= u'\u3100' and uchar <= u'\u312f': # Phonetic symbols + return True + elif uchar >= u'\u31a0' and uchar <= u'\u31bf': # Phonetic symbols (Taiwanese and Hakka expansion) + return True + elif uchar >= u'\ufe10' and uchar <= u'\ufe1f': + return True + elif uchar >= u'\ufe30' and uchar <= u'\ufe4f': + return True + elif uchar >= u'\u2600' and uchar <= u'\u26ff': + return True + elif uchar >= u'\u2700' and uchar <= u'\u27bf': + return True + elif uchar >= u'\u3200' and uchar <= u'\u32ff': + return True + elif uchar >= u'\u3300' and uchar <= u'\u33ff': + return True + + return False + + sentence = sentence.strip() + sentence_in_chars = "" + for char in sentence: + if is_chinese_char(char): + sentence_in_chars += " " + sentence_in_chars += char + sentence_in_chars += " " + else: + sentence_in_chars += char + sentence = sentence_in_chars + + # TODO: the code above could probably be replaced with the following line: + # import regex + # sentence = regex.sub(r'(\p{Han})', r' \1 ', sentence) + + # tokenize punctuation + sentence = re.sub(r'([\{-\~\[-\` -\&\(-\+\:-\@\/])', r' \1 ', sentence) + + # tokenize period and comma unless preceded by a digit + sentence = re.sub(r'([^0-9])([\.,])', r'\1 \2 ', sentence) + + # tokenize period and comma unless followed by a digit + sentence = re.sub(r'([\.,])([^0-9])', r' \1 \2', sentence) + + # tokenize dash when preceded by a digit + sentence = re.sub(r'([0-9])(-)', r'\1 \2 ', sentence) + + # one space only between words + sentence = re.sub(r'\s+', r' ', sentence) + + # no leading or trailing spaces + sentence = sentence.strip() + + return sentence + + +TOKENIZERS = { + '13a': tokenize_13a, + 'intl': tokenize_v14_international, + 'zh': tokenize_zh, + 'none': lambda x: x, +} +DEFAULT_TOKENIZER = '13a' + + +def extract_ngrams(line, min_order=1, max_order=NGRAM_ORDER) -> Counter: + """Extracts all the ngrams (1 <= n <= NGRAM_ORDER) from a sequence of tokens. + + :param line: a segment containing a sequence of words + :param max_order: collect n-grams from 1<=n<=max + :return: a dictionary containing ngrams and counts + """ + + ngrams = Counter() + tokens = line.split() + for n in range(min_order, max_order + 1): + for i in range(0, len(tokens) - n + 1): + ngram = ' '.join(tokens[i:i + n]) + ngrams[ngram] += 1 + + return ngrams + + +def ref_stats(output, refs): + ngrams = Counter() + closest_diff = None + closest_len = None + for ref in refs: + tokens = ref.split() + reflen = len(tokens) + diff = abs(len(output.split()) - reflen) + if closest_diff is None or diff < closest_diff: + closest_diff = diff + closest_len = reflen + elif diff == closest_diff: + if reflen < closest_len: + closest_len = reflen + + ngrams_ref = extract_ngrams(ref) + for ngram in ngrams_ref.keys(): + ngrams[ngram] = max(ngrams[ngram], ngrams_ref[ngram]) + + return ngrams, closest_diff, closest_len + + +BLEU = namedtuple('BLEU', + 'score, counts, totals, precisions, bp, sys_len, ref_len') + + +def compute_bleu(correct: List[int], + total: List[int], + sys_len: int, + ref_len: int, + smooth_method='none', + smooth_value=SMOOTH_VALUE_DEFAULT, + use_effective_order=False) -> BLEU: + """Computes BLEU score from its sufficient statistics. Adds smoothing. + + Smoothing methods (citing "A Systematic Comparison of Smoothing Techniques for Sentence-Level BLEU", + Boxing Chen and Colin Cherry, WMT 2014: http://aclweb.org/anthology/W14-3346) + + - exp: NIST smoothing method (Method 3) + - floor: Method 1 + - add-k: Method 2 (generalizing Lin and Och, 2004) + - none: do nothing. + + :param correct: List of counts of correct ngrams, 1 <= n <= NGRAM_ORDER + :param total: List of counts of total ngrams, 1 <= n <= NGRAM_ORDER + :param sys_len: The cumulative system length + :param ref_len: The cumulative reference length + :param smooth: The smoothing method to use + :param smooth_value: The smoothing value added, if smooth method 'floor' is used + :param use_effective_order: Use effective order. + :return: A BLEU object with the score (100-based) and other statistics. + """ + + precisions = [0 for x in range(NGRAM_ORDER)] + + smooth_mteval = 1. + effective_order = NGRAM_ORDER + for n in range(NGRAM_ORDER): + if smooth_method == 'add-k' and n > 1: + correct[n] += smooth_value + total[n] += smooth_value + if total[n] == 0: + break + + if use_effective_order: + effective_order = n + 1 + + if correct[n] == 0: + if smooth_method == 'exp': + smooth_mteval *= 2 + precisions[n] = 100. / (smooth_mteval * total[n]) + elif smooth_method == 'floor': + precisions[n] = 100. * smooth_value / total[n] + else: + precisions[n] = 100. * correct[n] / total[n] + + # If the system guesses no i-grams, 1 <= i <= NGRAM_ORDER, the BLEU score is 0 (technically undefined). + # This is a problem for sentence-level BLEU or a corpus of short sentences, where systems will get no credit + # if sentence lengths fall under the NGRAM_ORDER threshold. This fix scales NGRAM_ORDER to the observed + # maximum order. It is only available through the API and off by default + + brevity_penalty = 1.0 + if sys_len < ref_len: + brevity_penalty = math.exp(1 - ref_len / sys_len) if sys_len > 0 else 0.0 + + bleu = brevity_penalty * math.exp( + sum(map(my_log, precisions[:effective_order])) / effective_order) + + return BLEU._make( + [bleu, correct, total, precisions, brevity_penalty, sys_len, ref_len]) + -# Modified (added sync for PyTorch DDP) from -# https://github.com/mjpost/sacrebleu/blob/v1.3.1/sacrebleu.py. -# Assumes that sacrebleu==1.3.1 is installed. def corpus_bleu(sys_stream: Sequence[str], ref_streams: Sequence[str], smooth_method: str = 'exp', @@ -21,7 +363,7 @@ def corpus_bleu(sys_stream: Sequence[str], force: bool = False, lowercase: bool = False, tokenize: str = '13a', - use_effective_order: bool = False) -> sacrebleu.BLEU: + use_effective_order: bool = False) -> BLEU: """Produces BLEU scores along with its sufficient statistics from a source against one or more references. :param sys_stream: The system stream (a sequence of segments). @@ -44,8 +386,8 @@ def corpus_bleu(sys_stream: Sequence[str], sys_len = 0 ref_len = 0 - correct = [0 for _ in range(sacrebleu.NGRAM_ORDER)] - total = [0 for _ in range(sacrebleu.NGRAM_ORDER)] + correct = [0 for _ in range(NGRAM_ORDER)] + total = [0 for _ in range(NGRAM_ORDER)] # Look for already-tokenized sentences. tokenized_count = 0 @@ -70,14 +412,14 @@ def corpus_bleu(sys_stream: Sequence[str], 'or don\'t care, you can suppress this message with ' '\'--force\'.') - output, *refs = [sacrebleu.TOKENIZERS[tokenize](x.rstrip()) for x in lines] + output, *refs = [TOKENIZERS[tokenize](x.rstrip()) for x in lines] - ref_ngrams, _, closest_len = sacrebleu.ref_stats(output, refs) + ref_ngrams, _, closest_len = ref_stats(output, refs) sys_len += len(output.split()) ref_len += closest_len - sys_ngrams = sacrebleu.extract_ngrams(output) + sys_ngrams = extract_ngrams(output) for ngram, sys_ngram in sys_ngrams.items(): n = len(ngram.split()) correct[n - 1] += min(sys_ngram, ref_ngrams.get(ngram, 0)) @@ -100,7 +442,7 @@ def corpus_bleu(sys_stream: Sequence[str], dist.all_reduce(total) total = total.cpu().numpy().tolist() - return sacrebleu.compute_bleu( + return compute_bleu( correct, total, sys_len, diff --git a/setup.cfg b/setup.cfg index 2d246b48b..8e37acb7a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -102,7 +102,7 @@ librispeech_conformer = wmt = sentencepiece==0.2.0 tensorflow-text==2.18.0 - sacrebleu==1.3.1 + # Frameworks # # JAX Core From 5e348e4234b061f1819bddcd8d6a3b70ef9804b2 Mon Sep 17 00:00:00 2001 From: init-22 Date: Mon, 23 Dec 2024 00:31:48 +0530 Subject: [PATCH 35/63] fix: resolving pylint errors --- algorithmic_efficiency/workloads/wmt/bleu.py | 132 ++++++++++--------- 1 file changed, 71 insertions(+), 61 deletions(-) diff --git a/algorithmic_efficiency/workloads/wmt/bleu.py b/algorithmic_efficiency/workloads/wmt/bleu.py index dda6d102a..22f6a57e0 100644 --- a/algorithmic_efficiency/workloads/wmt/bleu.py +++ b/algorithmic_efficiency/workloads/wmt/bleu.py @@ -1,5 +1,6 @@ """ -Removing the dependency on sacrebleu, we reimplement the BLEU score computation in this file. +Removing the dependency on sacrebleu, we reimplement the BLEU score computation +in this file. Reference: https://github.com/mjpost/sacrebleu/blob/v1.3.1/sacrebleu.py. """ @@ -42,7 +43,8 @@ def my_log(num): def tokenize_13a(line): """ - Tokenizes an input line using a relatively minimal tokenization that is however equivalent to mteval-v13a, used by WMT. + Tokenizes an input line using a relatively minimal tokenization that is + however equivalent to mteval-v13a, used by WMT. :param line: a segment to tokenize :return: the tokenized line @@ -80,6 +82,7 @@ class UnicodeRegex: without depending on https://pypi.python.org/pypi/regex/.""" + @staticmethod def _property_chars(prefix): return ''.join( chr(x) @@ -95,20 +98,23 @@ def _property_chars(prefix): def tokenize_v14_international(string): r"""Tokenize a string following the official BLEU implementation. - See https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983 + See + https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983 In our case, the input string is expected to be just one line and no HTML entities de-escaping is needed. So we just tokenize on punctuation and symbols, except when a punctuation is preceded and followed by a digit (e.g. a comma/dot as a thousand/decimal separator). - Note that a number (e.g., a year) followed by a dot at the end of sentence is NOT tokenized, + Note that a number (e.g., a year) followed by a dot at the end of sentence + is NOT tokenized, i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g` does not match this case (unless we add a space after each sentence). However, this error is already in the original mteval-v14.pl and we want to be consistent with it. The error is not present in the non-international version, - which uses `$norm_text = " $norm_text "` (or `norm = " {} ".format(norm)` in Python). + which uses, + `$norm_text = " $norm_text "` (or `norm = " {} ".format(norm)` in Python). :param string: the input string :return: a list of tokens @@ -123,26 +129,28 @@ def tokenize_zh(sentence): """MIT License Copyright (c) 2017 - Shujian Huang - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE. - - The tokenization of Chinese text in this script contains two steps: separate each Chinese - characters (by utf-8 encoding); tokenize the non Chinese part (following the mteval script). + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files + (the "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to the + following conditions: + + The above copyright notice and this permission notice shall be included + in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, + DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE + USE OR OTHER DEALINGS IN THE SOFTWARE. + + The tokenization of Chinese text in this script contains two steps: + separate each Chinese characters (by utf-8 encoding); + tokenize the non Chinese part (following the mteval script). Author: Shujian Huang huangsj@nju.edu.cn :param sentence: input sentence @@ -151,54 +159,53 @@ def tokenize_zh(sentence): def is_chinese_char(uchar): """ - :param uchar: input char in unicode - :return: whether the input char is a Chinese character. - """ - if uchar >= u'\u3400' and uchar <= u'\u4db5': # CJK Unified Ideographs Extension A, release 3.0 + :param uchar: input char in unicode + :return: whether the input char is a Chinese character. + """ + if "\u3400" <= uchar <= "\u4db5": return True - elif uchar >= u'\u4e00' and uchar <= u'\u9fa5': # CJK Unified Ideographs, release 1.1 + elif "\u4e00" <= uchar <= "\u9fa5": return True - elif uchar >= u'\u9fa6' and uchar <= u'\u9fbb': # CJK Unified Ideographs, release 4.1 + elif "\u9fa6" <= uchar <= "\u9fbb": return True - elif uchar >= u'\uf900' and uchar <= u'\ufa2d': # CJK Compatibility Ideographs, release 1.1 + elif "\uf900" <= uchar <= "\ufa2d": return True - elif uchar >= u'\ufa30' and uchar <= u'\ufa6a': # CJK Compatibility Ideographs, release 3.2 + elif "\ufa30" <= uchar <= "\ufa6a": return True - elif uchar >= u'\ufa70' and uchar <= u'\ufad9': # CJK Compatibility Ideographs, release 4.1 + elif "\ufa70" <= uchar <= "\ufad9": return True - elif uchar >= u'\u20000' and uchar <= u'\u2a6d6': # CJK Unified Ideographs Extension B, release 3.1 + elif "\u20000" <= uchar <= "\u2a6d6": return True - elif uchar >= u'\u2f800' and uchar <= u'\u2fa1d': # CJK Compatibility Supplement, release 3.1 + elif "\u2f800" <= uchar <= "\u2fa1d": return True - elif uchar >= u'\uff00' and uchar <= u'\uffef': # Full width ASCII, full width of English punctuation, half width Katakana, half wide half width kana, Korean alphabet + elif "\uff00" <= uchar <= "\uffef": return True - elif uchar >= u'\u2e80' and uchar <= u'\u2eff': # CJK Radicals Supplement + elif "\u2e80" <= uchar <= "\u2eff": return True - elif uchar >= u'\u3000' and uchar <= u'\u303f': # CJK punctuation mark + elif "\u3000" <= uchar <= "\u303f": return True - elif uchar >= u'\u31c0' and uchar <= u'\u31ef': # CJK stroke + elif "\u31c0" <= uchar <= "\u31ef": return True - elif uchar >= u'\u2f00' and uchar <= u'\u2fdf': # Kangxi Radicals + elif "\u2f00" <= uchar <= "\u2fdf": return True - elif uchar >= u'\u2ff0' and uchar <= u'\u2fff': # Chinese character structure + elif "\u2ff0" <= uchar <= "\u2fff": return True - elif uchar >= u'\u3100' and uchar <= u'\u312f': # Phonetic symbols + elif "\u3100" <= uchar <= "\u312f": return True - elif uchar >= u'\u31a0' and uchar <= u'\u31bf': # Phonetic symbols (Taiwanese and Hakka expansion) + elif "\u31a0" <= uchar <= "\u31bf": return True - elif uchar >= u'\ufe10' and uchar <= u'\ufe1f': + elif "\ufe10" <= uchar <= "\ufe1f": return True - elif uchar >= u'\ufe30' and uchar <= u'\ufe4f': + elif "\ufe30" <= uchar <= "\ufe4f": return True - elif uchar >= u'\u2600' and uchar <= u'\u26ff': + elif "\u2600" <= uchar <= "\u26ff": return True - elif uchar >= u'\u2700' and uchar <= u'\u27bf': + elif "\u2700" <= uchar <= "\u27bf": return True - elif uchar >= u'\u3200' and uchar <= u'\u32ff': + elif "\u3200" <= uchar <= "\u32ff": return True - elif uchar >= u'\u3300' and uchar <= u'\u33ff': + elif "\u3300" <= uchar <= "\u33ff": return True - return False sentence = sentence.strip() @@ -280,13 +287,13 @@ def ref_stats(output, refs): closest_len = reflen ngrams_ref = extract_ngrams(ref) - for ngram in ngrams_ref.keys(): + for ngram in ngrams_ref: ngrams[ngram] = max(ngrams[ngram], ngrams_ref[ngram]) return ngrams, closest_diff, closest_len -BLEU = namedtuple('BLEU', +BLEU = namedtuple('BLE', 'score, counts, totals, precisions, bp, sys_len, ref_len') @@ -299,8 +306,9 @@ def compute_bleu(correct: List[int], use_effective_order=False) -> BLEU: """Computes BLEU score from its sufficient statistics. Adds smoothing. - Smoothing methods (citing "A Systematic Comparison of Smoothing Techniques for Sentence-Level BLEU", - Boxing Chen and Colin Cherry, WMT 2014: http://aclweb.org/anthology/W14-3346) + Smoothing methods (citing "A Systematic Comparison of Smoothing Techniques + for Sentence-Level BLEU", Boxing Chen and Colin Cherry, + WMT 2014: http://aclweb.org/anthology/W14-3346) - exp: NIST smoothing method (Method 3) - floor: Method 1 @@ -312,7 +320,7 @@ def compute_bleu(correct: List[int], :param sys_len: The cumulative system length :param ref_len: The cumulative reference length :param smooth: The smoothing method to use - :param smooth_value: The smoothing value added, if smooth method 'floor' is used + :param smooth_value: The smoothing value added, if smooth is 'floor' :param use_effective_order: Use effective order. :return: A BLEU object with the score (100-based) and other statistics. """ @@ -340,10 +348,12 @@ def compute_bleu(correct: List[int], else: precisions[n] = 100. * correct[n] / total[n] - # If the system guesses no i-grams, 1 <= i <= NGRAM_ORDER, the BLEU score is 0 (technically undefined). - # This is a problem for sentence-level BLEU or a corpus of short sentences, where systems will get no credit - # if sentence lengths fall under the NGRAM_ORDER threshold. This fix scales NGRAM_ORDER to the observed - # maximum order. It is only available through the API and off by default + # If the system guesses no i-grams, 1 <= i <= NGRAM_ORDER, the BLEU + # score is 0 (technically undefined). This is a problem for sentence-level + # BLEU or a corpus of short sentences, where systems will get no credit + # if sentence lengths fall under the NGRAM_ORDER threshold. This fix scales + # NGRAM_ORDER to the observed maximum order. + # It is only available through the API and off by default brevity_penalty = 1.0 if sys_len < ref_len: @@ -374,7 +384,7 @@ def corpus_bleu(sys_stream: Sequence[str], :param force: Ignore data that looks already tokenized. :param lowercase: Lowercase the data. :param tokenize: The tokenizer to use. - :return: A BLEU object containing everything you'd want. + :return: A BLEU object containing everything yo'd want. """ # Add some robustness to the input arguments. From b769e6cc5b877e5ed40ef7e86aaebd0c53d9d5ab Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 9 Jan 2025 17:59:09 +0000 Subject: [PATCH 36/63] fix startup script for python version upgrade --- docker/scripts/startup.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 527e8306a..1dbba9565 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -156,7 +156,7 @@ fi if [[ ${TEST} == "true" ]]; then cd algorithmic-efficiency - COMMAND="python3 tests/test_traindiffs.py" + COMMAND="python tests/test_traindiffs.py" echo $COMMAND eval $COMMAND exit @@ -209,7 +209,7 @@ TUNING_RULESET_FLAG="--tuning_ruleset=${TUNING_RULESET}" # Set run command prefix depending on framework if [[ "${FRAMEWORK}" == "jax" ]]; then - COMMAND_PREFIX="python3" + COMMAND_PREFIX="python" else COMMAND_PREFIX="torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8" fi From b65157eb4a8798029f3b21a38a807dd9a6067fa9 Mon Sep 17 00:00:00 2001 From: Isaac Date: Tue, 14 Jan 2025 15:52:44 +0000 Subject: [PATCH 37/63] fix: getargspec is not supported in python311, using getfullargspec instead --- .../workloads/imagenet_resnet/imagenet_jax/randaugment.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index e920331bc..41002ff9b 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -482,13 +482,13 @@ def _parse_policy_info(name, # Check to see if prob is passed into function. This is used for operations # where we alter bboxes independently. - if 'prob' in inspect.getargspec(func)[0]: + if 'prob' in inspect.getfullargspec(func)[0]: args = tuple([prob] + list(args)) # Add in replace arg if it is required for the function that is being called. - if 'replace' in inspect.getargspec(func)[0]: + if 'replace' in inspect.getfullargspec(func)[0]: # Make sure replace is the final argument - assert 'replace' == inspect.getargspec(func)[0][-1] + assert 'replace' == inspect.getfullargspec(func)[0][-1] args = tuple(list(args) + [replace_value]) return (func, prob, args) From d9f13ab9b6fd9c5ee0bd99d72ce7eb04851aa1c9 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 16 Jan 2025 22:30:09 +0000 Subject: [PATCH 38/63] upgrade_jax --- setup.cfg | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/setup.cfg b/setup.cfg index 8e37acb7a..43458cb07 100644 --- a/setup.cfg +++ b/setup.cfg @@ -121,17 +121,17 @@ jax_core_deps = # JAX CPU jax_cpu = - jax==0.4.35 - jaxlib==0.4.35 + jax==0.4.38 + jaxlib==0.4.38 %(jax_core_deps)s # JAX GPU # Note this installs both jax and jaxlib. jax_gpu = - jax==0.4.35 - jaxlib==0.4.35 - jax-cuda12-plugin[with_cuda]==0.4.35 - jax-cuda12-pjrt==0.4.35 + jax==0.4.38 + jaxlib==0.4.38 + jax-cuda12-plugin[with_cuda]==0.4.38 + jax-cuda12-pjrt==0.4.38 %(jax_core_deps)s # PyTorch CPU From 01eb8819dbe36d8b54987758706247c63b3f73df Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 16 Jan 2025 23:16:02 +0000 Subject: [PATCH 39/63] change jax version --- setup.cfg | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.cfg b/setup.cfg index 43458cb07..040d1e26a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -128,10 +128,10 @@ jax_cpu = # JAX GPU # Note this installs both jax and jaxlib. jax_gpu = - jax==0.4.38 - jaxlib==0.4.38 - jax-cuda12-plugin[with_cuda]==0.4.38 - jax-cuda12-pjrt==0.4.38 + jax==0.4.36 + jaxlib==0.4.36 + jax-cuda12-plugin[with_cuda]==0.4.36 + jax-cuda12-pjrt==0.4.36 %(jax_core_deps)s # PyTorch CPU From c9b641158e5ff41918f7e94f5d492c0b30e4ed80 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 16 Jan 2025 23:28:20 +0000 Subject: [PATCH 40/63] change jax python version --- setup.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 040d1e26a..8c512d32e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -121,8 +121,8 @@ jax_core_deps = # JAX CPU jax_cpu = - jax==0.4.38 - jaxlib==0.4.38 + jax==0.4.36 + jaxlib==0.4.36 %(jax_core_deps)s # JAX GPU From 7d580f1eb3955d42de328d90f423a09cc0ed25b5 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sat, 18 Jan 2025 15:27:26 +0530 Subject: [PATCH 41/63] fix: using jax.random.key_data only when the workload is jax --- submission_runner.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 4d494f607..b371489bd 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -213,7 +213,9 @@ def train_once( ) -> Tuple[spec.Timing, Dict[str, Any]]: _reset_cuda_mem() data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) - data_rng = jax.random.key_data(data_rng) + + if FLAGS.framework == 'jax': + data_rng = jax.random.key_data(data_rng) # Workload setup. logging.info('Initializing dataset.') if hasattr(workload, '_eval_num_workers'): @@ -345,7 +347,9 @@ def train_once( data_select_rng, update_rng, prep_eval_rng, eval_rng = \ prng.split(step_rng, 4) - eval_rng = jax.random.key_data(eval_rng) + + if FLAGS.framework == 'jax': + eval_rng = jax.random.key_data(eval_rng) with profiler.profile('Data selection'): batch = data_selection(workload, From 57156188502a70a50e29164ad29822e093c2b8db Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Jan 2025 21:43:38 +0000 Subject: [PATCH 42/63] revert to use PRNGKey --- algorithmic_efficiency/random_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index 603a644d1..a579976ad 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -75,5 +75,5 @@ def split(seed: SeedType, num: int = 2) -> SeedType: def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name if FLAGS.framework == 'jax': _check_jax_install() - return jax_rng.key(seed) + return jax_rng.PRNGKey(seed) return _PRNGKey(seed) From 3fb722d5e1ba47ab8d5e00caa29a534ae6b70e89 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Jan 2025 21:46:35 +0000 Subject: [PATCH 43/63] revert changes to submission runner for prng key --- submission_runner.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index b371489bd..228cbc4d7 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -214,8 +214,7 @@ def train_once( _reset_cuda_mem() data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) - if FLAGS.framework == 'jax': - data_rng = jax.random.key_data(data_rng) + data_rng = jax.random.key_data(data_rng) # Workload setup. logging.info('Initializing dataset.') if hasattr(workload, '_eval_num_workers'): @@ -348,8 +347,7 @@ def train_once( data_select_rng, update_rng, prep_eval_rng, eval_rng = \ prng.split(step_rng, 4) - if FLAGS.framework == 'jax': - eval_rng = jax.random.key_data(eval_rng) + eval_rng = jax.random.key_data(eval_rng) with profiler.profile('Data selection'): batch = data_selection(workload, From 9b7cee40010d45561f38acad0a701b0188e72181 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Jan 2025 23:42:38 +0000 Subject: [PATCH 44/63] remove extracting key_data --- submission_runner.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 228cbc4d7..06963fc9d 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -214,7 +214,6 @@ def train_once( _reset_cuda_mem() data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) - data_rng = jax.random.key_data(data_rng) # Workload setup. logging.info('Initializing dataset.') if hasattr(workload, '_eval_num_workers'): @@ -347,8 +346,6 @@ def train_once( data_select_rng, update_rng, prep_eval_rng, eval_rng = \ prng.split(step_rng, 4) - eval_rng = jax.random.key_data(eval_rng) - with profiler.profile('Data selection'): batch = data_selection(workload, input_queue, From 5775ed166b241abf12088dd2569ef9894eb9e6de Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 23 Jan 2025 00:58:24 +0000 Subject: [PATCH 45/63] cast np.int32 as int for random.Random arg --- .../workloads/cifar/cifar_pytorch/workload.py | 2 +- .../workloads/imagenet_resnet/imagenet_pytorch/workload.py | 2 +- .../librispeech_conformer/librispeech_pytorch/workload.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py index 7abcf4d6c..119c6378c 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py @@ -82,7 +82,7 @@ def _build_dataset( } if split == 'eval_train': train_indices = indices_split['train'] - random.Random(data_rng[0]).shuffle(train_indices) + random.Random(int(data_rng[0])).shuffle(train_indices) indices_split['eval_train'] = train_indices[:self.num_eval_train_examples] if split in indices_split: dataset = torch.utils.data.Subset(dataset, indices_split[split]) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 3549911fa..6387a40c0 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -120,7 +120,7 @@ def _build_dataset( if split == 'eval_train': indices = list(range(self.num_train_examples)) - random.Random(data_rng[0]).shuffle(indices) + random.Random(int(data_rng[0])).shuffle(indices) dataset = torch.utils.data.Subset(dataset, indices[:self.num_eval_train_examples]) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 155b30920..83f0a2de7 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -166,7 +166,7 @@ def _build_input_queue( ds = LibriSpeechDataset(split=ds_split, data_dir=data_dir) if split == 'eval_train': indices = list(range(len(ds))) - random.Random(data_rng[0]).shuffle(indices) + random.Random(int(data_rng[0])).shuffle(indices) ds = torch.utils.data.Subset(ds, indices[:self.num_eval_train_examples]) sampler = None From 1352e70ad71fef6b0508bcc95d3af317fe62fa90 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sun, 26 Jan 2025 11:36:58 +0530 Subject: [PATCH 46/63] fix: vim installation --- docker/Dockerfile | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 77dac5313..07375dd92 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -24,7 +24,8 @@ RUN apt-get update && apt-get install -y \ libffi-dev \ curl \ libbz2-dev \ - liblzma-dev + liblzma-dev \ + vim # Download and install Python 3.11 RUN cd /tmp \ @@ -91,6 +92,7 @@ RUN cd /algorithmic-efficiency && pip install -e '.[full]' RUN cd /algorithmic-efficiency && git fetch origin RUN cd /algorithmic-efficiency && git pull +RUN pip install wandb # Todo: remove this, this is temporary for developing COPY scripts/startup.sh /algorithmic-efficiency/docker/scripts/startup.sh From d7eebf88bf2fe61fc523991d62e5a0095af9fd64 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 30 Jan 2025 01:49:23 +0000 Subject: [PATCH 47/63] use inductor backend to compile deepspeech instead of eager --- submission_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 06963fc9d..d2dcb03ac 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -242,8 +242,9 @@ def train_once( 'ogbg', 'criteo1tb', 'imagenet_vit', + 'librispeech_deepspeech' ] - eager_backend_workloads = ['librispeech_deepspeech'] + eager_backend_workloads = [] aot_eager_backend_workloads = [] loss_compilation_workloads = [ 'fastmri', 'librispeech_deepspeech', 'ogbg', 'wmt' From 58159c5fd7c80c7dc43a62f9df714baf7d82eadb Mon Sep 17 00:00:00 2001 From: init-22 Date: Sun, 2 Feb 2025 23:17:14 +0530 Subject: [PATCH 48/63] adding mem_fraction 0.80 for jax workfloads to resolve OOM of certain worklods --- docker/Dockerfile | 1 - submission_runner.py | 13 +++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 07375dd92..76bc5cfe0 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -92,7 +92,6 @@ RUN cd /algorithmic-efficiency && pip install -e '.[full]' RUN cd /algorithmic-efficiency && git fetch origin RUN cd /algorithmic-efficiency && git pull -RUN pip install wandb # Todo: remove this, this is temporary for developing COPY scripts/startup.sh /algorithmic-efficiency/docker/scripts/startup.sh diff --git a/submission_runner.py b/submission_runner.py index d2dcb03ac..2acc9d33c 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -693,12 +693,21 @@ def main(_): # Prevent OOM on librispeech conformer. base_workload = workloads.get_base_workload_name(FLAGS.workload) - if base_workload == 'librispeech_conformer': - os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.85' + + if base_workload == [ + 'librispeech_conformer', + 'librispeech_deepspeech', + 'imagenet_vit', + 'criteo1tb' + ]: + os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' if FLAGS.set_pytorch_max_split_size: os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' + if FLAGS.framework == 'pytorch' and base_workload == 'librispeech_conformer': + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' + # Extend path according to framework. workload_metadata['workload_path'] = os.path.join( BASE_WORKLOADS_DIR, From 81bc93d2394762d883058d922ebc524ad69706f5 Mon Sep 17 00:00:00 2001 From: init-22 Date: Mon, 3 Feb 2025 15:26:07 +0530 Subject: [PATCH 49/63] mem fraction typo --- submission_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 2acc9d33c..6024ba1a2 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -694,7 +694,7 @@ def main(_): # Prevent OOM on librispeech conformer. base_workload = workloads.get_base_workload_name(FLAGS.workload) - if base_workload == [ + if base_workload in [ 'librispeech_conformer', 'librispeech_deepspeech', 'imagenet_vit', From f6ca2bce0593a622cf53f90c1750bf27848eb892 Mon Sep 17 00:00:00 2001 From: init-22 Date: Mon, 3 Feb 2025 16:02:46 +0530 Subject: [PATCH 50/63] env variable for conformer set at the top --- submission_runner.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 6024ba1a2..da4e8371c 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -21,6 +21,12 @@ import itertools import json import os + +os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. +# disable only for deepspeech if it works fine for other workloads. +os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' + import struct import time from types import MappingProxyType @@ -30,12 +36,10 @@ from absl import flags from absl import logging import jax +import tensorflow as tf import torch import torch.distributed as dist -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. -import tensorflow as tf - # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.set_visible_devices([], 'GPU') @@ -52,9 +56,6 @@ from algorithmic_efficiency.pytorch_utils import sync_ddp_time from algorithmic_efficiency.workloads import workloads -# disable only for deepspeech if it works fine for other workloads. -os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' - # TODO(znado): make a nicer registry of workloads that lookup in. BASE_WORKLOADS_DIR = workloads.BASE_WORKLOADS_DIR @@ -702,12 +703,13 @@ def main(_): ]: os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' + if base_workload != 'librispeech_conformer': + # Remove the environment variable (only for workloads other than librispeech conformer). + del os.environ['PYTORCH_CUDA_ALLOC_CONF'] + if FLAGS.set_pytorch_max_split_size: os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' - if FLAGS.framework == 'pytorch' and base_workload == 'librispeech_conformer': - os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' - # Extend path according to framework. workload_metadata['workload_path'] = os.path.join( BASE_WORKLOADS_DIR, From b4ed6cc11d1730b3402e83e5162302d14681a9c9 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 4 Feb 2025 01:38:52 +0000 Subject: [PATCH 51/63] set env variables for pytorch before initializing w ddp. --- submission_runner.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index da4e8371c..495fd2039 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -22,11 +22,6 @@ import json import os -os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. -# disable only for deepspeech if it works fine for other workloads. -os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' - import struct import time from types import MappingProxyType @@ -56,6 +51,11 @@ from algorithmic_efficiency.pytorch_utils import sync_ddp_time from algorithmic_efficiency.workloads import workloads +# Environment variables +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. +# disable only for deepspeech if it works fine for other workloads +os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' + # TODO(znado): make a nicer registry of workloads that lookup in. BASE_WORKLOADS_DIR = workloads.BASE_WORKLOADS_DIR @@ -681,6 +681,14 @@ def main(_): else: profiler = PassThroughProfiler() + # Set PyTorch environment variables before initializing w DDP + base_workload = workloads.get_base_workload_name(FLAGS.workload) + if base_workload == 'librispeech_conformer': + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' + + if FLAGS.set_pytorch_max_split_size: + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' + if FLAGS.framework == 'pytorch': pytorch_init(USE_PYTORCH_DDP, RANK, profiler) @@ -692,9 +700,6 @@ def main(_): workload_metadata = WORKLOADS[FLAGS.workload] - # Prevent OOM on librispeech conformer. - base_workload = workloads.get_base_workload_name(FLAGS.workload) - if base_workload in [ 'librispeech_conformer', 'librispeech_deepspeech', @@ -703,13 +708,6 @@ def main(_): ]: os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' - if base_workload != 'librispeech_conformer': - # Remove the environment variable (only for workloads other than librispeech conformer). - del os.environ['PYTORCH_CUDA_ALLOC_CONF'] - - if FLAGS.set_pytorch_max_split_size: - os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' - # Extend path according to framework. workload_metadata['workload_path'] = os.path.join( BASE_WORKLOADS_DIR, From ebf0341ab43d78d7f70be88416543a970c327efb Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 5 Feb 2025 23:08:59 +0000 Subject: [PATCH 52/63] set jax to 0.4.26 --- setup.cfg | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/setup.cfg b/setup.cfg index 8c512d32e..3c435c453 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,7 +42,8 @@ install_requires = tensorflow-datasets==4.9.7 gputil==1.4.0 psutil==6.1.0 - clu==0.0.12 + # clu==0.0.12 + clu matplotlib>=3.9.2 tabulate==0.9.0 wandb==0.18.7 @@ -107,31 +108,34 @@ wmt = # JAX Core jax_core_deps = - flax==0.10.1 - optax==0.2.4 + # flax==0.10.1 + flax + # optax==0.2.4 + optax # Fix chex (optax dependency) version. # Not fixing it can raise dependency issues with our # jax version. # Todo(kasimbeg): verify if this is necessary after we # upgrade jax. - chex==0.1.87 + # chex==0.1.87 + chex ml_dtypes==0.4.1 protobuf==4.25.5 # JAX CPU jax_cpu = - jax==0.4.36 - jaxlib==0.4.36 + jax==0.4.26 + jaxlib==0.4.26 %(jax_core_deps)s # JAX GPU # Note this installs both jax and jaxlib. jax_gpu = - jax==0.4.36 - jaxlib==0.4.36 - jax-cuda12-plugin[with_cuda]==0.4.36 - jax-cuda12-pjrt==0.4.36 + jax==0.4.26 + jaxlib==0.4.26 + jax-cuda12-plugin[with_cuda]==0.4.26 + jax-cuda12-pjrt==0.4.26 %(jax_core_deps)s # PyTorch CPU From 1ce6deac8cec6b5e06bd161b2deb5392432d43c5 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 7 Feb 2025 23:22:38 +0000 Subject: [PATCH 53/63] set jax versions --- setup.cfg | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/setup.cfg b/setup.cfg index 3c435c453..43b31b536 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,8 +42,7 @@ install_requires = tensorflow-datasets==4.9.7 gputil==1.4.0 psutil==6.1.0 - # clu==0.0.12 - clu + clu==0.0.12 matplotlib>=3.9.2 tabulate==0.9.0 wandb==0.18.7 @@ -108,17 +107,9 @@ wmt = # JAX Core jax_core_deps = - # flax==0.10.1 - flax - # optax==0.2.4 - optax - # Fix chex (optax dependency) version. - # Not fixing it can raise dependency issues with our - # jax version. - # Todo(kasimbeg): verify if this is necessary after we - # upgrade jax. - # chex==0.1.87 - chex + flax==0.8.4 + optax==0.2.2 + chex==0.1.86 ml_dtypes==0.4.1 protobuf==4.25.5 @@ -132,10 +123,10 @@ jax_cpu = # JAX GPU # Note this installs both jax and jaxlib. jax_gpu = - jax==0.4.26 - jaxlib==0.4.26 - jax-cuda12-plugin[with_cuda]==0.4.26 - jax-cuda12-pjrt==0.4.26 + jax==0.4.28 + jaxlib==0.4.28 + jax-cuda12-plugin[with_cuda]==0.4.28 + jax-cuda12-pjrt==0.4.28 %(jax_core_deps)s # PyTorch CPU From 082be0311894082409cb5580ecb8cb2f13821b8e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 8 Feb 2025 02:27:45 +0000 Subject: [PATCH 54/63] fix pytorch version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 870cfe99a..6acdd3351 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,7 +116,7 @@ jax_gpu = [ ] pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"] pytorch_gpu = [ - "torch==2.5.0", + "torch==2.5.1", "torchvision==0.20.1", ] # Note: omit the cuda suffix and installing from the appropriate wheel will result in using locally installed CUDA. wandb = ["wandb==0.16.5"] From 39bb87683aa068dad0040ac36c31d2d2d1769575 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 8 Feb 2025 02:28:45 +0000 Subject: [PATCH 55/63] fix jax versions --- pyproject.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6acdd3351..c34ec00f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,13 +103,13 @@ jax_core_deps = [ "protobuf==4.25.5", ] jax_cpu = [ - "jax==0.4.26", - "jaxlib==0.4.26", + "jax==0.4.28", + "jaxlib==0.4.28", "algorithmic_efficiency[jax_core_deps]", ] jax_gpu = [ - "jax==0.4.26", - "jaxlib==0.4.26", + "jax==0.4.28", + "jaxlib==0.4.28", "jax-cuda12-plugin[with_cuda]==0.4.28", "jax-cuda12-pjrt==0.4.28", "algorithmic_efficiency[jax_core_deps]", From b45a69b029f7d2d20e100cca5ea800843fec361a Mon Sep 17 00:00:00 2001 From: init-22 Date: Sat, 8 Feb 2025 14:13:13 +0530 Subject: [PATCH 56/63] fix: adding wandb under 'full' section --- pyproject.toml | 6 ++++-- submission_runner.py | 3 +-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c34ec00f3..b77adaef7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "clu==0.0.12", "matplotlib>=3.9.2", "tabulate==0.9.0", + ] [build-system] @@ -70,7 +71,7 @@ version_file = "algorithmic_efficiency/_version.py" [project.optional-dependencies] # All workloads full = [ - "algorithmic_efficiency[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]", + "algorithmic_efficiency[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt,wandb]", ] # All workloads plus development dependencies full_dev = ["algorithmic_efficiency[full,dev]"] @@ -83,6 +84,8 @@ dev = [ "pre-commit==4.0.1", ] +wandb = ["wandb==0.16.5"] + # Workloads criteo1tb = ["scikit-learn==1.5.2"] fastmri = ["h5py==3.12.0", "scikit-image==0.24.0"] @@ -119,7 +122,6 @@ pytorch_gpu = [ "torch==2.5.1", "torchvision==0.20.1", ] # Note: omit the cuda suffix and installing from the appropriate wheel will result in using locally installed CUDA. -wandb = ["wandb==0.16.5"] ############################################################################### # Linting Configurations # diff --git a/submission_runner.py b/submission_runner.py index 495fd2039..2753a604b 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -21,7 +21,6 @@ import itertools import json import os - import struct import time from types import MappingProxyType @@ -685,7 +684,7 @@ def main(_): base_workload = workloads.get_base_workload_name(FLAGS.workload) if base_workload == 'librispeech_conformer': os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' - + if FLAGS.set_pytorch_max_split_size: os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' From b719a6e5a5bb05abf85da768a60ea43183468684 Mon Sep 17 00:00:00 2001 From: init-22 Date: Mon, 10 Feb 2025 21:58:51 +0530 Subject: [PATCH 57/63] fix: wandb version upgrade --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b77adaef7..b4840b35c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,7 +84,7 @@ dev = [ "pre-commit==4.0.1", ] -wandb = ["wandb==0.16.5"] +wandb = ["wandb==0.19.6"] # Workloads criteo1tb = ["scikit-learn==1.5.2"] From 1ce3e624de7055af44d90b9fb0fab9a28ea268dd Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 10 Feb 2025 21:26:54 +0000 Subject: [PATCH 58/63] remove wandb from full --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b4840b35c..9130f733f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,7 @@ version_file = "algorithmic_efficiency/_version.py" [project.optional-dependencies] # All workloads full = [ - "algorithmic_efficiency[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt,wandb]", + "algorithmic_efficiency[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]", ] # All workloads plus development dependencies full_dev = ["algorithmic_efficiency[full,dev]"] From a12733afad3bf0b1656bea409bf1ee658c9c88fe Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 11 Feb 2025 00:04:19 +0000 Subject: [PATCH 59/63] revert isort version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2cc9dfdc8..cc404f4b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,7 +77,7 @@ full = [ full_dev = ["algoperf[full,dev]"] # Dependencies for developing the package dev = [ - "isort==5.13.0", + "isort==5.12.0", "pylint==2.17.4", "pytest==8.3.3", "yapf==0.32.0", From bee0e3f03a6b072519dcbdb522ec4afcca6fc3cc Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 11 Feb 2025 00:05:11 +0000 Subject: [PATCH 60/63] revert import order changes --- .../imagenet_jax/randaugment.py | 3 +-- algoperf/workloads/imagenet_vit/workload.py | 3 ++- .../librispeech_jax/workload.py | 10 ++++----- .../librispeech_pytorch/workload.py | 21 +++++++++---------- algoperf/workloads/ogbg/workload.py | 3 ++- 5 files changed, 20 insertions(+), 20 deletions(-) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py index 41002ff9b..98e6e0f8e 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -7,14 +7,13 @@ import inspect import math -import tensorflow as tf - from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ rotate_img from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ transform from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ translate +import tensorflow as tf # This signifies the max integer that the controller RNN could predict for the # augmentation scheme. diff --git a/algoperf/workloads/imagenet_vit/workload.py b/algoperf/workloads/imagenet_vit/workload.py index 9c885ca7c..f249ddee8 100644 --- a/algoperf/workloads/imagenet_vit/workload.py +++ b/algoperf/workloads/imagenet_vit/workload.py @@ -3,7 +3,8 @@ from typing import Dict, Iterator, Optional from algoperf import spec -from algoperf.workloads.imagenet_resnet.workload import BaseImagenetResNetWorkload +from algoperf.workloads.imagenet_resnet.workload import \ + BaseImagenetResNetWorkload def decode_variant(variant: str) -> Dict[str, int]: diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 1cadebf45..d3b616f43 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -1,15 +1,15 @@ import functools from typing import Dict, Optional, Tuple +from flax import jax_utils import jax import jax.numpy as jnp import numpy as np -from flax import jax_utils -from algoperf import param_utils, spec -from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import ( - LibriSpeechConformerWorkload, -) +from algoperf import param_utils +from algoperf import spec +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ + LibriSpeechConformerWorkload from algoperf.workloads.librispeech_deepspeech.librispeech_jax import models diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index c72c1daee..e5387f5cb 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -3,18 +3,17 @@ import torch from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import param_utils, spec +from algoperf import param_utils +from algoperf import spec from algoperf.pytorch_utils import pytorch_setup -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import ( - initialize, -) -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import ( - LibriSpeechConformerWorkload, -) -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import ( - DeepspeechConfig, - DeepspeechEncoderDecoder, -) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ + initialize +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ + LibriSpeechConformerWorkload +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ + DeepspeechConfig +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ + DeepspeechEncoderDecoder USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index ca123f885..971e7f0f6 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -9,7 +9,8 @@ from algoperf import random_utils as prng from algoperf import spec -from algoperf.workloads.ogbg import input_pipeline, metrics +from algoperf.workloads.ogbg import input_pipeline +from algoperf.workloads.ogbg import metrics class BaseOgbgWorkload(spec.Workload): From 1f72cb3f6df53d5d0330f363576832217ef0f537 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 11 Feb 2025 00:08:07 +0000 Subject: [PATCH 61/63] remove temporary testing for upgrades --- .../regression_tests_python_upgrade.yml | 183 ------------------ 1 file changed, 183 deletions(-) delete mode 100644 .github/workflows/regression_tests_python_upgrade.yml diff --git a/.github/workflows/regression_tests_python_upgrade.yml b/.github/workflows/regression_tests_python_upgrade.yml deleted file mode 100644 index 783395353..000000000 --- a/.github/workflows/regression_tests_python_upgrade.yml +++ /dev/null @@ -1,183 +0,0 @@ -name: Containerized Regression Tests Python Upgrades - -on: - pull_request: - branches: - - 'python_test_env_upgrade' - -jobs: - build_and_push_jax_docker_image: - runs-on: self-hosted - steps: - - uses: actions/checkout@v2 - - name: Build and push docker images - run: | - GIT_BRANCH=${{ github.head_ref || github.ref_name }} - FRAMEWORK=jax - IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" - cd $HOME/algorithmic-efficiency/docker - docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH - BUILD_RETURN=$? - if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi - docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME - docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME - build_and_push_pytorch_docker_image: - runs-on: self-hosted - steps: - - uses: actions/checkout@v2 - - name: Build and push docker images - run: | - GIT_BRANCH=${{ github.head_ref || github.ref_name }} - FRAMEWORK=pytorch - IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" - cd $HOME/algorithmic-efficiency/docker - docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH - BUILD_RETURN=$? - if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi - docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME - docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME - fastmri_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d fastmri -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w fastmri -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - imagenet_resnet_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d imagenet -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w imagenet_resnet -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - imagenet_vit_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d imagenet -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w imagenet_vit -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - ogbg_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d ogbg -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w ogbg -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - criteo_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w criteo1tb -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - librispeech_conformer_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d librispeech -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w librispeech_conformer -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - librispeech_deepspeech_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d librispeech -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w librispeech_deepspeech -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - wmt_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d wmt -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w wmt -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - fastmri_pytorch: - runs-on: self-hosted - needs: build_and_push_pytorch_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d fastmri -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w fastmri -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - imagenet_resnet_pytorch: - runs-on: self-hosted - needs: build_and_push_pytorch_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d imagenet -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w imagenet_resnet -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - imagenet_vit_pytorch: - runs-on: self-hosted - needs: build_and_push_pytorch_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d imagenet -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w imagenet_vit -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - ogbg_pytorch: - runs-on: self-hosted - needs: build_and_push_pytorch_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d ogbg -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w ogbg -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - criteo_pytorch: - runs-on: self-hosted - needs: build_and_push_pytorch_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w criteo1tb -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - exit $? - librispeech_conformer_pytorch: - runs-on: self-hosted - needs: build_and_push_pytorch_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d librispeech -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w librispeech_conformer -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - librispeech_deepspeech_pytorch: - runs-on: self-hosted - needs: build_and_push_pytorch_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d librispeech -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w librispeech_deepspeech -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - wmt_pytorch: - runs-on: self-hosted - needs: build_and_push_pytorch_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d wmt -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w wmt -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false From 4dadef0b93b8b08606c51571389efbcdd2552698 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 11 Feb 2025 00:34:02 +0000 Subject: [PATCH 62/63] update import path in randaugment.py --- .../workloads/imagenet_resnet/imagenet_jax/randaugment.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py index 98e6e0f8e..accd9b4a9 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -7,11 +7,11 @@ import inspect import math -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ +from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ rotate_img -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ +from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ transform -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ +from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ translate import tensorflow as tf From f375099648d9d9904d16d2008146644a793b3bd7 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 11 Feb 2025 01:33:02 +0000 Subject: [PATCH 63/63] isort changes --- algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py index accd9b4a9..c68e2de33 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -7,13 +7,14 @@ import inspect import math +import tensorflow as tf + from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ rotate_img from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ transform from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ translate -import tensorflow as tf # This signifies the max integer that the controller RNN could predict for the # augmentation scheme.