diff --git a/.pin/constraints-cuda-torch.txt b/.pin/constraints-cuda-torch.txt
index 2717ed4ef..1886ffa21 100644
--- a/.pin/constraints-cuda-torch.txt
+++ b/.pin/constraints-cuda-torch.txt
@@ -32,9 +32,9 @@ accelerate==0.34.2
     #   -r benchmarks/rlhf/requirements.in
     #   diffusers
     #   trl
-aiohappyeyeballs==2.4.0
+aiohappyeyeballs==2.4.3
     # via aiohttp
-aiohttp==3.10.5
+aiohttp==3.10.8
     # via
     #   datasets
     #   fsspec
@@ -50,7 +50,7 @@ argklass==1.4.4
     #   -r benchmarks/diffusion/requirements.in
     #   -r benchmarks/llm/requirements.in
     #   -r benchmarks/purejaxrl/requirements.in
-astroid==3.2.4
+astroid==3.3.4
     # via pylint
 asttokens==2.4.1
     # via giving
@@ -58,7 +58,7 @@ async-timeout==4.0.3
     # via aiohttp
 attrs==24.2.0
     # via aiohttp
-beartype==0.18.5
+beartype==0.19.0
     # via -r benchmarks/vjepa/requirements.in
 black==24.8.0
     # via navix
@@ -70,7 +70,7 @@ blobfile==3.0.0
     #   torchtune
 blosc2==2.7.1
     # via tables
-botorch==0.11.3
+botorch==0.12.0
     # via -r benchmarks/recursiongfn/requirements.in
 braceexpand==0.1.7
     # via
@@ -88,7 +88,7 @@ certifi==2024.8.30
     #   sentry-sdk
 charset-normalizer==3.3.2
     # via requests
-chex==0.1.86
+chex==0.1.87
     # via
     #   distrax
     #   evosax
@@ -117,7 +117,7 @@ cvxopt==1.3.2
     # via -r benchmarks/recursiongfn/requirements.in
 cycler==0.12.1
     # via matplotlib
-datasets==3.0.0
+datasets==3.0.1
     # via
     #   -r benchmarks/diffusion/requirements.in
     #   -r benchmarks/llama/requirements.in
@@ -129,7 +129,7 @@ decorator==5.1.1
     # via tensorflow-probability
 decord==0.6.0
     # via -r benchmarks/vjepa/requirements.in
-diffusers[torch]==0.30.2
+diffusers[torch]==0.30.3
     # via -r benchmarks/diffusion/requirements.in
 dill==0.3.8
     # via
@@ -179,7 +179,7 @@ fairscale==0.4.13
     #   -r benchmarks/llm/requirements.txt
 farama-notifications==0.0.4
     # via gymnasium
-filelock==3.16.0
+filelock==3.16.1
     # via
     #   blobfile
     #   datasets
@@ -188,7 +188,7 @@ filelock==3.16.0
     #   torch
     #   transformers
     #   triton
-fire==0.6.0
+fire==0.7.0
     # via
     #   -r benchmarks/llama/requirements.in
     #   -r benchmarks/llm/requirements.txt
@@ -210,7 +210,7 @@ flax==0.9.0
     #   flashbax
     #   gymnax
     #   navix
-fonttools==4.53.1
+fonttools==4.54.1
     # via matplotlib
 frozenlist==1.4.1
     # via
@@ -241,11 +241,11 @@ giving==0.4.3
     #   voir
 glfw==2.7.0
     # via mujoco
-gpytorch==1.12
+gpytorch==1.13
     # via
     #   -r benchmarks/recursiongfn/requirements.in
     #   botorch
-grpcio==1.66.1
+grpcio==1.66.2
     # via
     #   brax
     #   tensorboard
@@ -267,7 +267,7 @@ gymnax==0.0.8
     #   -r benchmarks/purejaxrl/requirements.in
 hjson==3.1.0
     # via argklass
-huggingface-hub==0.24.7
+huggingface-hub==0.25.1
     # via
     #   -r benchmarks/timm/requirements.in
     #   accelerate
@@ -301,7 +301,7 @@ isort==5.13.2
     # via pylint
 itsdangerous==2.2.0
     # via flask
-jax[cuda12]==0.4.31
+jax[cuda12]==0.4.33
     # via
     #   -r benchmarks/brax/requirements.in
     #   -r benchmarks/purejaxrl/requirements.in
@@ -318,11 +318,11 @@ jax[cuda12]==0.4.31
     #   optax
     #   orbax-checkpoint
     #   rlax
-jax-cuda12-pjrt==0.4.31
+jax-cuda12-pjrt==0.4.33
     # via jax-cuda12-plugin
-jax-cuda12-plugin[with-cuda]==0.4.31
+jax-cuda12-plugin[with-cuda]==0.4.33
     # via jax
-jaxlib==0.4.31
+jaxlib==0.4.33
     # via
     #   brax
     #   chex
@@ -338,8 +338,10 @@ jaxlib==0.4.31
     #   rlax
 jaxopt==0.8.3
     # via brax
-jaxtyping==0.2.34
-    # via linear-operator
+jaxtyping==0.2.19
+    # via
+    #   gpytorch
+    #   linear-operator
 jinja2==3.1.4
     # via
     #   brax
@@ -357,7 +359,7 @@ lightning-utilities==0.11.7
     #   lightning
     #   pytorch-lightning
     #   torchmetrics
-linear-operator==0.5.2
+linear-operator==0.5.3
     # via
     #   botorch
     #   gpytorch
@@ -393,17 +395,18 @@ mpmath==1.3.0
     # via
     #   botorch
     #   gpytorch
+    #   linear-operator
     #   sympy
 msgpack==1.1.0
     # via
     #   blosc2
     #   flax
     #   orbax-checkpoint
-mujoco==3.2.2
+mujoco==3.2.3
     # via
     #   brax
     #   mujoco-mjx
-mujoco-mjx==3.2.2
+mujoco-mjx==3.2.3
     # via brax
 multidict==6.1.0
     # via
@@ -417,7 +420,7 @@ mypy-extensions==1.0.0
     # via black
 navix==0.7.0
     # via -r benchmarks/purejaxrl/requirements.in
-ndindex==1.8
+ndindex==1.9.2
     # via blosc2
 nest-asyncio==1.6.0
     # via orbax-checkpoint
@@ -438,7 +441,6 @@ numpy==1.26.4
     #   -r benchmarks/vjepa/requirements.in
     #   accelerate
     #   blosc2
-    #   botorch
     #   brax
     #   chex
     #   contourpy
@@ -457,13 +459,13 @@ numpy==1.26.4
     #   jax
     #   jaxlib
     #   jaxopt
+    #   jaxtyping
     #   matplotlib
     #   ml-dtypes
     #   mujoco
     #   navix
     #   numexpr
     #   opencv-python
-    #   opt-einsum
     #   optax
     #   orbax-checkpoint
     #   pandas
@@ -498,7 +500,7 @@ nvidia-cuda-cupti-cu12==12.1.105
     # via
     #   jax-cuda12-plugin
     #   torch
-nvidia-cuda-nvcc-cu12==12.6.68
+nvidia-cuda-nvcc-cu12==12.6.77
     # via jax-cuda12-plugin
 nvidia-cuda-nvrtc-cu12==12.1.105
     # via torch
@@ -531,7 +533,7 @@ nvidia-nccl-cu12==2.20.5
     # via
     #   jax-cuda12-plugin
     #   torch
-nvidia-nvjitlink-cu12==12.6.68
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   jax-cuda12-plugin
     #   nvidia-cusolver-cu12
@@ -546,7 +548,7 @@ omegaconf==2.3.0
     #   voir
 opencv-python==4.10.0.84
     # via -r benchmarks/vjepa/requirements.in
-opt-einsum==3.3.0
+opt-einsum==3.4.0
     # via
     #   jax
     #   pyro-ppl
@@ -557,7 +559,7 @@ optax==0.2.3
     #   flax
 optree==0.12.1
     # via envpool
-orbax-checkpoint==0.6.3
+orbax-checkpoint==0.6.4
     # via
     #   brax
     #   flax
@@ -581,7 +583,7 @@ packaging==24.1
     #   tensorboardx
     #   torchmetrics
     #   transformers
-pandas==2.2.2
+pandas==2.2.3
     # via
     #   -r benchmarks/geo_gnn/requirements.in
     #   -r benchmarks/recursiongfn/requirements.in
@@ -601,7 +603,7 @@ pillow==10.4.0
     #   navix
     #   rdkit
     #   torchvision
-platformdirs==4.3.3
+platformdirs==4.3.6
     # via
     #   black
     #   pylint
@@ -610,7 +612,7 @@ pluggy==1.5.0
     # via pytest
 portalocker==2.10.1
     # via iopath
-protobuf==5.28.1
+protobuf==5.28.2
     # via
     #   orbax-checkpoint
     #   tensorboard
@@ -634,13 +636,13 @@ pyarrow==17.0.0
     #   datasets
 pycodestyle==2.12.1
     # via flake8
-pycryptodomex==3.20.0
+pycryptodomex==3.21.0
     # via blobfile
 pyflakes==3.2.0
     # via flake8
 pygments==2.18.0
     # via rich
-pylint==3.2.7
+pylint==3.3.1
     # via navix
 pyopengl==3.1.7
     # via mujoco
@@ -708,7 +710,7 @@ requests==2.32.3
     #   torch-geometric
     #   transformers
     #   wandb
-rich==13.8.1
+rich==13.9.1
     # via
     #   flax
     #   tyro
@@ -746,7 +748,7 @@ sentencepiece==0.2.0
     # via
     #   -r benchmarks/llama/requirements.in
     #   torchtune
-sentry-sdk==2.14.0
+sentry-sdk==2.15.0
     # via wandb
 setproctitle==1.3.3
     # via wandb
@@ -758,24 +760,23 @@ six==1.16.0
     # via
     #   asttokens
     #   docker-pycreds
-    #   fire
     #   ml-collections
     #   python-dateutil
     #   tensorboard
     #   tensorflow-probability
 smmap==5.0.1
     # via gitdb
-submitit==1.5.1
+submitit==1.5.2
     # via
     #   -r benchmarks/dinov2/requirements.in
     #   -r benchmarks/vjepa/requirements.in
-sympy==1.13.2
+sympy==1.13.3
     # via torch
 tables==3.10.1
     # via -r benchmarks/recursiongfn/requirements.in
 tabulate==0.9.0
     # via fvcore
-tensorboard==2.17.1
+tensorboard==2.18.0
     # via
     #   -r benchmarks/recursiongfn/requirements.in
     #   -r benchmarks/torchatari/requirements.in
@@ -785,7 +786,7 @@ tensorboardx==2.6.2.2
     # via brax
 tensorflow-probability==0.24.0
     # via distrax
-tensorstore==0.1.65
+tensorstore==0.1.66
     # via
     #   flashbax
     #   flax
@@ -802,7 +803,7 @@ timm==1.0.9
     # via -r benchmarks/vjepa/requirements.in
 tokenizers==0.19.1
     # via transformers
-tomli==2.0.1
+tomli==2.0.2
     # via
     #   black
     #   pylint
@@ -849,7 +850,7 @@ torch-cluster==1.6.3+pt24cu121
     # via
     #   -r benchmarks/geo_gnn/requirements.in
     #   -r benchmarks/recursiongfn/requirements.in
-torch-geometric==2.6.0
+torch-geometric==2.6.1
     # via
     #   -r benchmarks/geo_gnn/requirements.in
     #   -r benchmarks/recursiongfn/requirements.in
@@ -862,7 +863,10 @@ torch-sparse==0.6.18+pt24cu121
     #   -r benchmarks/geo_gnn/requirements.in
     #   -r benchmarks/recursiongfn/requirements.in
 torchao==0.3.1+cu121
-    # via torchtune
+    # via
+    #   -c .pin/../constraints/cuda.txt
+    #   -r benchmarks/llm/requirements.in
+    #   torchtune
 torchcompat==1.1.4
     # via
     #   -c .pin/../constraints/cuda.txt
@@ -877,7 +881,9 @@ torchmetrics==1.4.2
     #   lightning
     #   pytorch-lightning
 torchtune==0.2.1+cu121
-    # via -r benchmarks/llm/requirements.in
+    # via
+    #   -c .pin/../constraints/cuda.txt
+    #   -r benchmarks/llm/requirements.in
 torchvision==0.19.0+cu121
     # via
     #   -r benchmarks/diffusion/requirements.in
@@ -907,6 +913,7 @@ tqdm==4.66.5
     #   transformers
 transformers==4.44.2
     # via
+    #   -c .pin/../constraints/cuda.txt
     #   -r benchmarks/diffusion/requirements.in
     #   -r benchmarks/huggingface/requirements.in
     #   -r benchmarks/llama/requirements.in
@@ -921,17 +928,18 @@ trimesh==4.4.9
 triton==3.0.0
     # via torch
 trl==0.10.1
-    # via -r benchmarks/rlhf/requirements.in
-typeguard==2.13.3
     # via
-    #   jaxtyping
-    #   linear-operator
-types-protobuf==5.27.0.20240907
+    #   -c .pin/../constraints/cuda.txt
+    #   -r benchmarks/rlhf/requirements.in
+typeguard==4.3.0
+    # via jaxtyping
+types-protobuf==5.28.0.20240924
     # via envpool
 typing-extensions==4.12.2
     # via
     #   astroid
     #   black
+    #   botorch
     #   brax
     #   chex
     #   envpool
@@ -941,6 +949,7 @@ typing-extensions==4.12.2
     #   gymnasium
     #   huggingface-hub
     #   iopath
+    #   jaxtyping
     #   lightning
     #   lightning-utilities
     #   multidict
@@ -949,16 +958,18 @@ typing-extensions==4.12.2
     #   orbax-checkpoint
     #   pytorch-lightning
     #   reactivex
+    #   rich
     #   submitit
     #   tables
     #   torch
+    #   typeguard
     #   tyro
-tyro==0.8.10
+tyro==0.8.11
     # via
     #   -r benchmarks/torchatari/requirements.in
     #   navix
     #   trl
-tzdata==2024.1
+tzdata==2024.2
     # via pandas
 urllib3==2.2.3
     # via
@@ -988,7 +999,7 @@ voir==0.2.19
     #   -r benchmarks/torchvision/requirements.in
     #   -r benchmarks/torchvision_ddp/requirements.in
     #   -r benchmarks/vjepa/requirements.in
-wandb==0.18.0
+wandb==0.18.3
     # via
     #   -r benchmarks/recursiongfn/requirements.in
     #   navix
@@ -1006,7 +1017,7 @@ xxhash==3.5.0
     # via datasets
 yacs==0.1.8
     # via fvcore
-yarl==1.11.1
+yarl==1.13.1
     # via aiohttp
 zipp==3.20.2
     # via
diff --git a/.pin/constraints-hpu-torch.txt b/.pin/constraints-hpu-torch.txt
index 6481e8c67..92a55858c 100644
--- a/.pin/constraints-hpu-torch.txt
+++ b/.pin/constraints-hpu-torch.txt
@@ -2,204 +2,359 @@
 # This file is autogenerated by pip-compile with Python 3.10
 # by the following command:
 #
-#    pip-compile --output-file=.pin/constraints-hpu-torch.txt .pin/tmp-constraints.txt benchmarks/accelerate_opt/requirements.in benchmarks/brax/requirements.in benchmarks/dlrm/requirements.in benchmarks/flops/requirements.in benchmarks/huggingface/requirements.in benchmarks/llama/requirements.in benchmarks/stargan/requirements.in benchmarks/super-slomo/requirements.in benchmarks/timm/requirements.in benchmarks/torchvision/requirements.in benchmarks/torchvision_ddp/requirements.in
+#    pip-compile --output-file=.pin/constraints-hpu-torch.txt .pin/tmp-constraints.txt benchmarks/brax/requirements.in benchmarks/diffusion/requirements.in benchmarks/dinov2/requirements.in benchmarks/flops/requirements.in benchmarks/geo_gnn/requirements-pre.in benchmarks/geo_gnn/requirements.in benchmarks/huggingface/requirements.in benchmarks/lightning/requirements.in benchmarks/llama/requirements.in benchmarks/llava/requirements.in benchmarks/llm/requirements.in benchmarks/purejaxrl/requirements.in benchmarks/recursiongfn/requirements.in benchmarks/rlhf/requirements.in benchmarks/timm/requirements.in benchmarks/torchatari/requirements.in benchmarks/torchvision/requirements.in benchmarks/torchvision_ddp/requirements.in benchmarks/vjepa/requirements.in constraints/extra/torch.hpu.txt
 #
---extra-index-url https://pypi.ngc.nvidia.com
---find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
---trusted-host pypi.ngc.nvidia.com
-
 absl-py==2.1.0
     # via
     #   brax
     #   chex
+    #   distrax
     #   dm-env
     #   ml-collections
     #   mujoco
     #   mujoco-mjx
     #   optax
     #   orbax-checkpoint
+    #   rlax
     #   tensorboard
-accelerate==0.32.1
-    # via -r benchmarks/accelerate_opt/requirements.in
-aiohttp==3.9.5
+    #   tensorflow-probability
+accelerate==0.34.2
+    # via
+    #   -r benchmarks/diffusion/requirements.in
+    #   -r benchmarks/llava/requirements.in
+    #   -r benchmarks/llm/requirements.in
+    #   -r benchmarks/rlhf/requirements.in
+    #   diffusers
+    #   trl
+aiohappyeyeballs==2.4.3
+    # via aiohttp
+aiohttp==3.10.8
     # via
     #   datasets
     #   fsspec
+    #   torch-geometric
 aiosignal==1.3.1
     # via aiohttp
-annotated-types==0.7.0
-    # via pydantic
 antlr4-python3-runtime==4.9.3
     # via omegaconf
+appdirs==1.4.4
+    # via cantilever
+argklass==1.4.4
+    # via
+    #   -r benchmarks/diffusion/requirements.in
+    #   -r benchmarks/llm/requirements.in
+    #   -r benchmarks/purejaxrl/requirements.in
+astroid==3.3.4
+    # via pylint
 asttokens==2.4.1
     # via giving
 async-timeout==4.0.3
     # via aiohttp
-attrs==23.2.0
+attrs==24.2.0
     # via aiohttp
-beautifulsoup4==4.12.3
-    # via gdown
+beartype==0.19.0
+    # via -r benchmarks/vjepa/requirements.in
+black==24.8.0
+    # via navix
 blinker==1.8.2
     # via flask
+blobfile==3.0.0
+    # via
+    #   -r benchmarks/llm/requirements.txt
+    #   torchtune
+blosc2==2.7.1
+    # via tables
+botorch==0.12.0
+    # via -r benchmarks/recursiongfn/requirements.in
+braceexpand==0.1.7
+    # via
+    #   -r benchmarks/vjepa/requirements.in
+    #   webdataset
 brax==0.10.5
-    # via -r benchmarks/brax/requirements.in
-certifi==2024.6.2
-    # via requests
+    # via
+    #   -r benchmarks/brax/requirements.in
+    #   -r benchmarks/purejaxrl/requirements.in
+cantilever==0.1.0
+    # via -r benchmarks/torchatari/requirements.in
+certifi==2024.8.30
+    # via
+    #   requests
+    #   sentry-sdk
 charset-normalizer==3.3.2
     # via requests
-chex==0.1.86
-    # via optax
+chex==0.1.87
+    # via
+    #   distrax
+    #   evosax
+    #   flashbax
+    #   gymnax
+    #   optax
+    #   rlax
 click==8.1.7
-    # via flask
+    # via
+    #   black
+    #   flask
+    #   wandb
 cloudpickle==3.0.0
-    # via gym
-codefind==0.1.6
+    # via
+    #   gym
+    #   gymnasium
+    #   submitit
+    #   tensorflow-probability
+codefind==0.1.7
     # via ptera
 contextlib2==21.6.0
     # via ml-collections
-datasets==2.20.0
-    # via
-    #   -r benchmarks/accelerate_opt/requirements.in
+contourpy==1.3.0
+    # via matplotlib
+cvxopt==1.3.2
+    # via -r benchmarks/recursiongfn/requirements.in
+cycler==0.12.1
+    # via matplotlib
+datasets==3.0.1
+    # via
+    #   -r benchmarks/diffusion/requirements.in
     #   -r benchmarks/llama/requirements.in
-    #   evaluate
-deepspeed==0.14.4
-    # via -r benchmarks/accelerate_opt/requirements.in
+    #   -r benchmarks/llava/requirements.in
+    #   -r benchmarks/rlhf/requirements.in
+    #   torchtune
+    #   trl
+decorator==5.1.1
+    # via tensorflow-probability
+decord==0.6.0
+    # via -r benchmarks/vjepa/requirements.in
+diffusers[torch]==0.30.3
+    # via -r benchmarks/diffusion/requirements.in
 dill==0.3.8
     # via
     #   datasets
-    #   evaluate
     #   multiprocess
+    #   pylint
+distrax==0.1.5
+    # via
+    #   -r benchmarks/purejaxrl/requirements.in
+    #   rlax
 dm-env==1.6
-    # via brax
+    # via
+    #   brax
+    #   envpool
+    #   rlax
 dm-tree==0.1.8
-    # via dm-env
-docker==7.1.0
-    # via torchx
+    # via
+    #   dm-env
+    #   tensorflow-probability
+docker-pycreds==0.4.0
+    # via wandb
 docstring-parser==0.16
-    # via torchx
-etils[epath,epy]==1.7.0
+    # via tyro
+dotmap==1.3.30
+    # via evosax
+einops==0.8.0
+    # via -r benchmarks/vjepa/requirements.in
+envpool==0.8.4
+    # via -r benchmarks/torchatari/requirements.in
+etils[epath,epy]==1.9.4
     # via
     #   brax
     #   mujoco
     #   mujoco-mjx
     #   optax
     #   orbax-checkpoint
-evaluate==0.4.2
-    # via -r benchmarks/accelerate_opt/requirements.in
-executing==1.2.0
+evosax==0.1.6
+    # via -r benchmarks/purejaxrl/requirements.in
+exceptiongroup==1.2.2
+    # via pytest
+executing==2.1.0
     # via varname
 fairscale==0.4.13
-    # via -r benchmarks/llama/requirements.in
-fbgemm-gpu==0.7.0
-    # via torchrec
-filelock==3.15.4
     # via
+    #   -r benchmarks/llama/requirements.in
+    #   -r benchmarks/llm/requirements.in
+    #   -r benchmarks/llm/requirements.txt
+farama-notifications==0.0.4
+    # via gymnasium
+filelock==3.16.1
+    # via
+    #   blobfile
     #   datasets
-    #   gdown
+    #   diffusers
     #   huggingface-hub
     #   torch
-    #   torchx
     #   transformers
     #   triton
-fire==0.6.0
-    # via -r benchmarks/llama/requirements.in
+fire==0.7.0
+    # via
+    #   -r benchmarks/llama/requirements.in
+    #   -r benchmarks/llm/requirements.txt
+flake8==7.1.1
+    # via navix
+flashbax==0.1.2
+    # via -r benchmarks/purejaxrl/requirements.in
 flask==3.0.3
     # via
     #   brax
     #   flask-cors
-flask-cors==4.0.1
-    # via brax
-flax==0.8.5
+flask-cors==5.0.0
     # via brax
+flax==0.9.0
+    # via
+    #   -r benchmarks/purejaxrl/requirements.in
+    #   brax
+    #   evosax
+    #   flashbax
+    #   gymnax
+    #   navix
+fonttools==4.54.1
+    # via matplotlib
 frozenlist==1.4.1
     # via
     #   aiohttp
     #   aiosignal
-fsspec[http]==2024.5.0
+fsspec[http]==2024.6.1
     # via
     #   datasets
     #   etils
-    #   evaluate
     #   huggingface-hub
+    #   lightning
+    #   pytorch-lightning
     #   torch
-    #   torchx
-future==1.0.0
-    # via -r benchmarks/dlrm/requirements.in
-gdown==5.2.0
-    # via -r benchmarks/stargan/requirements.in
-giving==0.4.2
+    #   torch-geometric
+fvcore==0.1.5.post20221221
+    # via -r benchmarks/dinov2/requirements.in
+gast==0.6.0
+    # via tensorflow-probability
+gitdb==4.0.11
+    # via gitpython
+gitpython==3.1.43
+    # via
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   wandb
+giving==0.4.3
     # via
     #   ptera
     #   voir
 glfw==2.7.0
     # via mujoco
-graphviz==0.20.3
-    # via torchviz
-grpcio==1.65.1
+gpytorch==1.13
+    # via
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   botorch
+grpcio==1.66.2
     # via
     #   brax
     #   tensorboard
 gym==0.26.2
-    # via brax
+    # via
+    #   -r benchmarks/torchatari/requirements.in
+    #   brax
+    #   envpool
+    #   gymnax
 gym-notices==0.0.8
     # via gym
+gymnasium==0.29.1
+    # via
+    #   envpool
+    #   gymnax
+gymnax==0.0.8
+    # via
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/purejaxrl/requirements.in
 hjson==3.1.0
-    # via deepspeed
-huggingface-hub==0.24.0
+    # via argklass
+huggingface-hub==0.25.1
     # via
     #   -r benchmarks/timm/requirements.in
     #   accelerate
     #   datasets
-    #   evaluate
+    #   diffusers
+    #   timm
     #   tokenizers
+    #   torchtune
     #   transformers
-idna==3.7
+humanize==4.10.0
+    # via orbax-checkpoint
+idna==3.10
     # via
     #   requests
     #   yarl
-importlib-metadata==8.0.0
-    # via torchx
-importlib-resources==6.4.0
+importlib-metadata==8.5.0
+    # via diffusers
+importlib-resources==6.4.5
     # via
+    #   argklass
+    #   cantilever
     #   etils
     #   torchcompat
+iniconfig==2.0.0
+    # via pytest
+iopath==0.1.10
+    # via
+    #   -r benchmarks/dinov2/requirements.in
+    #   fvcore
+isort==5.13.2
+    # via pylint
 itsdangerous==2.2.0
     # via flask
-jax[cuda12]==0.4.28
+jax==0.4.33
     # via
     #   -r benchmarks/brax/requirements.in
+    #   -r benchmarks/purejaxrl/requirements.in
     #   brax
     #   chex
+    #   distrax
+    #   evosax
+    #   flashbax
     #   flax
+    #   gymnax
     #   jaxopt
     #   mujoco-mjx
     #   optax
     #   orbax-checkpoint
-jax-cuda12-pjrt==0.4.28
-    # via jax-cuda12-plugin
-jax-cuda12-plugin==0.4.28
-    # via jax
-jaxlib==0.4.28+cuda12.cudnn89
+    #   rlax
+jaxlib==0.4.33
     # via
     #   brax
     #   chex
+    #   distrax
+    #   evosax
+    #   flashbax
+    #   gymnax
     #   jax
     #   jaxopt
     #   mujoco-mjx
     #   optax
     #   orbax-checkpoint
+    #   rlax
 jaxopt==0.8.3
     # via brax
+jaxtyping==0.2.19
+    # via
+    #   gpytorch
+    #   linear-operator
 jinja2==3.1.4
     # via
     #   brax
     #   flask
     #   torch
+    #   torch-geometric
 joblib==1.4.2
     # via scikit-learn
-lightning-utilities==0.11.5
-    # via torchmetrics
-markdown==3.6
+kiwisolver==1.4.7
+    # via matplotlib
+lightning==2.4.0
+    # via -r benchmarks/lightning/requirements.in
+lightning-utilities==0.11.7
+    # via
+    #   lightning
+    #   pytorch-lightning
+    #   torchmetrics
+linear-operator==0.5.3
+    # via
+    #   botorch
+    #   gpytorch
+lxml==5.3.0
+    # via blobfile
+markdown==3.7
     # via tensorboard
 markdown-it-py==3.0.0
     # via rich
@@ -207,410 +362,634 @@ markupsafe==2.1.5
     # via
     #   jinja2
     #   werkzeug
+matplotlib==3.9.2
+    # via
+    #   evosax
+    #   gymnax
+    #   seaborn
+mccabe==0.7.0
+    # via
+    #   flake8
+    #   pylint
 mdurl==0.1.2
     # via markdown-it-py
 ml-collections==0.1.1
     # via brax
-ml-dtypes==0.4.0
+ml-dtypes==0.5.0
     # via
     #   jax
     #   jaxlib
     #   tensorstore
 mpmath==1.3.0
-    # via sympy
-msgpack==1.0.8
     # via
+    #   botorch
+    #   gpytorch
+    #   linear-operator
+    #   sympy
+msgpack==1.1.0
+    # via
+    #   blosc2
     #   flax
     #   orbax-checkpoint
-mujoco==3.2.0
+mujoco==3.2.3
     # via
     #   brax
     #   mujoco-mjx
-mujoco-mjx==3.2.0
+mujoco-mjx==3.2.3
     # via brax
-multidict==6.0.5
+multidict==6.1.0
     # via
     #   aiohttp
     #   yarl
+multipledispatch==1.0.0
+    # via botorch
 multiprocess==0.70.16
-    # via
-    #   datasets
-    #   evaluate
+    # via datasets
 mypy-extensions==1.0.0
-    # via typing-inspect
+    # via black
+navix==0.7.0
+    # via -r benchmarks/purejaxrl/requirements.in
+ndindex==1.9.2
+    # via blosc2
 nest-asyncio==1.6.0
     # via orbax-checkpoint
 networkx==3.3
-    # via torch
-ninja==1.11.1.1
-    # via deepspeed
+    # via
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   torch
+numexpr==2.10.1
+    # via
+    #   blosc2
+    #   tables
 numpy==1.26.4
     # via
-    #   -r benchmarks/dlrm/requirements.in
-    #   -r benchmarks/stargan/requirements.in
-    #   -r benchmarks/super-slomo/requirements.in
+    #   -r benchmarks/geo_gnn/requirements.in
+    #   -r benchmarks/llava/requirements.in
+    #   -r benchmarks/purejaxrl/requirements.in
+    #   -r benchmarks/torchatari/requirements.in
+    #   -r benchmarks/vjepa/requirements.in
     #   accelerate
+    #   blosc2
     #   brax
     #   chex
+    #   contourpy
     #   datasets
-    #   deepspeed
+    #   decord
+    #   diffusers
+    #   distrax
     #   dm-env
-    #   evaluate
+    #   envpool
+    #   evosax
     #   fairscale
-    #   fbgemm-gpu
-    #   flax
+    #   flashbax
+    #   fvcore
     #   gym
+    #   gymnasium
     #   jax
     #   jaxlib
     #   jaxopt
+    #   jaxtyping
+    #   matplotlib
     #   ml-dtypes
     #   mujoco
-    #   onnx
+    #   navix
+    #   numexpr
     #   opencv-python
-    #   opt-einsum
     #   optax
     #   orbax-checkpoint
     #   pandas
     #   pyarrow
+    #   pyro-ppl
+    #   rdkit
+    #   rlax
     #   scikit-learn
     #   scipy
+    #   seaborn
+    #   tables
     #   tensorboard
     #   tensorboardx
+    #   tensorflow-probability
     #   tensorstore
+    #   torch-geometric
     #   torchmetrics
+    #   torchtune
     #   torchvision
     #   transformers
     #   trimesh
+    #   trl
+    #   webdataset
+    #   xformers
 nvidia-cublas-cu12==12.1.3.1
     # via
-    #   jax
     #   nvidia-cudnn-cu12
     #   nvidia-cusolver-cu12
     #   torch
 nvidia-cuda-cupti-cu12==12.1.105
-    # via
-    #   jax
-    #   torch
-nvidia-cuda-nvcc-cu12==12.5.82
-    # via
-    #   jax
-    #   jax-cuda12-plugin
+    # via torch
 nvidia-cuda-nvrtc-cu12==12.1.105
     # via torch
 nvidia-cuda-runtime-cu12==12.1.105
-    # via
-    #   jax
-    #   torch
-nvidia-cudnn-cu12==8.9.2.26
-    # via
-    #   jax
-    #   torch
+    # via torch
+nvidia-cudnn-cu12==9.1.0.70
+    # via torch
 nvidia-cufft-cu12==11.0.2.54
-    # via
-    #   jax
-    #   torch
+    # via torch
 nvidia-curand-cu12==10.3.2.106
     # via torch
 nvidia-cusolver-cu12==11.4.5.107
-    # via
-    #   jax
-    #   torch
+    # via torch
 nvidia-cusparse-cu12==12.1.0.106
     # via
-    #   jax
     #   nvidia-cusolver-cu12
     #   torch
-nvidia-ml-py==12.555.43
-    # via deepspeed
+nvidia-ml-py==12.560.30
+    # via voir
 nvidia-nccl-cu12==2.20.5
+    # via torch
+nvidia-nvjitlink-cu12==12.6.77
     # via
-    #   jax
-    #   torch
-nvidia-nvjitlink-cu12==12.5.82
-    # via
-    #   jax
     #   nvidia-cusolver-cu12
     #   nvidia-cusparse-cu12
 nvidia-nvtx-cu12==12.1.105
     # via torch
 omegaconf==2.3.0
-    # via voir
-onnx==1.16.1
-    # via -r benchmarks/dlrm/requirements.in
+    # via
+    #   -r benchmarks/dinov2/requirements.in
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   torchtune
+    #   voir
 opencv-python==4.10.0.84
-    # via -r benchmarks/super-slomo/requirements.in
-opt-einsum==3.3.0
-    # via jax
+    # via -r benchmarks/vjepa/requirements.in
+opt-einsum==3.4.0
+    # via
+    #   jax
+    #   pyro-ppl
 optax==0.2.3
     # via
+    #   -r benchmarks/purejaxrl/requirements.in
     #   brax
     #   flax
-orbax-checkpoint==0.5.21
+optree==0.13.0
+    # via envpool
+orbax-checkpoint==0.6.4
     # via
     #   brax
     #   flax
-ovld==0.3.5
+ovld==0.3.9
     # via voir
 packaging==24.1
     # via
     #   accelerate
+    #   black
     #   datasets
-    #   deepspeed
-    #   evaluate
+    #   envpool
     #   huggingface-hub
+    #   lightning
     #   lightning-utilities
+    #   matplotlib
+    #   pytest
+    #   pytorch-lightning
+    #   setuptools-scm
+    #   tables
+    #   tensorboard
     #   tensorboardx
     #   torchmetrics
     #   transformers
-pandas==2.2.2
+pandas==2.2.3
     # via
+    #   -r benchmarks/geo_gnn/requirements.in
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   -r benchmarks/vjepa/requirements.in
     #   datasets
-    #   evaluate
+    #   seaborn
+pathspec==0.12.1
+    # via black
 pillow==10.4.0
     # via
+    #   -r benchmarks/huggingface/requirements.in
+    #   -r benchmarks/llava/requirements.in
     #   brax
+    #   diffusers
+    #   fvcore
+    #   matplotlib
+    #   navix
+    #   rdkit
     #   torchvision
-protobuf==4.25.3
+platformdirs==4.3.6
+    # via
+    #   black
+    #   pylint
+    #   wandb
+pluggy==1.5.0
+    # via pytest
+portalocker==2.10.1
+    # via iopath
+protobuf==5.28.2
     # via
-    #   onnx
     #   orbax-checkpoint
     #   tensorboard
     #   tensorboardx
+    #   wandb
 psutil==5.9.8
     # via
     #   accelerate
-    #   deepspeed
+    #   torch-geometric
     #   voir
+    #   wandb
 ptera==1.4.1
     # via voir
 py-cpuinfo==9.0.0
-    # via deepspeed
+    # via
+    #   blosc2
+    #   tables
 pyarrow==17.0.0
-    # via datasets
-pyarrow-hotfix==0.6
-    # via datasets
-pydantic==2.7.4
-    # via deepspeed
-pydantic-core==2.18.4
-    # via pydantic
-pydot==3.0.1
-    # via -r benchmarks/dlrm/requirements.in
+    # via
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   datasets
+pycodestyle==2.12.1
+    # via flake8
+pycryptodomex==3.21.0
+    # via blobfile
+pyflakes==3.2.0
+    # via flake8
 pygments==2.18.0
     # via rich
-pynvml==11.5.3
-    # via voir
+pylint==3.3.1
+    # via navix
 pyopengl==3.1.7
     # via mujoco
-pyparsing==3.1.2
-    # via pydot
-pyre-extensions==0.0.30
-    # via torchx
-pysocks==1.7.1
-    # via requests
+pyparsing==3.1.4
+    # via
+    #   matplotlib
+    #   torch-geometric
+pyro-api==0.1.2
+    # via pyro-ppl
+pyro-ppl==1.9.1
+    # via
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   botorch
+pytest==8.3.3
+    # via navix
 python-dateutil==2.9.0.post0
-    # via pandas
+    # via
+    #   matplotlib
+    #   pandas
 pytinyrenderer==0.0.14
     # via brax
-pytz==2024.1
+pytorch-lightning==2.4.0
+    # via lightning
+pytz==2024.2
     # via pandas
-pyyaml==6.0.1
+pyyaml==6.0.2
     # via
+    #   -r benchmarks/llm/requirements.in
     #   -r benchmarks/timm/requirements.in
+    #   -r benchmarks/vjepa/requirements.in
     #   accelerate
     #   datasets
+    #   evosax
     #   flax
+    #   fvcore
+    #   gymnax
     #   huggingface-hub
+    #   lightning
     #   ml-collections
     #   omegaconf
     #   orbax-checkpoint
-    #   torchx
+    #   pytorch-lightning
+    #   timm
     #   transformers
+    #   wandb
+    #   webdataset
+    #   yacs
+rdkit==2024.3.5
+    # via
+    #   -r benchmarks/geo_gnn/requirements.in
+    #   -r benchmarks/recursiongfn/requirements.in
 reactivex==4.0.4
     # via giving
-regex==2024.5.15
-    # via transformers
-requests[socks]==2.32.3
+regex==2024.9.11
+    # via
+    #   diffusers
+    #   tiktoken
+    #   transformers
+requests==2.32.3
     # via
     #   datasets
-    #   docker
-    #   evaluate
-    #   gdown
+    #   diffusers
     #   huggingface-hub
+    #   tiktoken
+    #   torch-geometric
     #   transformers
-rich==13.7.1
+    #   wandb
+rich==13.9.1
     # via
-    #   -r benchmarks/accelerate_opt/requirements.in
     #   flax
+    #   tyro
     #   voir
-safetensors==0.4.3
+rlax==0.1.6
+    # via navix
+safetensors==0.4.5
     # via
     #   -r benchmarks/timm/requirements.in
     #   accelerate
+    #   diffusers
+    #   timm
+    #   torchtune
     #   transformers
-scikit-learn==1.5.1
-    # via -r benchmarks/dlrm/requirements.in
-scipy==1.14.0
+scikit-learn==1.5.2
+    # via gpytorch
+scipy==1.14.1
     # via
+    #   -r benchmarks/dinov2/requirements.in
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   botorch
     #   brax
+    #   gpytorch
     #   jax
     #   jaxlib
     #   jaxopt
+    #   linear-operator
     #   mujoco-mjx
     #   scikit-learn
+    #   torch-cluster
+    #   torch-sparse
+seaborn==0.13.2
+    # via gymnax
 sentencepiece==0.2.0
-    # via -r benchmarks/llama/requirements.in
+    # via
+    #   -r benchmarks/llama/requirements.in
+    #   torchtune
+sentry-sdk==2.15.0
+    # via wandb
+setproctitle==1.3.3
+    # via wandb
+setuptools-scm==8.1.0
+    # via navix
+shtab==1.7.1
+    # via tyro
 six==1.16.0
     # via
     #   asttokens
-    #   fire
+    #   docker-pycreds
     #   ml-collections
     #   python-dateutil
     #   tensorboard
-soupsieve==2.5
-    # via beautifulsoup4
-sympy==1.13.0
+    #   tensorflow-probability
+smmap==5.0.1
+    # via gitdb
+submitit==1.5.2
+    # via
+    #   -r benchmarks/dinov2/requirements.in
+    #   -r benchmarks/vjepa/requirements.in
+sympy==1.13.3
     # via torch
+tables==3.10.1
+    # via -r benchmarks/recursiongfn/requirements.in
 tabulate==0.9.0
-    # via torchx
-tensorboard==2.17.0
-    # via -r benchmarks/dlrm/requirements.in
+    # via fvcore
+tensorboard==2.18.0
+    # via
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   -r benchmarks/torchatari/requirements.in
 tensorboard-data-server==0.7.2
     # via tensorboard
 tensorboardx==2.6.2.2
     # via brax
-tensorstore==0.1.63
+tensorflow-probability==0.24.0
+    # via distrax
+tensorstore==0.1.66
     # via
+    #   flashbax
     #   flax
     #   orbax-checkpoint
 termcolor==2.4.0
-    # via fire
+    # via
+    #   fire
+    #   fvcore
 threadpoolctl==3.5.0
     # via scikit-learn
+tiktoken==0.7.0
+    # via torchtune
+timm==1.0.9
+    # via -r benchmarks/vjepa/requirements.in
 tokenizers==0.19.1
     # via transformers
+tomli==2.0.2
+    # via
+    #   black
+    #   pylint
+    #   pytest
+    #   setuptools-scm
+tomlkit==0.13.2
+    # via pylint
 toolz==0.12.1
     # via chex
-torch==2.3.1
+torch==2.4.1
     # via
-    #   -r benchmarks/accelerate_opt/requirements.in
     #   -r benchmarks/brax/requirements.in
-    #   -r benchmarks/dlrm/requirements.in
+    #   -r benchmarks/dinov2/requirements.in
     #   -r benchmarks/flops/requirements.in
+    #   -r benchmarks/geo_gnn/requirements-pre.in
     #   -r benchmarks/huggingface/requirements.in
+    #   -r benchmarks/lightning/requirements.in
     #   -r benchmarks/llama/requirements.in
-    #   -r benchmarks/stargan/requirements.in
-    #   -r benchmarks/super-slomo/requirements.in
+    #   -r benchmarks/llava/requirements.in
+    #   -r benchmarks/llm/requirements.in
+    #   -r benchmarks/llm/requirements.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   -r benchmarks/rlhf/requirements.in
     #   -r benchmarks/timm/requirements.in
+    #   -r benchmarks/torchatari/requirements.in
     #   -r benchmarks/torchvision/requirements.in
     #   -r benchmarks/torchvision_ddp/requirements.in
+    #   -r benchmarks/vjepa/requirements.in
     #   accelerate
-    #   deepspeed
+    #   botorch
+    #   diffusers
     #   fairscale
-    #   torchaudio
+    #   lightning
+    #   linear-operator
+    #   pyro-ppl
+    #   pytorch-lightning
+    #   timm
     #   torchmetrics
     #   torchvision
-    #   torchviz
-torchaudio==2.3.1
-    # via -r benchmarks/accelerate_opt/requirements.in
+    #   trl
+    #   xformers
+torch-cluster==1.6.3
+    # via
+    #   -r benchmarks/geo_gnn/requirements.in
+    #   -r benchmarks/recursiongfn/requirements.in
+torch-geometric==2.6.1
+    # via
+    #   -r benchmarks/geo_gnn/requirements.in
+    #   -r benchmarks/recursiongfn/requirements.in
+torch-scatter==2.1.2
+    # via
+    #   -r benchmarks/geo_gnn/requirements.in
+    #   -r benchmarks/recursiongfn/requirements.in
+torch-sparse==0.6.18
+    # via
+    #   -r benchmarks/geo_gnn/requirements.in
+    #   -r benchmarks/recursiongfn/requirements.in
+torchao==0.3.1
+    # via
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/llm/requirements.in
+    #   torchtune
 torchcompat==1.1.4
     # via
     #   -c .pin/../constraints/hpu.txt
     #   -r benchmarks/flops/requirements.in
+    #   -r benchmarks/lightning/requirements.in
+    #   -r benchmarks/torchatari/requirements.in
     #   -r benchmarks/torchvision/requirements.in
     #   -r benchmarks/torchvision_ddp/requirements.in
-torchmetrics==1.0.3
-    # via torchrec
-torchrec==0.7.0
-    # via -r benchmarks/dlrm/requirements.in
-torchvision==0.18.1
+torchmetrics==1.4.2
+    # via
+    #   -r benchmarks/dinov2/requirements.in
+    #   lightning
+    #   pytorch-lightning
+torchtune==0.2.1
+    # via
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/llm/requirements.in
+torchvision==0.19.1
     # via
-    #   -r benchmarks/accelerate_opt/requirements.in
+    #   -r benchmarks/diffusion/requirements.in
+    #   -r benchmarks/dinov2/requirements.in
     #   -r benchmarks/flops/requirements.in
-    #   -r benchmarks/stargan/requirements.in
-    #   -r benchmarks/super-slomo/requirements.in
+    #   -r benchmarks/lightning/requirements.in
     #   -r benchmarks/timm/requirements.in
     #   -r benchmarks/torchvision/requirements.in
     #   -r benchmarks/torchvision_ddp/requirements.in
-torchviz==0.0.2
-    # via -r benchmarks/dlrm/requirements.in
-torchx==0.7.0
-    # via -r benchmarks/dlrm/requirements.in
-tqdm==4.66.4
+    #   -r benchmarks/vjepa/requirements.in
+    #   timm
+tqdm==4.66.5
     # via
-    #   -r benchmarks/dlrm/requirements.in
+    #   -r benchmarks/diffusion/requirements.in
     #   -r benchmarks/flops/requirements.in
-    #   -r benchmarks/super-slomo/requirements.in
     #   -r benchmarks/torchvision/requirements.in
     #   -r benchmarks/torchvision_ddp/requirements.in
     #   datasets
-    #   deepspeed
-    #   evaluate
-    #   gdown
+    #   fvcore
     #   huggingface-hub
-    #   torchrec
+    #   iopath
+    #   lightning
+    #   pyro-ppl
+    #   pytorch-lightning
+    #   torch-geometric
+    #   torchtune
     #   transformers
-transformers==4.42.4
+transformers==4.44.2
     # via
-    #   -r benchmarks/accelerate_opt/requirements.in
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/diffusion/requirements.in
     #   -r benchmarks/huggingface/requirements.in
     #   -r benchmarks/llama/requirements.in
-trimesh==4.4.3
+    #   -r benchmarks/llava/requirements.in
+    #   -r benchmarks/llm/requirements.in
+    #   -r benchmarks/rlhf/requirements.in
+    #   trl
+trimesh==4.4.9
     # via
     #   brax
     #   mujoco-mjx
-triton==2.3.1
+triton==3.0.0
     # via torch
+trl==0.10.1
+    # via
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/rlhf/requirements.in
+typeguard==4.3.0
+    # via jaxtyping
+types-protobuf==5.28.0.20240924
+    # via envpool
 typing-extensions==4.12.2
     # via
+    #   astroid
+    #   black
+    #   botorch
     #   brax
     #   chex
+    #   envpool
     #   etils
+    #   flashbax
     #   flax
+    #   gymnasium
     #   huggingface-hub
+    #   iopath
+    #   jaxtyping
+    #   lightning
     #   lightning-utilities
+    #   multidict
+    #   navix
+    #   optree
     #   orbax-checkpoint
-    #   pydantic
-    #   pydantic-core
-    #   pyre-extensions
+    #   pytorch-lightning
     #   reactivex
+    #   rich
+    #   submitit
+    #   tables
     #   torch
-    #   typing-inspect
-typing-inspect==0.9.0
-    # via pyre-extensions
-tzdata==2024.1
+    #   typeguard
+    #   tyro
+tyro==0.8.11
+    # via
+    #   -r benchmarks/torchatari/requirements.in
+    #   navix
+    #   trl
+tzdata==2024.2
     # via pandas
-urllib3==1.26.19
+urllib3==2.2.3
     # via
-    #   docker
+    #   blobfile
     #   requests
-    #   torchx
-varname==0.10.0
+    #   sentry-sdk
+varname==0.13.3
     # via giving
 voir==0.2.19
     # via
     #   -c .pin/../constraints/hpu.txt
-    #   -r benchmarks/accelerate_opt/requirements.in
     #   -r benchmarks/brax/requirements.in
-    #   -r benchmarks/dlrm/requirements.in
+    #   -r benchmarks/diffusion/requirements.in
+    #   -r benchmarks/dinov2/requirements.in
     #   -r benchmarks/flops/requirements.in
+    #   -r benchmarks/geo_gnn/requirements.in
     #   -r benchmarks/huggingface/requirements.in
+    #   -r benchmarks/lightning/requirements.in
     #   -r benchmarks/llama/requirements.in
-    #   -r benchmarks/stargan/requirements.in
-    #   -r benchmarks/super-slomo/requirements.in
+    #   -r benchmarks/llava/requirements.in
+    #   -r benchmarks/llm/requirements.in
+    #   -r benchmarks/purejaxrl/requirements.in
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   -r benchmarks/rlhf/requirements.in
     #   -r benchmarks/timm/requirements.in
+    #   -r benchmarks/torchatari/requirements.in
     #   -r benchmarks/torchvision/requirements.in
     #   -r benchmarks/torchvision_ddp/requirements.in
-werkzeug==3.0.3
+    #   -r benchmarks/vjepa/requirements.in
+wandb==0.18.3
+    # via
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   navix
+webdataset==0.2.100
+    # via -r benchmarks/vjepa/requirements.in
+werkzeug==3.0.4
     # via
     #   flask
     #   tensorboard
-xxhash==3.4.1
-    # via
-    #   datasets
-    #   evaluate
-yarl==1.9.4
+xformers==0.0.28.post1
+    # via -r benchmarks/dinov2/requirements.in
+xxhash==3.5.0
+    # via datasets
+yacs==0.1.8
+    # via fvcore
+yarl==1.13.1
     # via aiohttp
-zipp==3.19.2
+zipp==3.20.2
     # via
     #   etils
     #   importlib-metadata
diff --git a/.pin/constraints-rocm-torch.txt b/.pin/constraints-rocm-torch.txt
index 4fe6ae9da..ecc49d511 100644
--- a/.pin/constraints-rocm-torch.txt
+++ b/.pin/constraints-rocm-torch.txt
@@ -2,31 +2,39 @@
 # This file is autogenerated by pip-compile with Python 3.10
 # by the following command:
 #
-#    pip-compile --output-file=.pin/constraints-rocm-torch.txt .pin/tmp-constraints.txt benchmarks/brax/requirements.in benchmarks/diffusion/requirements.in benchmarks/dinov2/requirements.in benchmarks/flops/requirements.in benchmarks/huggingface/requirements.in benchmarks/lightning/requirements.in benchmarks/llama/requirements.in benchmarks/llm/requirements.in benchmarks/super-slomo/requirements.in benchmarks/timm/requirements.in benchmarks/torchatari/requirements.in benchmarks/torchvision/requirements.in benchmarks/torchvision_ddp/requirements.in
+#    pip-compile --output-file=.pin/constraints-rocm-torch.txt .pin/tmp-constraints.txt benchmarks/brax/requirements.in benchmarks/diffusion/requirements.in benchmarks/dinov2/requirements.in benchmarks/flops/requirements.in benchmarks/geo_gnn/requirements-pre.in benchmarks/geo_gnn/requirements.in benchmarks/huggingface/requirements.in benchmarks/lightning/requirements.in benchmarks/llama/requirements.in benchmarks/llava/requirements.in benchmarks/llm/requirements.in benchmarks/purejaxrl/requirements.in benchmarks/recursiongfn/requirements.in benchmarks/rlhf/requirements.in benchmarks/timm/requirements.in benchmarks/torchatari/requirements.in benchmarks/torchvision/requirements.in benchmarks/torchvision_ddp/requirements.in benchmarks/vjepa/requirements.in constraints/extra/torch.rocm.txt
 #
---extra-index-url https://download.pytorch.org/whl/rocm6.0
+--extra-index-url https://download.pytorch.org/whl/rocm6.1
 
 absl-py==2.1.0
     # via
     #   brax
     #   chex
+    #   distrax
     #   dm-env
     #   ml-collections
     #   mujoco
     #   mujoco-mjx
     #   optax
     #   orbax-checkpoint
+    #   rlax
     #   tensorboard
-accelerate==0.33.0
+    #   tensorflow-probability
+accelerate==0.34.2
     # via
     #   -r benchmarks/diffusion/requirements.in
+    #   -r benchmarks/llava/requirements.in
+    #   -r benchmarks/llm/requirements.in
+    #   -r benchmarks/rlhf/requirements.in
     #   diffusers
-aiohappyeyeballs==2.4.0
+    #   trl
+aiohappyeyeballs==2.4.3
     # via aiohttp
-aiohttp==3.10.5
+aiohttp==3.10.8
     # via
     #   datasets
     #   fsspec
+    #   torch-geometric
 aiosignal==1.3.1
     # via aiohttp
 antlr4-python3-runtime==4.9.3
@@ -37,72 +45,137 @@ argklass==1.4.4
     # via
     #   -r benchmarks/diffusion/requirements.in
     #   -r benchmarks/llm/requirements.in
+    #   -r benchmarks/purejaxrl/requirements.in
+astroid==3.3.4
+    # via pylint
 asttokens==2.4.1
     # via giving
 async-timeout==4.0.3
     # via aiohttp
 attrs==24.2.0
     # via aiohttp
+beartype==0.19.0
+    # via -r benchmarks/vjepa/requirements.in
+black==24.8.0
+    # via navix
 blinker==1.8.2
     # via flask
-blobfile==2.1.1
-    # via torchtune
+blobfile==3.0.0
+    # via
+    #   -r benchmarks/llm/requirements.txt
+    #   torchtune
+blosc2==2.7.1
+    # via tables
+botorch==0.12.0
+    # via -r benchmarks/recursiongfn/requirements.in
+braceexpand==0.1.7
+    # via
+    #   -r benchmarks/vjepa/requirements.in
+    #   webdataset
 brax==0.10.5
-    # via -r benchmarks/brax/requirements.in
+    # via
+    #   -r benchmarks/brax/requirements.in
+    #   -r benchmarks/purejaxrl/requirements.in
 cantilever==0.1.0
     # via -r benchmarks/torchatari/requirements.in
-certifi==2024.7.4
-    # via requests
+certifi==2024.8.30
+    # via
+    #   requests
+    #   sentry-sdk
 charset-normalizer==3.3.2
     # via requests
-chex==0.1.86
-    # via optax
+chex==0.1.87
+    # via
+    #   distrax
+    #   evosax
+    #   flashbax
+    #   gymnax
+    #   optax
+    #   rlax
 click==8.1.7
-    # via flask
+    # via
+    #   black
+    #   flask
+    #   wandb
 cloudpickle==3.0.0
     # via
     #   gym
     #   gymnasium
     #   submitit
-codefind==0.1.6
+    #   tensorflow-probability
+codefind==0.1.7
     # via ptera
 contextlib2==21.6.0
     # via ml-collections
-datasets==2.21.0
+contourpy==1.3.0
+    # via matplotlib
+cvxopt==1.3.2
+    # via -r benchmarks/recursiongfn/requirements.in
+cycler==0.12.1
+    # via matplotlib
+datasets==3.0.1
     # via
     #   -r benchmarks/diffusion/requirements.in
     #   -r benchmarks/llama/requirements.in
+    #   -r benchmarks/llava/requirements.in
+    #   -r benchmarks/rlhf/requirements.in
     #   torchtune
-diffusers[torch]==0.30.0
+    #   trl
+decorator==5.1.1
+    # via tensorflow-probability
+decord==0.6.0
+    # via -r benchmarks/vjepa/requirements.in
+diffusers[torch]==0.30.3
     # via -r benchmarks/diffusion/requirements.in
 dill==0.3.8
     # via
     #   datasets
     #   multiprocess
+    #   pylint
+distrax==0.1.5
+    # via
+    #   -r benchmarks/purejaxrl/requirements.in
+    #   rlax
 dm-env==1.6
     # via
     #   brax
     #   envpool
+    #   rlax
 dm-tree==0.1.8
-    # via dm-env
+    # via
+    #   dm-env
+    #   tensorflow-probability
+docker-pycreds==0.4.0
+    # via wandb
 docstring-parser==0.16
     # via tyro
+dotmap==1.3.30
+    # via evosax
+einops==0.8.0
+    # via -r benchmarks/vjepa/requirements.in
 envpool==0.8.4
     # via -r benchmarks/torchatari/requirements.in
-etils[epath,epy]==1.7.0
+etils[epath,epy]==1.9.4
     # via
     #   brax
     #   mujoco
     #   mujoco-mjx
     #   optax
     #   orbax-checkpoint
-executing==1.2.0
+evosax==0.1.6
+    # via -r benchmarks/purejaxrl/requirements.in
+exceptiongroup==1.2.2
+    # via pytest
+executing==2.1.0
     # via varname
 fairscale==0.4.13
-    # via -r benchmarks/llama/requirements.in
+    # via
+    #   -r benchmarks/llama/requirements.in
+    #   -r benchmarks/llm/requirements.in
+    #   -r benchmarks/llm/requirements.txt
 farama-notifications==0.0.4
     # via gymnasium
-filelock==3.15.4
+filelock==3.16.1
     # via
     #   blobfile
     #   datasets
@@ -111,16 +184,30 @@ filelock==3.15.4
     #   pytorch-triton-rocm
     #   torch
     #   transformers
-fire==0.6.0
-    # via -r benchmarks/llama/requirements.in
+fire==0.7.0
+    # via
+    #   -r benchmarks/llama/requirements.in
+    #   -r benchmarks/llm/requirements.txt
+flake8==7.1.1
+    # via navix
+flashbax==0.1.2
+    # via -r benchmarks/purejaxrl/requirements.in
 flask==3.0.3
     # via
     #   brax
     #   flask-cors
-flask-cors==4.0.1
-    # via brax
-flax==0.8.5
+flask-cors==5.0.0
     # via brax
+flax==0.9.0
+    # via
+    #   -r benchmarks/purejaxrl/requirements.in
+    #   brax
+    #   evosax
+    #   flashbax
+    #   gymnax
+    #   navix
+fonttools==4.54.1
+    # via matplotlib
 frozenlist==1.4.1
     # via
     #   aiohttp
@@ -133,92 +220,141 @@ fsspec[http]==2024.6.1
     #   lightning
     #   pytorch-lightning
     #   torch
+    #   torch-geometric
 fvcore==0.1.5.post20221221
     # via -r benchmarks/dinov2/requirements.in
-giving==0.4.2
+gast==0.6.0
+    # via tensorflow-probability
+gitdb==4.0.11
+    # via gitpython
+gitpython==3.1.43
+    # via
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   wandb
+giving==0.4.3
     # via
     #   ptera
     #   voir
 glfw==2.7.0
     # via mujoco
-grpcio==1.65.5
+gpytorch==1.13
+    # via
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   botorch
+grpcio==1.66.2
     # via
     #   brax
     #   tensorboard
-gym==0.23.1
+gym==0.26.2
     # via
     #   -r benchmarks/torchatari/requirements.in
     #   brax
     #   envpool
+    #   gymnax
 gym-notices==0.0.8
     # via gym
 gymnasium==0.29.1
-    # via envpool
+    # via
+    #   envpool
+    #   gymnax
+gymnax==0.0.8
+    # via
+    #   -c .pin/../constraints/rocm.txt
+    #   -r benchmarks/purejaxrl/requirements.in
 hjson==3.1.0
     # via argklass
-huggingface-hub==0.24.6
+huggingface-hub==0.25.1
     # via
     #   -r benchmarks/timm/requirements.in
     #   accelerate
     #   datasets
     #   diffusers
+    #   timm
     #   tokenizers
     #   torchtune
     #   transformers
 humanize==4.10.0
     # via orbax-checkpoint
-idna==3.7
+idna==3.10
     # via
     #   requests
     #   yarl
-importlib-metadata==8.4.0
+importlib-metadata==8.5.0
     # via diffusers
-importlib-resources==6.4.3
+importlib-resources==6.4.5
     # via
     #   argklass
     #   cantilever
     #   etils
     #   torchcompat
+iniconfig==2.0.0
+    # via pytest
 iopath==0.1.10
     # via
     #   -r benchmarks/dinov2/requirements.in
     #   fvcore
+isort==5.13.2
+    # via pylint
 itsdangerous==2.2.0
     # via flask
-jax==0.4.31
+jax==0.4.33
     # via
     #   -r benchmarks/brax/requirements.in
+    #   -r benchmarks/purejaxrl/requirements.in
     #   brax
     #   chex
+    #   distrax
+    #   evosax
+    #   flashbax
     #   flax
+    #   gymnax
     #   jaxopt
     #   mujoco-mjx
     #   optax
     #   orbax-checkpoint
-jaxlib==0.4.31
+    #   rlax
+jaxlib==0.4.33
     # via
     #   brax
     #   chex
+    #   distrax
+    #   evosax
+    #   flashbax
+    #   gymnax
     #   jax
     #   jaxopt
     #   mujoco-mjx
     #   optax
     #   orbax-checkpoint
+    #   rlax
 jaxopt==0.8.3
     # via brax
+jaxtyping==0.2.19
+    # via
+    #   gpytorch
+    #   linear-operator
 jinja2==3.1.4
     # via
     #   brax
     #   flask
     #   torch
+    #   torch-geometric
+joblib==1.4.2
+    # via scikit-learn
+kiwisolver==1.4.7
+    # via matplotlib
 lightning==2.4.0
     # via -r benchmarks/lightning/requirements.in
-lightning-utilities==0.11.6
+lightning-utilities==0.11.7
     # via
     #   lightning
     #   pytorch-lightning
     #   torchmetrics
-lxml==4.9.4
+linear-operator==0.5.3
+    # via
+    #   botorch
+    #   gpytorch
+lxml==5.3.0
     # via blobfile
 markdown==3.7
     # via tensorboard
@@ -228,169 +364,284 @@ markupsafe==2.1.5
     # via
     #   jinja2
     #   werkzeug
+matplotlib==3.9.2
+    # via
+    #   evosax
+    #   gymnax
+    #   seaborn
+mccabe==0.7.0
+    # via
+    #   flake8
+    #   pylint
 mdurl==0.1.2
     # via markdown-it-py
 ml-collections==0.1.1
     # via brax
-ml-dtypes==0.4.0
+ml-dtypes==0.5.0
     # via
     #   jax
     #   jaxlib
     #   tensorstore
 mpmath==1.3.0
-    # via sympy
-msgpack==1.0.8
     # via
+    #   botorch
+    #   gpytorch
+    #   linear-operator
+    #   sympy
+msgpack==1.1.0
+    # via
+    #   blosc2
     #   flax
     #   orbax-checkpoint
-mujoco==3.2.2
+mujoco==3.2.3
     # via
     #   brax
     #   mujoco-mjx
-mujoco-mjx==3.2.2
+mujoco-mjx==3.2.3
     # via brax
-multidict==6.0.5
+multidict==6.1.0
     # via
     #   aiohttp
     #   yarl
+multipledispatch==1.0.0
+    # via botorch
 multiprocess==0.70.16
     # via datasets
+mypy-extensions==1.0.0
+    # via black
+navix==0.7.0
+    # via -r benchmarks/purejaxrl/requirements.in
+ndindex==1.9.2
+    # via blosc2
 nest-asyncio==1.6.0
     # via orbax-checkpoint
 networkx==3.3
-    # via torch
+    # via
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   torch
+numexpr==2.10.1
+    # via
+    #   blosc2
+    #   tables
 numpy==1.26.4
     # via
-    #   -r benchmarks/super-slomo/requirements.in
+    #   -r benchmarks/geo_gnn/requirements.in
+    #   -r benchmarks/llava/requirements.in
+    #   -r benchmarks/purejaxrl/requirements.in
     #   -r benchmarks/torchatari/requirements.in
+    #   -r benchmarks/vjepa/requirements.in
     #   accelerate
+    #   blosc2
     #   brax
     #   chex
+    #   contourpy
     #   datasets
+    #   decord
     #   diffusers
+    #   distrax
     #   dm-env
     #   envpool
+    #   evosax
     #   fairscale
-    #   flax
+    #   flashbax
     #   fvcore
     #   gym
     #   gymnasium
     #   jax
     #   jaxlib
     #   jaxopt
+    #   jaxtyping
+    #   matplotlib
     #   ml-dtypes
     #   mujoco
+    #   navix
+    #   numexpr
     #   opencv-python
-    #   opt-einsum
     #   optax
     #   orbax-checkpoint
     #   pandas
     #   pyarrow
+    #   pyro-ppl
+    #   rdkit
+    #   rlax
+    #   scikit-learn
     #   scipy
+    #   seaborn
+    #   tables
     #   tensorboard
     #   tensorboardx
+    #   tensorflow-probability
     #   tensorstore
+    #   torch-geometric
     #   torchmetrics
     #   torchtune
     #   torchvision
     #   transformers
     #   trimesh
+    #   trl
+    #   webdataset
     #   xformers
+nvidia-ml-py==12.560.30
+    # via voir
 omegaconf==2.3.0
     # via
     #   -r benchmarks/dinov2/requirements.in
+    #   -r benchmarks/recursiongfn/requirements.in
     #   torchtune
     #   voir
 opencv-python==4.10.0.84
-    # via -r benchmarks/super-slomo/requirements.in
-opt-einsum==3.3.0
-    # via jax
+    # via -r benchmarks/vjepa/requirements.in
+opt-einsum==3.4.0
+    # via
+    #   jax
+    #   pyro-ppl
 optax==0.2.3
     # via
+    #   -r benchmarks/purejaxrl/requirements.in
     #   brax
     #   flax
-optree==0.12.1
+optree==0.13.0
     # via envpool
-orbax-checkpoint==0.6.0
+orbax-checkpoint==0.6.4
     # via
     #   brax
     #   flax
-ovld==0.3.8
+ovld==0.3.9
     # via voir
 packaging==24.1
     # via
     #   accelerate
+    #   black
     #   datasets
     #   envpool
     #   huggingface-hub
     #   lightning
     #   lightning-utilities
+    #   matplotlib
+    #   pytest
     #   pytorch-lightning
+    #   setuptools-scm
+    #   tables
     #   tensorboard
     #   tensorboardx
     #   torchmetrics
     #   transformers
-pandas==2.2.2
-    # via datasets
+pandas==2.2.3
+    # via
+    #   -r benchmarks/geo_gnn/requirements.in
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   -r benchmarks/vjepa/requirements.in
+    #   datasets
+    #   seaborn
+pathspec==0.12.1
+    # via black
 pillow==10.4.0
     # via
     #   -r benchmarks/huggingface/requirements.in
+    #   -r benchmarks/llava/requirements.in
     #   brax
     #   diffusers
     #   fvcore
+    #   matplotlib
+    #   navix
+    #   rdkit
     #   torchvision
+platformdirs==4.3.6
+    # via
+    #   black
+    #   pylint
+    #   wandb
+pluggy==1.5.0
+    # via pytest
 portalocker==2.10.1
     # via iopath
-protobuf==5.27.3
+protobuf==5.28.2
     # via
     #   orbax-checkpoint
     #   tensorboard
     #   tensorboardx
+    #   wandb
 psutil==5.9.8
     # via
     #   accelerate
+    #   torch-geometric
     #   voir
+    #   wandb
 ptera==1.4.1
     # via voir
+py-cpuinfo==9.0.0
+    # via
+    #   blosc2
+    #   tables
 pyarrow==17.0.0
-    # via datasets
-pycryptodomex==3.20.0
+    # via
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   datasets
+pycodestyle==2.12.1
+    # via flake8
+pycryptodomex==3.21.0
     # via blobfile
+pyflakes==3.2.0
+    # via flake8
 pygments==2.18.0
     # via rich
-pynvml==11.5.3
-    # via voir
+pylint==3.3.1
+    # via navix
 pyopengl==3.1.7
     # via mujoco
+pyparsing==3.1.4
+    # via
+    #   matplotlib
+    #   torch-geometric
+pyro-api==0.1.2
+    # via pyro-ppl
+pyro-ppl==1.9.1
+    # via
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   botorch
+pytest==8.3.3
+    # via navix
 python-dateutil==2.9.0.post0
-    # via pandas
+    # via
+    #   matplotlib
+    #   pandas
 pytinyrenderer==0.0.14
     # via brax
 pytorch-lightning==2.4.0
     # via lightning
 pytorch-triton-rocm==3.0.0
     # via torch
-pytz==2024.1
+pytz==2024.2
     # via pandas
 pyyaml==6.0.2
     # via
     #   -r benchmarks/llm/requirements.in
     #   -r benchmarks/timm/requirements.in
+    #   -r benchmarks/vjepa/requirements.in
     #   accelerate
     #   datasets
+    #   evosax
     #   flax
     #   fvcore
+    #   gymnax
     #   huggingface-hub
     #   lightning
     #   ml-collections
     #   omegaconf
     #   orbax-checkpoint
     #   pytorch-lightning
+    #   timm
     #   transformers
+    #   wandb
+    #   webdataset
     #   yacs
+rdkit==2024.3.5
+    # via
+    #   -r benchmarks/geo_gnn/requirements.in
+    #   -r benchmarks/recursiongfn/requirements.in
 reactivex==4.0.4
     # via giving
-regex==2024.7.24
+regex==2024.9.11
     # via
     #   diffusers
     #   tiktoken
@@ -401,90 +652,166 @@ requests==2.32.3
     #   diffusers
     #   huggingface-hub
     #   tiktoken
+    #   torch-geometric
     #   transformers
-rich==13.7.1
+    #   wandb
+rich==13.9.1
     # via
     #   flax
     #   tyro
     #   voir
-safetensors==0.4.4
+rlax==0.1.6
+    # via navix
+safetensors==0.4.5
     # via
     #   -r benchmarks/timm/requirements.in
     #   accelerate
     #   diffusers
+    #   timm
     #   torchtune
     #   transformers
-scipy==1.14.0
+scikit-learn==1.5.2
+    # via gpytorch
+scipy==1.14.1
     # via
     #   -r benchmarks/dinov2/requirements.in
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   botorch
     #   brax
+    #   gpytorch
     #   jax
     #   jaxlib
     #   jaxopt
+    #   linear-operator
     #   mujoco-mjx
+    #   scikit-learn
+    #   torch-cluster
+    #   torch-sparse
+seaborn==0.13.2
+    # via gymnax
 sentencepiece==0.2.0
     # via
     #   -r benchmarks/llama/requirements.in
     #   torchtune
+sentry-sdk==2.15.0
+    # via wandb
+setproctitle==1.3.3
+    # via wandb
+setuptools-scm==8.1.0
+    # via navix
 shtab==1.7.1
     # via tyro
 six==1.16.0
     # via
     #   asttokens
-    #   fire
+    #   docker-pycreds
     #   ml-collections
     #   python-dateutil
     #   tensorboard
-submitit==1.5.1
-    # via -r benchmarks/dinov2/requirements.in
-sympy==1.13.2
+    #   tensorflow-probability
+smmap==5.0.1
+    # via gitdb
+submitit==1.5.2
+    # via
+    #   -r benchmarks/dinov2/requirements.in
+    #   -r benchmarks/vjepa/requirements.in
+sympy==1.13.3
     # via torch
+tables==3.10.1
+    # via -r benchmarks/recursiongfn/requirements.in
 tabulate==0.9.0
     # via fvcore
-tensorboard==2.17.1
-    # via -r benchmarks/torchatari/requirements.in
+tensorboard==2.18.0
+    # via
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   -r benchmarks/torchatari/requirements.in
 tensorboard-data-server==0.7.2
     # via tensorboard
 tensorboardx==2.6.2.2
     # via brax
-tensorstore==0.1.64
+tensorflow-probability==0.24.0
+    # via distrax
+tensorstore==0.1.66
     # via
+    #   flashbax
     #   flax
     #   orbax-checkpoint
 termcolor==2.4.0
     # via
     #   fire
     #   fvcore
+threadpoolctl==3.5.0
+    # via scikit-learn
 tiktoken==0.7.0
     # via torchtune
+timm==1.0.9
+    # via -r benchmarks/vjepa/requirements.in
 tokenizers==0.19.1
     # via transformers
+tomli==2.0.2
+    # via
+    #   black
+    #   pylint
+    #   pytest
+    #   setuptools-scm
+tomlkit==0.13.2
+    # via pylint
 toolz==0.12.1
     # via chex
-torch==2.4.0+rocm6.0
+torch==2.4.1+rocm6.1
     # via
     #   -r benchmarks/brax/requirements.in
     #   -r benchmarks/dinov2/requirements.in
     #   -r benchmarks/flops/requirements.in
+    #   -r benchmarks/geo_gnn/requirements-pre.in
     #   -r benchmarks/huggingface/requirements.in
     #   -r benchmarks/lightning/requirements.in
     #   -r benchmarks/llama/requirements.in
+    #   -r benchmarks/llava/requirements.in
     #   -r benchmarks/llm/requirements.in
-    #   -r benchmarks/super-slomo/requirements.in
+    #   -r benchmarks/llm/requirements.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   -r benchmarks/rlhf/requirements.in
     #   -r benchmarks/timm/requirements.in
     #   -r benchmarks/torchatari/requirements.in
     #   -r benchmarks/torchvision/requirements.in
     #   -r benchmarks/torchvision_ddp/requirements.in
+    #   -r benchmarks/vjepa/requirements.in
     #   accelerate
+    #   botorch
     #   diffusers
     #   fairscale
     #   lightning
+    #   linear-operator
+    #   pyro-ppl
     #   pytorch-lightning
+    #   timm
     #   torchmetrics
     #   torchvision
+    #   trl
     #   xformers
+torch-cluster==1.6.3
+    # via
+    #   -r benchmarks/geo_gnn/requirements.in
+    #   -r benchmarks/recursiongfn/requirements.in
+torch-geometric==2.6.1
+    # via
+    #   -r benchmarks/geo_gnn/requirements.in
+    #   -r benchmarks/recursiongfn/requirements.in
+torch-scatter==2.1.2
+    # via
+    #   -r benchmarks/geo_gnn/requirements.in
+    #   -r benchmarks/recursiongfn/requirements.in
+torch-sparse==0.6.18
+    # via
+    #   -r benchmarks/geo_gnn/requirements.in
+    #   -r benchmarks/recursiongfn/requirements.in
 torchao==0.3.1
-    # via torchtune
+    # via
+    #   -c .pin/../constraints/rocm.txt
+    #   -r benchmarks/llm/requirements.in
+    #   torchtune
 torchcompat==1.1.4
     # via
     #   -c .pin/../constraints/rocm.txt
@@ -493,28 +820,30 @@ torchcompat==1.1.4
     #   -r benchmarks/torchatari/requirements.in
     #   -r benchmarks/torchvision/requirements.in
     #   -r benchmarks/torchvision_ddp/requirements.in
-torchmetrics==1.4.1
+torchmetrics==1.4.2
     # via
     #   -r benchmarks/dinov2/requirements.in
     #   lightning
     #   pytorch-lightning
 torchtune==0.2.1
-    # via -r benchmarks/llm/requirements.in
-torchvision==0.19.0+rocm6.0
+    # via
+    #   -c .pin/../constraints/rocm.txt
+    #   -r benchmarks/llm/requirements.in
+torchvision==0.19.1+rocm6.1
     # via
     #   -r benchmarks/diffusion/requirements.in
     #   -r benchmarks/dinov2/requirements.in
     #   -r benchmarks/flops/requirements.in
     #   -r benchmarks/lightning/requirements.in
-    #   -r benchmarks/super-slomo/requirements.in
     #   -r benchmarks/timm/requirements.in
     #   -r benchmarks/torchvision/requirements.in
     #   -r benchmarks/torchvision_ddp/requirements.in
+    #   -r benchmarks/vjepa/requirements.in
+    #   timm
 tqdm==4.66.5
     # via
     #   -r benchmarks/diffusion/requirements.in
     #   -r benchmarks/flops/requirements.in
-    #   -r benchmarks/super-slomo/requirements.in
     #   -r benchmarks/torchvision/requirements.in
     #   -r benchmarks/torchvision_ddp/requirements.in
     #   datasets
@@ -522,48 +851,75 @@ tqdm==4.66.5
     #   huggingface-hub
     #   iopath
     #   lightning
+    #   pyro-ppl
     #   pytorch-lightning
+    #   torch-geometric
     #   torchtune
     #   transformers
-transformers==4.44.1
+transformers==4.44.2
     # via
+    #   -c .pin/../constraints/rocm.txt
     #   -r benchmarks/diffusion/requirements.in
     #   -r benchmarks/huggingface/requirements.in
     #   -r benchmarks/llama/requirements.in
-trimesh==4.4.7
+    #   -r benchmarks/llava/requirements.in
+    #   -r benchmarks/llm/requirements.in
+    #   -r benchmarks/rlhf/requirements.in
+    #   trl
+trimesh==4.4.9
     # via
     #   brax
     #   mujoco-mjx
-types-protobuf==5.27.0.20240626
+trl==0.10.1
+    # via
+    #   -c .pin/../constraints/rocm.txt
+    #   -r benchmarks/rlhf/requirements.in
+typeguard==4.3.0
+    # via jaxtyping
+types-protobuf==5.28.0.20240924
     # via envpool
 typing-extensions==4.12.2
     # via
+    #   astroid
+    #   black
+    #   botorch
     #   brax
     #   chex
     #   envpool
     #   etils
+    #   flashbax
     #   flax
     #   gymnasium
     #   huggingface-hub
     #   iopath
+    #   jaxtyping
     #   lightning
     #   lightning-utilities
+    #   multidict
+    #   navix
     #   optree
     #   orbax-checkpoint
     #   pytorch-lightning
     #   reactivex
+    #   rich
     #   submitit
+    #   tables
     #   torch
+    #   typeguard
     #   tyro
-tyro==0.8.8
-    # via -r benchmarks/torchatari/requirements.in
-tzdata==2024.1
+tyro==0.8.11
+    # via
+    #   -r benchmarks/torchatari/requirements.in
+    #   navix
+    #   trl
+tzdata==2024.2
     # via pandas
-urllib3==2.2.2
+urllib3==2.2.3
     # via
     #   blobfile
     #   requests
-varname==0.10.0
+    #   sentry-sdk
+varname==0.13.3
     # via giving
 voir==0.2.19
     # via
@@ -572,28 +928,39 @@ voir==0.2.19
     #   -r benchmarks/diffusion/requirements.in
     #   -r benchmarks/dinov2/requirements.in
     #   -r benchmarks/flops/requirements.in
+    #   -r benchmarks/geo_gnn/requirements.in
     #   -r benchmarks/huggingface/requirements.in
     #   -r benchmarks/lightning/requirements.in
     #   -r benchmarks/llama/requirements.in
+    #   -r benchmarks/llava/requirements.in
     #   -r benchmarks/llm/requirements.in
-    #   -r benchmarks/super-slomo/requirements.in
+    #   -r benchmarks/purejaxrl/requirements.in
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   -r benchmarks/rlhf/requirements.in
     #   -r benchmarks/timm/requirements.in
     #   -r benchmarks/torchatari/requirements.in
     #   -r benchmarks/torchvision/requirements.in
     #   -r benchmarks/torchvision_ddp/requirements.in
-werkzeug==3.0.3
+    #   -r benchmarks/vjepa/requirements.in
+wandb==0.18.3
+    # via
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   navix
+webdataset==0.2.100
+    # via -r benchmarks/vjepa/requirements.in
+werkzeug==3.0.4
     # via
     #   flask
     #   tensorboard
-xformers==0.0.27.post2
+xformers==0.0.28.post1
     # via -r benchmarks/dinov2/requirements.in
 xxhash==3.5.0
     # via datasets
 yacs==0.1.8
     # via fvcore
-yarl==1.9.4
+yarl==1.13.1
     # via aiohttp
-zipp==3.20.0
+zipp==3.20.2
     # via
     #   etils
     #   importlib-metadata
diff --git a/README.md b/README.md
index 526398938..a2f8ce50b 100644
--- a/README.md
+++ b/README.md
@@ -20,62 +20,23 @@ evaluating current and future hardware in a research environment.
 * Focussed on training
 * Ease of use
 * Pytorch focused
-* ROCm & NVIDIA
+* ROCm, NVIDIA, Intel OneAPI, Habana Gaudi (Synapse)
 * Independent 
 
 ## Getting Started
 
-The easiest way to run milabbench is to run it with one of its docker image.
-It will include all of the necessary data
-
-
-    # Choose the image you want to use
-    export MILABENCH_IMAGE=ghcr.io/mila-iqia/milabench:cuda-nightly
-
-    # Pull the image we are going to run
-    docker pull $MILABENCH_IMAGE
-
-    # Run milabench
-    docker run -it --rm --ipc=host --gpus=all      \
-          -v $(pwd)/results:/milabench/envs/runs   \
-          $MILABENCH_IMAGE                         \
-          bash -c "milabench prepare && milabench run"
-
-    =================
-    Benchmark results
-    =================
-                             fail n       perf   sem%   std% peak_memory          score weight
-    bert-fp16                   0 8     155.08   0.3%   4.3%       24552    1241.260310   0.00
-    bert-fp32                   0 8      29.52   0.0%   0.5%       31524     236.337218   0.00
-    bert-tf32                   0 8     120.46   0.4%   6.1%       31524     964.713297   0.00
-    bert-tf32-fp16              0 8     154.76   0.3%   4.1%       24552    1238.477257   3.00
-    convnext_large-fp16         0 8     337.48   0.9%  14.0%       27658    2741.604444   0.00
-    convnext_large-fp32         0 8      44.61   0.8%  12.6%       49786     354.207225   0.00
-    convnext_large-tf32         0 8     135.99   0.7%  11.2%       49786    1089.394916   0.00
-    convnext_large-tf32-fp16    0 8     338.58   0.8%  13.0%       27658    2744.325170   3.00
-    davit_large                 0 8     312.79   0.3%   6.7%       35058    2515.326450   1.00
-    davit_large-multi           0 1    2401.65   1.0%   7.7%       42232    2401.651720   5.00
-    dlrm                        0 1  188777.20   1.8%  14.0%        3194  188777.203190   1.00
-    focalnet                    0 8     400.47   0.2%   5.4%       26604    3215.431924   2.00
-    opt-1_3b                    0 1      26.71   0.1%   0.4%       44116      26.714365   5.00
-    opt-1_3b-multinode          0 2      34.62   0.2%   1.0%       43552      34.618292  10.00
-    opt-6_7b                    0 1      14.32   0.0%   0.1%       55750      14.319587   5.00
-    opt-6_7b-multinode          0 2      10.79   0.1%   0.7%       49380      10.792595  10.00
-    reformer                    0 8      61.70   0.0%   0.9%       25376     494.110834   1.00
-    regnet_y_128gf              0 8      99.96   0.2%   5.0%       31840     803.012507   2.00
-    resnet152                   0 8     710.18   0.3%   6.2%       36732    5710.828608   1.00
-    resnet152-multi             0 1    5367.34   1.0%   8.1%       38638    5367.338469   5.00
-    resnet50                    0 8     984.43   0.9%  19.1%        5026    7927.257351   1.00
-    rwkv                        0 8     428.65   0.2%   3.8%        5546    3435.097716   1.00
-    stargan                     0 8      51.32   1.8%  40.8%       37848     413.238870   1.00
-    super-slomo                 0 8      41.63   0.1%   2.3%       34082     332.395065   1.00
-    t5                          0 8      48.05   0.2%   3.9%       35466     384.317023   2.00
-    whisper                     0 8     248.16   0.0%   0.6%       37006    1985.861017   1.00
+  
+    git clone https://github.com/mila-iqia/milabench.git
     
-    Scores
-    ------
-    Failure rate:       0.00% (PASS)
-    Score:             219.06
+    pip install -e milabench
+  
+    export MILABENCH_GPU_ARCH=cuda
+  
+    milabench install --base workspace --config milabench/config/standard.yaml --select fp32
+    
+    milabench prepare --base workspace --config milabench/config/standard.yaml --select fp32
+    
+    milabench run --base workspace --config milabench/config/standard.yaml --select fp32
 
 
 ## Details
@@ -84,13 +45,77 @@ The benchmark suite has been validated on the following configurations:
 
 | Python version |          GPU                   |   Configuration file |
 |       -        |        -                       |           -          |
-| 3.10   (conda) | 2 node x 8xNVIDIA A100 80GB    | config/standard.yaml |
-| 3.9.12 (conda) | 8x NVIDIA RTX8000 48GB         | config/standard.yaml |
-| 3.9.16 (conda) | 2x NVIDIA K80                  | config/ci.yaml       |
-| 3.9.16 (conda) | 2x AMD MI100                   | config/ci.yaml       |
-| 3.9.16 (conda) | 4x AMD MI250                   | config/standard.yaml |
+| 3.10           | 2 node x 8xNVIDIA A100 80GB    | config/standard.yaml |
+| 3.10           | 2 node x 8xMI300X              | config/standard.yaml |
+| 3.10           | 1 node x 8xGaudi2              | config/standard.yaml |
 
 We are working on validating it on more configurations and will update the above table as we do.
 
-
-
+## Report
+  
+    =================
+    Benchmark results
+    =================
+  
+    System
+    ------
+    cpu:      AMD EPYC 7742 64-Core Processor
+    n_cpu:    128
+    product:  NVIDIA A100-SXM4-80GB
+    n_gpu:    8
+    memory:   81920.0
+  
+    Breakdown
+    ---------
+    bench                    | fail |   n | ngpu |           perf |   sem% |   std% | peak_memory |           score | weight
+    brax                     |    0 |   1 |    8 |      730035.71 |   0.1% |   0.4% |        2670 |       730035.71 |   1.00
+    diffusion-gpus           |    0 |   1 |    8 |         117.67 |   1.5% |  11.7% |       59944 |          117.67 |   1.00
+    diffusion-single         |    0 |   8 |    1 |          25.02 |   0.8% |  17.9% |       53994 |          202.10 |   1.00
+    dimenet                  |    0 |   8 |    1 |         366.85 |   0.7% |  16.2% |        2302 |         2973.32 |   1.00
+    dinov2-giant-gpus        |    0 |   1 |    8 |         445.68 |   0.4% |   3.0% |       69614 |          445.68 |   1.00
+    dinov2-giant-single      |    0 |   8 |    1 |          53.54 |   0.4% |   9.5% |       74646 |          432.65 |   1.00
+    dqn                      |    0 |   8 |    1 | 23089954554.91 |   1.1% |  89.9% |       62106 | 184480810548.20 |   1.00
+    bf16                     |    0 |   8 |    1 |         293.43 |   0.2% |   6.3% |        1788 |         2361.16 |   0.00
+    fp16                     |    0 |   8 |    1 |         289.26 |   0.1% |   3.6% |        1788 |         2321.65 |   0.00
+    fp32                     |    0 |   8 |    1 |          19.14 |   0.0% |   0.7% |        2166 |          153.21 |   0.00
+    tf32                     |    0 |   8 |    1 |         146.63 |   0.1% |   3.6% |        2166 |         1177.04 |   0.00
+    bert-fp16                |    0 |   8 |    1 |         263.73 |   1.1% |  16.7% |         nan |         2165.37 |   0.00
+    bert-fp32                |    0 |   8 |    1 |          44.84 |   0.6% |   9.6% |       21170 |          364.52 |   0.00
+    bert-tf32                |    0 |   8 |    1 |         141.95 |   0.9% |  14.1% |        1764 |         1162.94 |   0.00
+    bert-tf32-fp16           |    0 |   8 |    1 |         265.04 |   1.0% |  15.6% |         nan |         2175.59 |   3.00
+    reformer                 |    0 |   8 |    1 |          62.29 |   0.3% |   6.0% |       25404 |          501.89 |   1.00
+    t5                       |    0 |   8 |    1 |          51.40 |   0.5% |   9.9% |       34390 |          416.14 |   2.00
+    whisper                  |    0 |   8 |    1 |         481.95 |   1.0% |  21.4% |        8520 |         3897.53 |   1.00
+    lightning                |    0 |   8 |    1 |         680.22 |   1.0% |  22.7% |       27360 |         5506.90 |   1.00
+    lightning-gpus           |    0 |   1 |    8 |        3504.74 |   7.9% |  62.9% |       28184 |         3504.74 |   1.00
+    llava-single             |    1 |   8 |    1 |           2.28 |   0.4% |   9.6% |       72556 |           14.12 |   1.00
+    llama                    |    0 |   8 |    1 |         484.86 |   4.4% |  80.0% |       27820 |         3680.86 |   1.00
+    llm-full-mp-gpus         |    0 |   1 |    8 |         193.92 |   3.1% |  16.2% |       48470 |          193.92 |   1.00
+    llm-lora-ddp-gpus        |    0 |   1 |    8 |       16738.58 |   0.4% |   2.0% |       36988 |        16738.58 |   1.00
+    llm-lora-mp-gpus         |    0 |   1 |    8 |        1980.63 |   2.2% |  11.8% |       55972 |         1980.63 |   1.00
+    llm-lora-single          |    0 |   8 |    1 |        2724.95 |   0.2% |   3.0% |       49926 |        21861.99 |   1.00
+    ppo                      |    0 |   8 |    1 |     3114264.32 |   1.6% |  57.2% |       62206 |     24915954.98 |   1.00
+    recursiongfn             |    0 |   8 |    1 |        7080.67 |   1.2% |  27.1% |       10292 |        57038.34 |   1.00
+    rlhf-gpus                |    0 |   1 |    8 |        6314.94 |   2.1% |  11.2% |       21730 |         6314.94 |   1.00
+    rlhf-single              |    0 |   8 |    1 |        1143.72 |   0.4% |   8.4% |       19566 |         9174.52 |   1.00
+    focalnet                 |    0 |   8 |    1 |         375.07 |   0.7% |  14.9% |       23536 |         3038.83 |   2.00
+    torchatari               |    0 |   8 |    1 |        5848.88 |   0.6% |  12.7% |        3834 |        46613.34 |   1.00
+    convnext_large-fp16      |    0 |   8 |    1 |         330.93 |   1.5% |  22.9% |       27376 |         2711.46 |   0.00
+    convnext_large-fp32      |    0 |   8 |    1 |          59.49 |   0.6% |   9.8% |       55950 |          483.84 |   0.00
+    convnext_large-tf32      |    0 |   8 |    1 |         155.41 |   0.9% |  14.3% |       49650 |         1273.31 |   0.00
+    convnext_large-tf32-fp16 |    0 |   8 |    1 |         322.28 |   1.6% |  24.5% |       27376 |         2637.88 |   3.00
+    regnet_y_128gf           |    0 |   8 |    1 |         119.46 |   0.5% |  10.0% |       29762 |          966.96 |   2.00
+    resnet152-ddp-gpus       |    0 |   1 |    8 |        3843.06 |   5.2% |  39.3% |       27980 |         3843.06 |   0.00
+    resnet50                 |    0 |   8 |    1 |         932.95 |   2.4% |  52.2% |       14848 |         7524.25 |   1.00
+    resnet50-noio            |    0 |   8 |    1 |        1163.88 |   0.3% |   6.7% |       27480 |         9385.35 |   0.00
+    vjepa-gpus               |    0 |   1 |    8 |         130.13 |   5.9% |  46.8% |       64244 |          130.13 |   1.00
+    vjepa-single             |    0 |   8 |    1 |          21.29 |   1.0% |  22.4% |       58552 |          172.11 |   1.00
+  
+    Scores
+    ------
+    Failure rate:       0.38% (PASS)
+    Score:            4175.57
+ 
+   Errors
+   ------
+   1 errors, details in HTML report.
diff --git a/benchmarks/_templates/simple/dev.yaml b/benchmarks/_templates/simple/dev.yaml
index e3aa94673..affcc977f 100644
--- a/benchmarks/_templates/simple/dev.yaml
+++ b/benchmarks/_templates/simple/dev.yaml
@@ -6,3 +6,5 @@ template:
   install_group: torch
   plan:
     method: per_gpu
+  tags:
+    - monogpu
diff --git a/benchmarks/_templates/stdout/dev.yaml b/benchmarks/_templates/stdout/dev.yaml
index 2b7e75a34..24c7b8131 100644
--- a/benchmarks/_templates/stdout/dev.yaml
+++ b/benchmarks/_templates/stdout/dev.yaml
@@ -3,7 +3,8 @@ _template:
   definition: .
   install-variant: unpinned
   install_group: torch
-
+  tags:
+    - monogpu
   #argv:
   #  --train_batch_size: 32
   #  --num_epochs: 5
diff --git a/benchmarks/_templates/voir/dev.yaml b/benchmarks/_templates/voir/dev.yaml
index e3aa94673..affcc977f 100644
--- a/benchmarks/_templates/voir/dev.yaml
+++ b/benchmarks/_templates/voir/dev.yaml
@@ -6,3 +6,5 @@ template:
   install_group: torch
   plan:
     method: per_gpu
+  tags:
+    - monogpu
diff --git a/benchmarks/brax/benchfile.py b/benchmarks/brax/benchfile.py
index 0388956d6..c33128138 100644
--- a/benchmarks/brax/benchfile.py
+++ b/benchmarks/brax/benchfile.py
@@ -5,5 +5,9 @@ class BraxBenchmark(Package):
     base_requirements = "requirements.in"
     main_script = "main.py"
 
-
+    def make_env(self):
+        env = super().make_env()
+        env["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
+        return env
+    
 __pack__ = BraxBenchmark
diff --git a/benchmarks/brax/main.py b/benchmarks/brax/main.py
index 572ce739c..6625bcd04 100644
--- a/benchmarks/brax/main.py
+++ b/benchmarks/brax/main.py
@@ -85,6 +85,9 @@ def run():
 
     args = parser.parse_args()
 
+    # args.num_envs = (args.batch_size * args.num_minibatches)  
+
+
     train(
         environment=envs.get_environment(env_name=args.env),
         num_timesteps=args.num_timesteps,
diff --git a/benchmarks/brax/requirements.cuda.txt b/benchmarks/brax/requirements.cuda.txt
index aa883171c..aae485613 100644
--- a/benchmarks/brax/requirements.cuda.txt
+++ b/benchmarks/brax/requirements.cuda.txt
@@ -37,7 +37,7 @@ brax==0.10.5
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/brax/requirements.in
-chex==0.1.86
+chex==0.1.87
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   optax
@@ -77,7 +77,7 @@ executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   varname
-filelock==3.16.0
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -109,7 +109,7 @@ glfw==2.7.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   mujoco
-grpcio==1.66.1
+grpcio==1.66.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   brax
@@ -133,7 +133,7 @@ itsdangerous==2.2.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   flask
-jax[cuda12]==0.4.31
+jax[cuda12]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r .pin/../constraints/extra/torch.cuda.txt
@@ -145,15 +145,15 @@ jax[cuda12]==0.4.31
     #   mujoco-mjx
     #   optax
     #   orbax-checkpoint
-jax-cuda12-pjrt==0.4.31
+jax-cuda12-pjrt==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
-jax-cuda12-plugin[with-cuda]==0.4.31
+jax-cuda12-plugin[with-cuda]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
-jaxlib==0.4.31
+jaxlib==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   brax
@@ -205,12 +205,12 @@ msgpack==1.1.0
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   flax
     #   orbax-checkpoint
-mujoco==3.2.2
+mujoco==3.2.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   brax
     #   mujoco-mjx
-mujoco-mjx==3.2.2
+mujoco-mjx==3.2.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   brax
@@ -234,7 +234,6 @@ numpy==1.26.4
     #   jaxopt
     #   ml-dtypes
     #   mujoco
-    #   opt-einsum
     #   optax
     #   orbax-checkpoint
     #   scipy
@@ -254,7 +253,7 @@ nvidia-cuda-cupti-cu12==12.1.105
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-cuda-nvcc-cu12==12.6.68
+nvidia-cuda-nvcc-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -301,7 +300,7 @@ nvidia-nccl-cu12==2.20.5
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-nvjitlink-cu12==12.6.68
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -315,7 +314,7 @@ omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
-opt-einsum==3.3.0
+opt-einsum==3.4.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -324,7 +323,7 @@ optax==0.2.3
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   brax
     #   flax
-orbax-checkpoint==0.6.3
+orbax-checkpoint==0.6.4
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   brax
@@ -341,7 +340,7 @@ pillow==10.4.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   brax
-protobuf==5.28.1
+protobuf==5.28.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   orbax-checkpoint
@@ -377,7 +376,7 @@ reactivex==4.0.4
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   giving
-rich==13.8.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   flax
@@ -395,7 +394,7 @@ six==1.16.0
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   asttokens
     #   ml-collections
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -403,7 +402,7 @@ tensorboardx==2.6.2.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   brax
-tensorstore==0.1.65
+tensorstore==0.1.66
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   flax
@@ -435,6 +434,7 @@ typing-extensions==4.12.2
     #   flax
     #   orbax-checkpoint
     #   reactivex
+    #   rich
     #   torch
 varname==0.13.3
     # via
diff --git a/benchmarks/brax/requirements.hpu.txt b/benchmarks/brax/requirements.hpu.txt
index cae1147c6..b02ff745f 100644
--- a/benchmarks/brax/requirements.hpu.txt
+++ b/benchmarks/brax/requirements.hpu.txt
@@ -4,10 +4,6 @@
 #
 #    pip-compile --output-file=benchmarks/brax/requirements.hpu.txt .pin/tmp-constraints-hpu-brax.txt benchmarks/brax/requirements.in
 #
---extra-index-url https://pypi.ngc.nvidia.com
---find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
---trusted-host pypi.ngc.nvidia.com
-
 absl-py==2.1.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
@@ -35,7 +31,7 @@ brax==0.10.5
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   -r benchmarks/brax/requirements.in
-chex==0.1.86
+chex==0.1.87
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   optax
@@ -47,7 +43,7 @@ cloudpickle==3.0.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   gym
-codefind==0.1.6
+codefind==0.1.7
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   ptera
@@ -63,7 +59,7 @@ dm-tree==0.1.8
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   dm-env
-etils[epath,epy]==1.7.0
+etils[epath,epy]==1.9.4
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   brax
@@ -71,11 +67,11 @@ etils[epath,epy]==1.7.0
     #   mujoco-mjx
     #   optax
     #   orbax-checkpoint
-executing==1.2.0
+executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   varname
-filelock==3.15.4
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
@@ -85,20 +81,20 @@ flask==3.0.3
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   brax
     #   flask-cors
-flask-cors==4.0.1
+flask-cors==5.0.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   brax
-flax==0.8.5
+flax==0.9.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   brax
-fsspec==2024.5.0
+fsspec==2024.6.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   etils
     #   torch
-giving==0.4.2
+giving==0.4.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   ptera
@@ -107,7 +103,7 @@ glfw==2.7.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   mujoco
-grpcio==1.65.1
+grpcio==1.66.2
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   brax
@@ -119,7 +115,11 @@ gym-notices==0.0.8
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   gym
-importlib-resources==6.4.0
+humanize==4.10.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   orbax-checkpoint
+importlib-resources==6.4.5
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   etils
@@ -127,7 +127,7 @@ itsdangerous==2.2.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   flask
-jax[cuda12]==0.4.28
+jax==0.4.33
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   -r benchmarks/brax/requirements.in
@@ -138,15 +138,7 @@ jax[cuda12]==0.4.28
     #   mujoco-mjx
     #   optax
     #   orbax-checkpoint
-jax-cuda12-pjrt==0.4.28
-    # via
-    #   -c .pin/../.pin/constraints-hpu-torch.txt
-    #   jax-cuda12-plugin
-jax-cuda12-plugin==0.4.28
-    # via
-    #   -c .pin/../.pin/constraints-hpu-torch.txt
-    #   jax
-jaxlib==0.4.28+cuda12.cudnn89
+jaxlib==0.4.33
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   brax
@@ -183,7 +175,7 @@ ml-collections==0.1.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   brax
-ml-dtypes==0.4.0
+ml-dtypes==0.5.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   jax
@@ -193,17 +185,17 @@ mpmath==1.3.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   sympy
-msgpack==1.0.8
+msgpack==1.1.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   flax
     #   orbax-checkpoint
-mujoco==3.2.0
+mujoco==3.2.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   brax
     #   mujoco-mjx
-mujoco-mjx==3.2.0
+mujoco-mjx==3.2.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   brax
@@ -221,14 +213,12 @@ numpy==1.26.4
     #   brax
     #   chex
     #   dm-env
-    #   flax
     #   gym
     #   jax
     #   jaxlib
     #   jaxopt
     #   ml-dtypes
     #   mujoco
-    #   opt-einsum
     #   optax
     #   orbax-checkpoint
     #   scipy
@@ -238,19 +228,13 @@ numpy==1.26.4
 nvidia-cublas-cu12==12.1.3.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
-    #   jax
     #   nvidia-cudnn-cu12
     #   nvidia-cusolver-cu12
     #   torch
 nvidia-cuda-cupti-cu12==12.1.105
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
-    #   jax
     #   torch
-nvidia-cuda-nvcc-cu12==12.5.82
-    # via
-    #   -c .pin/../.pin/constraints-hpu-torch.txt
-    #   jax
 nvidia-cuda-nvrtc-cu12==12.1.105
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
@@ -258,17 +242,14 @@ nvidia-cuda-nvrtc-cu12==12.1.105
 nvidia-cuda-runtime-cu12==12.1.105
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
-    #   jax
     #   torch
-nvidia-cudnn-cu12==8.9.2.26
+nvidia-cudnn-cu12==9.1.0.70
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
-    #   jax
     #   torch
 nvidia-cufft-cu12==11.0.2.54
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
-    #   jax
     #   torch
 nvidia-curand-cu12==10.3.2.106
     # via
@@ -277,23 +258,23 @@ nvidia-curand-cu12==10.3.2.106
 nvidia-cusolver-cu12==11.4.5.107
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
-    #   jax
     #   torch
 nvidia-cusparse-cu12==12.1.0.106
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
-    #   jax
     #   nvidia-cusolver-cu12
     #   torch
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
 nvidia-nccl-cu12==2.20.5
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
-    #   jax
     #   torch
-nvidia-nvjitlink-cu12==12.5.82
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
-    #   jax
     #   nvidia-cusolver-cu12
     #   nvidia-cusparse-cu12
 nvidia-nvtx-cu12==12.1.105
@@ -304,7 +285,7 @@ omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   voir
-opt-einsum==3.3.0
+opt-einsum==3.4.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   jax
@@ -313,12 +294,12 @@ optax==0.2.3
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   brax
     #   flax
-orbax-checkpoint==0.5.21
+orbax-checkpoint==0.6.4
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   brax
     #   flax
-ovld==0.3.5
+ovld==0.3.9
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   voir
@@ -330,7 +311,7 @@ pillow==10.4.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   brax
-protobuf==4.25.3
+protobuf==5.28.2
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   orbax-checkpoint
@@ -347,10 +328,6 @@ pygments==2.18.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   rich
-pynvml==11.5.3
-    # via
-    #   -c .pin/../.pin/constraints-hpu-torch.txt
-    #   voir
 pyopengl==3.1.7
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
@@ -359,7 +336,7 @@ pytinyrenderer==0.0.14
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   brax
-pyyaml==6.0.1
+pyyaml==6.0.2
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   flax
@@ -370,12 +347,12 @@ reactivex==4.0.4
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   giving
-rich==13.7.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   flax
     #   voir
-scipy==1.14.0
+scipy==1.14.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   brax
@@ -388,7 +365,7 @@ six==1.16.0
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   asttokens
     #   ml-collections
-sympy==1.13.0
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
@@ -396,7 +373,7 @@ tensorboardx==2.6.2.2
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   brax
-tensorstore==0.1.63
+tensorstore==0.1.66
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   flax
@@ -405,16 +382,16 @@ toolz==0.12.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   chex
-torch==2.3.1
+torch==2.4.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   -r benchmarks/brax/requirements.in
-trimesh==4.4.3
+trimesh==4.4.9
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   brax
     #   mujoco-mjx
-triton==2.3.1
+triton==3.0.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
@@ -427,8 +404,9 @@ typing-extensions==4.12.2
     #   flax
     #   orbax-checkpoint
     #   reactivex
+    #   rich
     #   torch
-varname==0.10.0
+varname==0.13.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   giving
@@ -437,11 +415,11 @@ voir==0.2.19
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   -c .pin/../constraints/hpu.txt
     #   -r benchmarks/brax/requirements.in
-werkzeug==3.0.3
+werkzeug==3.0.4
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   flask
-zipp==3.19.2
+zipp==3.20.2
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   etils
diff --git a/benchmarks/brax/requirements.rocm.txt b/benchmarks/brax/requirements.rocm.txt
index 0c14e04d9..a1923520c 100644
--- a/benchmarks/brax/requirements.rocm.txt
+++ b/benchmarks/brax/requirements.rocm.txt
@@ -4,7 +4,7 @@
 #
 #    pip-compile --output-file=benchmarks/brax/requirements.rocm.txt .pin/tmp-constraints-rocm-brax.txt benchmarks/brax/requirements.in
 #
---extra-index-url https://download.pytorch.org/whl/rocm6.0
+--extra-index-url https://download.pytorch.org/whl/rocm6.1
 
 absl-py==2.1.0
     # via
@@ -33,7 +33,7 @@ brax==0.10.5
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/brax/requirements.in
-chex==0.1.86
+chex==0.1.87
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   optax
@@ -45,7 +45,7 @@ cloudpickle==3.0.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   gym
-codefind==0.1.6
+codefind==0.1.7
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
@@ -61,7 +61,7 @@ dm-tree==0.1.8
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   dm-env
-etils[epath,epy]==1.7.0
+etils[epath,epy]==1.9.4
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   brax
@@ -69,11 +69,11 @@ etils[epath,epy]==1.7.0
     #   mujoco-mjx
     #   optax
     #   orbax-checkpoint
-executing==1.2.0
+executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   varname
-filelock==3.15.4
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   pytorch-triton-rocm
@@ -83,11 +83,11 @@ flask==3.0.3
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   brax
     #   flask-cors
-flask-cors==4.0.1
+flask-cors==5.0.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   brax
-flax==0.8.5
+flax==0.9.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   brax
@@ -96,7 +96,7 @@ fsspec==2024.6.1
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   etils
     #   torch
-giving==0.4.2
+giving==0.4.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
@@ -105,11 +105,11 @@ glfw==2.7.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   mujoco
-grpcio==1.65.5
+grpcio==1.66.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   brax
-gym==0.23.1
+gym==0.26.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   brax
@@ -121,7 +121,7 @@ humanize==4.10.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   orbax-checkpoint
-importlib-resources==6.4.3
+importlib-resources==6.4.5
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   etils
@@ -129,7 +129,7 @@ itsdangerous==2.2.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   flask
-jax==0.4.31
+jax==0.4.33
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/brax/requirements.in
@@ -140,7 +140,7 @@ jax==0.4.31
     #   mujoco-mjx
     #   optax
     #   orbax-checkpoint
-jaxlib==0.4.31
+jaxlib==0.4.33
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   brax
@@ -177,7 +177,7 @@ ml-collections==0.1.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   brax
-ml-dtypes==0.4.0
+ml-dtypes==0.5.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   jax
@@ -187,17 +187,17 @@ mpmath==1.3.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   sympy
-msgpack==1.0.8
+msgpack==1.1.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   flax
     #   orbax-checkpoint
-mujoco==3.2.2
+mujoco==3.2.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   brax
     #   mujoco-mjx
-mujoco-mjx==3.2.2
+mujoco-mjx==3.2.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   brax
@@ -215,25 +215,27 @@ numpy==1.26.4
     #   brax
     #   chex
     #   dm-env
-    #   flax
     #   gym
     #   jax
     #   jaxlib
     #   jaxopt
     #   ml-dtypes
     #   mujoco
-    #   opt-einsum
     #   optax
     #   orbax-checkpoint
     #   scipy
     #   tensorboardx
     #   tensorstore
     #   trimesh
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
 omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
-opt-einsum==3.3.0
+opt-einsum==3.4.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   jax
@@ -242,12 +244,12 @@ optax==0.2.3
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   brax
     #   flax
-orbax-checkpoint==0.6.0
+orbax-checkpoint==0.6.4
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   brax
     #   flax
-ovld==0.3.8
+ovld==0.3.9
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
@@ -259,7 +261,7 @@ pillow==10.4.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   brax
-protobuf==5.27.3
+protobuf==5.28.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   orbax-checkpoint
@@ -276,10 +278,6 @@ pygments==2.18.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   rich
-pynvml==11.5.3
-    # via
-    #   -c .pin/../.pin/constraints-rocm-torch.txt
-    #   voir
 pyopengl==3.1.7
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
@@ -303,12 +301,12 @@ reactivex==4.0.4
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
-rich==13.7.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   flax
     #   voir
-scipy==1.14.0
+scipy==1.14.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   brax
@@ -321,7 +319,7 @@ six==1.16.0
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   asttokens
     #   ml-collections
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
@@ -329,7 +327,7 @@ tensorboardx==2.6.2.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   brax
-tensorstore==0.1.64
+tensorstore==0.1.66
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   flax
@@ -338,11 +336,11 @@ toolz==0.12.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   chex
-torch==2.4.0+rocm6.0
+torch==2.4.1+rocm6.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/brax/requirements.in
-trimesh==4.4.7
+trimesh==4.4.9
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   brax
@@ -356,8 +354,9 @@ typing-extensions==4.12.2
     #   flax
     #   orbax-checkpoint
     #   reactivex
+    #   rich
     #   torch
-varname==0.10.0
+varname==0.13.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
@@ -366,11 +365,11 @@ voir==0.2.19
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -c .pin/../constraints/rocm.txt
     #   -r benchmarks/brax/requirements.in
-werkzeug==3.0.3
+werkzeug==3.0.4
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   flask
-zipp==3.20.0
+zipp==3.20.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   etils
diff --git a/benchmarks/brax/voirfile.py b/benchmarks/brax/voirfile.py
index fce6f66d0..3397dcb31 100644
--- a/benchmarks/brax/voirfile.py
+++ b/benchmarks/brax/voirfile.py
@@ -20,10 +20,10 @@ class Config:
     skip: int = 5
 
     # Number of rates to log before stopping
-    stop: int = 20
+    stop: int = 60
 
     # Number of seconds between each gpu poll
-    gpu_poll: int = 3
+    gpu_poll: int = 1
 
 
 @configurable
diff --git a/benchmarks/diffusion/requirements.cuda.txt b/benchmarks/diffusion/requirements.cuda.txt
index 6a062a7a0..676489f43 100644
--- a/benchmarks/diffusion/requirements.cuda.txt
+++ b/benchmarks/diffusion/requirements.cuda.txt
@@ -15,11 +15,11 @@ accelerate==0.34.2
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/diffusion/requirements.in
     #   diffusers
-aiohappyeyeballs==2.4.0
+aiohappyeyeballs==2.4.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   aiohttp
-aiohttp==3.10.5
+aiohttp==3.10.8
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   datasets
@@ -60,11 +60,11 @@ codefind==0.1.7
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   ptera
-datasets==3.0.0
+datasets==3.0.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/diffusion/requirements.in
-diffusers[torch]==0.30.2
+diffusers[torch]==0.30.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/diffusion/requirements.in
@@ -77,7 +77,7 @@ executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   varname
-filelock==3.16.0
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   datasets
@@ -106,7 +106,7 @@ hjson==3.1.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   argklass
-huggingface-hub==0.24.7
+huggingface-hub==0.25.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   accelerate
@@ -127,19 +127,19 @@ importlib-resources==6.4.5
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   argklass
-jax[cuda12]==0.4.31
+jax[cuda12]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r .pin/../constraints/extra/torch.cuda.txt
-jax-cuda12-pjrt==0.4.31
+jax-cuda12-pjrt==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
-jax-cuda12-plugin[with-cuda]==0.4.31
+jax-cuda12-plugin[with-cuda]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
-jaxlib==0.4.31
+jaxlib==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -190,7 +190,6 @@ numpy==1.26.4
     #   jax
     #   jaxlib
     #   ml-dtypes
-    #   opt-einsum
     #   pandas
     #   pyarrow
     #   scipy
@@ -209,7 +208,7 @@ nvidia-cuda-cupti-cu12==12.1.105
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-cuda-nvcc-cu12==12.6.68
+nvidia-cuda-nvcc-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -256,7 +255,7 @@ nvidia-nccl-cu12==2.20.5
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-nvjitlink-cu12==12.6.68
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -270,7 +269,7 @@ omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
-opt-einsum==3.3.0
+opt-einsum==3.4.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -285,7 +284,7 @@ packaging==24.1
     #   datasets
     #   huggingface-hub
     #   transformers
-pandas==2.2.2
+pandas==2.2.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   datasets
@@ -343,7 +342,7 @@ requests==2.32.3
     #   diffusers
     #   huggingface-hub
     #   transformers
-rich==13.8.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
@@ -363,7 +362,7 @@ six==1.16.0
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   asttokens
     #   python-dateutil
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -392,6 +391,7 @@ tqdm==4.66.5
 transformers==4.44.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
+    #   -c .pin/../constraints/cuda.txt
     #   -r benchmarks/diffusion/requirements.in
 triton==3.0.0
     # via
@@ -403,8 +403,9 @@ typing-extensions==4.12.2
     #   huggingface-hub
     #   multidict
     #   reactivex
+    #   rich
     #   torch
-tzdata==2024.1
+tzdata==2024.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   pandas
@@ -429,7 +430,7 @@ xxhash==3.5.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   datasets
-yarl==1.11.1
+yarl==1.13.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   aiohttp
diff --git a/benchmarks/diffusion/requirements.hpu.txt b/benchmarks/diffusion/requirements.hpu.txt
new file mode 100644
index 000000000..88ccd569e
--- /dev/null
+++ b/benchmarks/diffusion/requirements.hpu.txt
@@ -0,0 +1,381 @@
+#
+# This file is autogenerated by pip-compile with Python 3.10
+# by the following command:
+#
+#    pip-compile --output-file=benchmarks/diffusion/requirements.hpu.txt .pin/tmp-constraints-hpu-diffusion-nodes.txt benchmarks/diffusion/requirements.in
+#
+accelerate==0.34.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/diffusion/requirements.in
+    #   diffusers
+aiohappyeyeballs==2.4.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+aiohttp==3.10.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+    #   fsspec
+aiosignal==1.3.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+antlr4-python3-runtime==4.9.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   omegaconf
+argklass==1.4.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/diffusion/requirements.in
+asttokens==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+async-timeout==4.0.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+attrs==24.2.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+certifi==2024.8.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+charset-normalizer==3.3.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+codefind==0.1.7
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   ptera
+datasets==3.0.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/diffusion/requirements.in
+diffusers[torch]==0.30.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/diffusion/requirements.in
+dill==0.3.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+    #   multiprocess
+executing==2.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   varname
+filelock==3.16.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+    #   diffusers
+    #   huggingface-hub
+    #   torch
+    #   transformers
+    #   triton
+frozenlist==1.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+    #   aiosignal
+fsspec[http]==2024.6.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+    #   huggingface-hub
+    #   torch
+giving==0.4.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   ptera
+    #   voir
+hjson==3.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   argklass
+huggingface-hub==0.25.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   accelerate
+    #   datasets
+    #   diffusers
+    #   tokenizers
+    #   transformers
+idna==3.10
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+    #   yarl
+importlib-metadata==8.5.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   diffusers
+importlib-resources==6.4.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   argklass
+jinja2==3.1.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+markdown-it-py==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rich
+markupsafe==2.1.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   jinja2
+mdurl==0.1.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   markdown-it-py
+mpmath==1.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   sympy
+multidict==6.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+    #   yarl
+multiprocess==0.70.16
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+networkx==3.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+numpy==1.26.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   accelerate
+    #   datasets
+    #   diffusers
+    #   pandas
+    #   pyarrow
+    #   torchvision
+    #   transformers
+nvidia-cublas-cu12==12.1.3.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cudnn-cu12
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-cuda-cupti-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cuda-nvrtc-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cuda-runtime-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cudnn-cu12==9.1.0.70
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cufft-cu12==11.0.2.54
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-curand-cu12==10.3.2.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cusolver-cu12==11.4.5.107
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cusparse-cu12==12.1.0.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+nvidia-nccl-cu12==2.20.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-nvjitlink-cu12==12.6.77
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cusolver-cu12
+    #   nvidia-cusparse-cu12
+nvidia-nvtx-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+omegaconf==2.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+ovld==0.3.9
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+packaging==24.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   accelerate
+    #   datasets
+    #   huggingface-hub
+    #   transformers
+pandas==2.2.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+pillow==10.4.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   diffusers
+    #   torchvision
+psutil==5.9.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   accelerate
+    #   voir
+ptera==1.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+pyarrow==17.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+pygments==2.18.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rich
+python-dateutil==2.9.0.post0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pandas
+pytz==2024.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pandas
+pyyaml==6.0.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   accelerate
+    #   datasets
+    #   huggingface-hub
+    #   omegaconf
+    #   transformers
+reactivex==4.0.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+regex==2024.9.11
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   diffusers
+    #   transformers
+requests==2.32.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+    #   diffusers
+    #   huggingface-hub
+    #   transformers
+rich==13.9.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+safetensors==0.4.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   accelerate
+    #   diffusers
+    #   transformers
+six==1.16.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   asttokens
+    #   python-dateutil
+sympy==1.13.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+tokenizers==0.19.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   transformers
+torch==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   accelerate
+    #   diffusers
+    #   torchvision
+torchvision==0.19.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/diffusion/requirements.in
+tqdm==4.66.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/diffusion/requirements.in
+    #   datasets
+    #   huggingface-hub
+    #   transformers
+transformers==4.44.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/diffusion/requirements.in
+triton==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+typing-extensions==4.12.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   huggingface-hub
+    #   multidict
+    #   reactivex
+    #   rich
+    #   torch
+tzdata==2024.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pandas
+urllib3==2.2.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+varname==0.13.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+voir==0.2.19
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/diffusion/requirements.in
+xxhash==3.5.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+yarl==1.13.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+zipp==3.20.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   importlib-metadata
diff --git a/benchmarks/diffusion/requirements.rocm.txt b/benchmarks/diffusion/requirements.rocm.txt
index 5d0fd6e3f..ecedcbb4e 100644
--- a/benchmarks/diffusion/requirements.rocm.txt
+++ b/benchmarks/diffusion/requirements.rocm.txt
@@ -4,18 +4,18 @@
 #
 #    pip-compile --output-file=benchmarks/diffusion/requirements.rocm.txt .pin/tmp-constraints-rocm-diffusion-nodes.txt benchmarks/diffusion/requirements.in
 #
---extra-index-url https://download.pytorch.org/whl/rocm6.0
+--extra-index-url https://download.pytorch.org/whl/rocm6.1
 
-accelerate==0.33.0
+accelerate==0.34.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/diffusion/requirements.in
     #   diffusers
-aiohappyeyeballs==2.4.0
+aiohappyeyeballs==2.4.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
-aiohttp==3.10.5
+aiohttp==3.10.8
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   datasets
@@ -44,7 +44,7 @@ attrs==24.2.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
-certifi==2024.7.4
+certifi==2024.8.30
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
@@ -52,15 +52,15 @@ charset-normalizer==3.3.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
-codefind==0.1.6
+codefind==0.1.7
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
-datasets==2.21.0
+datasets==3.0.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/diffusion/requirements.in
-diffusers[torch]==0.30.0
+diffusers[torch]==0.30.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/diffusion/requirements.in
@@ -69,11 +69,11 @@ dill==0.3.8
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   datasets
     #   multiprocess
-executing==1.2.0
+executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   varname
-filelock==3.15.4
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   datasets
@@ -93,7 +93,7 @@ fsspec[http]==2024.6.1
     #   datasets
     #   huggingface-hub
     #   torch
-giving==0.4.2
+giving==0.4.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
@@ -102,7 +102,7 @@ hjson==3.1.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   argklass
-huggingface-hub==0.24.6
+huggingface-hub==0.25.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   accelerate
@@ -110,16 +110,16 @@ huggingface-hub==0.24.6
     #   diffusers
     #   tokenizers
     #   transformers
-idna==3.7
+idna==3.10
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
     #   yarl
-importlib-metadata==8.4.0
+importlib-metadata==8.5.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   diffusers
-importlib-resources==6.4.3
+importlib-resources==6.4.5
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   argklass
@@ -143,7 +143,7 @@ mpmath==1.3.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   sympy
-multidict==6.0.5
+multidict==6.1.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
@@ -166,11 +166,15 @@ numpy==1.26.4
     #   pyarrow
     #   torchvision
     #   transformers
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
 omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
-ovld==0.3.8
+ovld==0.3.9
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
@@ -181,7 +185,7 @@ packaging==24.1
     #   datasets
     #   huggingface-hub
     #   transformers
-pandas==2.2.2
+pandas==2.2.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   datasets
@@ -207,10 +211,6 @@ pygments==2.18.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   rich
-pynvml==11.5.3
-    # via
-    #   -c .pin/../.pin/constraints-rocm-torch.txt
-    #   voir
 python-dateutil==2.9.0.post0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
@@ -219,7 +219,7 @@ pytorch-triton-rocm==3.0.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
-pytz==2024.1
+pytz==2024.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   pandas
@@ -235,7 +235,7 @@ reactivex==4.0.4
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
-regex==2024.7.24
+regex==2024.9.11
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   diffusers
@@ -247,11 +247,11 @@ requests==2.32.3
     #   diffusers
     #   huggingface-hub
     #   transformers
-rich==13.7.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
-safetensors==0.4.4
+safetensors==0.4.5
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   accelerate
@@ -262,7 +262,7 @@ six==1.16.0
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   asttokens
     #   python-dateutil
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
@@ -270,13 +270,13 @@ tokenizers==0.19.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   transformers
-torch==2.4.0+rocm6.0
+torch==2.4.1+rocm6.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   accelerate
     #   diffusers
     #   torchvision
-torchvision==0.19.0+rocm6.0
+torchvision==0.19.1+rocm6.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/diffusion/requirements.in
@@ -287,29 +287,32 @@ tqdm==4.66.5
     #   datasets
     #   huggingface-hub
     #   transformers
-transformers==4.44.1
+transformers==4.44.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -c .pin/../constraints/rocm.txt
     #   -r benchmarks/diffusion/requirements.in
 typing-extensions==4.12.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   huggingface-hub
+    #   multidict
     #   reactivex
+    #   rich
     #   torch
-tzdata==2024.1
+tzdata==2024.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   pandas
-urllib3==2.2.2
+urllib3==2.2.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
-varname==0.10.0
+varname==0.13.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
-voir==0.2.17
+voir==0.2.19
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -c .pin/../constraints/rocm.txt
@@ -318,11 +321,11 @@ xxhash==3.5.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   datasets
-yarl==1.9.4
+yarl==1.13.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
-zipp==3.20.0
+zipp==3.20.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   importlib-metadata
diff --git a/benchmarks/dinov2/requirements.cuda.txt b/benchmarks/dinov2/requirements.cuda.txt
index aef36dbf3..bb0535894 100644
--- a/benchmarks/dinov2/requirements.cuda.txt
+++ b/benchmarks/dinov2/requirements.cuda.txt
@@ -30,7 +30,7 @@ executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   varname
-filelock==3.16.0
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -53,19 +53,19 @@ iopath==0.1.10
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/dinov2/requirements.in
     #   fvcore
-jax[cuda12]==0.4.31
+jax[cuda12]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r .pin/../constraints/extra/torch.cuda.txt
-jax-cuda12-pjrt==0.4.31
+jax-cuda12-pjrt==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
-jax-cuda12-plugin[with-cuda]==0.4.31
+jax-cuda12-plugin[with-cuda]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
-jaxlib==0.4.31
+jaxlib==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -109,7 +109,6 @@ numpy==1.26.4
     #   jax
     #   jaxlib
     #   ml-dtypes
-    #   opt-einsum
     #   scipy
     #   torchmetrics
     #   torchvision
@@ -126,7 +125,7 @@ nvidia-cuda-cupti-cu12==12.1.105
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-cuda-nvcc-cu12==12.6.68
+nvidia-cuda-nvcc-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -173,7 +172,7 @@ nvidia-nccl-cu12==2.20.5
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-nvjitlink-cu12==12.6.68
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -188,7 +187,7 @@ omegaconf==2.3.0
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/dinov2/requirements.in
     #   voir
-opt-einsum==3.3.0
+opt-einsum==3.4.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -232,7 +231,7 @@ reactivex==4.0.4
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   giving
-rich==13.8.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
@@ -246,11 +245,11 @@ six==1.16.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   asttokens
-submitit==1.5.1
+submitit==1.5.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/dinov2/requirements.in
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -292,6 +291,7 @@ typing-extensions==4.12.2
     #   iopath
     #   lightning-utilities
     #   reactivex
+    #   rich
     #   submitit
     #   torch
 varname==0.13.3
diff --git a/benchmarks/dinov2/requirements.hpu.txt b/benchmarks/dinov2/requirements.hpu.txt
new file mode 100644
index 000000000..4a11ccfbc
--- /dev/null
+++ b/benchmarks/dinov2/requirements.hpu.txt
@@ -0,0 +1,267 @@
+#
+# This file is autogenerated by pip-compile with Python 3.10
+# by the following command:
+#
+#    pip-compile --output-file=benchmarks/dinov2/requirements.hpu.txt .pin/tmp-constraints-hpu-dinov2-giant-gpus.txt benchmarks/dinov2/requirements.in
+#
+antlr4-python3-runtime==4.9.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   omegaconf
+asttokens==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+cloudpickle==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   submitit
+codefind==0.1.7
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   ptera
+executing==2.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   varname
+filelock==3.16.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+    #   triton
+fsspec==2024.6.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+fvcore==0.1.5.post20221221
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/dinov2/requirements.in
+giving==0.4.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   ptera
+    #   voir
+iopath==0.1.10
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/dinov2/requirements.in
+    #   fvcore
+jinja2==3.1.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+lightning-utilities==0.11.7
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torchmetrics
+markdown-it-py==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rich
+markupsafe==2.1.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   jinja2
+mdurl==0.1.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   markdown-it-py
+mpmath==1.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   sympy
+networkx==3.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+numpy==1.26.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   fvcore
+    #   scipy
+    #   torchmetrics
+    #   torchvision
+    #   xformers
+nvidia-cublas-cu12==12.1.3.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cudnn-cu12
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-cuda-cupti-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cuda-nvrtc-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cuda-runtime-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cudnn-cu12==9.1.0.70
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cufft-cu12==11.0.2.54
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-curand-cu12==10.3.2.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cusolver-cu12==11.4.5.107
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cusparse-cu12==12.1.0.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+nvidia-nccl-cu12==2.20.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-nvjitlink-cu12==12.6.77
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cusolver-cu12
+    #   nvidia-cusparse-cu12
+nvidia-nvtx-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+omegaconf==2.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/dinov2/requirements.in
+    #   voir
+ovld==0.3.9
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+packaging==24.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   lightning-utilities
+    #   torchmetrics
+pillow==10.4.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   fvcore
+    #   torchvision
+portalocker==2.10.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   iopath
+psutil==5.9.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+ptera==1.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+pygments==2.18.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rich
+pyyaml==6.0.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   fvcore
+    #   omegaconf
+    #   yacs
+reactivex==4.0.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+rich==13.9.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+scipy==1.14.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/dinov2/requirements.in
+six==1.16.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   asttokens
+submitit==1.5.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/dinov2/requirements.in
+sympy==1.13.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+tabulate==0.9.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   fvcore
+termcolor==2.4.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   fvcore
+torch==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/dinov2/requirements.in
+    #   torchmetrics
+    #   torchvision
+    #   xformers
+torchmetrics==1.4.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/dinov2/requirements.in
+torchvision==0.19.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/dinov2/requirements.in
+tqdm==4.66.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   fvcore
+    #   iopath
+triton==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+typing-extensions==4.12.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   iopath
+    #   lightning-utilities
+    #   reactivex
+    #   rich
+    #   submitit
+    #   torch
+varname==0.13.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+voir==0.2.19
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/dinov2/requirements.in
+xformers==0.0.28.post1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/dinov2/requirements.in
+yacs==0.1.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   fvcore
+
+# The following packages are considered to be unsafe in a requirements file:
+# setuptools
diff --git a/benchmarks/dinov2/requirements.rocm.txt b/benchmarks/dinov2/requirements.rocm.txt
index c46ba9819..f8b7f43e2 100644
--- a/benchmarks/dinov2/requirements.rocm.txt
+++ b/benchmarks/dinov2/requirements.rocm.txt
@@ -4,7 +4,7 @@
 #
 #    pip-compile --output-file=benchmarks/dinov2/requirements.rocm.txt .pin/tmp-constraints-rocm-dinov2-giant-gpus.txt benchmarks/dinov2/requirements.in
 #
---extra-index-url https://download.pytorch.org/whl/rocm6.0
+--extra-index-url https://download.pytorch.org/whl/rocm6.1
 
 antlr4-python3-runtime==4.9.3
     # via
@@ -18,15 +18,15 @@ cloudpickle==3.0.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   submitit
-codefind==0.1.6
+codefind==0.1.7
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
-executing==1.2.0
+executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   varname
-filelock==3.15.4
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   pytorch-triton-rocm
@@ -39,7 +39,7 @@ fvcore==0.1.5.post20221221
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/dinov2/requirements.in
-giving==0.4.2
+giving==0.4.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
@@ -53,7 +53,7 @@ jinja2==3.1.4
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
-lightning-utilities==0.11.6
+lightning-utilities==0.11.7
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torchmetrics
@@ -85,12 +85,16 @@ numpy==1.26.4
     #   torchmetrics
     #   torchvision
     #   xformers
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
 omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/dinov2/requirements.in
     #   voir
-ovld==0.3.8
+ovld==0.3.9
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
@@ -120,10 +124,6 @@ pygments==2.18.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   rich
-pynvml==11.5.3
-    # via
-    #   -c .pin/../.pin/constraints-rocm-torch.txt
-    #   voir
 pytorch-triton-rocm==3.0.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
@@ -138,11 +138,11 @@ reactivex==4.0.4
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
-rich==13.7.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
-scipy==1.14.0
+scipy==1.14.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/dinov2/requirements.in
@@ -150,11 +150,11 @@ six==1.16.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   asttokens
-submitit==1.5.1
+submitit==1.5.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/dinov2/requirements.in
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
@@ -166,18 +166,18 @@ termcolor==2.4.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   fvcore
-torch==2.4.0+rocm6.0
+torch==2.4.1+rocm6.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/dinov2/requirements.in
     #   torchmetrics
     #   torchvision
     #   xformers
-torchmetrics==1.4.1
+torchmetrics==1.4.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/dinov2/requirements.in
-torchvision==0.19.0+rocm6.0
+torchvision==0.19.1+rocm6.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/dinov2/requirements.in
@@ -192,18 +192,19 @@ typing-extensions==4.12.2
     #   iopath
     #   lightning-utilities
     #   reactivex
+    #   rich
     #   submitit
     #   torch
-varname==0.10.0
+varname==0.13.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
-voir==0.2.17
+voir==0.2.19
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -c .pin/../constraints/rocm.txt
     #   -r benchmarks/dinov2/requirements.in
-xformers==0.0.27.post2
+xformers==0.0.28.post1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/dinov2/requirements.in
diff --git a/benchmarks/flops/benchfile.py b/benchmarks/flops/benchfile.py
index 59c5c4a7f..17f360193 100644
--- a/benchmarks/flops/benchfile.py
+++ b/benchmarks/flops/benchfile.py
@@ -9,10 +9,15 @@ class FlopsBenchmarch(Package):
     def build_run_plan(self) -> "Command":
         import milabench.commands as cmd
 
+        main = self.dirs.code / self.main_script
         pack = cmd.PackCommand(self, *self.argv, lazy=True)
-        # pack = cmd.VoirCommand(pack, cwd=main.parent)
-        pack = cmd.ActivatorCommand(pack)
-        return pack.use_stdout()
-
+            
+        use_stdout = True
+        
+        if use_stdout:
+            return pack.use_stdout()
+        else:
+            pack = cmd.VoirCommand(pack, cwd=main.parent)
+            return pack
 
 __pack__ = FlopsBenchmarch
diff --git a/benchmarks/flops/requirements.cuda.txt b/benchmarks/flops/requirements.cuda.txt
index afb7ff130..fd027a8fb 100644
--- a/benchmarks/flops/requirements.cuda.txt
+++ b/benchmarks/flops/requirements.cuda.txt
@@ -26,7 +26,7 @@ executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   varname
-filelock==3.16.0
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -44,19 +44,19 @@ importlib-resources==6.4.5
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torchcompat
-jax[cuda12]==0.4.31
+jax[cuda12]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r .pin/../constraints/extra/torch.cuda.txt
-jax-cuda12-pjrt==0.4.31
+jax-cuda12-pjrt==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
-jax-cuda12-plugin[with-cuda]==0.4.31
+jax-cuda12-plugin[with-cuda]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
-jaxlib==0.4.31
+jaxlib==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -95,7 +95,6 @@ numpy==1.26.4
     #   jax
     #   jaxlib
     #   ml-dtypes
-    #   opt-einsum
     #   scipy
     #   torchvision
     #   xformers
@@ -111,7 +110,7 @@ nvidia-cuda-cupti-cu12==12.1.105
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-cuda-nvcc-cu12==12.6.68
+nvidia-cuda-nvcc-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -158,7 +157,7 @@ nvidia-nccl-cu12==2.20.5
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-nvjitlink-cu12==12.6.68
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -172,7 +171,7 @@ omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
-opt-einsum==3.3.0
+opt-einsum==3.4.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -204,7 +203,7 @@ reactivex==4.0.4
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   giving
-rich==13.8.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
@@ -217,7 +216,7 @@ six==1.16.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   asttokens
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -248,6 +247,7 @@ typing-extensions==4.12.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   reactivex
+    #   rich
     #   torch
 varname==0.13.3
     # via
diff --git a/benchmarks/flops/requirements.hpu.txt b/benchmarks/flops/requirements.hpu.txt
index 77595d5f7..91e5677fe 100644
--- a/benchmarks/flops/requirements.hpu.txt
+++ b/benchmarks/flops/requirements.hpu.txt
@@ -4,10 +4,6 @@
 #
 #    pip-compile --output-file=benchmarks/flops/requirements.hpu.txt .pin/tmp-constraints-hpu-flops.txt benchmarks/flops/requirements.in
 #
---extra-index-url https://pypi.ngc.nvidia.com
---find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
---trusted-host pypi.ngc.nvidia.com
-
 antlr4-python3-runtime==4.9.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
@@ -16,29 +12,29 @@ asttokens==2.4.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   giving
-codefind==0.1.6
+codefind==0.1.7
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   ptera
-executing==1.2.0
+executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   varname
-filelock==3.15.4
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
     #   triton
-fsspec==2024.5.0
+fsspec==2024.6.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
-giving==0.4.2
+giving==0.4.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   ptera
     #   voir
-importlib-resources==6.4.0
+importlib-resources==6.4.5
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torchcompat
@@ -88,7 +84,7 @@ nvidia-cuda-runtime-cu12==12.1.105
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
-nvidia-cudnn-cu12==8.9.2.26
+nvidia-cudnn-cu12==9.1.0.70
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
@@ -109,11 +105,15 @@ nvidia-cusparse-cu12==12.1.0.106
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   nvidia-cusolver-cu12
     #   torch
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
 nvidia-nccl-cu12==2.20.5
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
-nvidia-nvjitlink-cu12==12.5.82
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   nvidia-cusolver-cu12
@@ -126,7 +126,7 @@ omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   voir
-ovld==0.3.5
+ovld==0.3.9
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   voir
@@ -146,11 +146,7 @@ pygments==2.18.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   rich
-pynvml==11.5.3
-    # via
-    #   -c .pin/../.pin/constraints-hpu-torch.txt
-    #   voir
-pyyaml==6.0.1
+pyyaml==6.0.2
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   omegaconf
@@ -158,7 +154,7 @@ reactivex==4.0.4
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   giving
-rich==13.7.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   voir
@@ -166,11 +162,11 @@ six==1.16.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   asttokens
-sympy==1.13.0
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
-torch==2.3.1
+torch==2.4.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   -r benchmarks/flops/requirements.in
@@ -180,15 +176,15 @@ torchcompat==1.1.4
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   -c .pin/../constraints/hpu.txt
     #   -r benchmarks/flops/requirements.in
-torchvision==0.18.1
+torchvision==0.19.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   -r benchmarks/flops/requirements.in
-tqdm==4.66.4
+tqdm==4.66.5
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   -r benchmarks/flops/requirements.in
-triton==2.3.1
+triton==3.0.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
@@ -196,8 +192,9 @@ typing-extensions==4.12.2
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   reactivex
+    #   rich
     #   torch
-varname==0.10.0
+varname==0.13.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   giving
diff --git a/benchmarks/flops/requirements.rocm.txt b/benchmarks/flops/requirements.rocm.txt
index d9ac15eb5..fbc8952d1 100644
--- a/benchmarks/flops/requirements.rocm.txt
+++ b/benchmarks/flops/requirements.rocm.txt
@@ -4,7 +4,7 @@
 #
 #    pip-compile --output-file=benchmarks/flops/requirements.rocm.txt .pin/tmp-constraints-rocm-flops.txt benchmarks/flops/requirements.in
 #
---extra-index-url https://download.pytorch.org/whl/rocm6.0
+--extra-index-url https://download.pytorch.org/whl/rocm6.1
 
 antlr4-python3-runtime==4.9.3
     # via
@@ -14,15 +14,15 @@ asttokens==2.4.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
-codefind==0.1.6
+codefind==0.1.7
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
-executing==1.2.0
+executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   varname
-filelock==3.15.4
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   pytorch-triton-rocm
@@ -31,12 +31,12 @@ fsspec==2024.6.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
-giving==0.4.2
+giving==0.4.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
     #   voir
-importlib-resources==6.4.3
+importlib-resources==6.4.5
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torchcompat
@@ -68,11 +68,15 @@ numpy==1.26.4
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torchvision
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
 omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
-ovld==0.3.8
+ovld==0.3.9
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
@@ -92,10 +96,6 @@ pygments==2.18.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   rich
-pynvml==11.5.3
-    # via
-    #   -c .pin/../.pin/constraints-rocm-torch.txt
-    #   voir
 pytorch-triton-rocm==3.0.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
@@ -108,7 +108,7 @@ reactivex==4.0.4
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
-rich==13.7.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
@@ -116,11 +116,11 @@ six==1.16.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   asttokens
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
-torch==2.4.0+rocm6.0
+torch==2.4.1+rocm6.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/flops/requirements.in
@@ -130,7 +130,7 @@ torchcompat==1.1.4
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -c .pin/../constraints/rocm.txt
     #   -r benchmarks/flops/requirements.in
-torchvision==0.19.0+rocm6.0
+torchvision==0.19.1+rocm6.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/flops/requirements.in
@@ -142,8 +142,9 @@ typing-extensions==4.12.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   reactivex
+    #   rich
     #   torch
-varname==0.10.0
+varname==0.13.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
diff --git a/benchmarks/geo_gnn/dev.yaml b/benchmarks/geo_gnn/dev.yaml
index 6f261c895..67cb5bd2d 100644
--- a/benchmarks/geo_gnn/dev.yaml
+++ b/benchmarks/geo_gnn/dev.yaml
@@ -19,6 +19,6 @@ dimenet:
     method: per_gpu
   argv:
     --model: 'DimeNet'
-    --num-samples: 10000
+    --num-samples: 100000
     --use3d: True
     --batch-size: 512
\ No newline at end of file
diff --git a/benchmarks/geo_gnn/main.py b/benchmarks/geo_gnn/main.py
index 71e1c8827..b8875d2bf 100644
--- a/benchmarks/geo_gnn/main.py
+++ b/benchmarks/geo_gnn/main.py
@@ -78,14 +78,20 @@ def train_degree(train_dataset):
     # Compute the maximum in-degree in the training data.
     max_degree = -1
     for data in train_dataset:
-        d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
-        max_degree = max(max_degree, int(d.max()))
+        try:
+            d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
+            max_degree = max(max_degree, int(d.max()))
+        except TypeError:
+            pass
 
     # Compute the in-degree histogram tensor
     deg = torch.zeros(max_degree + 1, dtype=torch.long)
     for data in train_dataset:
-        d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
-        deg += torch.bincount(d, minlength=deg.numel())
+        try:
+            d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
+            deg += torch.bincount(d, minlength=deg.numel())
+        except TypeError:
+            pass
 
     return deg
 
@@ -109,13 +115,14 @@ def batch_size(x):
     observer = BenchObserver(batch_size_fn=batch_size)
 
     train_dataset = PCQM4Mv2Subset(args.num_samples, args.root)
+    degree = train_degree(train_dataset)
 
     sample = next(iter(train_dataset))
 
     info = models[args.model](
         args,
         sample=sample,
-        degree=lambda: train_degree(train_dataset),
+        degree=lambda: degree,
     )
 
     TRAIN_mean, TRAIN_std = (
diff --git a/benchmarks/geo_gnn/requirements-pre.cuda.txt b/benchmarks/geo_gnn/requirements-pre.cuda.txt
index 0ec4d88dd..f56bb4988 100644
--- a/benchmarks/geo_gnn/requirements-pre.cuda.txt
+++ b/benchmarks/geo_gnn/requirements-pre.cuda.txt
@@ -10,7 +10,7 @@
 --find-links https://data.pyg.org/whl/torch-2.4.0+cu121.html
 --trusted-host pypi.ngc.nvidia.com
 
-filelock==3.16.0
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -19,19 +19,19 @@ fsspec==2024.6.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
-jax[cuda12]==0.4.31
+jax[cuda12]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r .pin/../constraints/extra/torch.cuda.txt
-jax-cuda12-pjrt==0.4.31
+jax-cuda12-pjrt==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
-jax-cuda12-plugin[with-cuda]==0.4.31
+jax-cuda12-plugin[with-cuda]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
-jaxlib==0.4.31
+jaxlib==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -62,7 +62,6 @@ numpy==1.26.4
     #   jax
     #   jaxlib
     #   ml-dtypes
-    #   opt-einsum
     #   scipy
     #   xformers
 nvidia-cublas-cu12==12.1.3.1
@@ -77,7 +76,7 @@ nvidia-cuda-cupti-cu12==12.1.105
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-cuda-nvcc-cu12==12.6.68
+nvidia-cuda-nvcc-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -120,7 +119,7 @@ nvidia-nccl-cu12==2.20.5
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-nvjitlink-cu12==12.6.68
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -130,7 +129,7 @@ nvidia-nvtx-cu12==12.1.105
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
-opt-einsum==3.3.0
+opt-einsum==3.4.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -139,7 +138,7 @@ scipy==1.14.1
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
     #   jaxlib
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
diff --git a/benchmarks/geo_gnn/requirements-pre.hpu.txt b/benchmarks/geo_gnn/requirements-pre.hpu.txt
new file mode 100644
index 000000000..db910c1ae
--- /dev/null
+++ b/benchmarks/geo_gnn/requirements-pre.hpu.txt
@@ -0,0 +1,99 @@
+#
+# This file is autogenerated by pip-compile with Python 3.10
+# by the following command:
+#
+#    pip-compile --output-file=benchmarks/geo_gnn/requirements-pre.hpu.txt .pin/tmp-constraints-hpu-dimenet.txt benchmarks/geo_gnn/requirements-pre.in
+#
+filelock==3.16.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+    #   triton
+fsspec==2024.6.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+jinja2==3.1.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+markupsafe==2.1.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   jinja2
+mpmath==1.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   sympy
+networkx==3.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cublas-cu12==12.1.3.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cudnn-cu12
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-cuda-cupti-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cuda-nvrtc-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cuda-runtime-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cudnn-cu12==9.1.0.70
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cufft-cu12==11.0.2.54
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-curand-cu12==10.3.2.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cusolver-cu12==11.4.5.107
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cusparse-cu12==12.1.0.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-nccl-cu12==2.20.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-nvjitlink-cu12==12.6.77
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cusolver-cu12
+    #   nvidia-cusparse-cu12
+nvidia-nvtx-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+sympy==1.13.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+torch==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements-pre.in
+triton==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+typing-extensions==4.12.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
diff --git a/benchmarks/geo_gnn/requirements-pre.rocm.txt b/benchmarks/geo_gnn/requirements-pre.rocm.txt
index 3aded346f..9b4cf02fb 100644
--- a/benchmarks/geo_gnn/requirements-pre.rocm.txt
+++ b/benchmarks/geo_gnn/requirements-pre.rocm.txt
@@ -2,48 +2,48 @@
 # This file is autogenerated by pip-compile with Python 3.10
 # by the following command:
 #
-#    pip-compile --output-file=benchmarks/geo_gnn/requirements-pre.rocm.txt .pin/tmp-constraints-rocm-geo_gnn.txt benchmarks/geo_gnn/requirements-pre.in
+#    pip-compile --output-file=benchmarks/geo_gnn/requirements-pre.rocm.txt .pin/tmp-constraints-rocm-dimenet.txt benchmarks/geo_gnn/requirements-pre.in
 #
---extra-index-url https://download.pytorch.org/whl/rocm6.0
+--extra-index-url https://download.pytorch.org/whl/rocm6.1
 
-filelock==3.15.4
+filelock==3.16.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   pytorch-triton-rocm
     #   torch
 fsspec==2024.6.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
 jinja2==3.1.4
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
 markupsafe==2.1.5
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   jinja2
 mpmath==1.3.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   sympy
 networkx==3.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
 pytorch-triton-rocm==3.0.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
-sympy==1.13.2
+sympy==1.13.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
-torch==2.4.0+rocm6.0
+torch==2.4.1+rocm6.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/geo_gnn/requirements-pre.in
 typing-extensions==4.12.2
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
diff --git a/benchmarks/geo_gnn/requirements.cuda.txt b/benchmarks/geo_gnn/requirements.cuda.txt
index 88e329e6d..c4ffaa639 100644
--- a/benchmarks/geo_gnn/requirements.cuda.txt
+++ b/benchmarks/geo_gnn/requirements.cuda.txt
@@ -10,11 +10,11 @@
 --find-links https://data.pyg.org/whl/torch-2.4.0+cu121.html
 --trusted-host pypi.ngc.nvidia.com
 
-aiohappyeyeballs==2.4.0
+aiohappyeyeballs==2.4.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   aiohttp
-aiohttp==3.10.5
+aiohttp==3.10.8
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch-geometric
@@ -54,7 +54,7 @@ executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   varname
-filelock==3.16.0
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/geo_gnn/requirements-pre.cuda.txt
@@ -81,22 +81,22 @@ idna==3.10
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   requests
     #   yarl
-jax[cuda12]==0.4.31
+jax[cuda12]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r .pin/../constraints/extra/torch.cuda.txt
     #   -r benchmarks/geo_gnn/requirements-pre.cuda.txt
-jax-cuda12-pjrt==0.4.31
+jax-cuda12-pjrt==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/geo_gnn/requirements-pre.cuda.txt
     #   jax-cuda12-plugin
-jax-cuda12-plugin[with-cuda]==0.4.31
+jax-cuda12-plugin[with-cuda]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/geo_gnn/requirements-pre.cuda.txt
     #   jax
-jaxlib==0.4.31
+jaxlib==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/geo_gnn/requirements-pre.cuda.txt
@@ -149,7 +149,6 @@ numpy==1.26.4
     #   jax
     #   jaxlib
     #   ml-dtypes
-    #   opt-einsum
     #   pandas
     #   rdkit
     #   scipy
@@ -169,7 +168,7 @@ nvidia-cuda-cupti-cu12==12.1.105
     #   -r benchmarks/geo_gnn/requirements-pre.cuda.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-cuda-nvcc-cu12==12.6.68
+nvidia-cuda-nvcc-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/geo_gnn/requirements-pre.cuda.txt
@@ -225,7 +224,7 @@ nvidia-nccl-cu12==2.20.5
     #   -r benchmarks/geo_gnn/requirements-pre.cuda.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-nvjitlink-cu12==12.6.68
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/geo_gnn/requirements-pre.cuda.txt
@@ -241,7 +240,7 @@ omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
-opt-einsum==3.3.0
+opt-einsum==3.4.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/geo_gnn/requirements-pre.cuda.txt
@@ -250,7 +249,7 @@ ovld==0.3.9
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
-pandas==2.2.2
+pandas==2.2.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/geo_gnn/requirements.in
@@ -299,7 +298,7 @@ requests==2.32.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch-geometric
-rich==13.8.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
@@ -316,7 +315,7 @@ six==1.16.0
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   asttokens
     #   python-dateutil
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/geo_gnn/requirements-pre.cuda.txt
@@ -330,7 +329,7 @@ torch-cluster==1.6.3+pt24cu121
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/geo_gnn/requirements.in
-torch-geometric==2.6.0
+torch-geometric==2.6.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/geo_gnn/requirements.in
@@ -357,8 +356,9 @@ typing-extensions==4.12.2
     #   -r benchmarks/geo_gnn/requirements-pre.cuda.txt
     #   multidict
     #   reactivex
+    #   rich
     #   torch
-tzdata==2024.1
+tzdata==2024.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   pandas
@@ -380,7 +380,7 @@ xformers==0.0.27.post2
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r .pin/../constraints/extra/torch.cuda.txt
     #   -r benchmarks/geo_gnn/requirements-pre.cuda.txt
-yarl==1.11.1
+yarl==1.13.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   aiohttp
diff --git a/benchmarks/geo_gnn/requirements.hpu.txt b/benchmarks/geo_gnn/requirements.hpu.txt
new file mode 100644
index 000000000..9c6bb6d69
--- /dev/null
+++ b/benchmarks/geo_gnn/requirements.hpu.txt
@@ -0,0 +1,321 @@
+#
+# This file is autogenerated by pip-compile with Python 3.10
+# by the following command:
+#
+#    pip-compile --output-file=benchmarks/geo_gnn/requirements.hpu.txt .pin/tmp-constraints-hpu-dimenet.txt benchmarks/geo_gnn/requirements-pre.hpu.txt benchmarks/geo_gnn/requirements.in
+#
+aiohappyeyeballs==2.4.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+aiohttp==3.10.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch-geometric
+aiosignal==1.3.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+antlr4-python3-runtime==4.9.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   omegaconf
+asttokens==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+async-timeout==4.0.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+attrs==24.2.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+certifi==2024.8.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+charset-normalizer==3.3.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+codefind==0.1.7
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   ptera
+executing==2.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   varname
+filelock==3.16.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements-pre.hpu.txt
+    #   torch
+    #   triton
+frozenlist==1.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+    #   aiosignal
+fsspec==2024.6.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements-pre.hpu.txt
+    #   torch
+    #   torch-geometric
+giving==0.4.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   ptera
+    #   voir
+idna==3.10
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+    #   yarl
+jinja2==3.1.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements-pre.hpu.txt
+    #   torch
+    #   torch-geometric
+markdown-it-py==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rich
+markupsafe==2.1.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements-pre.hpu.txt
+    #   jinja2
+mdurl==0.1.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   markdown-it-py
+mpmath==1.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements-pre.hpu.txt
+    #   sympy
+multidict==6.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+    #   yarl
+networkx==3.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements-pre.hpu.txt
+    #   torch
+numpy==1.26.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements.in
+    #   pandas
+    #   rdkit
+    #   scipy
+    #   torch-geometric
+nvidia-cublas-cu12==12.1.3.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements-pre.hpu.txt
+    #   nvidia-cudnn-cu12
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-cuda-cupti-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements-pre.hpu.txt
+    #   torch
+nvidia-cuda-nvrtc-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements-pre.hpu.txt
+    #   torch
+nvidia-cuda-runtime-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements-pre.hpu.txt
+    #   torch
+nvidia-cudnn-cu12==9.1.0.70
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements-pre.hpu.txt
+    #   torch
+nvidia-cufft-cu12==11.0.2.54
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements-pre.hpu.txt
+    #   torch
+nvidia-curand-cu12==10.3.2.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements-pre.hpu.txt
+    #   torch
+nvidia-cusolver-cu12==11.4.5.107
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements-pre.hpu.txt
+    #   torch
+nvidia-cusparse-cu12==12.1.0.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements-pre.hpu.txt
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+nvidia-nccl-cu12==2.20.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements-pre.hpu.txt
+    #   torch
+nvidia-nvjitlink-cu12==12.6.77
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements-pre.hpu.txt
+    #   nvidia-cusolver-cu12
+    #   nvidia-cusparse-cu12
+nvidia-nvtx-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements-pre.hpu.txt
+    #   torch
+omegaconf==2.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+ovld==0.3.9
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+pandas==2.2.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements.in
+pillow==10.4.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rdkit
+psutil==5.9.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch-geometric
+    #   voir
+ptera==1.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+pygments==2.18.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rich
+pyparsing==3.1.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch-geometric
+python-dateutil==2.9.0.post0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pandas
+pytz==2024.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pandas
+pyyaml==6.0.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   omegaconf
+rdkit==2024.3.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements.in
+reactivex==4.0.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+requests==2.32.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch-geometric
+rich==13.9.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+scipy==1.14.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch-cluster
+    #   torch-sparse
+six==1.16.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   asttokens
+    #   python-dateutil
+sympy==1.13.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements-pre.hpu.txt
+    #   torch
+torch==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements-pre.hpu.txt
+torch-cluster==1.6.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements.in
+torch-geometric==2.6.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements.in
+torch-scatter==2.1.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements.in
+torch-sparse==0.6.18
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements.in
+tqdm==4.66.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch-geometric
+triton==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements-pre.hpu.txt
+    #   torch
+typing-extensions==4.12.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/geo_gnn/requirements-pre.hpu.txt
+    #   multidict
+    #   reactivex
+    #   rich
+    #   torch
+tzdata==2024.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pandas
+urllib3==2.2.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+varname==0.13.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+voir==0.2.19
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/geo_gnn/requirements.in
+yarl==1.13.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
diff --git a/benchmarks/geo_gnn/requirements.rocm.txt b/benchmarks/geo_gnn/requirements.rocm.txt
index 60246f795..8dfacfe7a 100644
--- a/benchmarks/geo_gnn/requirements.rocm.txt
+++ b/benchmarks/geo_gnn/requirements.rocm.txt
@@ -2,271 +2,258 @@
 # This file is autogenerated by pip-compile with Python 3.10
 # by the following command:
 #
-#    pip-compile --output-file=benchmarks/geo_gnn/requirements.rocm.txt .pin/tmp-constraints-rocm-geo_gnn.txt benchmarks/geo_gnn/requirements-pre.rocm.txt benchmarks/geo_gnn/requirements.in
+#    pip-compile --output-file=benchmarks/geo_gnn/requirements.rocm.txt .pin/tmp-constraints-rocm-dimenet.txt benchmarks/geo_gnn/requirements-pre.rocm.txt benchmarks/geo_gnn/requirements.in
 #
---extra-index-url https://download.pytorch.org/whl/rocm6.0
+--extra-index-url https://download.pytorch.org/whl/rocm6.1
 
-aiohappyeyeballs==2.4.0
+aiohappyeyeballs==2.4.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
-aiohttp==3.10.5
+aiohttp==3.10.8
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch-geometric
 aiosignal==1.3.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
 antlr4-python3-runtime==4.9.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   omegaconf
 asttokens==2.4.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
 async-timeout==4.0.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
 attrs==24.2.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
-certifi==2024.7.4
+certifi==2024.8.30
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
 charset-normalizer==3.3.2
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
-codefind==0.1.6
+codefind==0.1.7
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
-executing==1.2.0
+executing==2.1.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   varname
-filelock==3.15.4
+filelock==3.16.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/geo_gnn/requirements-pre.rocm.txt
     #   pytorch-triton-rocm
     #   torch
 frozenlist==1.4.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
     #   aiosignal
 fsspec==2024.6.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/geo_gnn/requirements-pre.rocm.txt
     #   torch
     #   torch-geometric
-giving==0.4.2
+giving==0.4.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
     #   voir
-idna==3.7
+idna==3.10
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
     #   yarl
 jinja2==3.1.4
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/geo_gnn/requirements-pre.rocm.txt
     #   torch
     #   torch-geometric
-joblib==1.4.2
-    # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
-    #   scikit-learn
 markdown-it-py==3.0.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   rich
 markupsafe==2.1.5
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/geo_gnn/requirements-pre.rocm.txt
     #   jinja2
 mdurl==0.1.2
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   markdown-it-py
 mpmath==1.3.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/geo_gnn/requirements-pre.rocm.txt
     #   sympy
-multidict==6.0.5
+multidict==6.1.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
     #   yarl
 networkx==3.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/geo_gnn/requirements-pre.rocm.txt
     #   torch
 numpy==1.26.4
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/geo_gnn/requirements.in
     #   pandas
     #   rdkit
-    #   scikit-learn
     #   scipy
     #   torch-geometric
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
 omegaconf==2.3.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
-ovld==0.3.8
+ovld==0.3.9
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
-pandas==2.2.2
+pandas==2.2.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/geo_gnn/requirements.in
 pillow==10.4.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   rdkit
 psutil==5.9.8
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch-geometric
     #   voir
 ptera==1.4.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
 pygments==2.18.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   rich
-pynvml==11.5.3
-    # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
-    #   voir
-pyparsing==3.1.2
+pyparsing==3.1.4
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch-geometric
 python-dateutil==2.9.0.post0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   pandas
 pytorch-triton-rocm==3.0.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/geo_gnn/requirements-pre.rocm.txt
     #   torch
-pytz==2024.1
+pytz==2024.2
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   pandas
 pyyaml==6.0.2
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   omegaconf
 rdkit==2024.3.5
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/geo_gnn/requirements.in
 reactivex==4.0.4
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
 requests==2.32.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch-geometric
-rich==13.7.1
+rich==13.9.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
-scikit-learn==1.5.1
+scipy==1.14.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
-    #   torch-geometric
-scipy==1.14.0
-    # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
-    #   scikit-learn
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch-cluster
-    #   torch-geometric
     #   torch-sparse
 six==1.16.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   asttokens
     #   python-dateutil
-sympy==1.13.2
+sympy==1.13.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/geo_gnn/requirements-pre.rocm.txt
     #   torch
-threadpoolctl==3.5.0
+torch==2.4.1+rocm6.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
-    #   scikit-learn
-torch==2.4.0+rocm6.0
-    # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/geo_gnn/requirements-pre.rocm.txt
 torch-cluster==1.6.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/geo_gnn/requirements.in
-torch-geometric==2.5.3
+torch-geometric==2.6.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/geo_gnn/requirements.in
 torch-scatter==2.1.2
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/geo_gnn/requirements.in
 torch-sparse==0.6.18
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/geo_gnn/requirements.in
 tqdm==4.66.5
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch-geometric
 typing-extensions==4.12.2
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/geo_gnn/requirements-pre.rocm.txt
+    #   multidict
     #   reactivex
+    #   rich
     #   torch
-tzdata==2024.1
+tzdata==2024.2
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   pandas
-urllib3==2.2.2
+urllib3==2.2.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
-varname==0.10.0
+varname==0.13.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
-voir==0.2.17
+voir==0.2.19
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -c .pin/../constraints/rocm.txt
     #   -r benchmarks/geo_gnn/requirements.in
-yarl==1.9.4
+yarl==1.13.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
diff --git a/benchmarks/huggingface/requirements.cuda.txt b/benchmarks/huggingface/requirements.cuda.txt
index d4bcacca7..45e68e325 100644
--- a/benchmarks/huggingface/requirements.cuda.txt
+++ b/benchmarks/huggingface/requirements.cuda.txt
@@ -34,7 +34,7 @@ executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   varname
-filelock==3.16.0
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   huggingface-hub
@@ -51,7 +51,7 @@ giving==0.4.3
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   ptera
     #   voir
-huggingface-hub==0.24.7
+huggingface-hub==0.25.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   tokenizers
@@ -60,19 +60,19 @@ idna==3.10
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   requests
-jax[cuda12]==0.4.31
+jax[cuda12]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r .pin/../constraints/extra/torch.cuda.txt
-jax-cuda12-pjrt==0.4.31
+jax-cuda12-pjrt==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
-jax-cuda12-plugin[with-cuda]==0.4.31
+jax-cuda12-plugin[with-cuda]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
-jaxlib==0.4.31
+jaxlib==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -111,7 +111,6 @@ numpy==1.26.4
     #   jax
     #   jaxlib
     #   ml-dtypes
-    #   opt-einsum
     #   scipy
     #   transformers
     #   xformers
@@ -127,7 +126,7 @@ nvidia-cuda-cupti-cu12==12.1.105
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-cuda-nvcc-cu12==12.6.68
+nvidia-cuda-nvcc-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -174,7 +173,7 @@ nvidia-nccl-cu12==2.20.5
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-nvjitlink-cu12==12.6.68
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -188,7 +187,7 @@ omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
-opt-einsum==3.3.0
+opt-einsum==3.4.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -236,7 +235,7 @@ requests==2.32.3
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   huggingface-hub
     #   transformers
-rich==13.8.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
@@ -253,7 +252,7 @@ six==1.16.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   asttokens
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -274,6 +273,7 @@ tqdm==4.66.5
 transformers==4.44.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
+    #   -c .pin/../constraints/cuda.txt
     #   -r benchmarks/huggingface/requirements.in
 triton==3.0.0
     # via
@@ -284,6 +284,7 @@ typing-extensions==4.12.2
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   huggingface-hub
     #   reactivex
+    #   rich
     #   torch
 urllib3==2.2.3
     # via
diff --git a/benchmarks/huggingface/requirements.hpu.txt b/benchmarks/huggingface/requirements.hpu.txt
index a504cba14..b5e21d99e 100644
--- a/benchmarks/huggingface/requirements.hpu.txt
+++ b/benchmarks/huggingface/requirements.hpu.txt
@@ -4,10 +4,6 @@
 #
 #    pip-compile --output-file=benchmarks/huggingface/requirements.hpu.txt .pin/tmp-constraints-hpu-hf.txt benchmarks/huggingface/requirements.in
 #
---extra-index-url https://pypi.ngc.nvidia.com
---find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
---trusted-host pypi.ngc.nvidia.com
-
 antlr4-python3-runtime==4.9.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
@@ -16,7 +12,7 @@ asttokens==2.4.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   giving
-certifi==2024.6.2
+certifi==2024.8.30
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   requests
@@ -24,37 +20,37 @@ charset-normalizer==3.3.2
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   requests
-codefind==0.1.6
+codefind==0.1.7
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   ptera
-executing==1.2.0
+executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   varname
-filelock==3.15.4
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   huggingface-hub
     #   torch
     #   transformers
     #   triton
-fsspec==2024.5.0
+fsspec==2024.6.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   huggingface-hub
     #   torch
-giving==0.4.2
+giving==0.4.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   ptera
     #   voir
-huggingface-hub==0.24.0
+huggingface-hub==0.25.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   tokenizers
     #   transformers
-idna==3.7
+idna==3.10
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   requests
@@ -104,7 +100,7 @@ nvidia-cuda-runtime-cu12==12.1.105
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
-nvidia-cudnn-cu12==8.9.2.26
+nvidia-cudnn-cu12==9.1.0.70
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
@@ -125,11 +121,15 @@ nvidia-cusparse-cu12==12.1.0.106
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   nvidia-cusolver-cu12
     #   torch
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
 nvidia-nccl-cu12==2.20.5
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
-nvidia-nvjitlink-cu12==12.5.82
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   nvidia-cusolver-cu12
@@ -142,7 +142,7 @@ omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   voir
-ovld==0.3.5
+ovld==0.3.9
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   voir
@@ -151,6 +151,10 @@ packaging==24.1
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   huggingface-hub
     #   transformers
+pillow==10.4.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/huggingface/requirements.in
 psutil==5.9.8
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
@@ -163,11 +167,7 @@ pygments==2.18.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   rich
-pynvml==11.5.3
-    # via
-    #   -c .pin/../.pin/constraints-hpu-torch.txt
-    #   voir
-pyyaml==6.0.1
+pyyaml==6.0.2
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   huggingface-hub
@@ -177,7 +177,7 @@ reactivex==4.0.4
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   giving
-regex==2024.5.15
+regex==2024.9.11
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   transformers
@@ -186,11 +186,11 @@ requests==2.32.3
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   huggingface-hub
     #   transformers
-rich==13.7.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   voir
-safetensors==0.4.3
+safetensors==0.4.5
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   transformers
@@ -198,7 +198,7 @@ six==1.16.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   asttokens
-sympy==1.13.0
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
@@ -206,20 +206,21 @@ tokenizers==0.19.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   transformers
-torch==2.3.1
+torch==2.4.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   -r benchmarks/huggingface/requirements.in
-tqdm==4.66.4
+tqdm==4.66.5
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   huggingface-hub
     #   transformers
-transformers==4.42.4
+transformers==4.44.2
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -c .pin/../constraints/hpu.txt
     #   -r benchmarks/huggingface/requirements.in
-triton==2.3.1
+triton==3.0.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
@@ -228,12 +229,13 @@ typing-extensions==4.12.2
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   huggingface-hub
     #   reactivex
+    #   rich
     #   torch
-urllib3==1.26.19
+urllib3==2.2.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   requests
-varname==0.10.0
+varname==0.13.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   giving
diff --git a/benchmarks/huggingface/requirements.rocm.txt b/benchmarks/huggingface/requirements.rocm.txt
index 1f54d841a..653d2c59e 100644
--- a/benchmarks/huggingface/requirements.rocm.txt
+++ b/benchmarks/huggingface/requirements.rocm.txt
@@ -4,7 +4,7 @@
 #
 #    pip-compile --output-file=benchmarks/huggingface/requirements.rocm.txt .pin/tmp-constraints-rocm-hf.txt benchmarks/huggingface/requirements.in
 #
---extra-index-url https://download.pytorch.org/whl/rocm6.0
+--extra-index-url https://download.pytorch.org/whl/rocm6.1
 
 antlr4-python3-runtime==4.9.3
     # via
@@ -14,7 +14,7 @@ asttokens==2.4.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
-certifi==2024.7.4
+certifi==2024.8.30
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
@@ -22,15 +22,15 @@ charset-normalizer==3.3.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
-codefind==0.1.6
+codefind==0.1.7
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
-executing==1.2.0
+executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   varname
-filelock==3.15.4
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   huggingface-hub
@@ -42,17 +42,17 @@ fsspec==2024.6.1
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   huggingface-hub
     #   torch
-giving==0.4.2
+giving==0.4.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
     #   voir
-huggingface-hub==0.24.6
+huggingface-hub==0.25.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   tokenizers
     #   transformers
-idna==3.7
+idna==3.10
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
@@ -84,11 +84,15 @@ numpy==1.26.4
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   transformers
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
 omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
-ovld==0.3.8
+ovld==0.3.9
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
@@ -113,10 +117,6 @@ pygments==2.18.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   rich
-pynvml==11.5.3
-    # via
-    #   -c .pin/../.pin/constraints-rocm-torch.txt
-    #   voir
 pytorch-triton-rocm==3.0.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
@@ -131,7 +131,7 @@ reactivex==4.0.4
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
-regex==2024.7.24
+regex==2024.9.11
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   transformers
@@ -140,11 +140,11 @@ requests==2.32.3
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   huggingface-hub
     #   transformers
-rich==13.7.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
-safetensors==0.4.4
+safetensors==0.4.5
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   transformers
@@ -152,7 +152,7 @@ six==1.16.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   asttokens
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
@@ -160,7 +160,7 @@ tokenizers==0.19.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   transformers
-torch==2.4.0+rocm6.0
+torch==2.4.1+rocm6.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/huggingface/requirements.in
@@ -169,21 +169,23 @@ tqdm==4.66.5
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   huggingface-hub
     #   transformers
-transformers==4.44.1
+transformers==4.44.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -c .pin/../constraints/rocm.txt
     #   -r benchmarks/huggingface/requirements.in
 typing-extensions==4.12.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   huggingface-hub
     #   reactivex
+    #   rich
     #   torch
-urllib3==2.2.2
+urllib3==2.2.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
-varname==0.10.0
+varname==0.13.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
diff --git a/benchmarks/lightning/main.py b/benchmarks/lightning/main.py
index b31f3880c..4c3d1206f 100644
--- a/benchmarks/lightning/main.py
+++ b/benchmarks/lightning/main.py
@@ -1,14 +1,17 @@
 #!/usr/bin/env python
 
+
 import argparse
 import os
 
+# FIXME this is HPU only
+os.environ["PT_HPU_LAZY_MODE"] = str(int(int(os.getenv("WORLD_SIZE", -1)) <= 0))
+
 import torch
 import torch.nn.functional as F
 import lightning as L
 import torchvision.models as torchvision_models
 
-import torchcompat.core as accelerator
 from benchmate.dataloader import imagenet_dataloader, dataloader_arguments
 
 
@@ -37,10 +40,10 @@ def configure_optimizers(self):
 def prepare_voir():
     from benchmate.observer import BenchObserver
     from benchmate.monitor import bench_monitor
-
+    import torchcompat.core as accelerator
     observer = BenchObserver(
         accelerator.Event, 
-        earlystop=65,
+        earlystop=100,
         batch_size_fn=lambda x: len(x[0]),
         raise_stop_program=False,
         stdout=True,
@@ -49,6 +52,10 @@ def prepare_voir():
     return observer, bench_monitor
 
 def main():
+    rank = int(os.getenv("RANK", 0))
+    world_size = int(os.getenv("WORLD_SIZE", 1))
+    local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", 1))
+    
     parser = argparse.ArgumentParser(description='simple distributed training job')
     parser.add_argument(
         "--epochs",
@@ -64,17 +71,14 @@ def main():
     args = parser.parse_args()
     model = getattr(torchvision_models, args.model)()
 
-    rank = int(os.getenv("RANK", 0))
-    world_size = int(os.getenv("WORLD_SIZE", 1))
-    local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", 1))
-
+    import torchcompat.core as accelerator
+  
     n = accelerator.device_count()
+    n = local_world_size
     nnodes = world_size // local_world_size
 
     model = TorchvisionLightning(model)
 
-    
-   
     accelerator.set_enable_tf32(True)
 
     observer, monitor = prepare_voir()
@@ -85,16 +89,16 @@ def main():
         accelerator="auto", 
         devices=n, 
         num_nodes=nnodes, 
-        strategy="ddp",
+        strategy="auto",
         max_epochs=args.epochs,
-        precision="16-mixed",
+        precision="bf16-mixed",
         enable_checkpointing=False,
         enable_progress_bar=False,
         reload_dataloaders_every_n_epochs=1,
-        max_steps=100
+        max_steps=120
     )
 
-    with monitor():
+    with monitor(poll_interval=0.1):
         trainer.fit(model=model, train_dataloaders=loader)
     print("finished: ", rank)
 
diff --git a/benchmarks/lightning/requirements.cuda.txt b/benchmarks/lightning/requirements.cuda.txt
index d6823c252..04b4eb4b3 100644
--- a/benchmarks/lightning/requirements.cuda.txt
+++ b/benchmarks/lightning/requirements.cuda.txt
@@ -10,11 +10,11 @@
 --find-links https://data.pyg.org/whl/torch-2.4.0+cu121.html
 --trusted-host pypi.ngc.nvidia.com
 
-aiohappyeyeballs==2.4.0
+aiohappyeyeballs==2.4.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   aiohttp
-aiohttp==3.10.5
+aiohttp==3.10.8
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   fsspec
@@ -46,7 +46,7 @@ executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   varname
-filelock==3.16.0
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -75,19 +75,19 @@ importlib-resources==6.4.5
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torchcompat
-jax[cuda12]==0.4.31
+jax[cuda12]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r .pin/../constraints/extra/torch.cuda.txt
-jax-cuda12-pjrt==0.4.31
+jax-cuda12-pjrt==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
-jax-cuda12-plugin[with-cuda]==0.4.31
+jax-cuda12-plugin[with-cuda]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
-jaxlib==0.4.31
+jaxlib==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -141,7 +141,6 @@ numpy==1.26.4
     #   jax
     #   jaxlib
     #   ml-dtypes
-    #   opt-einsum
     #   scipy
     #   torchmetrics
     #   torchvision
@@ -158,7 +157,7 @@ nvidia-cuda-cupti-cu12==12.1.105
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-cuda-nvcc-cu12==12.6.68
+nvidia-cuda-nvcc-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -205,7 +204,7 @@ nvidia-nccl-cu12==2.20.5
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-nvjitlink-cu12==12.6.68
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -219,7 +218,7 @@ omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
-opt-einsum==3.3.0
+opt-einsum==3.4.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -264,7 +263,7 @@ reactivex==4.0.4
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   giving
-rich==13.8.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
@@ -277,7 +276,7 @@ six==1.16.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   asttokens
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -321,6 +320,7 @@ typing-extensions==4.12.2
     #   multidict
     #   pytorch-lightning
     #   reactivex
+    #   rich
     #   torch
 varname==0.13.3
     # via
@@ -335,7 +335,7 @@ xformers==0.0.27.post2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r .pin/../constraints/extra/torch.cuda.txt
-yarl==1.11.1
+yarl==1.13.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   aiohttp
diff --git a/benchmarks/lightning/requirements.hpu.txt b/benchmarks/lightning/requirements.hpu.txt
new file mode 100644
index 000000000..f86fb064d
--- /dev/null
+++ b/benchmarks/lightning/requirements.hpu.txt
@@ -0,0 +1,285 @@
+#
+# This file is autogenerated by pip-compile with Python 3.10
+# by the following command:
+#
+#    pip-compile --output-file=benchmarks/lightning/requirements.hpu.txt .pin/tmp-constraints-hpu-lightning-gpus.txt benchmarks/lightning/requirements.in
+#
+aiohappyeyeballs==2.4.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+aiohttp==3.10.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   fsspec
+aiosignal==1.3.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+antlr4-python3-runtime==4.9.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   omegaconf
+asttokens==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+async-timeout==4.0.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+attrs==24.2.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+codefind==0.1.7
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   ptera
+executing==2.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   varname
+filelock==3.16.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+    #   triton
+frozenlist==1.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+    #   aiosignal
+fsspec[http]==2024.6.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   lightning
+    #   pytorch-lightning
+    #   torch
+giving==0.4.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   ptera
+    #   voir
+idna==3.10
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   yarl
+importlib-resources==6.4.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torchcompat
+jinja2==3.1.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+lightning==2.4.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/lightning/requirements.in
+lightning-utilities==0.11.7
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   lightning
+    #   pytorch-lightning
+    #   torchmetrics
+markdown-it-py==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rich
+markupsafe==2.1.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   jinja2
+mdurl==0.1.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   markdown-it-py
+mpmath==1.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   sympy
+multidict==6.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+    #   yarl
+networkx==3.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+numpy==1.26.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torchmetrics
+    #   torchvision
+nvidia-cublas-cu12==12.1.3.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cudnn-cu12
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-cuda-cupti-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cuda-nvrtc-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cuda-runtime-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cudnn-cu12==9.1.0.70
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cufft-cu12==11.0.2.54
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-curand-cu12==10.3.2.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cusolver-cu12==11.4.5.107
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cusparse-cu12==12.1.0.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+nvidia-nccl-cu12==2.20.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-nvjitlink-cu12==12.6.77
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cusolver-cu12
+    #   nvidia-cusparse-cu12
+nvidia-nvtx-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+omegaconf==2.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+ovld==0.3.9
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+packaging==24.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   lightning
+    #   lightning-utilities
+    #   pytorch-lightning
+    #   torchmetrics
+pillow==10.4.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torchvision
+psutil==5.9.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+ptera==1.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+pygments==2.18.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rich
+pytorch-lightning==2.4.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   lightning
+pyyaml==6.0.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   lightning
+    #   omegaconf
+    #   pytorch-lightning
+reactivex==4.0.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+rich==13.9.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+six==1.16.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   asttokens
+sympy==1.13.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+torch==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/lightning/requirements.in
+    #   lightning
+    #   pytorch-lightning
+    #   torchmetrics
+    #   torchvision
+torchcompat==1.1.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/lightning/requirements.in
+torchmetrics==1.4.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   lightning
+    #   pytorch-lightning
+torchvision==0.19.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/lightning/requirements.in
+tqdm==4.66.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   lightning
+    #   pytorch-lightning
+triton==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+typing-extensions==4.12.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   lightning
+    #   lightning-utilities
+    #   multidict
+    #   pytorch-lightning
+    #   reactivex
+    #   rich
+    #   torch
+varname==0.13.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+voir==0.2.19
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/lightning/requirements.in
+yarl==1.13.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+
+# The following packages are considered to be unsafe in a requirements file:
+# setuptools
diff --git a/benchmarks/lightning/requirements.rocm.txt b/benchmarks/lightning/requirements.rocm.txt
index 26fdcedfa..aee2b1ba3 100644
--- a/benchmarks/lightning/requirements.rocm.txt
+++ b/benchmarks/lightning/requirements.rocm.txt
@@ -4,13 +4,13 @@
 #
 #    pip-compile --output-file=benchmarks/lightning/requirements.rocm.txt .pin/tmp-constraints-rocm-lightning-gpus.txt benchmarks/lightning/requirements.in
 #
---extra-index-url https://download.pytorch.org/whl/rocm6.0
+--extra-index-url https://download.pytorch.org/whl/rocm6.1
 
-aiohappyeyeballs==2.4.0
+aiohappyeyeballs==2.4.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
-aiohttp==3.10.5
+aiohttp==3.10.8
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   fsspec
@@ -34,15 +34,15 @@ attrs==24.2.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
-codefind==0.1.6
+codefind==0.1.7
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
-executing==1.2.0
+executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   varname
-filelock==3.15.4
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   pytorch-triton-rocm
@@ -58,16 +58,16 @@ fsspec[http]==2024.6.1
     #   lightning
     #   pytorch-lightning
     #   torch
-giving==0.4.2
+giving==0.4.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
     #   voir
-idna==3.7
+idna==3.10
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   yarl
-importlib-resources==6.4.3
+importlib-resources==6.4.5
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torchcompat
@@ -79,7 +79,7 @@ lightning==2.4.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/lightning/requirements.in
-lightning-utilities==0.11.6
+lightning-utilities==0.11.7
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   lightning
@@ -101,7 +101,7 @@ mpmath==1.3.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   sympy
-multidict==6.0.5
+multidict==6.1.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
@@ -115,11 +115,15 @@ numpy==1.26.4
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torchmetrics
     #   torchvision
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
 omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
-ovld==0.3.8
+ovld==0.3.9
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
@@ -146,10 +150,6 @@ pygments==2.18.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   rich
-pynvml==11.5.3
-    # via
-    #   -c .pin/../.pin/constraints-rocm-torch.txt
-    #   voir
 pytorch-lightning==2.4.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
@@ -168,7 +168,7 @@ reactivex==4.0.4
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
-rich==13.7.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
@@ -176,11 +176,11 @@ six==1.16.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   asttokens
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
-torch==2.4.0+rocm6.0
+torch==2.4.1+rocm6.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/lightning/requirements.in
@@ -193,12 +193,12 @@ torchcompat==1.1.4
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -c .pin/../constraints/rocm.txt
     #   -r benchmarks/lightning/requirements.in
-torchmetrics==1.4.1
+torchmetrics==1.4.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   lightning
     #   pytorch-lightning
-torchvision==0.19.0+rocm6.0
+torchvision==0.19.1+rocm6.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/lightning/requirements.in
@@ -212,19 +212,21 @@ typing-extensions==4.12.2
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   lightning
     #   lightning-utilities
+    #   multidict
     #   pytorch-lightning
     #   reactivex
+    #   rich
     #   torch
-varname==0.10.0
+varname==0.13.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
-voir==0.2.17
+voir==0.2.19
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -c .pin/../constraints/rocm.txt
     #   -r benchmarks/lightning/requirements.in
-yarl==1.9.4
+yarl==1.13.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
diff --git a/benchmarks/llama/requirements.cuda.txt b/benchmarks/llama/requirements.cuda.txt
index 7d972b40f..0b3188482 100644
--- a/benchmarks/llama/requirements.cuda.txt
+++ b/benchmarks/llama/requirements.cuda.txt
@@ -10,11 +10,11 @@
 --find-links https://data.pyg.org/whl/torch-2.4.0+cu121.html
 --trusted-host pypi.ngc.nvidia.com
 
-aiohappyeyeballs==2.4.0
+aiohappyeyeballs==2.4.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   aiohttp
-aiohttp==3.10.5
+aiohttp==3.10.8
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   datasets
@@ -51,7 +51,7 @@ codefind==0.1.7
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   ptera
-datasets==3.0.0
+datasets==3.0.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/llama/requirements.in
@@ -68,7 +68,7 @@ fairscale==0.4.13
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/llama/requirements.in
-filelock==3.16.0
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   datasets
@@ -76,7 +76,7 @@ filelock==3.16.0
     #   torch
     #   transformers
     #   triton
-fire==0.6.0
+fire==0.7.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/llama/requirements.in
@@ -96,7 +96,7 @@ giving==0.4.3
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   ptera
     #   voir
-huggingface-hub==0.24.7
+huggingface-hub==0.25.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   datasets
@@ -107,19 +107,19 @@ idna==3.10
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   requests
     #   yarl
-jax[cuda12]==0.4.31
+jax[cuda12]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r .pin/../constraints/extra/torch.cuda.txt
-jax-cuda12-pjrt==0.4.31
+jax-cuda12-pjrt==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
-jax-cuda12-plugin[with-cuda]==0.4.31
+jax-cuda12-plugin[with-cuda]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
-jaxlib==0.4.31
+jaxlib==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -169,7 +169,6 @@ numpy==1.26.4
     #   jax
     #   jaxlib
     #   ml-dtypes
-    #   opt-einsum
     #   pandas
     #   pyarrow
     #   scipy
@@ -187,7 +186,7 @@ nvidia-cuda-cupti-cu12==12.1.105
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-cuda-nvcc-cu12==12.6.68
+nvidia-cuda-nvcc-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -234,7 +233,7 @@ nvidia-nccl-cu12==2.20.5
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-nvjitlink-cu12==12.6.68
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -248,7 +247,7 @@ omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
-opt-einsum==3.3.0
+opt-einsum==3.4.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -262,7 +261,7 @@ packaging==24.1
     #   datasets
     #   huggingface-hub
     #   transformers
-pandas==2.2.2
+pandas==2.2.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   datasets
@@ -311,7 +310,7 @@ requests==2.32.3
     #   datasets
     #   huggingface-hub
     #   transformers
-rich==13.8.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
@@ -332,9 +331,8 @@ six==1.16.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   asttokens
-    #   fire
     #   python-dateutil
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -361,6 +359,7 @@ tqdm==4.66.5
 transformers==4.44.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
+    #   -c .pin/../constraints/cuda.txt
     #   -r benchmarks/llama/requirements.in
 triton==3.0.0
     # via
@@ -372,8 +371,9 @@ typing-extensions==4.12.2
     #   huggingface-hub
     #   multidict
     #   reactivex
+    #   rich
     #   torch
-tzdata==2024.1
+tzdata==2024.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   pandas
@@ -398,7 +398,7 @@ xxhash==3.5.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   datasets
-yarl==1.11.1
+yarl==1.13.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   aiohttp
diff --git a/benchmarks/llama/requirements.hpu.txt b/benchmarks/llama/requirements.hpu.txt
index 2368c1502..9c01a4dd6 100644
--- a/benchmarks/llama/requirements.hpu.txt
+++ b/benchmarks/llama/requirements.hpu.txt
@@ -4,11 +4,11 @@
 #
 #    pip-compile --output-file=benchmarks/llama/requirements.hpu.txt .pin/tmp-constraints-hpu-llm.txt benchmarks/llama/requirements.in
 #
---extra-index-url https://pypi.ngc.nvidia.com
---find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
---trusted-host pypi.ngc.nvidia.com
-
-aiohttp==3.9.5
+aiohappyeyeballs==2.4.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+aiohttp==3.10.8
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   datasets
@@ -29,11 +29,11 @@ async-timeout==4.0.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   aiohttp
-attrs==23.2.0
+attrs==24.2.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   aiohttp
-certifi==2024.6.2
+certifi==2024.8.30
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   requests
@@ -41,11 +41,11 @@ charset-normalizer==3.3.2
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   requests
-codefind==0.1.6
+codefind==0.1.7
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   ptera
-datasets==2.20.0
+datasets==3.0.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   -r benchmarks/llama/requirements.in
@@ -54,7 +54,7 @@ dill==0.3.8
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   datasets
     #   multiprocess
-executing==1.2.0
+executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   varname
@@ -62,7 +62,7 @@ fairscale==0.4.13
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   -r benchmarks/llama/requirements.in
-filelock==3.15.4
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   datasets
@@ -70,7 +70,7 @@ filelock==3.15.4
     #   torch
     #   transformers
     #   triton
-fire==0.6.0
+fire==0.7.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   -r benchmarks/llama/requirements.in
@@ -79,24 +79,24 @@ frozenlist==1.4.1
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   aiohttp
     #   aiosignal
-fsspec[http]==2024.5.0
+fsspec[http]==2024.6.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   datasets
     #   huggingface-hub
     #   torch
-giving==0.4.2
+giving==0.4.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   ptera
     #   voir
-huggingface-hub==0.24.0
+huggingface-hub==0.25.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   datasets
     #   tokenizers
     #   transformers
-idna==3.7
+idna==3.10
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   requests
@@ -121,7 +121,7 @@ mpmath==1.3.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   sympy
-multidict==6.0.5
+multidict==6.1.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   aiohttp
@@ -160,7 +160,7 @@ nvidia-cuda-runtime-cu12==12.1.105
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
-nvidia-cudnn-cu12==8.9.2.26
+nvidia-cudnn-cu12==9.1.0.70
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
@@ -181,11 +181,15 @@ nvidia-cusparse-cu12==12.1.0.106
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   nvidia-cusolver-cu12
     #   torch
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
 nvidia-nccl-cu12==2.20.5
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
-nvidia-nvjitlink-cu12==12.5.82
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   nvidia-cusolver-cu12
@@ -198,7 +202,7 @@ omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   voir
-ovld==0.3.5
+ovld==0.3.9
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   voir
@@ -208,7 +212,7 @@ packaging==24.1
     #   datasets
     #   huggingface-hub
     #   transformers
-pandas==2.2.2
+pandas==2.2.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   datasets
@@ -224,27 +228,19 @@ pyarrow==17.0.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   datasets
-pyarrow-hotfix==0.6
-    # via
-    #   -c .pin/../.pin/constraints-hpu-torch.txt
-    #   datasets
 pygments==2.18.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   rich
-pynvml==11.5.3
-    # via
-    #   -c .pin/../.pin/constraints-hpu-torch.txt
-    #   voir
 python-dateutil==2.9.0.post0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   pandas
-pytz==2024.1
+pytz==2024.2
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   pandas
-pyyaml==6.0.1
+pyyaml==6.0.2
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   datasets
@@ -255,7 +251,7 @@ reactivex==4.0.4
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   giving
-regex==2024.5.15
+regex==2024.9.11
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   transformers
@@ -265,11 +261,11 @@ requests==2.32.3
     #   datasets
     #   huggingface-hub
     #   transformers
-rich==13.7.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   voir
-safetensors==0.4.3
+safetensors==0.4.5
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   transformers
@@ -281,9 +277,8 @@ six==1.16.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   asttokens
-    #   fire
     #   python-dateutil
-sympy==1.13.0
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
@@ -295,22 +290,23 @@ tokenizers==0.19.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   transformers
-torch==2.3.1
+torch==2.4.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   -r benchmarks/llama/requirements.in
     #   fairscale
-tqdm==4.66.4
+tqdm==4.66.5
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   datasets
     #   huggingface-hub
     #   transformers
-transformers==4.42.4
+transformers==4.44.2
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -c .pin/../constraints/hpu.txt
     #   -r benchmarks/llama/requirements.in
-triton==2.3.1
+triton==3.0.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
@@ -318,17 +314,19 @@ typing-extensions==4.12.2
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   huggingface-hub
+    #   multidict
     #   reactivex
+    #   rich
     #   torch
-tzdata==2024.1
+tzdata==2024.2
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   pandas
-urllib3==1.26.19
+urllib3==2.2.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   requests
-varname==0.10.0
+varname==0.13.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   giving
@@ -337,11 +335,11 @@ voir==0.2.19
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   -c .pin/../constraints/hpu.txt
     #   -r benchmarks/llama/requirements.in
-xxhash==3.4.1
+xxhash==3.5.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   datasets
-yarl==1.9.4
+yarl==1.13.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   aiohttp
diff --git a/benchmarks/llama/requirements.rocm.txt b/benchmarks/llama/requirements.rocm.txt
index 97c44bb0c..41a93e559 100644
--- a/benchmarks/llama/requirements.rocm.txt
+++ b/benchmarks/llama/requirements.rocm.txt
@@ -4,13 +4,13 @@
 #
 #    pip-compile --output-file=benchmarks/llama/requirements.rocm.txt .pin/tmp-constraints-rocm-llm.txt benchmarks/llama/requirements.in
 #
---extra-index-url https://download.pytorch.org/whl/rocm6.0
+--extra-index-url https://download.pytorch.org/whl/rocm6.1
 
-aiohappyeyeballs==2.4.0
+aiohappyeyeballs==2.4.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
-aiohttp==3.10.5
+aiohttp==3.10.8
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   datasets
@@ -35,7 +35,7 @@ attrs==24.2.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
-certifi==2024.7.4
+certifi==2024.8.30
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
@@ -43,11 +43,11 @@ charset-normalizer==3.3.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
-codefind==0.1.6
+codefind==0.1.7
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
-datasets==2.21.0
+datasets==3.0.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/llama/requirements.in
@@ -56,7 +56,7 @@ dill==0.3.8
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   datasets
     #   multiprocess
-executing==1.2.0
+executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   varname
@@ -64,7 +64,7 @@ fairscale==0.4.13
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/llama/requirements.in
-filelock==3.15.4
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   datasets
@@ -72,7 +72,7 @@ filelock==3.15.4
     #   pytorch-triton-rocm
     #   torch
     #   transformers
-fire==0.6.0
+fire==0.7.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/llama/requirements.in
@@ -87,18 +87,18 @@ fsspec[http]==2024.6.1
     #   datasets
     #   huggingface-hub
     #   torch
-giving==0.4.2
+giving==0.4.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
     #   voir
-huggingface-hub==0.24.6
+huggingface-hub==0.25.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   datasets
     #   tokenizers
     #   transformers
-idna==3.7
+idna==3.10
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
@@ -123,7 +123,7 @@ mpmath==1.3.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   sympy
-multidict==6.0.5
+multidict==6.1.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
@@ -144,11 +144,15 @@ numpy==1.26.4
     #   pandas
     #   pyarrow
     #   transformers
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
 omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
-ovld==0.3.8
+ovld==0.3.9
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
@@ -158,7 +162,7 @@ packaging==24.1
     #   datasets
     #   huggingface-hub
     #   transformers
-pandas==2.2.2
+pandas==2.2.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   datasets
@@ -178,10 +182,6 @@ pygments==2.18.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   rich
-pynvml==11.5.3
-    # via
-    #   -c .pin/../.pin/constraints-rocm-torch.txt
-    #   voir
 python-dateutil==2.9.0.post0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
@@ -190,7 +190,7 @@ pytorch-triton-rocm==3.0.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
-pytz==2024.1
+pytz==2024.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   pandas
@@ -205,7 +205,7 @@ reactivex==4.0.4
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
-regex==2024.7.24
+regex==2024.9.11
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   transformers
@@ -215,11 +215,11 @@ requests==2.32.3
     #   datasets
     #   huggingface-hub
     #   transformers
-rich==13.7.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
-safetensors==0.4.4
+safetensors==0.4.5
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   transformers
@@ -231,9 +231,8 @@ six==1.16.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   asttokens
-    #   fire
     #   python-dateutil
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
@@ -245,7 +244,7 @@ tokenizers==0.19.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   transformers
-torch==2.4.0+rocm6.0
+torch==2.4.1+rocm6.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/llama/requirements.in
@@ -256,25 +255,28 @@ tqdm==4.66.5
     #   datasets
     #   huggingface-hub
     #   transformers
-transformers==4.44.1
+transformers==4.44.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -c .pin/../constraints/rocm.txt
     #   -r benchmarks/llama/requirements.in
 typing-extensions==4.12.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   huggingface-hub
+    #   multidict
     #   reactivex
+    #   rich
     #   torch
-tzdata==2024.1
+tzdata==2024.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   pandas
-urllib3==2.2.2
+urllib3==2.2.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
-varname==0.10.0
+varname==0.13.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
@@ -287,7 +289,7 @@ xxhash==3.5.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   datasets
-yarl==1.9.4
+yarl==1.13.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
diff --git a/benchmarks/llava/benchfile.py b/benchmarks/llava/benchfile.py
index 3bc06eaa7..d6d40d6e7 100644
--- a/benchmarks/llava/benchfile.py
+++ b/benchmarks/llava/benchfile.py
@@ -19,7 +19,9 @@ class Llava(Package):
     def make_env(self):
         # Return a dict of environment variables for prepare_script and
         # main_script.
-        return super().make_env()
+        env = super().make_env()
+        env["PT_HPU_LAZY_MODE"] = "0"
+        return env
 
     async def install(self):
         await super().install()  # super() call installs the requirements
diff --git a/benchmarks/llava/main.py b/benchmarks/llava/main.py
index 879baca01..6c49b04a6 100755
--- a/benchmarks/llava/main.py
+++ b/benchmarks/llava/main.py
@@ -1,7 +1,6 @@
 #!/usr/bin/env python
 
 from dataclasses import dataclass
-
 import torch
 from accelerate import Accelerator
 from accelerate.utils import set_seed
@@ -63,8 +62,12 @@ def main():
         "llava-hf/llava-1.5-7b-hf",
         torch_dtype=torch.bfloat16,
         device_map=compat.device_type,
+        revision="a272c74b2481d8aff3aa6fc2c4bf891fe57334fb"
+    )
+    processor = AutoProcessor.from_pretrained(
+        "llava-hf/llava-1.5-7b-hf",
+        revision="a272c74b2481d8aff3aa6fc2c4bf891fe57334fb"
     )
-    processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
 
     # Load dataset and create DataLoader
     dataset = load_dataset("HuggingFaceM4/the_cauldron", "aokvqa")["train"]
@@ -90,8 +93,11 @@ def batch_size_fn(batch):
     optimizer = observer.optimizer(torch.optim.AdamW(model.parameters(), lr=5e-5))
     model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
 
+    # model = torch.compile(model,backend="hpu_backend")
+
     for epoch in range(args.epochs):
         for i, batch in enumerate(observer.iterate(dataloader)):
+            print("HERE")
             images = batch["images"][0]  # Access the first item in the list of images
             texts = batch["texts"]
             prompt = apply_chat_template(texts)
@@ -124,7 +130,9 @@ def batch_size_fn(batch):
             if accelerator.sync_gradients:
                 accelerator.clip_grad_norm_(model.parameters(), 1.0)
 
+            compat.mark_step()
             optimizer.step()
+            compat.mark_step()
             optimizer.zero_grad()
             observer.record_loss(loss)
 
diff --git a/benchmarks/llava/prepare.py b/benchmarks/llava/prepare.py
index afa480b86..5e8b018f3 100755
--- a/benchmarks/llava/prepare.py
+++ b/benchmarks/llava/prepare.py
@@ -11,8 +11,12 @@ def main():
         "llava-hf/llava-1.5-7b-hf",
         torch_dtype=torch.float32,  # Change to float32
         device_map="auto",
+        revision="a272c74b2481d8aff3aa6fc2c4bf891fe57334fb"
+    )
+    _ = AutoProcessor.from_pretrained(
+        "llava-hf/llava-1.5-7b-hf",
+        revision="a272c74b2481d8aff3aa6fc2c4bf891fe57334fb"
     )
-    _ = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
 
     # Load dataset and create DataLoader
     _ = load_dataset("HuggingFaceM4/the_cauldron", "aokvqa")["train"]
diff --git a/benchmarks/llava/requirements.cuda.txt b/benchmarks/llava/requirements.cuda.txt
index 02cc24fbc..5c6f9f64b 100644
--- a/benchmarks/llava/requirements.cuda.txt
+++ b/benchmarks/llava/requirements.cuda.txt
@@ -14,11 +14,11 @@ accelerate==0.34.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/llava/requirements.in
-aiohappyeyeballs==2.4.0
+aiohappyeyeballs==2.4.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   aiohttp
-aiohttp==3.10.5
+aiohttp==3.10.8
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   datasets
@@ -55,7 +55,7 @@ codefind==0.1.7
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   ptera
-datasets==3.0.0
+datasets==3.0.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/llava/requirements.in
@@ -68,7 +68,7 @@ executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   varname
-filelock==3.16.0
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   datasets
@@ -92,7 +92,7 @@ giving==0.4.3
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   ptera
     #   voir
-huggingface-hub==0.24.7
+huggingface-hub==0.25.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   accelerate
@@ -104,19 +104,19 @@ idna==3.10
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   requests
     #   yarl
-jax[cuda12]==0.4.31
+jax[cuda12]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r .pin/../constraints/extra/torch.cuda.txt
-jax-cuda12-pjrt==0.4.31
+jax-cuda12-pjrt==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
-jax-cuda12-plugin[with-cuda]==0.4.31
+jax-cuda12-plugin[with-cuda]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
-jaxlib==0.4.31
+jaxlib==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -167,7 +167,6 @@ numpy==1.26.4
     #   jax
     #   jaxlib
     #   ml-dtypes
-    #   opt-einsum
     #   pandas
     #   pyarrow
     #   scipy
@@ -185,7 +184,7 @@ nvidia-cuda-cupti-cu12==12.1.105
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-cuda-nvcc-cu12==12.6.68
+nvidia-cuda-nvcc-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -232,7 +231,7 @@ nvidia-nccl-cu12==2.20.5
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-nvjitlink-cu12==12.6.68
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -246,7 +245,7 @@ omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
-opt-einsum==3.3.0
+opt-einsum==3.4.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -261,7 +260,7 @@ packaging==24.1
     #   datasets
     #   huggingface-hub
     #   transformers
-pandas==2.2.2
+pandas==2.2.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   datasets
@@ -316,7 +315,7 @@ requests==2.32.3
     #   datasets
     #   huggingface-hub
     #   transformers
-rich==13.8.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
@@ -335,7 +334,7 @@ six==1.16.0
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   asttokens
     #   python-dateutil
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -358,6 +357,7 @@ tqdm==4.66.5
 transformers==4.44.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
+    #   -c .pin/../constraints/cuda.txt
     #   -r benchmarks/llava/requirements.in
 triton==3.0.0
     # via
@@ -369,8 +369,9 @@ typing-extensions==4.12.2
     #   huggingface-hub
     #   multidict
     #   reactivex
+    #   rich
     #   torch
-tzdata==2024.1
+tzdata==2024.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   pandas
@@ -395,7 +396,7 @@ xxhash==3.5.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   datasets
-yarl==1.11.1
+yarl==1.13.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   aiohttp
diff --git a/benchmarks/llava/requirements.hpu.txt b/benchmarks/llava/requirements.hpu.txt
new file mode 100644
index 000000000..3bd40dff2
--- /dev/null
+++ b/benchmarks/llava/requirements.hpu.txt
@@ -0,0 +1,343 @@
+#
+# This file is autogenerated by pip-compile with Python 3.10
+# by the following command:
+#
+#    pip-compile --output-file=benchmarks/llava/requirements.hpu.txt .pin/tmp-constraints-hpu-llava-single.txt benchmarks/llava/requirements.in
+#
+accelerate==0.34.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/llava/requirements.in
+aiohappyeyeballs==2.4.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+aiohttp==3.10.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+    #   fsspec
+aiosignal==1.3.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+antlr4-python3-runtime==4.9.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   omegaconf
+asttokens==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+async-timeout==4.0.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+attrs==24.2.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+certifi==2024.8.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+charset-normalizer==3.3.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+codefind==0.1.7
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   ptera
+datasets==3.0.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/llava/requirements.in
+dill==0.3.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+    #   multiprocess
+executing==2.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   varname
+filelock==3.16.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+    #   huggingface-hub
+    #   torch
+    #   transformers
+    #   triton
+frozenlist==1.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+    #   aiosignal
+fsspec[http]==2024.6.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+    #   huggingface-hub
+    #   torch
+giving==0.4.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   ptera
+    #   voir
+huggingface-hub==0.25.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   accelerate
+    #   datasets
+    #   tokenizers
+    #   transformers
+idna==3.10
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+    #   yarl
+jinja2==3.1.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+markdown-it-py==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rich
+markupsafe==2.1.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   jinja2
+mdurl==0.1.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   markdown-it-py
+mpmath==1.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   sympy
+multidict==6.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+    #   yarl
+multiprocess==0.70.16
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+networkx==3.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+numpy==1.26.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/llava/requirements.in
+    #   accelerate
+    #   datasets
+    #   pandas
+    #   pyarrow
+    #   transformers
+nvidia-cublas-cu12==12.1.3.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cudnn-cu12
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-cuda-cupti-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cuda-nvrtc-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cuda-runtime-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cudnn-cu12==9.1.0.70
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cufft-cu12==11.0.2.54
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-curand-cu12==10.3.2.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cusolver-cu12==11.4.5.107
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cusparse-cu12==12.1.0.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+nvidia-nccl-cu12==2.20.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-nvjitlink-cu12==12.6.77
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cusolver-cu12
+    #   nvidia-cusparse-cu12
+nvidia-nvtx-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+omegaconf==2.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+ovld==0.3.9
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+packaging==24.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   accelerate
+    #   datasets
+    #   huggingface-hub
+    #   transformers
+pandas==2.2.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+pillow==10.4.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/llava/requirements.in
+psutil==5.9.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   accelerate
+    #   voir
+ptera==1.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+pyarrow==17.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+pygments==2.18.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rich
+python-dateutil==2.9.0.post0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pandas
+pytz==2024.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pandas
+pyyaml==6.0.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   accelerate
+    #   datasets
+    #   huggingface-hub
+    #   omegaconf
+    #   transformers
+reactivex==4.0.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+regex==2024.9.11
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   transformers
+requests==2.32.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+    #   huggingface-hub
+    #   transformers
+rich==13.9.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+safetensors==0.4.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   accelerate
+    #   transformers
+six==1.16.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   asttokens
+    #   python-dateutil
+sympy==1.13.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+tokenizers==0.19.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   transformers
+torch==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/llava/requirements.in
+    #   accelerate
+tqdm==4.66.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+    #   huggingface-hub
+    #   transformers
+transformers==4.44.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/llava/requirements.in
+triton==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+typing-extensions==4.12.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   huggingface-hub
+    #   multidict
+    #   reactivex
+    #   rich
+    #   torch
+tzdata==2024.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pandas
+urllib3==2.2.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+varname==0.13.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+voir==0.2.19
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/llava/requirements.in
+xxhash==3.5.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+yarl==1.13.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
diff --git a/benchmarks/llava/requirements.rocm.txt b/benchmarks/llava/requirements.rocm.txt
new file mode 100644
index 000000000..fe11f280d
--- /dev/null
+++ b/benchmarks/llava/requirements.rocm.txt
@@ -0,0 +1,293 @@
+#
+# This file is autogenerated by pip-compile with Python 3.10
+# by the following command:
+#
+#    pip-compile --output-file=benchmarks/llava/requirements.rocm.txt .pin/tmp-constraints-rocm-llava-single.txt benchmarks/llava/requirements.in
+#
+--extra-index-url https://download.pytorch.org/whl/rocm6.1
+
+accelerate==0.34.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/llava/requirements.in
+aiohappyeyeballs==2.4.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   aiohttp
+aiohttp==3.10.8
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   datasets
+    #   fsspec
+aiosignal==1.3.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   aiohttp
+antlr4-python3-runtime==4.9.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   omegaconf
+asttokens==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   giving
+async-timeout==4.0.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   aiohttp
+attrs==24.2.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   aiohttp
+certifi==2024.8.30
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   requests
+charset-normalizer==3.3.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   requests
+codefind==0.1.7
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   ptera
+datasets==3.0.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/llava/requirements.in
+dill==0.3.8
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   datasets
+    #   multiprocess
+executing==2.1.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   varname
+filelock==3.16.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   datasets
+    #   huggingface-hub
+    #   pytorch-triton-rocm
+    #   torch
+    #   transformers
+frozenlist==1.4.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   aiohttp
+    #   aiosignal
+fsspec[http]==2024.6.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   datasets
+    #   huggingface-hub
+    #   torch
+giving==0.4.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   ptera
+    #   voir
+huggingface-hub==0.25.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   accelerate
+    #   datasets
+    #   tokenizers
+    #   transformers
+idna==3.10
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   requests
+    #   yarl
+jinja2==3.1.4
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   torch
+markdown-it-py==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   rich
+markupsafe==2.1.5
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   jinja2
+mdurl==0.1.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   markdown-it-py
+mpmath==1.3.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   sympy
+multidict==6.1.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   aiohttp
+    #   yarl
+multiprocess==0.70.16
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   datasets
+networkx==3.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   torch
+numpy==1.26.4
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/llava/requirements.in
+    #   accelerate
+    #   datasets
+    #   pandas
+    #   pyarrow
+    #   transformers
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
+omegaconf==2.3.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
+ovld==0.3.9
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
+packaging==24.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   accelerate
+    #   datasets
+    #   huggingface-hub
+    #   transformers
+pandas==2.2.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   datasets
+pillow==10.4.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/llava/requirements.in
+psutil==5.9.8
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   accelerate
+    #   voir
+ptera==1.4.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
+pyarrow==17.0.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   datasets
+pygments==2.18.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   rich
+python-dateutil==2.9.0.post0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   pandas
+pytorch-triton-rocm==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   torch
+pytz==2024.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   pandas
+pyyaml==6.0.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   accelerate
+    #   datasets
+    #   huggingface-hub
+    #   omegaconf
+    #   transformers
+reactivex==4.0.4
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   giving
+regex==2024.9.11
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   transformers
+requests==2.32.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   datasets
+    #   huggingface-hub
+    #   transformers
+rich==13.9.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
+safetensors==0.4.5
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   accelerate
+    #   transformers
+six==1.16.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   asttokens
+    #   python-dateutil
+sympy==1.13.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   torch
+tokenizers==0.19.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   transformers
+torch==2.4.1+rocm6.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/llava/requirements.in
+    #   accelerate
+tqdm==4.66.5
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   datasets
+    #   huggingface-hub
+    #   transformers
+transformers==4.44.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -c .pin/../constraints/rocm.txt
+    #   -r benchmarks/llava/requirements.in
+typing-extensions==4.12.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   huggingface-hub
+    #   multidict
+    #   reactivex
+    #   rich
+    #   torch
+tzdata==2024.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   pandas
+urllib3==2.2.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   requests
+varname==0.13.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   giving
+voir==0.2.19
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -c .pin/../constraints/rocm.txt
+    #   -r benchmarks/llava/requirements.in
+xxhash==3.5.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   datasets
+yarl==1.13.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   aiohttp
diff --git a/benchmarks/llm/configs/llama3_70B_full.yaml b/benchmarks/llm/configs/llama3_70B_full.yaml
index ae5cf2afb..703eb876e 100644
--- a/benchmarks/llm/configs/llama3_70B_full.yaml
+++ b/benchmarks/llm/configs/llama3_70B_full.yaml
@@ -82,7 +82,7 @@ optimizer:
   foreach: False
   # Note: highly recommended to use fused=True optimizer flag
   # with CPU offload for faster optimizer step.
-  fused: True
+  fused: true
 
 loss:
   _component_: torch.nn.CrossEntropyLoss
@@ -94,9 +94,9 @@ gradient_accumulation_steps: 1
 device: cuda
 
 # Memory management
-enable_activation_checkpointing: True
-memory_efficient_fsdp_wrap: True
-fsdp_cpu_offload: True
+enable_activation_checkpointing: true
+memory_efficient_fsdp_wrap: true
+fsdp_cpu_offload: true
 
 # Reduced precision
 dtype: bf16
diff --git a/benchmarks/llm/recipes/full_finetune_distributed.py b/benchmarks/llm/recipes/full_finetune_distributed.py
index 3a51842da..f8d58e2f4 100755
--- a/benchmarks/llm/recipes/full_finetune_distributed.py
+++ b/benchmarks/llm/recipes/full_finetune_distributed.py
@@ -16,6 +16,7 @@
 import torch
 from omegaconf import DictConfig, ListConfig
 
+import torchcompat.core as acc
 from torch import nn
 from torch.distributed import init_process_group
 from torch.distributed.fsdp import (
@@ -38,6 +39,8 @@
 
 log = utils.get_logger("DEBUG")
 
+HPU_UNSUPPORTED = False
+
 
 class FullFinetuneRecipeDistributed(FTRecipeInterface):
     """
@@ -97,8 +100,8 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface):
     """
 
     def __init__(self, cfg: DictConfig) -> None:
-
-        self._device = utils.get_device(device=cfg.device)
+        import os
+        self._device = acc.fetch_device(int(os.getenv("LOCAL_RANK", "0")))
         self._dtype = utils.get_dtype(cfg.dtype, device=self._device)
 
         if self._dtype == torch.float16:
@@ -131,7 +134,10 @@ def __init__(self, cfg: DictConfig) -> None:
 
         # These are public properties which are updated by the checkpoint loader
         # when ``resume_from_checkpoint`` is `True` or validated in tests
-        self.seed = utils.set_seed(seed=cfg.seed)
+        if HPU_UNSUPPORTED:
+            self.seed = utils.set_seed(seed=cfg.seed)
+        else:
+            self.seed = 1
         self.epochs_run = 0
         self.total_epochs = cfg.epochs
         self.max_steps_per_epoch = cfg.max_steps_per_epoch
@@ -351,8 +357,10 @@ def _setup_model(
             )
 
         if self._is_rank_zero:
-            memory_stats = utils.get_memory_stats(device=self._device)
-            utils.log_memory_stats(memory_stats)
+            if HPU_UNSUPPORTED:
+                pass
+                #memory_stats = utils.get_memory_stats(device=self._device)
+                #utils.log_memory_stats(memory_stats)
 
         # synchronize before training begins
         torch.distributed.barrier()
@@ -413,6 +421,7 @@ def _setup_data(
             dataset=ds,
             batch_size=batch_size,
             sampler=sampler,
+            # persistent_workers=True,
             collate_fn=partial(
                 utils.padded_collate,
                 padding_idx=self._tokenizer.pad_id,
@@ -543,31 +552,14 @@ def train(self) -> None:
                         f"{curr_epoch+1}|{self.global_step}|Loss: {loss_to_log}"
                     )
 
-                    # Log per-step metrics
-                    if (
-                        self.global_step % self._log_every_n_steps == 0
-                        and self._is_rank_zero
-                    ):
-                        time_per_step = time.perf_counter() - t0
-                        log_dict = {
-                            "loss": loss_to_log,
-                            "lr": self._optimizer.param_groups[0]["lr"],
-                            "tokens_per_second_per_gpu": num_tokens / time_per_step,
-                        }
-                        if self._log_peak_memory_stats:
-                            log_dict.update(utils.get_memory_stats(device=self._device))
-                        self._metric_logger.log_dict(
-                            log_dict,
-                            step=self.global_step,
-                        )
-
                     # Reset running stats for the next step
                     running_loss = 0
                     num_tokens = 0
                     t0 = time.perf_counter()
-
+                    
+            print("HERE")
             self.epochs_run += 1
-            self.save_checkpoint(epoch=curr_epoch)
+            # self.save_checkpoint(epoch=curr_epoch)
 
     def cleanup(self) -> None:
         if self._is_rank_zero:
@@ -618,7 +610,8 @@ def recipe_main(cfg: DictConfig) -> None:
             "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
         )
 
-    init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
+    acc.init_process_group()
+
     if cfg.get("fsdp_cpu_offload", False):
         # Utilize all available CPU cores for intra-op parallelism. This provides ~2x
         # speed up when benchmarking fused AdamW on CPU
diff --git a/benchmarks/llm/recipes/full_finetune_single_device.py b/benchmarks/llm/recipes/full_finetune_single_device.py
index 98322579f..629b0e9a9 100755
--- a/benchmarks/llm/recipes/full_finetune_single_device.py
+++ b/benchmarks/llm/recipes/full_finetune_single_device.py
@@ -97,7 +97,7 @@ class FullFinetuneRecipeSingleDevice(FTRecipeInterface):
     """
 
     def __init__(self, cfg: DictConfig) -> None:
-        self._device = utils.get_device(device=cfg.device)
+        self._device = accelerator.fetch_device(int(os.getenv("HABANA_VISIBLE_MODULES", "0").split(",")[0]))
         self._dtype = utils.get_dtype(cfg.dtype, device=self._device)
         # Disable for fp16, as we haven't validated "full" fp16 with this recipe, nor
         # enabled necessary features such as gradient scaling.
@@ -279,9 +279,9 @@ def _setup_model(
             log.info("Compiling model with torch.compile...")
             backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
             model.compile(backend=backend)
-        if self._device.type == "cuda":
-            memory_stats = utils.get_memory_stats(device=self._device)
-            utils.log_memory_stats(memory_stats)
+        # if self._device.type == "cuda":
+        #     memory_stats = utils.get_memory_stats(device=self._device)
+        #     utils.log_memory_stats(memory_stats)
 
         return model
 
@@ -487,8 +487,8 @@ def train(self) -> None:
                             ),
                             "tokens_per_second_per_gpu": num_tokens / time_per_step,
                         }
-                        if self._device.type == "cuda" and self._log_peak_memory_stats:
-                            log_dict.update(utils.get_memory_stats(device=self._device))
+                        # if self._device.type == "cuda" and self._log_peak_memory_stats:
+                        #     log_dict.update(utils.get_memory_stats(device=self._device))
                         self._metric_logger.log_dict(
                             log_dict,
                             step=self.global_step,
diff --git a/benchmarks/llm/recipes/lora_finetune_distributed.py b/benchmarks/llm/recipes/lora_finetune_distributed.py
index 18b736fbf..ae7c5b403 100755
--- a/benchmarks/llm/recipes/lora_finetune_distributed.py
+++ b/benchmarks/llm/recipes/lora_finetune_distributed.py
@@ -16,6 +16,7 @@
 
 import torch
 from omegaconf import DictConfig, ListConfig
+import torchcompat.core as acc
 
 from torch import nn
 from torch.distributed import destroy_process_group, init_process_group
@@ -44,6 +45,9 @@
 log = utils.get_logger("DEBUG")
 
 
+HPU_UNSUPPORTED = False
+
+
 class LoRAFinetuneRecipeDistributed(FTRecipeInterface):
     """
     Distributed LoRA finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe supports
@@ -108,7 +112,7 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface):
     """
 
     def __init__(self, cfg: DictConfig) -> None:
-        self._device = utils.get_device(device=cfg.device)
+        self._device = acc.fetch_device(int(os.getenv("LOCAL_RANK", "0")))
         self._dtype = utils.get_dtype(cfg.dtype, device=self._device)
 
         if self._dtype == torch.float16:
@@ -132,7 +136,11 @@ def __init__(self, cfg: DictConfig) -> None:
 
         # These attributes constitute the recipe state and are updated by ``load_checkpoint``
         # when ``resume_from_checkpoint`` is ``True``
-        self.seed = utils.set_seed(seed=cfg.seed)
+        if HPU_UNSUPPORTED:
+            self.seed = utils.set_seed(seed=cfg.seed)
+        else:
+            self.seed = 1
+        
         self.epochs_run = 0
         self.total_epochs = cfg.epochs
         self.max_steps_per_epoch = cfg.max_steps_per_epoch
@@ -428,7 +436,7 @@ def _setup_model(
             # Initialize empty modules on all non-zero ranks
             param_init_fn=(
                 lambda module: module.to_empty(
-                    device=torch.device("cuda"), recurse=False
+                    device=self._device, recurse=False
                 )
                 if not self._is_rank_zero
                 else None
@@ -443,8 +451,10 @@ def _setup_model(
                 model, auto_wrap_policy={modules.TransformerDecoderLayer}
             )
         if self._is_rank_zero:
-            memory_stats = utils.get_memory_stats(device=self._device)
-            utils.log_memory_stats(memory_stats)
+            if HPU_UNSUPPORTED:
+                pass
+                # memory_stats = utils.get_memory_stats(device=self._device)
+                # utils.log_memory_stats(memory_stats)
 
         # synchronize before training begins
         torch.distributed.barrier()
@@ -703,8 +713,9 @@ def train(self) -> None:
                             "lr": self._optimizer.param_groups[0]["lr"],
                             "tokens_per_second_per_gpu": num_tokens / time_per_step,
                         }
-                        if self._log_peak_memory_stats:
-                            log_dict.update(utils.get_memory_stats(device=self._device))
+                        # if self._log_peak_memory_stats:
+                        #     if HPU_UNSUPPORTED:
+                        #         log_dict.update(utils.get_memory_stats(device=self._device))
                         self._metric_logger.log_dict(
                             log_dict,
                             step=self.global_step,
@@ -773,7 +784,7 @@ def recipe_main(cfg: DictConfig) -> None:
             "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
         )
     os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
-    init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
+    acc.init_process_group()
 
     config.log_config(recipe_name="LoRAFinetuneRecipeDistributed", cfg=cfg)
 
diff --git a/benchmarks/llm/recipes/lora_finetune_single_device.py b/benchmarks/llm/recipes/lora_finetune_single_device.py
index cf5256ead..9060d2036 100755
--- a/benchmarks/llm/recipes/lora_finetune_single_device.py
+++ b/benchmarks/llm/recipes/lora_finetune_single_device.py
@@ -101,8 +101,9 @@ class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface):
     """
 
     def __init__(self, cfg: DictConfig) -> None:
-
-        self._device = utils.get_device(device=cfg.device)
+        import torchcompat.core as accelerator
+         
+        self._device = accelerator.fetch_device(int(os.getenv("HABANA_VISIBLE_MODULES", "0").split(",")[0]))
         # Reduced precision logic
         self._dtype = utils.get_dtype(cfg.dtype, device=self._device)
         # fp16 precision is explicitly disabled as it is not supported in this
@@ -388,9 +389,9 @@ def _setup_model(
             log.info("Compiling model with torch.compile...")
             backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
             model.compile(backend=backend)
-        if self._device.type == "cuda":
-            memory_stats = utils.get_memory_stats(device=self._device)
-            utils.log_memory_stats(memory_stats)
+        # if self._device.type == "cuda":
+        #     memory_stats = utils.get_memory_stats(device=self._device)
+        #     utils.log_memory_stats(memory_stats)
         return model
 
     def _setup_optimizer(
@@ -528,7 +529,8 @@ def train(self) -> None:
         """
         The core training loop.
         """
-
+        import torchcompat.core as accelerator
+        
         if self._model_compile:
             log.info(
                 "NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration."
@@ -579,10 +581,13 @@ def train(self) -> None:
                     loss = self._loss_fn(logits, labels) / self._gradient_accumulation_steps
                     running_loss += loss
                     loss.backward()
+                    accelerator.mark_step()
 
                     # Step with optimizer
                     if (idx + 1) % self._gradient_accumulation_steps == 0:
                         self._optimizer.step()
+                        accelerator.mark_step()
+                        
                         self._optimizer.zero_grad(set_to_none=True)
                         self._lr_scheduler.step()
                         # Update the number of steps when the weights are updated
@@ -603,13 +608,13 @@ def train(self) -> None:
                                 "lr": self._optimizer.param_groups[0]["lr"],
                                 "tokens_per_second_per_gpu": num_tokens / time_per_step,
                             }
-                            if (
-                                self._device.type == "cuda"
-                                and self._log_peak_memory_stats
-                            ):
-                                log_dict.update(
-                                    utils.get_memory_stats(device=self._device)
-                                )
+                            # if (
+                            #     self._device.type == "cuda"
+                            #     and self._log_peak_memory_stats
+                            # ):
+                            #     log_dict.update(
+                            #         utils.get_memory_stats(device=self._device)
+                            #     )
                             self._metric_logger.log_dict(
                                 log_dict,
                                 step=self.global_step,
diff --git a/benchmarks/llm/recipes/ppo_full_finetune_single_device.py b/benchmarks/llm/recipes/ppo_full_finetune_single_device.py
index 8ee77c06a..fbf8630a2 100644
--- a/benchmarks/llm/recipes/ppo_full_finetune_single_device.py
+++ b/benchmarks/llm/recipes/ppo_full_finetune_single_device.py
@@ -496,9 +496,9 @@ def _setup_model(
             ref_policy_model.compile(backend=backend)
             value_model.compile(backend=backend)
 
-        if self._device.type == "cuda":
-            memory_stats = utils.get_memory_stats(device=self._device)
-            utils.log_memory_stats(memory_stats)
+        # if self._device.type == "cuda":
+        #     memory_stats = utils.get_memory_stats(device=self._device)
+        #     utils.log_memory_stats(memory_stats)
 
         return policy_model, value_model, reward_model, ref_policy_model
 
@@ -1031,8 +1031,8 @@ def log_metrics(
             "approx_policy_kl": ppo_stats.approx_policy_kls.mean(),
             "response_lengths": trajectory.seq_lens.float().mean(),
         }
-        if self._device.type == "cuda" and self._log_peak_memory_stats:
-            log_dict.update(utils.get_memory_stats(device=self._device))
+        # if self._device.type == "cuda" and self._log_peak_memory_stats:
+        #     log_dict.update(utils.get_memory_stats(device=self._device))
 
         self._metric_logger.log_dict(log_dict, step=self.global_step)
 
diff --git a/benchmarks/llm/requirements.cuda.txt b/benchmarks/llm/requirements.cuda.txt
index 0e1e0010a..db34901fd 100644
--- a/benchmarks/llm/requirements.cuda.txt
+++ b/benchmarks/llm/requirements.cuda.txt
@@ -14,11 +14,11 @@ accelerate==0.34.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/llm/requirements.in
-aiohappyeyeballs==2.4.0
+aiohappyeyeballs==2.4.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   aiohttp
-aiohttp==3.10.5
+aiohttp==3.10.8
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   datasets
@@ -64,7 +64,7 @@ codefind==0.1.7
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   ptera
-datasets==3.0.0
+datasets==3.0.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torchtune
@@ -82,7 +82,7 @@ fairscale==0.4.13
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/llm/requirements.in
     #   -r benchmarks/llm/requirements.txt
-filelock==3.16.0
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   blobfile
@@ -91,7 +91,7 @@ filelock==3.16.0
     #   torch
     #   transformers
     #   triton
-fire==0.6.0
+fire==0.7.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/llm/requirements.txt
@@ -115,7 +115,7 @@ hjson==3.1.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   argklass
-huggingface-hub==0.24.7
+huggingface-hub==0.25.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   accelerate
@@ -132,19 +132,19 @@ importlib-resources==6.4.5
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   argklass
-jax[cuda12]==0.4.31
+jax[cuda12]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r .pin/../constraints/extra/torch.cuda.txt
-jax-cuda12-pjrt==0.4.31
+jax-cuda12-pjrt==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
-jax-cuda12-plugin[with-cuda]==0.4.31
+jax-cuda12-plugin[with-cuda]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
-jaxlib==0.4.31
+jaxlib==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -199,7 +199,6 @@ numpy==1.26.4
     #   jax
     #   jaxlib
     #   ml-dtypes
-    #   opt-einsum
     #   pandas
     #   pyarrow
     #   scipy
@@ -218,7 +217,7 @@ nvidia-cuda-cupti-cu12==12.1.105
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-cuda-nvcc-cu12==12.6.68
+nvidia-cuda-nvcc-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -265,7 +264,7 @@ nvidia-nccl-cu12==2.20.5
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-nvjitlink-cu12==12.6.68
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -280,7 +279,7 @@ omegaconf==2.3.0
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torchtune
     #   voir
-opt-einsum==3.3.0
+opt-einsum==3.4.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -295,7 +294,7 @@ packaging==24.1
     #   datasets
     #   huggingface-hub
     #   transformers
-pandas==2.2.2
+pandas==2.2.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   datasets
@@ -312,7 +311,7 @@ pyarrow==17.0.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   datasets
-pycryptodomex==3.20.0
+pycryptodomex==3.21.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   blobfile
@@ -353,7 +352,7 @@ requests==2.32.3
     #   huggingface-hub
     #   tiktoken
     #   transformers
-rich==13.8.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
@@ -376,9 +375,8 @@ six==1.16.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   asttokens
-    #   fire
     #   python-dateutil
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -405,10 +403,13 @@ torch==2.4.0+cu121
 torchao==0.3.1+cu121
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
+    #   -c .pin/../constraints/cuda.txt
+    #   -r benchmarks/llm/requirements.in
     #   torchtune
 torchtune==0.2.1+cu121
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
+    #   -c .pin/../constraints/cuda.txt
     #   -r benchmarks/llm/requirements.in
 tqdm==4.66.5
     # via
@@ -420,6 +421,7 @@ tqdm==4.66.5
 transformers==4.44.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
+    #   -c .pin/../constraints/cuda.txt
     #   -r benchmarks/llm/requirements.in
 triton==3.0.0
     # via
@@ -431,8 +433,9 @@ typing-extensions==4.12.2
     #   huggingface-hub
     #   multidict
     #   reactivex
+    #   rich
     #   torch
-tzdata==2024.1
+tzdata==2024.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   pandas
@@ -458,7 +461,7 @@ xxhash==3.5.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   datasets
-yarl==1.11.1
+yarl==1.13.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   aiohttp
diff --git a/benchmarks/llm/requirements.hpu.txt b/benchmarks/llm/requirements.hpu.txt
new file mode 100644
index 000000000..9b88be532
--- /dev/null
+++ b/benchmarks/llm/requirements.hpu.txt
@@ -0,0 +1,408 @@
+#
+# This file is autogenerated by pip-compile with Python 3.10
+# by the following command:
+#
+#    pip-compile --output-file=benchmarks/llm/requirements.hpu.txt .pin/tmp-constraints-hpu-llm-full-mp-nodes.txt benchmarks/llm/requirements.in
+#
+accelerate==0.34.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/llm/requirements.in
+aiohappyeyeballs==2.4.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+aiohttp==3.10.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+    #   fsspec
+aiosignal==1.3.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+antlr4-python3-runtime==4.9.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   omegaconf
+argklass==1.4.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/llm/requirements.in
+asttokens==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+async-timeout==4.0.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+attrs==24.2.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+blobfile==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/llm/requirements.txt
+    #   torchtune
+certifi==2024.8.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+charset-normalizer==3.3.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+codefind==0.1.7
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   ptera
+datasets==3.0.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torchtune
+dill==0.3.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+    #   multiprocess
+executing==2.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   varname
+fairscale==0.4.13
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/llm/requirements.in
+    #   -r benchmarks/llm/requirements.txt
+filelock==3.16.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   blobfile
+    #   datasets
+    #   huggingface-hub
+    #   torch
+    #   transformers
+    #   triton
+fire==0.7.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/llm/requirements.txt
+frozenlist==1.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+    #   aiosignal
+fsspec[http]==2024.6.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+    #   huggingface-hub
+    #   torch
+giving==0.4.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   ptera
+    #   voir
+hjson==3.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   argklass
+huggingface-hub==0.25.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   accelerate
+    #   datasets
+    #   tokenizers
+    #   torchtune
+    #   transformers
+idna==3.10
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+    #   yarl
+importlib-resources==6.4.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   argklass
+jinja2==3.1.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+lxml==5.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   blobfile
+markdown-it-py==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rich
+markupsafe==2.1.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   jinja2
+mdurl==0.1.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   markdown-it-py
+mpmath==1.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   sympy
+multidict==6.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+    #   yarl
+multiprocess==0.70.16
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+networkx==3.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+numpy==1.26.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   accelerate
+    #   datasets
+    #   fairscale
+    #   pandas
+    #   pyarrow
+    #   torchtune
+    #   transformers
+nvidia-cublas-cu12==12.1.3.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cudnn-cu12
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-cuda-cupti-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cuda-nvrtc-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cuda-runtime-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cudnn-cu12==9.1.0.70
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cufft-cu12==11.0.2.54
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-curand-cu12==10.3.2.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cusolver-cu12==11.4.5.107
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cusparse-cu12==12.1.0.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+nvidia-nccl-cu12==2.20.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-nvjitlink-cu12==12.6.77
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cusolver-cu12
+    #   nvidia-cusparse-cu12
+nvidia-nvtx-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+omegaconf==2.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torchtune
+    #   voir
+ovld==0.3.9
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+packaging==24.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   accelerate
+    #   datasets
+    #   huggingface-hub
+    #   transformers
+pandas==2.2.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+psutil==5.9.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   accelerate
+    #   voir
+ptera==1.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+pyarrow==17.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+pycryptodomex==3.21.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   blobfile
+pygments==2.18.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rich
+python-dateutil==2.9.0.post0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pandas
+pytz==2024.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pandas
+pyyaml==6.0.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/llm/requirements.in
+    #   accelerate
+    #   datasets
+    #   huggingface-hub
+    #   omegaconf
+    #   transformers
+reactivex==4.0.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+regex==2024.9.11
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   tiktoken
+    #   transformers
+requests==2.32.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+    #   huggingface-hub
+    #   tiktoken
+    #   transformers
+rich==13.9.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+safetensors==0.4.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   accelerate
+    #   torchtune
+    #   transformers
+sentencepiece==0.2.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torchtune
+six==1.16.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   asttokens
+    #   python-dateutil
+sympy==1.13.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+termcolor==2.4.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   fire
+tiktoken==0.7.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torchtune
+tokenizers==0.19.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   transformers
+torch==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/llm/requirements.in
+    #   -r benchmarks/llm/requirements.txt
+    #   accelerate
+    #   fairscale
+torchao==0.3.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/llm/requirements.in
+    #   torchtune
+torchtune==0.2.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/llm/requirements.in
+tqdm==4.66.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+    #   huggingface-hub
+    #   torchtune
+    #   transformers
+transformers==4.44.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/llm/requirements.in
+triton==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+typing-extensions==4.12.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   huggingface-hub
+    #   multidict
+    #   reactivex
+    #   rich
+    #   torch
+tzdata==2024.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pandas
+urllib3==2.2.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   blobfile
+    #   requests
+varname==0.13.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+voir==0.2.19
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/llm/requirements.in
+xxhash==3.5.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+yarl==1.13.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
diff --git a/benchmarks/llm/requirements.in b/benchmarks/llm/requirements.in
index 91b62c073..a3ab63c07 100644
--- a/benchmarks/llm/requirements.in
+++ b/benchmarks/llm/requirements.in
@@ -1,9 +1,10 @@
 voir>=0.2.19,<0.3
-torchtune
+torchtune<0.3.0
 torch
 PyYAML
 argklass
 fairscale
+torchao
 
 # Prepare
 accelerate
diff --git a/benchmarks/llm/requirements.rocm.txt b/benchmarks/llm/requirements.rocm.txt
index ab5098d08..055089f04 100644
--- a/benchmarks/llm/requirements.rocm.txt
+++ b/benchmarks/llm/requirements.rocm.txt
@@ -4,13 +4,17 @@
 #
 #    pip-compile --output-file=benchmarks/llm/requirements.rocm.txt .pin/tmp-constraints-rocm-llm-full-mp-nodes.txt benchmarks/llm/requirements.in
 #
---extra-index-url https://download.pytorch.org/whl/rocm6.0
+--extra-index-url https://download.pytorch.org/whl/rocm6.1
 
-aiohappyeyeballs==2.4.0
+accelerate==0.34.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/llm/requirements.in
+aiohappyeyeballs==2.4.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
-aiohttp==3.10.5
+aiohttp==3.10.8
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   datasets
@@ -39,11 +43,12 @@ attrs==24.2.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
-blobfile==2.1.1
+blobfile==3.0.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/llm/requirements.txt
     #   torchtune
-certifi==2024.7.4
+certifi==2024.8.30
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
@@ -51,11 +56,11 @@ charset-normalizer==3.3.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
-codefind==0.1.6
+codefind==0.1.7
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
-datasets==2.21.0
+datasets==3.0.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torchtune
@@ -64,11 +69,16 @@ dill==0.3.8
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   datasets
     #   multiprocess
-executing==1.2.0
+executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   varname
-filelock==3.15.4
+fairscale==0.4.13
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/llm/requirements.in
+    #   -r benchmarks/llm/requirements.txt
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   blobfile
@@ -76,6 +86,11 @@ filelock==3.15.4
     #   huggingface-hub
     #   pytorch-triton-rocm
     #   torch
+    #   transformers
+fire==0.7.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/llm/requirements.txt
 frozenlist==1.4.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
@@ -87,7 +102,7 @@ fsspec[http]==2024.6.1
     #   datasets
     #   huggingface-hub
     #   torch
-giving==0.4.2
+giving==0.4.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
@@ -96,17 +111,20 @@ hjson==3.1.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   argklass
-huggingface-hub==0.24.6
+huggingface-hub==0.25.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   accelerate
     #   datasets
+    #   tokenizers
     #   torchtune
-idna==3.7
+    #   transformers
+idna==3.10
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
     #   yarl
-importlib-resources==6.4.3
+importlib-resources==6.4.5
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   argklass
@@ -114,7 +132,7 @@ jinja2==3.1.4
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
-lxml==4.9.4
+lxml==5.3.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   blobfile
@@ -134,7 +152,7 @@ mpmath==1.3.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   sympy
-multidict==6.0.5
+multidict==6.1.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
@@ -150,31 +168,41 @@ networkx==3.3
 numpy==1.26.4
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   accelerate
     #   datasets
+    #   fairscale
     #   pandas
     #   pyarrow
     #   torchtune
+    #   transformers
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
 omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torchtune
     #   voir
-ovld==0.3.8
+ovld==0.3.9
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
 packaging==24.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   accelerate
     #   datasets
     #   huggingface-hub
-pandas==2.2.2
+    #   transformers
+pandas==2.2.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   datasets
 psutil==5.9.8
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   accelerate
     #   voir
 ptera==1.4.1
     # via
@@ -184,7 +212,7 @@ pyarrow==17.0.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   datasets
-pycryptodomex==3.20.0
+pycryptodomex==3.21.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   blobfile
@@ -192,10 +220,6 @@ pygments==2.18.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   rich
-pynvml==11.5.3
-    # via
-    #   -c .pin/../.pin/constraints-rocm-torch.txt
-    #   voir
 python-dateutil==2.9.0.post0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
@@ -204,7 +228,7 @@ pytorch-triton-rocm==3.0.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
-pytz==2024.1
+pytz==2024.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   pandas
@@ -212,31 +236,37 @@ pyyaml==6.0.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/llm/requirements.in
+    #   accelerate
     #   datasets
     #   huggingface-hub
     #   omegaconf
+    #   transformers
 reactivex==4.0.4
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
-regex==2024.7.24
+regex==2024.9.11
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   tiktoken
+    #   transformers
 requests==2.32.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   datasets
     #   huggingface-hub
     #   tiktoken
-rich==13.7.1
+    #   transformers
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
-safetensors==0.4.4
+safetensors==0.4.5
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   accelerate
     #   torchtune
+    #   transformers
 sentencepiece==0.2.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
@@ -246,25 +276,39 @@ six==1.16.0
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   asttokens
     #   python-dateutil
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
+termcolor==2.4.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   fire
 tiktoken==0.7.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torchtune
-torch==2.4.0+rocm6.0
+tokenizers==0.19.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   transformers
+torch==2.4.1+rocm6.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/llm/requirements.in
+    #   -r benchmarks/llm/requirements.txt
+    #   accelerate
+    #   fairscale
 torchao==0.3.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -c .pin/../constraints/rocm.txt
+    #   -r benchmarks/llm/requirements.in
     #   torchtune
 torchtune==0.2.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -c .pin/../constraints/rocm.txt
     #   -r benchmarks/llm/requirements.in
 tqdm==4.66.5
     # via
@@ -272,26 +316,34 @@ tqdm==4.66.5
     #   datasets
     #   huggingface-hub
     #   torchtune
+    #   transformers
+transformers==4.44.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -c .pin/../constraints/rocm.txt
+    #   -r benchmarks/llm/requirements.in
 typing-extensions==4.12.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   huggingface-hub
+    #   multidict
     #   reactivex
+    #   rich
     #   torch
-tzdata==2024.1
+tzdata==2024.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   pandas
-urllib3==2.2.2
+urllib3==2.2.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   blobfile
     #   requests
-varname==0.10.0
+varname==0.13.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
-voir==0.2.17
+voir==0.2.19
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -c .pin/../constraints/rocm.txt
@@ -300,7 +352,7 @@ xxhash==3.5.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   datasets
-yarl==1.9.4
+yarl==1.13.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
diff --git a/benchmarks/purejaxrl/benchfile.py b/benchmarks/purejaxrl/benchfile.py
index 08a51cef0..ab1c0ee73 100644
--- a/benchmarks/purejaxrl/benchfile.py
+++ b/benchmarks/purejaxrl/benchfile.py
@@ -18,7 +18,9 @@ class Template(Package):
     def make_env(self):
         # Return a dict of environment variables for prepare_script and
         # main_script.
-        return super().make_env()
+        env = super().make_env()
+        env["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
+        return env
 
     async def install(self):
         await super().install()  # super() call installs the requirements
diff --git a/benchmarks/purejaxrl/dqn.py b/benchmarks/purejaxrl/dqn.py
index 17c839147..fc0a97b8d 100644
--- a/benchmarks/purejaxrl/dqn.py
+++ b/benchmarks/purejaxrl/dqn.py
@@ -50,7 +50,9 @@ def make_train(config):
     config["NUM_UPDATES"] = config["TOTAL_TIMESTEPS"] // config["NUM_ENVS"]
 
     from benchmate.timings import StepTimer
+    from benchmate.jaxmem import memory_peak_fetcher
     step_timer = StepTimer(give_push())
+    fetch_memory_peak = memory_peak_fetcher()
 
     basic_env, env_params = gymnax.make(config["ENV_NAME"])
     env = FlattenObservationWrapper(basic_env)
@@ -238,6 +240,7 @@ def callback(metrics):
                     
                     step_timer.step(delta.item())
                     step_timer.log(returns=returns, loss=loss)
+                    step_timer.log(memory_peak=fetch_memory_peak(), units="MiB")
                     step_timer.end()
 
             jax.debug.callback(callback, metrics)
@@ -258,12 +261,49 @@ def callback(metrics):
     return train
 
 
+# When using nvidia-smi to monitor memory
+#   arg: --buffer_size
+#   model:
+#     256: 61900.25 MiB
+#     1000: 61900.25 MiB
+#     10000: 61900.25 MiB
+
+# dqn:
+#   arg: --num_envs
+#   model:
+#     2: 61900.25 MiB
+#     4: 61900.25 MiB
+#     16: 61900.25 MiB
+#     32: 61900.25 MiB
+#     64: 61900.25 MiB
+#     128: 61900.25 MiB
+
+#   arg: --total_timesteps
+#   model:
+#     32768: 61900.25 MiB
+#     65536: 61900.25 MiB
+
+# When using Jax to monitor memory
+
+# dqn.D0 [stdout] Device: cuda:0
+# dqn.D0 [stdout]   num_allocs: 0.0006799697875976562 MiB
+# dqn.D0 [stdout]   bytes_in_use: 0.915771484375 MiB
+# dqn.D0 [stdout]   peak_bytes_in_use: 80.41552734375 MiB
+# dqn.D0 [stdout]   largest_alloc_size: 16.07958984375 MiB
+# dqn.D0 [stdout]   bytes_limit: 60832.359375 MiB
+# dqn.D0 [stdout]   bytes_reserved: 0.0 MiB
+# dqn.D0 [stdout]   peak_bytes_reserved: 0.0 MiB
+# dqn.D0 [stdout]   largest_free_block_bytes: 0.0 MiB
+# dqn.D0 [stdout]   pool_bytes: 60832.359375 MiB
+# dqn.D0 [stdout]   peak_pool_bytes: 60832.359375 MiB
+
+
 @dataclass
 class Arguments:
-    num_envs: int = 10
-    buffer_size: int = 10000
+    num_envs: int = 10                  # No impact on memory
+    buffer_size: int = 10000            # No impact on memory
     buffer_batch_size: int = 128
-    total_timesteps: int = 100_000
+    total_timesteps: int = 100_000      # No impact on memory
     epsilon_start: float =  1.0
     epsilon_finish: float = 0.05
     epsilon_anneal_time: int = 25e4
diff --git a/benchmarks/purejaxrl/main.py b/benchmarks/purejaxrl/main.py
index f37c45e0d..c3a3630dd 100644
--- a/benchmarks/purejaxrl/main.py
+++ b/benchmarks/purejaxrl/main.py
@@ -6,6 +6,7 @@
 
 import argklass
 
+import torch  # This is a bit of a trick to make jax use torch's packaged libs
 
 from dqn import add_dqn_command, main as dqn_main
 from ppo import add_ppo_command, main as ppo_main
diff --git a/benchmarks/purejaxrl/ppo.py b/benchmarks/purejaxrl/ppo.py
index a053373f3..0cc8896cc 100644
--- a/benchmarks/purejaxrl/ppo.py
+++ b/benchmarks/purejaxrl/ppo.py
@@ -75,7 +75,10 @@ class Transition(NamedTuple):
 
 def make_train(config):
     from benchmate.timings import StepTimer
+    from benchmate.jaxmem import memory_peak_fetcher
+    
     step_timer = StepTimer(give_push())
+    fetch_memory_peak = memory_peak_fetcher()
 
     config["NUM_UPDATES"] = (
         config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
@@ -280,6 +283,7 @@ def callback(info):
 
                 step_timer.step(config["NUM_ENVS"] * config["NUM_STEPS"])
                 step_timer.log(loss=loss)
+                step_timer.log(memory_peak=fetch_memory_peak(), units="MiB")
                 step_timer.end()
                 
             jax.debug.callback(callback, metrics)
diff --git a/benchmarks/purejaxrl/requirements.cuda.txt b/benchmarks/purejaxrl/requirements.cuda.txt
index a59468762..3f09e47f7 100644
--- a/benchmarks/purejaxrl/requirements.cuda.txt
+++ b/benchmarks/purejaxrl/requirements.cuda.txt
@@ -32,7 +32,7 @@ argklass==1.4.4
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/purejaxrl/requirements.in
-astroid==3.2.4
+astroid==3.3.4
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   pylint
@@ -61,7 +61,7 @@ charset-normalizer==3.3.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   requests
-chex==0.1.86
+chex==0.1.87
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   distrax
@@ -157,7 +157,7 @@ farama-notifications==0.0.4
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   gymnasium
-filelock==3.16.0
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -188,7 +188,7 @@ flax==0.9.0
     #   flashbax
     #   gymnax
     #   navix
-fonttools==4.53.1
+fonttools==4.54.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   matplotlib
@@ -218,7 +218,7 @@ glfw==2.7.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   mujoco
-grpcio==1.66.1
+grpcio==1.66.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   brax
@@ -269,7 +269,7 @@ itsdangerous==2.2.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   flask
-jax[cuda12]==0.4.31
+jax[cuda12]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r .pin/../constraints/extra/torch.cuda.txt
@@ -286,15 +286,15 @@ jax[cuda12]==0.4.31
     #   optax
     #   orbax-checkpoint
     #   rlax
-jax-cuda12-pjrt==0.4.31
+jax-cuda12-pjrt==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
-jax-cuda12-plugin[with-cuda]==0.4.31
+jax-cuda12-plugin[with-cuda]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
-jaxlib==0.4.31
+jaxlib==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   brax
@@ -366,12 +366,12 @@ msgpack==1.1.0
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   flax
     #   orbax-checkpoint
-mujoco==3.2.2
+mujoco==3.2.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   brax
     #   mujoco-mjx
-mujoco-mjx==3.2.2
+mujoco-mjx==3.2.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   brax
@@ -411,7 +411,6 @@ numpy==1.26.4
     #   ml-dtypes
     #   mujoco
     #   navix
-    #   opt-einsum
     #   optax
     #   orbax-checkpoint
     #   pandas
@@ -435,7 +434,7 @@ nvidia-cuda-cupti-cu12==12.1.105
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-cuda-nvcc-cu12==12.6.68
+nvidia-cuda-nvcc-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -482,7 +481,7 @@ nvidia-nccl-cu12==2.20.5
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-nvjitlink-cu12==12.6.68
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -496,7 +495,7 @@ omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
-opt-einsum==3.3.0
+opt-einsum==3.4.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -506,7 +505,7 @@ optax==0.2.3
     #   -r benchmarks/purejaxrl/requirements.in
     #   brax
     #   flax
-orbax-checkpoint==0.6.3
+orbax-checkpoint==0.6.4
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   brax
@@ -523,7 +522,7 @@ packaging==24.1
     #   pytest
     #   setuptools-scm
     #   tensorboardx
-pandas==2.2.2
+pandas==2.2.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   seaborn
@@ -537,7 +536,7 @@ pillow==10.4.0
     #   brax
     #   matplotlib
     #   navix
-platformdirs==4.3.3
+platformdirs==4.3.6
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   black
@@ -547,7 +546,7 @@ pluggy==1.5.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   pytest
-protobuf==5.28.1
+protobuf==5.28.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   orbax-checkpoint
@@ -574,7 +573,7 @@ pygments==2.18.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   rich
-pylint==3.2.7
+pylint==3.3.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   navix
@@ -621,7 +620,7 @@ requests==2.32.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   wandb
-rich==13.8.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   flax
@@ -643,7 +642,7 @@ seaborn==0.13.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   gymnax
-sentry-sdk==2.14.0
+sentry-sdk==2.15.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   wandb
@@ -671,7 +670,7 @@ smmap==5.0.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   gitdb
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -683,13 +682,13 @@ tensorflow-probability==0.24.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   distrax
-tensorstore==0.1.65
+tensorstore==0.1.66
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   flashbax
     #   flax
     #   orbax-checkpoint
-tomli==2.0.1
+tomli==2.0.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   black
@@ -732,13 +731,14 @@ typing-extensions==4.12.2
     #   navix
     #   orbax-checkpoint
     #   reactivex
+    #   rich
     #   torch
     #   tyro
-tyro==0.8.10
+tyro==0.8.11
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   navix
-tzdata==2024.1
+tzdata==2024.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   pandas
@@ -756,7 +756,7 @@ voir==0.2.19
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -c .pin/../constraints/cuda.txt
     #   -r benchmarks/purejaxrl/requirements.in
-wandb==0.18.0
+wandb==0.18.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   navix
diff --git a/benchmarks/purejaxrl/requirements.hpu.txt b/benchmarks/purejaxrl/requirements.hpu.txt
new file mode 100644
index 000000000..aeb2b1101
--- /dev/null
+++ b/benchmarks/purejaxrl/requirements.hpu.txt
@@ -0,0 +1,743 @@
+#
+# This file is autogenerated by pip-compile with Python 3.10
+# by the following command:
+#
+#    pip-compile --output-file=benchmarks/purejaxrl/requirements.hpu.txt .pin/tmp-constraints-hpu-ppo.txt benchmarks/purejaxrl/requirements.in
+#
+absl-py==2.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   brax
+    #   chex
+    #   distrax
+    #   dm-env
+    #   ml-collections
+    #   mujoco
+    #   mujoco-mjx
+    #   optax
+    #   orbax-checkpoint
+    #   rlax
+    #   tensorflow-probability
+antlr4-python3-runtime==4.9.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   omegaconf
+argklass==1.4.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+astroid==3.3.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pylint
+asttokens==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+black==24.8.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   navix
+blinker==1.8.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   flask
+brax==0.10.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+certifi==2024.8.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+    #   sentry-sdk
+charset-normalizer==3.3.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+chex==0.1.87
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   distrax
+    #   evosax
+    #   flashbax
+    #   gymnax
+    #   optax
+    #   rlax
+click==8.1.7
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   black
+    #   flask
+    #   wandb
+cloudpickle==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   gym
+    #   gymnasium
+    #   tensorflow-probability
+codefind==0.1.7
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   ptera
+contextlib2==21.6.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   ml-collections
+contourpy==1.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   matplotlib
+cycler==0.12.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   matplotlib
+decorator==5.1.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   tensorflow-probability
+dill==0.3.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pylint
+distrax==0.1.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+    #   rlax
+dm-env==1.6
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   brax
+    #   rlax
+dm-tree==0.1.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   dm-env
+    #   tensorflow-probability
+docker-pycreds==0.4.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   wandb
+docstring-parser==0.16
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   tyro
+dotmap==1.3.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   evosax
+etils[epath,epy]==1.9.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   brax
+    #   mujoco
+    #   mujoco-mjx
+    #   optax
+    #   orbax-checkpoint
+evosax==0.1.6
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+exceptiongroup==1.2.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pytest
+executing==2.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   varname
+farama-notifications==0.0.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   gymnasium
+filelock==3.16.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+    #   triton
+flake8==7.1.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   navix
+flashbax==0.1.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+flask==3.0.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   brax
+    #   flask-cors
+flask-cors==5.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   brax
+flax==0.9.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+    #   brax
+    #   evosax
+    #   flashbax
+    #   gymnax
+    #   navix
+fonttools==4.54.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   matplotlib
+fsspec==2024.6.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   etils
+    #   torch
+gast==0.6.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   tensorflow-probability
+gitdb==4.0.11
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   gitpython
+gitpython==3.1.43
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   wandb
+giving==0.4.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   ptera
+    #   voir
+glfw==2.7.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   mujoco
+grpcio==1.66.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   brax
+gym==0.26.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   brax
+    #   gymnax
+gym-notices==0.0.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   gym
+gymnasium==0.29.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   gymnax
+gymnax==0.0.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+hjson==3.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   argklass
+humanize==4.10.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   orbax-checkpoint
+idna==3.10
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+importlib-resources==6.4.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   argklass
+    #   etils
+iniconfig==2.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pytest
+isort==5.13.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pylint
+itsdangerous==2.2.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   flask
+jax==0.4.33
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+    #   brax
+    #   chex
+    #   distrax
+    #   evosax
+    #   flashbax
+    #   flax
+    #   gymnax
+    #   jaxopt
+    #   mujoco-mjx
+    #   optax
+    #   orbax-checkpoint
+    #   rlax
+jaxlib==0.4.33
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   brax
+    #   chex
+    #   distrax
+    #   evosax
+    #   flashbax
+    #   gymnax
+    #   jax
+    #   jaxopt
+    #   mujoco-mjx
+    #   optax
+    #   orbax-checkpoint
+    #   rlax
+jaxopt==0.8.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   brax
+jinja2==3.1.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   brax
+    #   flask
+    #   torch
+kiwisolver==1.4.7
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   matplotlib
+markdown-it-py==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rich
+markupsafe==2.1.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   jinja2
+    #   werkzeug
+matplotlib==3.9.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   evosax
+    #   gymnax
+    #   seaborn
+mccabe==0.7.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   flake8
+    #   pylint
+mdurl==0.1.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   markdown-it-py
+ml-collections==0.1.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   brax
+ml-dtypes==0.5.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   jax
+    #   jaxlib
+    #   tensorstore
+mpmath==1.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   sympy
+msgpack==1.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   flax
+    #   orbax-checkpoint
+mujoco==3.2.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   brax
+    #   mujoco-mjx
+mujoco-mjx==3.2.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   brax
+mypy-extensions==1.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   black
+navix==0.7.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+nest-asyncio==1.6.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   orbax-checkpoint
+networkx==3.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+numpy==1.26.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+    #   brax
+    #   chex
+    #   contourpy
+    #   distrax
+    #   dm-env
+    #   evosax
+    #   flashbax
+    #   gym
+    #   gymnasium
+    #   jax
+    #   jaxlib
+    #   jaxopt
+    #   matplotlib
+    #   ml-dtypes
+    #   mujoco
+    #   navix
+    #   optax
+    #   orbax-checkpoint
+    #   pandas
+    #   rlax
+    #   scipy
+    #   seaborn
+    #   tensorboardx
+    #   tensorflow-probability
+    #   tensorstore
+    #   trimesh
+nvidia-cublas-cu12==12.1.3.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cudnn-cu12
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-cuda-cupti-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cuda-nvrtc-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cuda-runtime-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cudnn-cu12==9.1.0.70
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cufft-cu12==11.0.2.54
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-curand-cu12==10.3.2.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cusolver-cu12==11.4.5.107
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cusparse-cu12==12.1.0.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+nvidia-nccl-cu12==2.20.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-nvjitlink-cu12==12.6.77
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cusolver-cu12
+    #   nvidia-cusparse-cu12
+nvidia-nvtx-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+omegaconf==2.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+opt-einsum==3.4.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   jax
+optax==0.2.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+    #   brax
+    #   flax
+orbax-checkpoint==0.6.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   brax
+    #   flax
+ovld==0.3.9
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+packaging==24.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   black
+    #   matplotlib
+    #   pytest
+    #   setuptools-scm
+    #   tensorboardx
+pandas==2.2.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   seaborn
+pathspec==0.12.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   black
+pillow==10.4.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   brax
+    #   matplotlib
+    #   navix
+platformdirs==4.3.6
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   black
+    #   pylint
+    #   wandb
+pluggy==1.5.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pytest
+protobuf==5.28.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   orbax-checkpoint
+    #   tensorboardx
+    #   wandb
+psutil==5.9.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+    #   wandb
+ptera==1.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+pycodestyle==2.12.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   flake8
+pyflakes==3.2.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   flake8
+pygments==2.18.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rich
+pylint==3.3.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   navix
+pyopengl==3.1.7
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   mujoco
+pyparsing==3.1.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   matplotlib
+pytest==8.3.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   navix
+python-dateutil==2.9.0.post0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   matplotlib
+    #   pandas
+pytinyrenderer==0.0.14
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   brax
+pytz==2024.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pandas
+pyyaml==6.0.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   evosax
+    #   flax
+    #   gymnax
+    #   ml-collections
+    #   omegaconf
+    #   orbax-checkpoint
+    #   wandb
+reactivex==4.0.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+requests==2.32.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   wandb
+rich==13.9.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   flax
+    #   tyro
+    #   voir
+rlax==0.1.6
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   navix
+scipy==1.14.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   brax
+    #   jax
+    #   jaxlib
+    #   jaxopt
+    #   mujoco-mjx
+seaborn==0.13.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   gymnax
+sentry-sdk==2.15.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   wandb
+setproctitle==1.3.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   wandb
+setuptools-scm==8.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   navix
+shtab==1.7.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   tyro
+six==1.16.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   asttokens
+    #   docker-pycreds
+    #   ml-collections
+    #   python-dateutil
+    #   tensorflow-probability
+smmap==5.0.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   gitdb
+sympy==1.13.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+tensorboardx==2.6.2.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   brax
+tensorflow-probability==0.24.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   distrax
+tensorstore==0.1.66
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   flashbax
+    #   flax
+    #   orbax-checkpoint
+tomli==2.0.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   black
+    #   pylint
+    #   pytest
+    #   setuptools-scm
+tomlkit==0.13.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pylint
+toolz==0.12.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   chex
+torch==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+trimesh==4.4.9
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   brax
+    #   mujoco-mjx
+triton==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+typing-extensions==4.12.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   astroid
+    #   black
+    #   brax
+    #   chex
+    #   etils
+    #   flashbax
+    #   flax
+    #   gymnasium
+    #   navix
+    #   orbax-checkpoint
+    #   reactivex
+    #   rich
+    #   torch
+    #   tyro
+tyro==0.8.11
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   navix
+tzdata==2024.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pandas
+urllib3==2.2.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+    #   sentry-sdk
+varname==0.13.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+voir==0.2.19
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+wandb==0.18.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   navix
+werkzeug==3.0.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   flask
+zipp==3.20.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   etils
+
+# The following packages are considered to be unsafe in a requirements file:
+# setuptools
diff --git a/benchmarks/purejaxrl/requirements.rocm.txt b/benchmarks/purejaxrl/requirements.rocm.txt
new file mode 100644
index 000000000..226415e04
--- /dev/null
+++ b/benchmarks/purejaxrl/requirements.rocm.txt
@@ -0,0 +1,693 @@
+#
+# This file is autogenerated by pip-compile with Python 3.10
+# by the following command:
+#
+#    pip-compile --output-file=benchmarks/purejaxrl/requirements.rocm.txt .pin/tmp-constraints-rocm-ppo.txt benchmarks/purejaxrl/requirements.in
+#
+--extra-index-url https://download.pytorch.org/whl/rocm6.1
+
+absl-py==2.1.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   brax
+    #   chex
+    #   distrax
+    #   dm-env
+    #   ml-collections
+    #   mujoco
+    #   mujoco-mjx
+    #   optax
+    #   orbax-checkpoint
+    #   rlax
+    #   tensorflow-probability
+antlr4-python3-runtime==4.9.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   omegaconf
+argklass==1.4.4
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+astroid==3.3.4
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   pylint
+asttokens==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   giving
+black==24.8.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   navix
+blinker==1.8.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   flask
+brax==0.10.5
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+certifi==2024.8.30
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   requests
+    #   sentry-sdk
+charset-normalizer==3.3.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   requests
+chex==0.1.87
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   distrax
+    #   evosax
+    #   flashbax
+    #   gymnax
+    #   optax
+    #   rlax
+click==8.1.7
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   black
+    #   flask
+    #   wandb
+cloudpickle==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   gym
+    #   gymnasium
+    #   tensorflow-probability
+codefind==0.1.7
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   ptera
+contextlib2==21.6.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   ml-collections
+contourpy==1.3.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   matplotlib
+cycler==0.12.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   matplotlib
+decorator==5.1.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   tensorflow-probability
+dill==0.3.8
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   pylint
+distrax==0.1.5
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+    #   rlax
+dm-env==1.6
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   brax
+    #   rlax
+dm-tree==0.1.8
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   dm-env
+    #   tensorflow-probability
+docker-pycreds==0.4.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   wandb
+docstring-parser==0.16
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   tyro
+dotmap==1.3.30
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   evosax
+etils[epath,epy]==1.9.4
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   brax
+    #   mujoco
+    #   mujoco-mjx
+    #   optax
+    #   orbax-checkpoint
+evosax==0.1.6
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+exceptiongroup==1.2.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   pytest
+executing==2.1.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   varname
+farama-notifications==0.0.4
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   gymnasium
+filelock==3.16.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   pytorch-triton-rocm
+    #   torch
+flake8==7.1.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   navix
+flashbax==0.1.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+flask==3.0.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   brax
+    #   flask-cors
+flask-cors==5.0.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   brax
+flax==0.9.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+    #   brax
+    #   evosax
+    #   flashbax
+    #   gymnax
+    #   navix
+fonttools==4.54.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   matplotlib
+fsspec==2024.6.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   etils
+    #   torch
+gast==0.6.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   tensorflow-probability
+gitdb==4.0.11
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   gitpython
+gitpython==3.1.43
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   wandb
+giving==0.4.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   ptera
+    #   voir
+glfw==2.7.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   mujoco
+grpcio==1.66.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   brax
+gym==0.26.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   brax
+    #   gymnax
+gym-notices==0.0.8
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   gym
+gymnasium==0.29.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   gymnax
+gymnax==0.0.8
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -c .pin/../constraints/rocm.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+hjson==3.1.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   argklass
+humanize==4.10.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   orbax-checkpoint
+idna==3.10
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   requests
+importlib-resources==6.4.5
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   argklass
+    #   etils
+iniconfig==2.0.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   pytest
+isort==5.13.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   pylint
+itsdangerous==2.2.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   flask
+jax==0.4.33
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+    #   brax
+    #   chex
+    #   distrax
+    #   evosax
+    #   flashbax
+    #   flax
+    #   gymnax
+    #   jaxopt
+    #   mujoco-mjx
+    #   optax
+    #   orbax-checkpoint
+    #   rlax
+jaxlib==0.4.33
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   brax
+    #   chex
+    #   distrax
+    #   evosax
+    #   flashbax
+    #   gymnax
+    #   jax
+    #   jaxopt
+    #   mujoco-mjx
+    #   optax
+    #   orbax-checkpoint
+    #   rlax
+jaxopt==0.8.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   brax
+jinja2==3.1.4
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   brax
+    #   flask
+    #   torch
+kiwisolver==1.4.7
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   matplotlib
+markdown-it-py==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   rich
+markupsafe==2.1.5
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   jinja2
+    #   werkzeug
+matplotlib==3.9.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   evosax
+    #   gymnax
+    #   seaborn
+mccabe==0.7.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   flake8
+    #   pylint
+mdurl==0.1.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   markdown-it-py
+ml-collections==0.1.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   brax
+ml-dtypes==0.5.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   jax
+    #   jaxlib
+    #   tensorstore
+mpmath==1.3.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   sympy
+msgpack==1.1.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   flax
+    #   orbax-checkpoint
+mujoco==3.2.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   brax
+    #   mujoco-mjx
+mujoco-mjx==3.2.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   brax
+mypy-extensions==1.0.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   black
+navix==0.7.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+nest-asyncio==1.6.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   orbax-checkpoint
+networkx==3.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   torch
+numpy==1.26.4
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+    #   brax
+    #   chex
+    #   contourpy
+    #   distrax
+    #   dm-env
+    #   evosax
+    #   flashbax
+    #   gym
+    #   gymnasium
+    #   jax
+    #   jaxlib
+    #   jaxopt
+    #   matplotlib
+    #   ml-dtypes
+    #   mujoco
+    #   navix
+    #   optax
+    #   orbax-checkpoint
+    #   pandas
+    #   rlax
+    #   scipy
+    #   seaborn
+    #   tensorboardx
+    #   tensorflow-probability
+    #   tensorstore
+    #   trimesh
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
+omegaconf==2.3.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
+opt-einsum==3.4.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   jax
+optax==0.2.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+    #   brax
+    #   flax
+orbax-checkpoint==0.6.4
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   brax
+    #   flax
+ovld==0.3.9
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
+packaging==24.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   black
+    #   matplotlib
+    #   pytest
+    #   setuptools-scm
+    #   tensorboardx
+pandas==2.2.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   seaborn
+pathspec==0.12.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   black
+pillow==10.4.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   brax
+    #   matplotlib
+    #   navix
+platformdirs==4.3.6
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   black
+    #   pylint
+    #   wandb
+pluggy==1.5.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   pytest
+protobuf==5.28.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   orbax-checkpoint
+    #   tensorboardx
+    #   wandb
+psutil==5.9.8
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
+    #   wandb
+ptera==1.4.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
+pycodestyle==2.12.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   flake8
+pyflakes==3.2.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   flake8
+pygments==2.18.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   rich
+pylint==3.3.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   navix
+pyopengl==3.1.7
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   mujoco
+pyparsing==3.1.4
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   matplotlib
+pytest==8.3.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   navix
+python-dateutil==2.9.0.post0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   matplotlib
+    #   pandas
+pytinyrenderer==0.0.14
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   brax
+pytorch-triton-rocm==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   torch
+pytz==2024.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   pandas
+pyyaml==6.0.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   evosax
+    #   flax
+    #   gymnax
+    #   ml-collections
+    #   omegaconf
+    #   orbax-checkpoint
+    #   wandb
+reactivex==4.0.4
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   giving
+requests==2.32.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   wandb
+rich==13.9.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   flax
+    #   tyro
+    #   voir
+rlax==0.1.6
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   navix
+scipy==1.14.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   brax
+    #   jax
+    #   jaxlib
+    #   jaxopt
+    #   mujoco-mjx
+seaborn==0.13.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   gymnax
+sentry-sdk==2.15.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   wandb
+setproctitle==1.3.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   wandb
+setuptools-scm==8.1.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   navix
+shtab==1.7.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   tyro
+six==1.16.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   asttokens
+    #   docker-pycreds
+    #   ml-collections
+    #   python-dateutil
+    #   tensorflow-probability
+smmap==5.0.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   gitdb
+sympy==1.13.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   torch
+tensorboardx==2.6.2.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   brax
+tensorflow-probability==0.24.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   distrax
+tensorstore==0.1.66
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   flashbax
+    #   flax
+    #   orbax-checkpoint
+tomli==2.0.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   black
+    #   pylint
+    #   pytest
+    #   setuptools-scm
+tomlkit==0.13.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   pylint
+toolz==0.12.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   chex
+torch==2.4.1+rocm6.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+trimesh==4.4.9
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   brax
+    #   mujoco-mjx
+typing-extensions==4.12.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   astroid
+    #   black
+    #   brax
+    #   chex
+    #   etils
+    #   flashbax
+    #   flax
+    #   gymnasium
+    #   navix
+    #   orbax-checkpoint
+    #   reactivex
+    #   rich
+    #   torch
+    #   tyro
+tyro==0.8.11
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   navix
+tzdata==2024.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   pandas
+urllib3==2.2.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   requests
+    #   sentry-sdk
+varname==0.13.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   giving
+voir==0.2.19
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -c .pin/../constraints/rocm.txt
+    #   -r benchmarks/purejaxrl/requirements.in
+wandb==0.18.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   navix
+werkzeug==3.0.4
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   flask
+zipp==3.20.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   etils
+
+# The following packages are considered to be unsafe in a requirements file:
+# setuptools
diff --git a/benchmarks/purejaxrl/voirfile.py b/benchmarks/purejaxrl/voirfile.py
index 5305be3f4..a94eb7646 100644
--- a/benchmarks/purejaxrl/voirfile.py
+++ b/benchmarks/purejaxrl/voirfile.py
@@ -32,7 +32,7 @@ def instrument_main(ov, options: Config):
         ov.require(dash)
 
     ov.require(
-        log("value", "progress", "rate", "units", "loss", "gpudata", context="task"),
+        log("value", "progress", "rate", "units", "loss", "gpudata", "memory_peak", "cpudata", context="task"),
         # early_stop(n=options.stop, key="rate", task="train"),
         monitor_monogpu(poll_interval=options.gpu_poll),
     )
diff --git a/benchmarks/recursiongfn/requirements.cuda.txt b/benchmarks/recursiongfn/requirements.cuda.txt
index 89c02624f..497f573ab 100644
--- a/benchmarks/recursiongfn/requirements.cuda.txt
+++ b/benchmarks/recursiongfn/requirements.cuda.txt
@@ -14,11 +14,11 @@ absl-py==2.1.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   tensorboard
-aiohappyeyeballs==2.4.0
+aiohappyeyeballs==2.4.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   aiohttp
-aiohttp==3.10.5
+aiohttp==3.10.8
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch-geometric
@@ -46,7 +46,7 @@ blosc2==2.7.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   tables
-botorch==0.11.3
+botorch==0.12.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/recursiongfn/requirements.in
@@ -79,7 +79,7 @@ executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   varname
-filelock==3.16.0
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -108,12 +108,12 @@ giving==0.4.3
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   ptera
     #   voir
-gpytorch==1.12
+gpytorch==1.13
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/recursiongfn/requirements.in
     #   botorch
-grpcio==1.66.1
+grpcio==1.66.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   tensorboard
@@ -122,25 +122,26 @@ idna==3.10
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   requests
     #   yarl
-jax[cuda12]==0.4.31
+jax[cuda12]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r .pin/../constraints/extra/torch.cuda.txt
-jax-cuda12-pjrt==0.4.31
+jax-cuda12-pjrt==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
-jax-cuda12-plugin[with-cuda]==0.4.31
+jax-cuda12-plugin[with-cuda]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
-jaxlib==0.4.31
+jaxlib==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
-jaxtyping==0.2.34
+jaxtyping==0.2.19
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
+    #   gpytorch
     #   linear-operator
 jinja2==3.1.4
     # via
@@ -151,7 +152,7 @@ joblib==1.4.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   scikit-learn
-linear-operator==0.5.2
+linear-operator==0.5.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   botorch
@@ -183,6 +184,7 @@ mpmath==1.3.0
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   botorch
     #   gpytorch
+    #   linear-operator
     #   sympy
 msgpack==1.1.0
     # via
@@ -197,7 +199,7 @@ multipledispatch==1.0.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   botorch
-ndindex==1.8
+ndindex==1.9.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   blosc2
@@ -215,12 +217,11 @@ numpy==1.26.4
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   blosc2
-    #   botorch
     #   jax
     #   jaxlib
+    #   jaxtyping
     #   ml-dtypes
     #   numexpr
-    #   opt-einsum
     #   pandas
     #   pyarrow
     #   pyro-ppl
@@ -243,7 +244,7 @@ nvidia-cuda-cupti-cu12==12.1.105
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-cuda-nvcc-cu12==12.6.68
+nvidia-cuda-nvcc-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -290,7 +291,7 @@ nvidia-nccl-cu12==2.20.5
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-nvjitlink-cu12==12.6.68
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -305,7 +306,7 @@ omegaconf==2.3.0
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/recursiongfn/requirements.in
     #   voir
-opt-einsum==3.3.0
+opt-einsum==3.4.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -319,7 +320,7 @@ packaging==24.1
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   tables
     #   tensorboard
-pandas==2.2.2
+pandas==2.2.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/recursiongfn/requirements.in
@@ -327,11 +328,11 @@ pillow==10.4.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   rdkit
-platformdirs==4.3.3
+platformdirs==4.3.6
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   wandb
-protobuf==5.28.1
+protobuf==5.28.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   tensorboard
@@ -398,7 +399,7 @@ requests==2.32.3
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch-geometric
     #   wandb
-rich==13.8.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
@@ -418,7 +419,7 @@ scipy==1.14.1
     #   scikit-learn
     #   torch-cluster
     #   torch-sparse
-sentry-sdk==2.14.0
+sentry-sdk==2.15.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   wandb
@@ -437,7 +438,7 @@ smmap==5.0.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   gitdb
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -445,7 +446,7 @@ tables==3.10.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/recursiongfn/requirements.in
-tensorboard==2.17.1
+tensorboard==2.18.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/recursiongfn/requirements.in
@@ -469,7 +470,7 @@ torch-cluster==1.6.3+pt24cu121
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/recursiongfn/requirements.in
-torch-geometric==2.6.0
+torch-geometric==2.6.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/recursiongfn/requirements.in
@@ -490,19 +491,22 @@ triton==3.0.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
-typeguard==2.13.3
+typeguard==4.3.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jaxtyping
-    #   linear-operator
 typing-extensions==4.12.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
+    #   botorch
+    #   jaxtyping
     #   multidict
     #   reactivex
+    #   rich
     #   tables
     #   torch
-tzdata==2024.1
+    #   typeguard
+tzdata==2024.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   pandas
@@ -520,7 +524,7 @@ voir==0.2.19
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -c .pin/../constraints/cuda.txt
     #   -r benchmarks/recursiongfn/requirements.in
-wandb==0.18.0
+wandb==0.18.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/recursiongfn/requirements.in
@@ -532,7 +536,7 @@ xformers==0.0.27.post2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r .pin/../constraints/extra/torch.cuda.txt
-yarl==1.11.1
+yarl==1.13.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   aiohttp
diff --git a/benchmarks/recursiongfn/requirements.hpu.txt b/benchmarks/recursiongfn/requirements.hpu.txt
new file mode 100644
index 000000000..4e362ae63
--- /dev/null
+++ b/benchmarks/recursiongfn/requirements.hpu.txt
@@ -0,0 +1,493 @@
+#
+# This file is autogenerated by pip-compile with Python 3.10
+# by the following command:
+#
+#    pip-compile --output-file=benchmarks/recursiongfn/requirements.hpu.txt .pin/tmp-constraints-hpu-recursiongfn.txt benchmarks/recursiongfn/requirements.in
+#
+absl-py==2.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   tensorboard
+aiohappyeyeballs==2.4.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+aiohttp==3.10.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch-geometric
+aiosignal==1.3.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+antlr4-python3-runtime==4.9.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   omegaconf
+asttokens==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+async-timeout==4.0.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+attrs==24.2.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+blosc2==2.7.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   tables
+botorch==0.12.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
+certifi==2024.8.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+    #   sentry-sdk
+charset-normalizer==3.3.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+click==8.1.7
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   wandb
+codefind==0.1.7
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   ptera
+cvxopt==1.3.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
+docker-pycreds==0.4.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   wandb
+executing==2.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   varname
+filelock==3.16.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+    #   triton
+frozenlist==1.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+    #   aiosignal
+fsspec==2024.6.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+    #   torch-geometric
+gitdb==4.0.11
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   gitpython
+gitpython==3.1.43
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   wandb
+giving==0.4.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   ptera
+    #   voir
+gpytorch==1.13
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   botorch
+grpcio==1.66.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   tensorboard
+idna==3.10
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+    #   yarl
+jaxtyping==0.2.19
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   gpytorch
+    #   linear-operator
+jinja2==3.1.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+    #   torch-geometric
+joblib==1.4.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   scikit-learn
+linear-operator==0.5.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   botorch
+    #   gpytorch
+markdown==3.7
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   tensorboard
+markdown-it-py==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rich
+markupsafe==2.1.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   jinja2
+    #   werkzeug
+mdurl==0.1.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   markdown-it-py
+mpmath==1.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   botorch
+    #   gpytorch
+    #   linear-operator
+    #   sympy
+msgpack==1.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   blosc2
+multidict==6.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+    #   yarl
+multipledispatch==1.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   botorch
+ndindex==1.9.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   blosc2
+networkx==3.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   torch
+numexpr==2.10.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   blosc2
+    #   tables
+numpy==1.26.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   blosc2
+    #   jaxtyping
+    #   numexpr
+    #   pandas
+    #   pyarrow
+    #   pyro-ppl
+    #   rdkit
+    #   scikit-learn
+    #   scipy
+    #   tables
+    #   tensorboard
+    #   torch-geometric
+nvidia-cublas-cu12==12.1.3.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cudnn-cu12
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-cuda-cupti-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cuda-nvrtc-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cuda-runtime-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cudnn-cu12==9.1.0.70
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cufft-cu12==11.0.2.54
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-curand-cu12==10.3.2.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cusolver-cu12==11.4.5.107
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cusparse-cu12==12.1.0.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+nvidia-nccl-cu12==2.20.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-nvjitlink-cu12==12.6.77
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cusolver-cu12
+    #   nvidia-cusparse-cu12
+nvidia-nvtx-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+omegaconf==2.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   voir
+opt-einsum==3.4.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pyro-ppl
+ovld==0.3.9
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+packaging==24.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   tables
+    #   tensorboard
+pandas==2.2.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
+pillow==10.4.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rdkit
+platformdirs==4.3.6
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   wandb
+protobuf==5.28.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   tensorboard
+    #   wandb
+psutil==5.9.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch-geometric
+    #   voir
+    #   wandb
+ptera==1.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+py-cpuinfo==9.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   blosc2
+    #   tables
+pyarrow==17.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
+pygments==2.18.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rich
+pyparsing==3.1.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch-geometric
+pyro-api==0.1.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pyro-ppl
+pyro-ppl==1.9.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   botorch
+python-dateutil==2.9.0.post0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pandas
+pytz==2024.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pandas
+pyyaml==6.0.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   omegaconf
+    #   wandb
+rdkit==2024.3.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
+reactivex==4.0.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+requests==2.32.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch-geometric
+    #   wandb
+rich==13.9.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+scikit-learn==1.5.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   gpytorch
+scipy==1.14.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   botorch
+    #   gpytorch
+    #   linear-operator
+    #   scikit-learn
+    #   torch-cluster
+    #   torch-sparse
+sentry-sdk==2.15.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   wandb
+setproctitle==1.3.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   wandb
+six==1.16.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   asttokens
+    #   docker-pycreds
+    #   python-dateutil
+    #   tensorboard
+smmap==5.0.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   gitdb
+sympy==1.13.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+tables==3.10.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
+tensorboard==2.18.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
+tensorboard-data-server==0.7.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   tensorboard
+threadpoolctl==3.5.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   scikit-learn
+torch==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
+    #   botorch
+    #   linear-operator
+    #   pyro-ppl
+torch-cluster==1.6.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
+torch-geometric==2.6.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
+torch-scatter==2.1.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
+torch-sparse==0.6.18
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
+tqdm==4.66.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pyro-ppl
+    #   torch-geometric
+triton==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+typeguard==4.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   jaxtyping
+typing-extensions==4.12.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   botorch
+    #   jaxtyping
+    #   multidict
+    #   reactivex
+    #   rich
+    #   tables
+    #   torch
+    #   typeguard
+tzdata==2024.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pandas
+urllib3==2.2.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+    #   sentry-sdk
+varname==0.13.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+voir==0.2.19
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/recursiongfn/requirements.in
+wandb==0.18.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
+werkzeug==3.0.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   tensorboard
+yarl==1.13.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+
+# The following packages are considered to be unsafe in a requirements file:
+# setuptools
diff --git a/benchmarks/recursiongfn/requirements.rocm.txt b/benchmarks/recursiongfn/requirements.rocm.txt
index 1bc73f14e..bcb64cdb2 100644
--- a/benchmarks/recursiongfn/requirements.rocm.txt
+++ b/benchmarks/recursiongfn/requirements.rocm.txt
@@ -2,201 +2,198 @@
 # This file is autogenerated by pip-compile with Python 3.10
 # by the following command:
 #
-#    pip-compile --output-file=benchmarks/recursiongfn/requirements.rocm.txt .pin/tmp-constraints-rocm-recursiongfn_gnn.txt benchmarks/recursiongfn/requirements.in
+#    pip-compile --output-file=benchmarks/recursiongfn/requirements.rocm.txt .pin/tmp-constraints-rocm-recursiongfn.txt benchmarks/recursiongfn/requirements.in
 #
---extra-index-url https://download.pytorch.org/whl/rocm6.0
+--extra-index-url https://download.pytorch.org/whl/rocm6.1
 
 absl-py==2.1.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   tensorboard
-aiohappyeyeballs==2.4.0
+aiohappyeyeballs==2.4.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
-aiohttp==3.10.5
+aiohttp==3.10.8
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch-geometric
 aiosignal==1.3.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
 antlr4-python3-runtime==4.9.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   omegaconf
 asttokens==2.4.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
 async-timeout==4.0.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
 attrs==24.2.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
 blosc2==2.7.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   tables
-botorch==0.11.3
+botorch==0.12.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
-    #   gflownet
-certifi==2024.7.4
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
+certifi==2024.8.30
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
     #   sentry-sdk
 charset-normalizer==3.3.2
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
 click==8.1.7
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   wandb
-codefind==0.1.6
+codefind==0.1.7
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
 cvxopt==1.3.2
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
-    #   gflownet
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
 docker-pycreds==0.4.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   wandb
-executing==1.2.0
+executing==2.1.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   varname
-filelock==3.15.4
+filelock==3.16.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   pytorch-triton-rocm
     #   torch
 frozenlist==1.4.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
     #   aiosignal
 fsspec==2024.6.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
     #   torch-geometric
-gflownet @ git+https://github.com/Delaunay/gflownet@milabench
-    # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
-    #   -r benchmarks/recursiongfn/requirements.in
 gitdb==4.0.11
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   gitpython
 gitpython==3.1.43
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
-    #   gflownet
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
     #   wandb
-giving==0.4.2
+giving==0.4.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
     #   voir
-gpytorch==1.12
+gpytorch==1.13
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
     #   botorch
-    #   gflownet
-grpcio==1.65.5
+grpcio==1.66.2
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   tensorboard
-idna==3.7
+idna==3.10
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
     #   yarl
-jaxtyping==0.2.33
+jaxtyping==0.2.19
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   gpytorch
     #   linear-operator
 jinja2==3.1.4
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
     #   torch-geometric
 joblib==1.4.2
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   scikit-learn
-linear-operator==0.5.2
+linear-operator==0.5.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   botorch
     #   gpytorch
 markdown==3.7
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   tensorboard
 markdown-it-py==3.0.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   rich
 markupsafe==2.1.5
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   jinja2
     #   werkzeug
 mdurl==0.1.2
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   markdown-it-py
 mpmath==1.3.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   botorch
     #   gpytorch
+    #   linear-operator
     #   sympy
-msgpack==1.0.8
+msgpack==1.1.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   blosc2
-multidict==6.0.5
+multidict==6.1.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
     #   yarl
 multipledispatch==1.0.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   botorch
-ndindex==1.8
+ndindex==1.9.2
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   blosc2
 networkx==3.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
-    #   gflownet
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
     #   torch
 numexpr==2.10.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   blosc2
     #   tables
 numpy==1.26.4
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   blosc2
-    #   botorch
+    #   jaxtyping
     #   numexpr
-    #   opt-einsum
     #   pandas
     #   pyarrow
     #   pyro-ppl
@@ -206,239 +203,240 @@ numpy==1.26.4
     #   tables
     #   tensorboard
     #   torch-geometric
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
 omegaconf==2.3.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
-    #   gflownet
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
     #   voir
-opt-einsum==3.3.0
+opt-einsum==3.4.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   pyro-ppl
-ovld==0.3.8
+ovld==0.3.9
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
 packaging==24.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   tables
     #   tensorboard
-pandas==2.2.2
+pandas==2.2.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
-    #   gflownet
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
 pillow==10.4.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   rdkit
-platformdirs==4.2.2
+platformdirs==4.3.6
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   wandb
-protobuf==5.27.3
+protobuf==5.28.2
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   tensorboard
     #   wandb
 psutil==5.9.8
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch-geometric
     #   voir
     #   wandb
 ptera==1.4.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
 py-cpuinfo==9.0.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   blosc2
     #   tables
 pyarrow==17.0.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
-    #   gflownet
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
 pygments==2.18.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   rich
-pynvml==11.5.3
-    # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
-    #   voir
-pyparsing==3.1.2
+pyparsing==3.1.4
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch-geometric
 pyro-api==0.1.2
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   pyro-ppl
 pyro-ppl==1.9.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
     #   botorch
-    #   gflownet
 python-dateutil==2.9.0.post0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   pandas
 pytorch-triton-rocm==3.0.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
-pytz==2024.1
+pytz==2024.2
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   pandas
 pyyaml==6.0.2
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   omegaconf
     #   wandb
 rdkit==2024.3.5
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
-    #   gflownet
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
 reactivex==4.0.4
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
 requests==2.32.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch-geometric
     #   wandb
-rich==13.7.1
+rich==13.9.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
-scikit-learn==1.5.1
+scikit-learn==1.5.2
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   gpytorch
-    #   torch-geometric
-scipy==1.14.0
+scipy==1.14.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
     #   botorch
-    #   gflownet
     #   gpytorch
     #   linear-operator
     #   scikit-learn
     #   torch-cluster
-    #   torch-geometric
     #   torch-sparse
-sentry-sdk==2.13.0
+sentry-sdk==2.15.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   wandb
 setproctitle==1.3.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   wandb
 six==1.16.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   asttokens
     #   docker-pycreds
     #   python-dateutil
     #   tensorboard
 smmap==5.0.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   gitdb
-sympy==1.13.2
+sympy==1.13.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
 tables==3.10.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
-    #   gflownet
-tensorboard==2.17.1
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
+tensorboard==2.18.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
-    #   gflownet
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
 tensorboard-data-server==0.7.2
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   tensorboard
 threadpoolctl==3.5.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   scikit-learn
-torch==2.4.0+rocm6.0
+torch==2.4.1+rocm6.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/recursiongfn/requirements.in
     #   botorch
-    #   gflownet
     #   linear-operator
     #   pyro-ppl
 torch-cluster==1.6.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
-    #   gflownet
-torch-geometric==2.5.3
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
+torch-geometric==2.6.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
-    #   gflownet
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
 torch-scatter==2.1.2
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
-    #   gflownet
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
 torch-sparse==0.6.18
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
-    #   gflownet
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
 tqdm==4.66.5
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   pyro-ppl
     #   torch-geometric
-typeguard==2.13.3
+typeguard==4.3.0
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   jaxtyping
-    #   linear-operator
 typing-extensions==4.12.2
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   botorch
+    #   jaxtyping
+    #   multidict
     #   reactivex
+    #   rich
     #   tables
     #   torch
-tzdata==2024.1
+    #   typeguard
+tzdata==2024.2
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   pandas
-urllib3==2.2.2
+urllib3==2.2.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
     #   sentry-sdk
-varname==0.10.0
+varname==0.13.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
-voir==0.2.17
+voir==0.2.19
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -c .pin/../constraints/rocm.txt
     #   -r benchmarks/recursiongfn/requirements.in
-wandb==0.17.7
+wandb==0.18.3
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
-    #   gflownet
-werkzeug==3.0.3
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/recursiongfn/requirements.in
+werkzeug==3.0.4
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   tensorboard
-yarl==1.9.4
+yarl==1.13.1
     # via
-    #   -c .pin/../.pin/constraints-rocm-gnn.txt
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   aiohttp
 
 # The following packages are considered to be unsafe in a requirements file:
diff --git a/benchmarks/rlhf/main.py b/benchmarks/rlhf/main.py
index 0be12d282..3a5f1ddab 100755
--- a/benchmarks/rlhf/main.py
+++ b/benchmarks/rlhf/main.py
@@ -2,6 +2,7 @@
 
 import shutil
 
+import accelerate
 from accelerate import PartialState
 from datasets import load_dataset
 from transformers import (
@@ -15,10 +16,16 @@
 from trl.trainer.ppov2_trainer import PPOv2Config, PPOv2Trainer
 from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
 
+import torchcompat.core as compat
+
 
 class PPOv2TrainerIntrumented(PPOv2Trainer):
     def __init__(self, config: PPOv2Config, *args, **kwargs):
         config.report_to = []
+        
+        # FIXME: better way to monkeypatch this ?
+        # Use the compatibility accelerator class
+        accelerate.Accelerator = compat.accelerate.Accelerator
         super().__init__(config, *args, **kwargs)
 
         def batch_size_fn(batch):
@@ -46,9 +53,13 @@ def save_model(self, *args, **kwargs):
 
 
 def main():
+    
 
     parser = HfArgumentParser((PPOv2Config, ModelConfig))
     config, model_config = parser.parse_args_into_dataclasses()
+    
+    import torchcompat.core
+    
     # remove output_dir if exists
     shutil.rmtree(config.output_dir, ignore_errors=True)
 
diff --git a/benchmarks/rlhf/requirements.cuda.txt b/benchmarks/rlhf/requirements.cuda.txt
index 12a24c6c4..dee2ae27c 100644
--- a/benchmarks/rlhf/requirements.cuda.txt
+++ b/benchmarks/rlhf/requirements.cuda.txt
@@ -15,11 +15,11 @@ accelerate==0.34.2
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/rlhf/requirements.in
     #   trl
-aiohappyeyeballs==2.4.0
+aiohappyeyeballs==2.4.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   aiohttp
-aiohttp==3.10.5
+aiohttp==3.10.8
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   datasets
@@ -56,7 +56,7 @@ codefind==0.1.7
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   ptera
-datasets==3.0.0
+datasets==3.0.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/rlhf/requirements.in
@@ -74,7 +74,7 @@ executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   varname
-filelock==3.16.0
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   datasets
@@ -98,7 +98,7 @@ giving==0.4.3
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   ptera
     #   voir
-huggingface-hub==0.24.7
+huggingface-hub==0.25.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   accelerate
@@ -110,19 +110,19 @@ idna==3.10
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   requests
     #   yarl
-jax[cuda12]==0.4.31
+jax[cuda12]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r .pin/../constraints/extra/torch.cuda.txt
-jax-cuda12-pjrt==0.4.31
+jax-cuda12-pjrt==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
-jax-cuda12-plugin[with-cuda]==0.4.31
+jax-cuda12-plugin[with-cuda]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
-jaxlib==0.4.31
+jaxlib==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -172,7 +172,6 @@ numpy==1.26.4
     #   jax
     #   jaxlib
     #   ml-dtypes
-    #   opt-einsum
     #   pandas
     #   pyarrow
     #   scipy
@@ -191,7 +190,7 @@ nvidia-cuda-cupti-cu12==12.1.105
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-cuda-nvcc-cu12==12.6.68
+nvidia-cuda-nvcc-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -238,7 +237,7 @@ nvidia-nccl-cu12==2.20.5
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-nvjitlink-cu12==12.6.68
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -252,7 +251,7 @@ omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
-opt-einsum==3.3.0
+opt-einsum==3.4.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -267,7 +266,7 @@ packaging==24.1
     #   datasets
     #   huggingface-hub
     #   transformers
-pandas==2.2.2
+pandas==2.2.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   datasets
@@ -318,7 +317,7 @@ requests==2.32.3
     #   datasets
     #   huggingface-hub
     #   transformers
-rich==13.8.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   tyro
@@ -342,7 +341,7 @@ six==1.16.0
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   asttokens
     #   python-dateutil
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -366,6 +365,7 @@ tqdm==4.66.5
 transformers==4.44.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
+    #   -c .pin/../constraints/cuda.txt
     #   -r benchmarks/rlhf/requirements.in
     #   trl
 triton==3.0.0
@@ -375,6 +375,7 @@ triton==3.0.0
 trl==0.10.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
+    #   -c .pin/../constraints/cuda.txt
     #   -r benchmarks/rlhf/requirements.in
 typing-extensions==4.12.2
     # via
@@ -382,13 +383,14 @@ typing-extensions==4.12.2
     #   huggingface-hub
     #   multidict
     #   reactivex
+    #   rich
     #   torch
     #   tyro
-tyro==0.8.10
+tyro==0.8.11
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   trl
-tzdata==2024.1
+tzdata==2024.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   pandas
@@ -413,7 +415,7 @@ xxhash==3.5.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   datasets
-yarl==1.11.1
+yarl==1.13.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   aiohttp
diff --git a/benchmarks/rlhf/requirements.hpu.txt b/benchmarks/rlhf/requirements.hpu.txt
new file mode 100644
index 000000000..a6c127653
--- /dev/null
+++ b/benchmarks/rlhf/requirements.hpu.txt
@@ -0,0 +1,362 @@
+#
+# This file is autogenerated by pip-compile with Python 3.10
+# by the following command:
+#
+#    pip-compile --output-file=benchmarks/rlhf/requirements.hpu.txt .pin/tmp-constraints-hpu-rlhf-gpus.txt benchmarks/rlhf/requirements.in
+#
+accelerate==0.34.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/rlhf/requirements.in
+    #   trl
+aiohappyeyeballs==2.4.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+aiohttp==3.10.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+    #   fsspec
+aiosignal==1.3.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+antlr4-python3-runtime==4.9.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   omegaconf
+asttokens==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+async-timeout==4.0.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+attrs==24.2.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+certifi==2024.8.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+charset-normalizer==3.3.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+codefind==0.1.7
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   ptera
+datasets==3.0.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/rlhf/requirements.in
+    #   trl
+dill==0.3.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+    #   multiprocess
+docstring-parser==0.16
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   tyro
+executing==2.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   varname
+filelock==3.16.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+    #   huggingface-hub
+    #   torch
+    #   transformers
+    #   triton
+frozenlist==1.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+    #   aiosignal
+fsspec[http]==2024.6.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+    #   huggingface-hub
+    #   torch
+giving==0.4.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   ptera
+    #   voir
+huggingface-hub==0.25.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   accelerate
+    #   datasets
+    #   tokenizers
+    #   transformers
+idna==3.10
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+    #   yarl
+jinja2==3.1.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+markdown-it-py==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rich
+markupsafe==2.1.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   jinja2
+mdurl==0.1.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   markdown-it-py
+mpmath==1.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   sympy
+multidict==6.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
+    #   yarl
+multiprocess==0.70.16
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+networkx==3.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+numpy==1.26.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   accelerate
+    #   datasets
+    #   pandas
+    #   pyarrow
+    #   transformers
+    #   trl
+nvidia-cublas-cu12==12.1.3.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cudnn-cu12
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-cuda-cupti-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cuda-nvrtc-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cuda-runtime-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cudnn-cu12==9.1.0.70
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cufft-cu12==11.0.2.54
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-curand-cu12==10.3.2.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cusolver-cu12==11.4.5.107
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cusparse-cu12==12.1.0.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+nvidia-nccl-cu12==2.20.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-nvjitlink-cu12==12.6.77
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cusolver-cu12
+    #   nvidia-cusparse-cu12
+nvidia-nvtx-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+omegaconf==2.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+ovld==0.3.9
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+packaging==24.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   accelerate
+    #   datasets
+    #   huggingface-hub
+    #   transformers
+pandas==2.2.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+psutil==5.9.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   accelerate
+    #   voir
+ptera==1.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+pyarrow==17.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+pygments==2.18.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rich
+python-dateutil==2.9.0.post0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pandas
+pytz==2024.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pandas
+pyyaml==6.0.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   accelerate
+    #   datasets
+    #   huggingface-hub
+    #   omegaconf
+    #   transformers
+reactivex==4.0.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+regex==2024.9.11
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   transformers
+requests==2.32.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+    #   huggingface-hub
+    #   transformers
+rich==13.9.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   tyro
+    #   voir
+safetensors==0.4.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   accelerate
+    #   transformers
+shtab==1.7.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   tyro
+six==1.16.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   asttokens
+    #   python-dateutil
+sympy==1.13.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+tokenizers==0.19.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   transformers
+torch==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/rlhf/requirements.in
+    #   accelerate
+    #   trl
+tqdm==4.66.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+    #   huggingface-hub
+    #   transformers
+transformers==4.44.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/rlhf/requirements.in
+    #   trl
+triton==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+trl==0.10.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/rlhf/requirements.in
+typing-extensions==4.12.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   huggingface-hub
+    #   multidict
+    #   reactivex
+    #   rich
+    #   torch
+    #   tyro
+tyro==0.8.11
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   trl
+tzdata==2024.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pandas
+urllib3==2.2.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+varname==0.13.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+voir==0.2.19
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/rlhf/requirements.in
+xxhash==3.5.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   datasets
+yarl==1.13.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   aiohttp
diff --git a/benchmarks/rlhf/requirements.in b/benchmarks/rlhf/requirements.in
index 045bca09c..1cb6cd247 100644
--- a/benchmarks/rlhf/requirements.in
+++ b/benchmarks/rlhf/requirements.in
@@ -4,3 +4,4 @@ trl
 accelerate
 transformers
 datasets
+einops
\ No newline at end of file
diff --git a/benchmarks/rlhf/requirements.rocm.txt b/benchmarks/rlhf/requirements.rocm.txt
new file mode 100644
index 000000000..5b7f2726b
--- /dev/null
+++ b/benchmarks/rlhf/requirements.rocm.txt
@@ -0,0 +1,313 @@
+#
+# This file is autogenerated by pip-compile with Python 3.10
+# by the following command:
+#
+#    pip-compile --output-file=benchmarks/rlhf/requirements.rocm.txt .pin/tmp-constraints-rocm-rlhf-gpus.txt benchmarks/rlhf/requirements.in
+#
+--extra-index-url https://download.pytorch.org/whl/rocm6.1
+
+accelerate==0.34.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/rlhf/requirements.in
+    #   trl
+aiohappyeyeballs==2.4.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   aiohttp
+aiohttp==3.10.8
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   datasets
+    #   fsspec
+aiosignal==1.3.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   aiohttp
+antlr4-python3-runtime==4.9.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   omegaconf
+asttokens==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   giving
+async-timeout==4.0.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   aiohttp
+attrs==24.2.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   aiohttp
+certifi==2024.8.30
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   requests
+charset-normalizer==3.3.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   requests
+codefind==0.1.7
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   ptera
+datasets==3.0.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/rlhf/requirements.in
+    #   trl
+dill==0.3.8
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   datasets
+    #   multiprocess
+docstring-parser==0.16
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   tyro
+executing==2.1.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   varname
+filelock==3.16.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   datasets
+    #   huggingface-hub
+    #   pytorch-triton-rocm
+    #   torch
+    #   transformers
+frozenlist==1.4.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   aiohttp
+    #   aiosignal
+fsspec[http]==2024.6.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   datasets
+    #   huggingface-hub
+    #   torch
+giving==0.4.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   ptera
+    #   voir
+huggingface-hub==0.25.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   accelerate
+    #   datasets
+    #   tokenizers
+    #   transformers
+idna==3.10
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   requests
+    #   yarl
+jinja2==3.1.4
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   torch
+markdown-it-py==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   rich
+markupsafe==2.1.5
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   jinja2
+mdurl==0.1.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   markdown-it-py
+mpmath==1.3.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   sympy
+multidict==6.1.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   aiohttp
+    #   yarl
+multiprocess==0.70.16
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   datasets
+networkx==3.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   torch
+numpy==1.26.4
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   accelerate
+    #   datasets
+    #   pandas
+    #   pyarrow
+    #   transformers
+    #   trl
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
+omegaconf==2.3.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
+ovld==0.3.9
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
+packaging==24.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   accelerate
+    #   datasets
+    #   huggingface-hub
+    #   transformers
+pandas==2.2.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   datasets
+psutil==5.9.8
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   accelerate
+    #   voir
+ptera==1.4.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
+pyarrow==17.0.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   datasets
+pygments==2.18.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   rich
+python-dateutil==2.9.0.post0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   pandas
+pytorch-triton-rocm==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   torch
+pytz==2024.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   pandas
+pyyaml==6.0.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   accelerate
+    #   datasets
+    #   huggingface-hub
+    #   omegaconf
+    #   transformers
+reactivex==4.0.4
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   giving
+regex==2024.9.11
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   transformers
+requests==2.32.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   datasets
+    #   huggingface-hub
+    #   transformers
+rich==13.9.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   tyro
+    #   voir
+safetensors==0.4.5
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   accelerate
+    #   transformers
+shtab==1.7.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   tyro
+six==1.16.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   asttokens
+    #   python-dateutil
+sympy==1.13.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   torch
+tokenizers==0.19.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   transformers
+torch==2.4.1+rocm6.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/rlhf/requirements.in
+    #   accelerate
+    #   trl
+tqdm==4.66.5
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   datasets
+    #   huggingface-hub
+    #   transformers
+transformers==4.44.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -c .pin/../constraints/rocm.txt
+    #   -r benchmarks/rlhf/requirements.in
+    #   trl
+trl==0.10.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -c .pin/../constraints/rocm.txt
+    #   -r benchmarks/rlhf/requirements.in
+typing-extensions==4.12.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   huggingface-hub
+    #   multidict
+    #   reactivex
+    #   rich
+    #   torch
+    #   tyro
+tyro==0.8.11
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   trl
+tzdata==2024.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   pandas
+urllib3==2.2.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   requests
+varname==0.13.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   giving
+voir==0.2.19
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -c .pin/../constraints/rocm.txt
+    #   -r benchmarks/rlhf/requirements.in
+xxhash==3.5.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   datasets
+yarl==1.13.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   aiohttp
+einops
\ No newline at end of file
diff --git a/benchmarks/timm/requirements.cuda.txt b/benchmarks/timm/requirements.cuda.txt
index 4554f91ec..b55428950 100644
--- a/benchmarks/timm/requirements.cuda.txt
+++ b/benchmarks/timm/requirements.cuda.txt
@@ -34,7 +34,7 @@ executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   varname
-filelock==3.16.0
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   huggingface-hub
@@ -50,7 +50,7 @@ giving==0.4.3
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   ptera
     #   voir
-huggingface-hub==0.24.7
+huggingface-hub==0.25.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/timm/requirements.in
@@ -58,19 +58,19 @@ idna==3.10
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   requests
-jax[cuda12]==0.4.31
+jax[cuda12]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r .pin/../constraints/extra/torch.cuda.txt
-jax-cuda12-pjrt==0.4.31
+jax-cuda12-pjrt==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
-jax-cuda12-plugin[with-cuda]==0.4.31
+jax-cuda12-plugin[with-cuda]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
-jaxlib==0.4.31
+jaxlib==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -109,7 +109,6 @@ numpy==1.26.4
     #   jax
     #   jaxlib
     #   ml-dtypes
-    #   opt-einsum
     #   scipy
     #   torchvision
     #   xformers
@@ -125,7 +124,7 @@ nvidia-cuda-cupti-cu12==12.1.105
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-cuda-nvcc-cu12==12.6.68
+nvidia-cuda-nvcc-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -172,7 +171,7 @@ nvidia-nccl-cu12==2.20.5
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-nvjitlink-cu12==12.6.68
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -186,7 +185,7 @@ omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
-opt-einsum==3.3.0
+opt-einsum==3.4.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -228,7 +227,7 @@ requests==2.32.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   huggingface-hub
-rich==13.8.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
@@ -245,7 +244,7 @@ six==1.16.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   asttokens
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -272,6 +271,7 @@ typing-extensions==4.12.2
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   huggingface-hub
     #   reactivex
+    #   rich
     #   torch
 urllib3==2.2.3
     # via
diff --git a/benchmarks/timm/requirements.hpu.txt b/benchmarks/timm/requirements.hpu.txt
index 432c91bc4..e626bd1f0 100644
--- a/benchmarks/timm/requirements.hpu.txt
+++ b/benchmarks/timm/requirements.hpu.txt
@@ -4,10 +4,6 @@
 #
 #    pip-compile --output-file=benchmarks/timm/requirements.hpu.txt .pin/tmp-constraints-hpu-timm.txt benchmarks/timm/requirements.in
 #
---extra-index-url https://pypi.ngc.nvidia.com
---find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
---trusted-host pypi.ngc.nvidia.com
-
 antlr4-python3-runtime==4.9.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
@@ -16,7 +12,7 @@ asttokens==2.4.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   giving
-certifi==2024.6.2
+certifi==2024.8.30
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   requests
@@ -24,35 +20,35 @@ charset-normalizer==3.3.2
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   requests
-codefind==0.1.6
+codefind==0.1.7
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   ptera
-executing==1.2.0
+executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   varname
-filelock==3.15.4
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   huggingface-hub
     #   torch
     #   triton
-fsspec==2024.5.0
+fsspec==2024.6.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   huggingface-hub
     #   torch
-giving==0.4.2
+giving==0.4.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   ptera
     #   voir
-huggingface-hub==0.24.0
+huggingface-hub==0.25.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   -r benchmarks/timm/requirements.in
-idna==3.7
+idna==3.10
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   requests
@@ -102,7 +98,7 @@ nvidia-cuda-runtime-cu12==12.1.105
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
-nvidia-cudnn-cu12==8.9.2.26
+nvidia-cudnn-cu12==9.1.0.70
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
@@ -123,11 +119,15 @@ nvidia-cusparse-cu12==12.1.0.106
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   nvidia-cusolver-cu12
     #   torch
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
 nvidia-nccl-cu12==2.20.5
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
-nvidia-nvjitlink-cu12==12.5.82
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   nvidia-cusolver-cu12
@@ -140,7 +140,7 @@ omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   voir
-ovld==0.3.5
+ovld==0.3.9
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   voir
@@ -164,11 +164,7 @@ pygments==2.18.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   rich
-pynvml==11.5.3
-    # via
-    #   -c .pin/../.pin/constraints-hpu-torch.txt
-    #   voir
-pyyaml==6.0.1
+pyyaml==6.0.2
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   -r benchmarks/timm/requirements.in
@@ -182,11 +178,11 @@ requests==2.32.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   huggingface-hub
-rich==13.7.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   voir
-safetensors==0.4.3
+safetensors==0.4.5
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   -r benchmarks/timm/requirements.in
@@ -194,24 +190,24 @@ six==1.16.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   asttokens
-sympy==1.13.0
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
-torch==2.3.1
+torch==2.4.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   -r benchmarks/timm/requirements.in
     #   torchvision
-torchvision==0.18.1
+torchvision==0.19.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   -r benchmarks/timm/requirements.in
-tqdm==4.66.4
+tqdm==4.66.5
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   huggingface-hub
-triton==2.3.1
+triton==3.0.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
@@ -220,12 +216,13 @@ typing-extensions==4.12.2
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   huggingface-hub
     #   reactivex
+    #   rich
     #   torch
-urllib3==1.26.19
+urllib3==2.2.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   requests
-varname==0.10.0
+varname==0.13.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   giving
diff --git a/benchmarks/timm/requirements.rocm.txt b/benchmarks/timm/requirements.rocm.txt
index 8383f9e6b..18e83d953 100644
--- a/benchmarks/timm/requirements.rocm.txt
+++ b/benchmarks/timm/requirements.rocm.txt
@@ -4,7 +4,7 @@
 #
 #    pip-compile --output-file=benchmarks/timm/requirements.rocm.txt .pin/tmp-constraints-rocm-timm.txt benchmarks/timm/requirements.in
 #
---extra-index-url https://download.pytorch.org/whl/rocm6.0
+--extra-index-url https://download.pytorch.org/whl/rocm6.1
 
 antlr4-python3-runtime==4.9.3
     # via
@@ -14,7 +14,7 @@ asttokens==2.4.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
-certifi==2024.7.4
+certifi==2024.8.30
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
@@ -22,15 +22,15 @@ charset-normalizer==3.3.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
-codefind==0.1.6
+codefind==0.1.7
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
-executing==1.2.0
+executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   varname
-filelock==3.15.4
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   huggingface-hub
@@ -41,16 +41,16 @@ fsspec==2024.6.1
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   huggingface-hub
     #   torch
-giving==0.4.2
+giving==0.4.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
     #   voir
-huggingface-hub==0.24.6
+huggingface-hub==0.25.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/timm/requirements.in
-idna==3.7
+idna==3.10
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
@@ -82,11 +82,15 @@ numpy==1.26.4
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torchvision
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
 omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
-ovld==0.3.8
+ovld==0.3.9
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
@@ -110,10 +114,6 @@ pygments==2.18.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   rich
-pynvml==11.5.3
-    # via
-    #   -c .pin/../.pin/constraints-rocm-torch.txt
-    #   voir
 pytorch-triton-rocm==3.0.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
@@ -132,11 +132,11 @@ requests==2.32.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   huggingface-hub
-rich==13.7.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
-safetensors==0.4.4
+safetensors==0.4.5
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/timm/requirements.in
@@ -144,16 +144,16 @@ six==1.16.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   asttokens
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
-torch==2.4.0+rocm6.0
+torch==2.4.1+rocm6.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/timm/requirements.in
     #   torchvision
-torchvision==0.19.0+rocm6.0
+torchvision==0.19.1+rocm6.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/timm/requirements.in
@@ -166,12 +166,13 @@ typing-extensions==4.12.2
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   huggingface-hub
     #   reactivex
+    #   rich
     #   torch
-urllib3==2.2.2
+urllib3==2.2.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   requests
-varname==0.10.0
+varname==0.13.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
diff --git a/benchmarks/torchatari/requirements.cuda.txt b/benchmarks/torchatari/requirements.cuda.txt
index 2b0aa99d6..1be36a969 100644
--- a/benchmarks/torchatari/requirements.cuda.txt
+++ b/benchmarks/torchatari/requirements.cuda.txt
@@ -64,7 +64,7 @@ farama-notifications==0.0.4
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   gymnasium
-filelock==3.16.0
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -78,7 +78,7 @@ giving==0.4.3
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   ptera
     #   voir
-grpcio==1.66.1
+grpcio==1.66.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   tensorboard
@@ -100,19 +100,19 @@ importlib-resources==6.4.5
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   cantilever
     #   torchcompat
-jax[cuda12]==0.4.31
+jax[cuda12]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r .pin/../constraints/extra/torch.cuda.txt
-jax-cuda12-pjrt==0.4.31
+jax-cuda12-pjrt==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
-jax-cuda12-plugin[with-cuda]==0.4.31
+jax-cuda12-plugin[with-cuda]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
-jaxlib==0.4.31
+jaxlib==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -161,7 +161,6 @@ numpy==1.26.4
     #   jax
     #   jaxlib
     #   ml-dtypes
-    #   opt-einsum
     #   scipy
     #   tensorboard
     #   xformers
@@ -177,7 +176,7 @@ nvidia-cuda-cupti-cu12==12.1.105
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-cuda-nvcc-cu12==12.6.68
+nvidia-cuda-nvcc-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -224,7 +223,7 @@ nvidia-nccl-cu12==2.20.5
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-nvjitlink-cu12==12.6.68
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -238,7 +237,7 @@ omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
-opt-einsum==3.3.0
+opt-einsum==3.4.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -255,7 +254,7 @@ packaging==24.1
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   envpool
     #   tensorboard
-protobuf==5.28.1
+protobuf==5.28.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   tensorboard
@@ -279,7 +278,7 @@ reactivex==4.0.4
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   giving
-rich==13.8.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   tyro
@@ -298,11 +297,11 @@ six==1.16.0
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   asttokens
     #   tensorboard
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
-tensorboard==2.17.1
+tensorboard==2.18.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/torchatari/requirements.in
@@ -324,7 +323,7 @@ triton==3.0.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
-types-protobuf==5.27.0.20240907
+types-protobuf==5.28.0.20240924
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   envpool
@@ -335,9 +334,10 @@ typing-extensions==4.12.2
     #   gymnasium
     #   optree
     #   reactivex
+    #   rich
     #   torch
     #   tyro
-tyro==0.8.10
+tyro==0.8.11
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/torchatari/requirements.in
diff --git a/benchmarks/torchatari/requirements.hpu.txt b/benchmarks/torchatari/requirements.hpu.txt
new file mode 100644
index 000000000..6d7369dfc
--- /dev/null
+++ b/benchmarks/torchatari/requirements.hpu.txt
@@ -0,0 +1,304 @@
+#
+# This file is autogenerated by pip-compile with Python 3.10
+# by the following command:
+#
+#    pip-compile --output-file=benchmarks/torchatari/requirements.hpu.txt .pin/tmp-constraints-hpu-torchatari.txt benchmarks/torchatari/requirements.in
+#
+absl-py==2.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   dm-env
+    #   tensorboard
+antlr4-python3-runtime==4.9.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   omegaconf
+appdirs==1.4.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   cantilever
+asttokens==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+cantilever==0.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/torchatari/requirements.in
+cloudpickle==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   gym
+    #   gymnasium
+codefind==0.1.7
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   ptera
+dm-env==1.6
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   envpool
+dm-tree==0.1.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   dm-env
+docstring-parser==0.16
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   tyro
+envpool==0.8.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/torchatari/requirements.in
+executing==2.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   varname
+farama-notifications==0.0.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   gymnasium
+filelock==3.16.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+    #   triton
+fsspec==2024.6.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+giving==0.4.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   ptera
+    #   voir
+grpcio==1.66.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   tensorboard
+gym==0.26.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/torchatari/requirements.in
+    #   envpool
+gym-notices==0.0.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   gym
+gymnasium==0.29.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   envpool
+importlib-resources==6.4.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   cantilever
+    #   torchcompat
+jinja2==3.1.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+markdown==3.7
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   tensorboard
+markdown-it-py==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rich
+markupsafe==2.1.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   jinja2
+    #   werkzeug
+mdurl==0.1.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   markdown-it-py
+mpmath==1.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   sympy
+networkx==3.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+numpy==1.26.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/torchatari/requirements.in
+    #   dm-env
+    #   envpool
+    #   gym
+    #   gymnasium
+    #   tensorboard
+nvidia-cublas-cu12==12.1.3.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cudnn-cu12
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-cuda-cupti-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cuda-nvrtc-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cuda-runtime-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cudnn-cu12==9.1.0.70
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cufft-cu12==11.0.2.54
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-curand-cu12==10.3.2.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cusolver-cu12==11.4.5.107
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cusparse-cu12==12.1.0.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+nvidia-nccl-cu12==2.20.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-nvjitlink-cu12==12.6.77
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cusolver-cu12
+    #   nvidia-cusparse-cu12
+nvidia-nvtx-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+omegaconf==2.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+optree==0.13.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   envpool
+ovld==0.3.9
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+packaging==24.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   envpool
+    #   tensorboard
+protobuf==5.28.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   tensorboard
+psutil==5.9.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+ptera==1.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+pygments==2.18.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rich
+pyyaml==6.0.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   omegaconf
+reactivex==4.0.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+rich==13.9.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   tyro
+    #   voir
+shtab==1.7.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   tyro
+six==1.16.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   asttokens
+    #   tensorboard
+sympy==1.13.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+tensorboard==2.18.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/torchatari/requirements.in
+tensorboard-data-server==0.7.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   tensorboard
+torch==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/torchatari/requirements.in
+torchcompat==1.1.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/torchatari/requirements.in
+triton==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+types-protobuf==5.28.0.20240924
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   envpool
+typing-extensions==4.12.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   envpool
+    #   gymnasium
+    #   optree
+    #   reactivex
+    #   rich
+    #   torch
+    #   tyro
+tyro==0.8.11
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/torchatari/requirements.in
+varname==0.13.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+voir==0.2.19
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/torchatari/requirements.in
+werkzeug==3.0.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   tensorboard
+
+# The following packages are considered to be unsafe in a requirements file:
+# setuptools
diff --git a/benchmarks/torchatari/requirements.rocm.txt b/benchmarks/torchatari/requirements.rocm.txt
index 71fd92e51..76fa829c5 100644
--- a/benchmarks/torchatari/requirements.rocm.txt
+++ b/benchmarks/torchatari/requirements.rocm.txt
@@ -4,7 +4,7 @@
 #
 #    pip-compile --output-file=benchmarks/torchatari/requirements.rocm.txt .pin/tmp-constraints-rocm-torchatari.txt benchmarks/torchatari/requirements.in
 #
---extra-index-url https://download.pytorch.org/whl/rocm6.0
+--extra-index-url https://download.pytorch.org/whl/rocm6.1
 
 absl-py==2.1.0
     # via
@@ -32,7 +32,7 @@ cloudpickle==3.0.0
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   gym
     #   gymnasium
-codefind==0.1.6
+codefind==0.1.7
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
@@ -52,7 +52,7 @@ envpool==0.8.4
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/torchatari/requirements.in
-executing==1.2.0
+executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   varname
@@ -60,7 +60,7 @@ farama-notifications==0.0.4
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   gymnasium
-filelock==3.15.4
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   pytorch-triton-rocm
@@ -69,16 +69,16 @@ fsspec==2024.6.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
-giving==0.4.2
+giving==0.4.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
     #   voir
-grpcio==1.65.5
+grpcio==1.66.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   tensorboard
-gym==0.23.1
+gym==0.26.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/torchatari/requirements.in
@@ -91,7 +91,7 @@ gymnasium==0.29.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   envpool
-importlib-resources==6.4.3
+importlib-resources==6.4.5
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   cantilever
@@ -134,15 +134,19 @@ numpy==1.26.4
     #   gym
     #   gymnasium
     #   tensorboard
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
 omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
-optree==0.12.1
+optree==0.13.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   envpool
-ovld==0.3.8
+ovld==0.3.9
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
@@ -151,7 +155,7 @@ packaging==24.1
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   envpool
     #   tensorboard
-protobuf==5.27.3
+protobuf==5.28.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   tensorboard
@@ -167,10 +171,6 @@ pygments==2.18.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   rich
-pynvml==11.5.3
-    # via
-    #   -c .pin/../.pin/constraints-rocm-torch.txt
-    #   voir
 pytorch-triton-rocm==3.0.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
@@ -183,7 +183,7 @@ reactivex==4.0.4
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
-rich==13.7.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   tyro
@@ -197,11 +197,11 @@ six==1.16.0
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   asttokens
     #   tensorboard
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
-tensorboard==2.17.1
+tensorboard==2.18.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/torchatari/requirements.in
@@ -209,7 +209,7 @@ tensorboard-data-server==0.7.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   tensorboard
-torch==2.4.0+rocm6.0
+torch==2.4.1+rocm6.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/torchatari/requirements.in
@@ -218,7 +218,7 @@ torchcompat==1.1.4
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -c .pin/../constraints/rocm.txt
     #   -r benchmarks/torchatari/requirements.in
-types-protobuf==5.27.0.20240626
+types-protobuf==5.28.0.20240924
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   envpool
@@ -229,22 +229,23 @@ typing-extensions==4.12.2
     #   gymnasium
     #   optree
     #   reactivex
+    #   rich
     #   torch
     #   tyro
-tyro==0.8.8
+tyro==0.8.11
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/torchatari/requirements.in
-varname==0.10.0
+varname==0.13.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
-voir==0.2.17
+voir==0.2.19
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -c .pin/../constraints/rocm.txt
     #   -r benchmarks/torchatari/requirements.in
-werkzeug==3.0.3
+werkzeug==3.0.4
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   tensorboard
diff --git a/benchmarks/torchvision/requirements.cuda.txt b/benchmarks/torchvision/requirements.cuda.txt
index 6b1a837f0..108cc0e69 100644
--- a/benchmarks/torchvision/requirements.cuda.txt
+++ b/benchmarks/torchvision/requirements.cuda.txt
@@ -26,7 +26,7 @@ executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   varname
-filelock==3.16.0
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -44,19 +44,19 @@ importlib-resources==6.4.5
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torchcompat
-jax[cuda12]==0.4.31
+jax[cuda12]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r .pin/../constraints/extra/torch.cuda.txt
-jax-cuda12-pjrt==0.4.31
+jax-cuda12-pjrt==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
-jax-cuda12-plugin[with-cuda]==0.4.31
+jax-cuda12-plugin[with-cuda]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
-jaxlib==0.4.31
+jaxlib==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -95,7 +95,6 @@ numpy==1.26.4
     #   jax
     #   jaxlib
     #   ml-dtypes
-    #   opt-einsum
     #   scipy
     #   torchvision
     #   xformers
@@ -111,7 +110,7 @@ nvidia-cuda-cupti-cu12==12.1.105
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-cuda-nvcc-cu12==12.6.68
+nvidia-cuda-nvcc-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -158,7 +157,7 @@ nvidia-nccl-cu12==2.20.5
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-nvjitlink-cu12==12.6.68
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -172,7 +171,7 @@ omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
-opt-einsum==3.3.0
+opt-einsum==3.4.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -204,7 +203,7 @@ reactivex==4.0.4
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   giving
-rich==13.8.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
@@ -217,7 +216,7 @@ six==1.16.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   asttokens
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -248,6 +247,7 @@ typing-extensions==4.12.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   reactivex
+    #   rich
     #   torch
 varname==0.13.3
     # via
diff --git a/benchmarks/torchvision/requirements.hpu.txt b/benchmarks/torchvision/requirements.hpu.txt
index 369a1753e..f0b47e914 100644
--- a/benchmarks/torchvision/requirements.hpu.txt
+++ b/benchmarks/torchvision/requirements.hpu.txt
@@ -4,10 +4,6 @@
 #
 #    pip-compile --output-file=benchmarks/torchvision/requirements.hpu.txt .pin/tmp-constraints-hpu-torchvision.txt benchmarks/torchvision/requirements.in
 #
---extra-index-url https://pypi.ngc.nvidia.com
---find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
---trusted-host pypi.ngc.nvidia.com
-
 antlr4-python3-runtime==4.9.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
@@ -16,29 +12,29 @@ asttokens==2.4.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   giving
-codefind==0.1.6
+codefind==0.1.7
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   ptera
-executing==1.2.0
+executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   varname
-filelock==3.15.4
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
     #   triton
-fsspec==2024.5.0
+fsspec==2024.6.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
-giving==0.4.2
+giving==0.4.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   ptera
     #   voir
-importlib-resources==6.4.0
+importlib-resources==6.4.5
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torchcompat
@@ -88,7 +84,7 @@ nvidia-cuda-runtime-cu12==12.1.105
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
-nvidia-cudnn-cu12==8.9.2.26
+nvidia-cudnn-cu12==9.1.0.70
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
@@ -109,11 +105,15 @@ nvidia-cusparse-cu12==12.1.0.106
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   nvidia-cusolver-cu12
     #   torch
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
 nvidia-nccl-cu12==2.20.5
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
-nvidia-nvjitlink-cu12==12.5.82
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   nvidia-cusolver-cu12
@@ -126,7 +126,7 @@ omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   voir
-ovld==0.3.5
+ovld==0.3.9
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   voir
@@ -146,11 +146,7 @@ pygments==2.18.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   rich
-pynvml==11.5.3
-    # via
-    #   -c .pin/../.pin/constraints-hpu-torch.txt
-    #   voir
-pyyaml==6.0.1
+pyyaml==6.0.2
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   omegaconf
@@ -158,7 +154,7 @@ reactivex==4.0.4
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   giving
-rich==13.7.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   voir
@@ -166,11 +162,11 @@ six==1.16.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   asttokens
-sympy==1.13.0
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
-torch==2.3.1
+torch==2.4.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   -r benchmarks/torchvision/requirements.in
@@ -180,15 +176,15 @@ torchcompat==1.1.4
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   -c .pin/../constraints/hpu.txt
     #   -r benchmarks/torchvision/requirements.in
-torchvision==0.18.1
+torchvision==0.19.1
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   -r benchmarks/torchvision/requirements.in
-tqdm==4.66.4
+tqdm==4.66.5
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   -r benchmarks/torchvision/requirements.in
-triton==2.3.1
+triton==3.0.0
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   torch
@@ -196,8 +192,9 @@ typing-extensions==4.12.2
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   reactivex
+    #   rich
     #   torch
-varname==0.10.0
+varname==0.13.3
     # via
     #   -c .pin/../.pin/constraints-hpu-torch.txt
     #   giving
diff --git a/benchmarks/torchvision/requirements.rocm.txt b/benchmarks/torchvision/requirements.rocm.txt
index 094eb29b6..08dfdebf1 100644
--- a/benchmarks/torchvision/requirements.rocm.txt
+++ b/benchmarks/torchvision/requirements.rocm.txt
@@ -4,7 +4,7 @@
 #
 #    pip-compile --output-file=benchmarks/torchvision/requirements.rocm.txt .pin/tmp-constraints-rocm-torchvision.txt benchmarks/torchvision/requirements.in
 #
---extra-index-url https://download.pytorch.org/whl/rocm6.0
+--extra-index-url https://download.pytorch.org/whl/rocm6.1
 
 antlr4-python3-runtime==4.9.3
     # via
@@ -14,15 +14,15 @@ asttokens==2.4.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
-codefind==0.1.6
+codefind==0.1.7
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
-executing==1.2.0
+executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   varname
-filelock==3.15.4
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   pytorch-triton-rocm
@@ -31,12 +31,12 @@ fsspec==2024.6.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
-giving==0.4.2
+giving==0.4.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
     #   voir
-importlib-resources==6.4.3
+importlib-resources==6.4.5
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torchcompat
@@ -68,11 +68,15 @@ numpy==1.26.4
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torchvision
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
 omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
-ovld==0.3.8
+ovld==0.3.9
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
@@ -92,10 +96,6 @@ pygments==2.18.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   rich
-pynvml==11.5.3
-    # via
-    #   -c .pin/../.pin/constraints-rocm-torch.txt
-    #   voir
 pytorch-triton-rocm==3.0.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
@@ -108,7 +108,7 @@ reactivex==4.0.4
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
-rich==13.7.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
@@ -116,11 +116,11 @@ six==1.16.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   asttokens
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
-torch==2.4.0+rocm6.0
+torch==2.4.1+rocm6.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/torchvision/requirements.in
@@ -130,7 +130,7 @@ torchcompat==1.1.4
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -c .pin/../constraints/rocm.txt
     #   -r benchmarks/torchvision/requirements.in
-torchvision==0.19.0+rocm6.0
+torchvision==0.19.1+rocm6.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/torchvision/requirements.in
@@ -142,8 +142,9 @@ typing-extensions==4.12.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   reactivex
+    #   rich
     #   torch
-varname==0.10.0
+varname==0.13.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
diff --git a/benchmarks/torchvision/voirfile.py b/benchmarks/torchvision/voirfile.py
index ed3f0af7c..a05c99774 100644
--- a/benchmarks/torchvision/voirfile.py
+++ b/benchmarks/torchvision/voirfile.py
@@ -24,7 +24,7 @@ class Config:
     stop: int = 20
 
     # Number of seconds between each gpu poll
-    gpu_poll: int = 3
+    gpu_poll: float = 1
 
 
 @configurable
diff --git a/benchmarks/torchvision_ddp/requirements.cuda.txt b/benchmarks/torchvision_ddp/requirements.cuda.txt
index 28c6198b2..8572482df 100644
--- a/benchmarks/torchvision_ddp/requirements.cuda.txt
+++ b/benchmarks/torchvision_ddp/requirements.cuda.txt
@@ -26,7 +26,7 @@ executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   varname
-filelock==3.16.0
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -44,19 +44,19 @@ importlib-resources==6.4.5
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torchcompat
-jax[cuda12]==0.4.31
+jax[cuda12]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r .pin/../constraints/extra/torch.cuda.txt
-jax-cuda12-pjrt==0.4.31
+jax-cuda12-pjrt==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
-jax-cuda12-plugin[with-cuda]==0.4.31
+jax-cuda12-plugin[with-cuda]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
-jaxlib==0.4.31
+jaxlib==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -95,7 +95,6 @@ numpy==1.26.4
     #   jax
     #   jaxlib
     #   ml-dtypes
-    #   opt-einsum
     #   scipy
     #   torchvision
     #   xformers
@@ -111,7 +110,7 @@ nvidia-cuda-cupti-cu12==12.1.105
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-cuda-nvcc-cu12==12.6.68
+nvidia-cuda-nvcc-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -158,7 +157,7 @@ nvidia-nccl-cu12==2.20.5
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-nvjitlink-cu12==12.6.68
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -172,7 +171,7 @@ omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
-opt-einsum==3.3.0
+opt-einsum==3.4.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -204,7 +203,7 @@ reactivex==4.0.4
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   giving
-rich==13.8.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
@@ -217,7 +216,7 @@ six==1.16.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   asttokens
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -248,6 +247,7 @@ typing-extensions==4.12.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   reactivex
+    #   rich
     #   torch
 varname==0.13.3
     # via
diff --git a/benchmarks/torchvision_ddp/requirements.hpu.txt b/benchmarks/torchvision_ddp/requirements.hpu.txt
index e69de29bb..a4174e7bc 100644
--- a/benchmarks/torchvision_ddp/requirements.hpu.txt
+++ b/benchmarks/torchvision_ddp/requirements.hpu.txt
@@ -0,0 +1,205 @@
+#
+# This file is autogenerated by pip-compile with Python 3.10
+# by the following command:
+#
+#    pip-compile --output-file=benchmarks/torchvision_ddp/requirements.hpu.txt .pin/tmp-constraints-hpu-torchvision.txt benchmarks/torchvision_ddp/requirements.in
+#
+antlr4-python3-runtime==4.9.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   omegaconf
+asttokens==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+codefind==0.1.7
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   ptera
+executing==2.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   varname
+filelock==3.16.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+    #   triton
+fsspec==2024.6.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+giving==0.4.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   ptera
+    #   voir
+importlib-resources==6.4.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torchcompat
+jinja2==3.1.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+markdown-it-py==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rich
+markupsafe==2.1.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   jinja2
+mdurl==0.1.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   markdown-it-py
+mpmath==1.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   sympy
+networkx==3.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+numpy==1.26.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torchvision
+nvidia-cublas-cu12==12.1.3.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cudnn-cu12
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-cuda-cupti-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cuda-nvrtc-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cuda-runtime-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cudnn-cu12==9.1.0.70
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cufft-cu12==11.0.2.54
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-curand-cu12==10.3.2.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cusolver-cu12==11.4.5.107
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cusparse-cu12==12.1.0.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+nvidia-nccl-cu12==2.20.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-nvjitlink-cu12==12.6.77
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cusolver-cu12
+    #   nvidia-cusparse-cu12
+nvidia-nvtx-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+omegaconf==2.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+ovld==0.3.9
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+pillow==10.4.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torchvision
+psutil==5.9.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+ptera==1.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+pygments==2.18.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rich
+pyyaml==6.0.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   omegaconf
+reactivex==4.0.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+rich==13.9.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+six==1.16.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   asttokens
+sympy==1.13.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+torch==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/torchvision_ddp/requirements.in
+    #   torchvision
+torchcompat==1.1.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/torchvision_ddp/requirements.in
+torchvision==0.19.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/torchvision_ddp/requirements.in
+tqdm==4.66.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/torchvision_ddp/requirements.in
+triton==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+typing-extensions==4.12.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   reactivex
+    #   rich
+    #   torch
+varname==0.13.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+voir==0.2.19
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/torchvision_ddp/requirements.in
diff --git a/benchmarks/torchvision_ddp/requirements.rocm.txt b/benchmarks/torchvision_ddp/requirements.rocm.txt
index d1241db8b..9eed94421 100644
--- a/benchmarks/torchvision_ddp/requirements.rocm.txt
+++ b/benchmarks/torchvision_ddp/requirements.rocm.txt
@@ -4,7 +4,7 @@
 #
 #    pip-compile --output-file=benchmarks/torchvision_ddp/requirements.rocm.txt .pin/tmp-constraints-rocm-torchvision.txt benchmarks/torchvision_ddp/requirements.in
 #
---extra-index-url https://download.pytorch.org/whl/rocm6.0
+--extra-index-url https://download.pytorch.org/whl/rocm6.1
 
 antlr4-python3-runtime==4.9.3
     # via
@@ -14,15 +14,15 @@ asttokens==2.4.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
-codefind==0.1.6
+codefind==0.1.7
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
-executing==1.2.0
+executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   varname
-filelock==3.15.4
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   pytorch-triton-rocm
@@ -31,12 +31,12 @@ fsspec==2024.6.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
-giving==0.4.2
+giving==0.4.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   ptera
     #   voir
-importlib-resources==6.4.3
+importlib-resources==6.4.5
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torchcompat
@@ -68,11 +68,15 @@ numpy==1.26.4
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torchvision
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
 omegaconf==2.3.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
-ovld==0.3.8
+ovld==0.3.9
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
@@ -92,10 +96,6 @@ pygments==2.18.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   rich
-pynvml==11.5.3
-    # via
-    #   -c .pin/../.pin/constraints-rocm-torch.txt
-    #   voir
 pytorch-triton-rocm==3.0.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
@@ -108,7 +108,7 @@ reactivex==4.0.4
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
-rich==13.7.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   voir
@@ -116,11 +116,11 @@ six==1.16.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   asttokens
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
-torch==2.4.0+rocm6.0
+torch==2.4.1+rocm6.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/torchvision_ddp/requirements.in
@@ -130,7 +130,7 @@ torchcompat==1.1.4
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -c .pin/../constraints/rocm.txt
     #   -r benchmarks/torchvision_ddp/requirements.in
-torchvision==0.19.0+rocm6.0
+torchvision==0.19.1+rocm6.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/torchvision_ddp/requirements.in
@@ -142,8 +142,9 @@ typing-extensions==4.12.2
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   reactivex
+    #   rich
     #   torch
-varname==0.10.0
+varname==0.13.3
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   giving
diff --git a/benchmarks/vjepa/benchfile.py b/benchmarks/vjepa/benchfile.py
index d25b47b53..228023ced 100644
--- a/benchmarks/vjepa/benchfile.py
+++ b/benchmarks/vjepa/benchfile.py
@@ -23,7 +23,9 @@ class Vjepa(Package):
     def make_env(self):
         # Return a dict of environment variables for prepare_script and
         # main_script.
-        return super().make_env()
+        env = super().make_env()
+        env["PT_HPU_LAZY_MODE"] = "0"
+        return env
 
     async def install(self):
         vjepa = self.dirs.code / "jepa"
diff --git a/benchmarks/vjepa/main.py b/benchmarks/vjepa/main.py
index 74ca606f7..55981859c 100644
--- a/benchmarks/vjepa/main.py
+++ b/benchmarks/vjepa/main.py
@@ -475,14 +475,19 @@ def reg_fn(z):
                     scaler.unscale_(optimizer)
                 else:
                     loss.backward()
+                
                 if (epoch > warmup) and (clip_grad is not None):
                     _enc_norm = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip_grad)
                     _pred_norm = torch.nn.utils.clip_grad_norm_(predictor.parameters(), clip_grad)
+                
+                acc.mark_step()
                 if mixed_precision:
                     scaler.step(optimizer)
                     scaler.update()
                 else:
                     optimizer.step()
+                acc.mark_step()
+                    
                 grad_stats = grad_logger(encoder.named_parameters())
                 grad_stats.global_norm = float(_enc_norm)
                 grad_stats_pred = grad_logger(predictor.named_parameters())
@@ -506,7 +511,8 @@ def reg_fn(z):
                     grad_stats_pred,
                     optim_stats,
                 )
-            (loss, loss_jepa, loss_reg, _new_lr, _new_wd, grad_stats, grad_stats_pred, optim_stats,), gpu_etime_ms = gpu_timer(train_step)
+            loss, loss_jepa, loss_reg, _new_lr, _new_wd, grad_stats, grad_stats_pred, optim_stats = train_step()
+            
             iter_elapsed_time_ms = (time.time() - itr_start_time) * 1000.
             loss_meter.update(loss)
             input_var = float(AllReduce.apply(clips.view(clips.shape[0], -1).var(dim=1).mean(dim=0)))
@@ -515,7 +521,7 @@ def reg_fn(z):
             input_var_min_meter.update(input_var_min)
             jepa_loss_meter.update(loss_jepa)
             reg_loss_meter.update(loss_reg)
-            gpu_time_meter.update(gpu_etime_ms)
+            # gpu_time_meter.update(gpu_etime_ms)
             wall_time_meter.update(iter_elapsed_time_ms)
 
             observer.record_loss(loss)
@@ -530,7 +536,6 @@ def log_stats():
                     loss_reg,
                     grad_stats.global_norm,
                     grad_stats_pred.global_norm,
-                    gpu_etime_ms,
                     iter_elapsed_time_ms)
                 if (itr % log_freq == 0) or np.isnan(loss) or np.isinf(loss):
                     logger.info(
@@ -637,7 +642,11 @@ def main():
     params["nodes"] = nnodes
     params["tasks_per_node"] = gpu_per_nodes
 
+    print("HERE", os.getenv("RANK", -1) )
     if os.getenv("RANK", -1) != -1:
+        print("INIT PROCESS GROUP HERE")
+        print(acc)
+        print(acc.init_process_group)
         acc.init_process_group()
 
     try:
@@ -650,7 +659,7 @@ def main():
         if os.getenv("RANK", -1) != -1:
             acc.destroy_process_group()
     
-    sys.exit(0)
+    # sys.exit(0)
 
 if __name__ == "__main__":
     main()
diff --git a/benchmarks/vjepa/requirements.cuda.txt b/benchmarks/vjepa/requirements.cuda.txt
index c6e6ebb0e..867c50b53 100644
--- a/benchmarks/vjepa/requirements.cuda.txt
+++ b/benchmarks/vjepa/requirements.cuda.txt
@@ -18,7 +18,7 @@ asttokens==2.4.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   giving
-beartype==0.18.5
+beartype==0.19.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/vjepa/requirements.in
@@ -55,7 +55,7 @@ executing==2.1.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   varname
-filelock==3.16.0
+filelock==3.16.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   huggingface-hub
@@ -71,7 +71,7 @@ giving==0.4.3
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   ptera
     #   voir
-huggingface-hub==0.24.7
+huggingface-hub==0.25.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   timm
@@ -79,19 +79,19 @@ idna==3.10
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   requests
-jax[cuda12]==0.4.31
+jax[cuda12]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r .pin/../constraints/extra/torch.cuda.txt
-jax-cuda12-pjrt==0.4.31
+jax-cuda12-pjrt==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
-jax-cuda12-plugin[with-cuda]==0.4.31
+jax-cuda12-plugin[with-cuda]==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
-jaxlib==0.4.31
+jaxlib==0.4.33
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -133,7 +133,6 @@ numpy==1.26.4
     #   jaxlib
     #   ml-dtypes
     #   opencv-python
-    #   opt-einsum
     #   pandas
     #   scipy
     #   torchvision
@@ -151,7 +150,7 @@ nvidia-cuda-cupti-cu12==12.1.105
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-cuda-nvcc-cu12==12.6.68
+nvidia-cuda-nvcc-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -198,7 +197,7 @@ nvidia-nccl-cu12==2.20.5
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
     #   torch
-nvidia-nvjitlink-cu12==12.6.68
+nvidia-nvjitlink-cu12==12.6.77
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax-cuda12-plugin
@@ -216,7 +215,7 @@ opencv-python==4.10.0.84
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/vjepa/requirements.in
-opt-einsum==3.3.0
+opt-einsum==3.4.0
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   jax
@@ -228,7 +227,7 @@ packaging==24.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   huggingface-hub
-pandas==2.2.2
+pandas==2.2.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/vjepa/requirements.in
@@ -272,7 +271,7 @@ requests==2.32.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   huggingface-hub
-rich==13.8.1
+rich==13.9.1
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   voir
@@ -290,11 +289,11 @@ six==1.16.0
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   asttokens
     #   python-dateutil
-submitit==1.5.1
+submitit==1.5.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   -r benchmarks/vjepa/requirements.in
-sympy==1.13.2
+sympy==1.13.3
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   torch
@@ -327,9 +326,10 @@ typing-extensions==4.12.2
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   huggingface-hub
     #   reactivex
+    #   rich
     #   submitit
     #   torch
-tzdata==2024.1
+tzdata==2024.2
     # via
     #   -c .pin/../.pin/constraints-cuda-torch.txt
     #   pandas
diff --git a/benchmarks/vjepa/requirements.hpu.txt b/benchmarks/vjepa/requirements.hpu.txt
new file mode 100644
index 000000000..b1c986ecb
--- /dev/null
+++ b/benchmarks/vjepa/requirements.hpu.txt
@@ -0,0 +1,297 @@
+#
+# This file is autogenerated by pip-compile with Python 3.10
+# by the following command:
+#
+#    pip-compile --output-file=benchmarks/vjepa/requirements.hpu.txt .pin/tmp-constraints-hpu-vjepa-gpus.txt benchmarks/vjepa/requirements.in
+#
+antlr4-python3-runtime==4.9.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   omegaconf
+asttokens==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+beartype==0.19.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
+braceexpand==0.1.7
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
+    #   webdataset
+certifi==2024.8.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+charset-normalizer==3.3.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+cloudpickle==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   submitit
+codefind==0.1.7
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   ptera
+decord==0.6.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
+einops==0.8.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
+executing==2.1.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   varname
+filelock==3.16.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   huggingface-hub
+    #   torch
+    #   triton
+fsspec==2024.6.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   huggingface-hub
+    #   torch
+giving==0.4.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   ptera
+    #   voir
+huggingface-hub==0.25.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   timm
+idna==3.10
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+jinja2==3.1.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+markdown-it-py==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rich
+markupsafe==2.1.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   jinja2
+mdurl==0.1.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   markdown-it-py
+mpmath==1.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   sympy
+networkx==3.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+numpy==1.26.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
+    #   decord
+    #   opencv-python
+    #   pandas
+    #   torchvision
+    #   webdataset
+nvidia-cublas-cu12==12.1.3.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cudnn-cu12
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-cuda-cupti-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cuda-nvrtc-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cuda-runtime-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cudnn-cu12==9.1.0.70
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cufft-cu12==11.0.2.54
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-curand-cu12==10.3.2.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cusolver-cu12==11.4.5.107
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-cusparse-cu12==12.1.0.106
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+nvidia-nccl-cu12==2.20.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+nvidia-nvjitlink-cu12==12.6.77
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   nvidia-cusolver-cu12
+    #   nvidia-cusparse-cu12
+nvidia-nvtx-cu12==12.1.105
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+omegaconf==2.3.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+opencv-python==4.10.0.84
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
+ovld==0.3.9
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+packaging==24.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   huggingface-hub
+pandas==2.2.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
+pillow==10.4.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torchvision
+psutil==5.9.8
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+ptera==1.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+pygments==2.18.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   rich
+python-dateutil==2.9.0.post0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pandas
+pytz==2024.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pandas
+pyyaml==6.0.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
+    #   huggingface-hub
+    #   omegaconf
+    #   timm
+    #   webdataset
+reactivex==4.0.4
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+requests==2.32.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   huggingface-hub
+rich==13.9.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   voir
+safetensors==0.4.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   timm
+six==1.16.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   asttokens
+    #   python-dateutil
+submitit==1.5.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
+sympy==1.13.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+timm==1.0.9
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
+torch==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
+    #   timm
+    #   torchvision
+torchvision==0.19.1
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
+    #   timm
+tqdm==4.66.5
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   huggingface-hub
+triton==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   torch
+typing-extensions==4.12.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   huggingface-hub
+    #   reactivex
+    #   rich
+    #   submitit
+    #   torch
+tzdata==2024.2
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   pandas
+urllib3==2.2.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   requests
+varname==0.13.3
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   giving
+voir==0.2.19
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -c .pin/../constraints/hpu.txt
+    #   -r benchmarks/vjepa/requirements.in
+webdataset==0.2.100
+    # via
+    #   -c .pin/../.pin/constraints-hpu-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
diff --git a/benchmarks/vjepa/requirements.rocm.txt b/benchmarks/vjepa/requirements.rocm.txt
new file mode 100644
index 000000000..a473fac77
--- /dev/null
+++ b/benchmarks/vjepa/requirements.rocm.txt
@@ -0,0 +1,247 @@
+#
+# This file is autogenerated by pip-compile with Python 3.10
+# by the following command:
+#
+#    pip-compile --output-file=benchmarks/vjepa/requirements.rocm.txt .pin/tmp-constraints-rocm-vjepa-gpus.txt benchmarks/vjepa/requirements.in
+#
+--extra-index-url https://download.pytorch.org/whl/rocm6.1
+
+antlr4-python3-runtime==4.9.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   omegaconf
+asttokens==2.4.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   giving
+beartype==0.19.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
+braceexpand==0.1.7
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
+    #   webdataset
+certifi==2024.8.30
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   requests
+charset-normalizer==3.3.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   requests
+cloudpickle==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   submitit
+codefind==0.1.7
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   ptera
+decord==0.6.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
+einops==0.8.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
+executing==2.1.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   varname
+filelock==3.16.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   huggingface-hub
+    #   pytorch-triton-rocm
+    #   torch
+fsspec==2024.6.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   huggingface-hub
+    #   torch
+giving==0.4.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   ptera
+    #   voir
+huggingface-hub==0.25.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   timm
+idna==3.10
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   requests
+jinja2==3.1.4
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   torch
+markdown-it-py==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   rich
+markupsafe==2.1.5
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   jinja2
+mdurl==0.1.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   markdown-it-py
+mpmath==1.3.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   sympy
+networkx==3.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   torch
+numpy==1.26.4
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
+    #   decord
+    #   opencv-python
+    #   pandas
+    #   torchvision
+    #   webdataset
+nvidia-ml-py==12.560.30
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
+omegaconf==2.3.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
+opencv-python==4.10.0.84
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
+ovld==0.3.9
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
+packaging==24.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   huggingface-hub
+pandas==2.2.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
+pillow==10.4.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   torchvision
+psutil==5.9.8
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
+ptera==1.4.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
+pygments==2.18.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   rich
+python-dateutil==2.9.0.post0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   pandas
+pytorch-triton-rocm==3.0.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   torch
+pytz==2024.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   pandas
+pyyaml==6.0.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
+    #   huggingface-hub
+    #   omegaconf
+    #   timm
+    #   webdataset
+reactivex==4.0.4
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   giving
+requests==2.32.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   huggingface-hub
+rich==13.9.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   voir
+safetensors==0.4.5
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   timm
+six==1.16.0
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   asttokens
+    #   python-dateutil
+submitit==1.5.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
+sympy==1.13.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   torch
+timm==1.0.9
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
+torch==2.4.1+rocm6.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
+    #   timm
+    #   torchvision
+torchvision==0.19.1+rocm6.1
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
+    #   timm
+tqdm==4.66.5
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   huggingface-hub
+typing-extensions==4.12.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   huggingface-hub
+    #   reactivex
+    #   rich
+    #   submitit
+    #   torch
+tzdata==2024.2
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   pandas
+urllib3==2.2.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   requests
+varname==0.13.3
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   giving
+voir==0.2.19
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -c .pin/../constraints/rocm.txt
+    #   -r benchmarks/vjepa/requirements.in
+webdataset==0.2.100
+    # via
+    #   -c .pin/../.pin/constraints-rocm-torch.txt
+    #   -r benchmarks/vjepa/requirements.in
diff --git a/benchmate/benchmate/jaxmem.py b/benchmate/benchmate/jaxmem.py
new file mode 100644
index 000000000..1ead3bff5
--- /dev/null
+++ b/benchmate/benchmate/jaxmem.py
@@ -0,0 +1,30 @@
+
+
+
+def memory_peak_fetcher():
+    import jax
+
+    def fetch_memory_peak():
+        # 'memory', 'memory_stats'
+        devices = jax.devices()
+        max_mem = -1
+        for device in devices:
+            # dqn.D0 [stdout] Device: cuda:0
+            # dqn.D0 [stdout]   num_allocs: 0.0006799697875976562 MiB
+            # dqn.D0 [stdout]   bytes_in_use: 0.915771484375 MiB
+            # dqn.D0 [stdout]   peak_bytes_in_use: 80.41552734375 MiB
+            # dqn.D0 [stdout]   largest_alloc_size: 16.07958984375 MiB
+            # dqn.D0 [stdout]   bytes_limit: 60832.359375 MiB
+            # dqn.D0 [stdout]   bytes_reserved: 0.0 MiB
+            # dqn.D0 [stdout]   peak_bytes_reserved: 0.0 MiB
+            # dqn.D0 [stdout]   largest_free_block_bytes: 0.0 MiB
+            # dqn.D0 [stdout]   pool_bytes: 60832.359375 MiB
+            # dqn.D0 [stdout]   peak_pool_bytes: 60832.359375 MiB
+
+            # device_name = str(device)
+            mem = device.memory_stats().get("peak_bytes_in_use", 0) / (1024 ** 2)
+            max_mem = max(mem, max_mem)
+
+        return max_mem
+    
+    return fetch_memory_peak
diff --git a/benchmate/benchmate/monitor.py b/benchmate/benchmate/monitor.py
index 5d2624201..ee8d19596 100644
--- a/benchmate/benchmate/monitor.py
+++ b/benchmate/benchmate/monitor.py
@@ -15,9 +15,32 @@
 from voir.instruments.monitor import monitor
 
 
+from .metrics import sumggle_push, give_push, file_push
+
+
+def auto_push():
+    # use_stdout = int(os.getenv("MILABENCH_USE_STDOUT", 0))
+    mb_managed = int(os.getenv("MILABENCH_MANAGED", 0))
+
+    # Milabench managed: we need to push metrics to it
+    if mb_managed == 1:
+        # Using voir, DATA_FD is defined as well
+        ov = current_overseer.get()
+        if ov is not None:
+            return ov.give
+        
+        # Not using Voir, using structured stdout
+        if int(os.getenv("MILABENCH_USE_STDOUT", 0)) == 1:
+            return sumggle_push()
+
+        raise RuntimeError("Could not find something to push to")
+
+    # Not using milabench; using stdout
+    return file_push()
+
 
 @instrument_definition
-def monitor_monogpu(ov, poll_interval=10, arch=None):
+def monitor_monogpu(ov, poll_interval=1, arch=None):
     return monitor(
         ov,
         poll_interval=poll_interval,
@@ -28,7 +51,7 @@ def monitor_monogpu(ov, poll_interval=10, arch=None):
 
 
 @instrument_definition
-def monitor_node(ov, poll_interval=10, arch=None):
+def monitor_node(ov, poll_interval=1, arch=None):
     return monitor(
         ov,
         poll_interval=poll_interval,
@@ -41,6 +64,7 @@ def monitor_node(ov, poll_interval=10, arch=None):
 
 
 def _smuggle_monitor(poll_interval=10, worker_init=None, **monitors):
+    # USE auto push
     data_file = SmuggleWriter(sys.stdout)
     def mblog(data):
         nonlocal data_file
@@ -49,7 +73,8 @@ def mblog(data):
             try:
                 print(json.dumps(data), file=data_file)
             except ValueError:
-                print("Is bench ending?, ignoring ValueError")
+                pass
+                # print("Is bench ending?, ignoring ValueError")
     
     def get():
         t = time.time()
diff --git a/config/base.yaml b/config/base.yaml
index 28a72afb7..38dfc4d38 100644
--- a/config/base.yaml
+++ b/config/base.yaml
@@ -27,7 +27,6 @@ _torchvision:
     --loader: pytorch
     --data: "{milabench_data}/FakeImageNet"
 
-
 _torchvision_ddp:
   inherits: _defaults
   definition: ../benchmarks/torchvision_ddp
@@ -67,7 +66,7 @@ llama:
   definition: ../benchmarks/llama
   group: llm
   install_group: torch
-  max_duration: 800
+  max_duration: 3600
   tags:
     - nlp
     - llm
@@ -119,7 +118,6 @@ _timm:
     --dataset: "FakeImageNet"
     --workers: "auto({n_worker}, 8)"
 
-
 _accelerate_opt:
   inherits: _defaults
   tags:
@@ -156,7 +154,6 @@ _accelerate_opt:
   use_deepspeed: true
   num_machines: 1
 
-
 fp16:
   inherits: _flops
 
@@ -209,6 +206,11 @@ resnet50:
   
 resnet50-noio:
   inherits: _torchvision
+  voir:
+    options:
+      stop: 500
+      interval: "1s"
+
   tags:
     - vision
     - classification
@@ -346,7 +348,7 @@ reformer:
     - monogpu
   argv:
     --model: "Reformer"
-    --batch-size: 64
+    --batch-size: 32
 
 whisper:
   inherits: _hf
@@ -372,12 +374,15 @@ focalnet:
     --model: focalnet_base_lrf
 
 brax:
+  # Brax requires very specific sizes to work
+  # so the resizer is not capable of handling resizing this bench
   inherits: _defaults
   tags:
     - rl
     - jax
     - multigpu
     - gym
+    - nobatch
   definition: ../benchmarks/brax
   group: brax
   install_group: torch
@@ -390,7 +395,6 @@ brax:
     --num-minibatches: 32
     --num-envs: 8192
 
-
 _diffusion:
   inherits: _defaults
   definition: ../benchmarks/diffusion
@@ -537,19 +541,19 @@ _llm:
   tags:
     - nlp
     - llm
-  max_duration: 1200
+  max_duration: 3600
   num_machines: 1
   inherits: _defaults
   definition: ../benchmarks/llm
   install_group: torch
 
-
 llm-lora-single:
   inherits: _llm
   tags:
     - monogpu
   plan:
     method: per_gpu
+
   argv:
     "{milabench_code}/recipes/lora_finetune_single_device.py": true
     --config: "{milabench_code}/configs/llama3_8B_lora_single_device.yaml"
@@ -562,6 +566,7 @@ llm-lora-single:
     repo_id="meta-llama/Meta-Llama-3.1-8B": true
     batch_size=8: true
     gradient_accumulation_steps=8: true
+    device={device_name}: true
 
 
 llm-lora-ddp-gpus:
@@ -583,7 +588,7 @@ llm-lora-ddp-gpus:
     repo_id="meta-llama/Meta-Llama-3.1-8B": true
     batch_size=8: true
     gradient_accumulation_steps=8: true
-
+    device={device_name}: true
 
 llm-lora-ddp-nodes:
   tags:
@@ -606,12 +611,11 @@ llm-lora-ddp-nodes:
     repo_id="meta-llama/Meta-Llama-3.1-8B": true
     batch_size=8: true
     gradient_accumulation_steps=8: true
-
+    device={device_name}: true
   num_machines: 2
   requires_capabilities:
     - "len(nodes) >= ${num_machines}"
 
-
 llm-lora-mp-gpus:
   inherits: _llm
   tags:
@@ -633,8 +637,12 @@ llm-lora-mp-gpus:
     repo_id="meta-llama/Meta-Llama-3.1-70B": true
     batch_size=8: true
     gradient_accumulation_steps=1: true
-
+    device={device_name}: true
+  
 llm-full-mp-gpus:
+  voir:
+    options:
+      stop: 30
   inherits: _llm
   tags:
     - multigpu
@@ -655,7 +663,8 @@ llm-full-mp-gpus:
     safetensors=true: true
     batch_size=2: true
     gradient_accumulation_steps=1: true
-
+    device={device_name}: true
+  
 llm-full-mp-nodes:
   tags:
     - multinode
@@ -678,7 +687,8 @@ llm-full-mp-nodes:
     safetensors=true: true
     batch_size=2: true
     gradient_accumulation_steps=1: true
-
+    device={device_name}: true
+  
   num_machines: 2
   requires_capabilities:
     - "len(nodes) >= ${num_machines}"
@@ -690,6 +700,7 @@ _purejaxrl:
     - monogpu
     - gym
     - rl
+    - jax
   definition: ../benchmarks/purejaxrl
   plan:
     method: per_gpu
@@ -699,7 +710,8 @@ dqn:
   argv:
     dqn: true
     --num_envs: auto({cpu_per_gpu}, 128)
-    --buffer_batch_size: 128
+    --buffer_size: 131072
+    --buffer_batch_size: 65536
     --env_name: CartPole-v1
     --training_interval: 10
 
@@ -712,7 +724,7 @@ ppo:
     --num_minibatches: 32
     --update_epochs: 4
     --env_name: hopper
-    --total_timesteps: 200000
+    --total_timesteps: 2000000
 
 _geo_gnn:
   inherits: _defaults
@@ -724,14 +736,22 @@ _geo_gnn:
   plan:
     method: per_gpu
 
+pna:
+  inherits: _geo_gnn
+  argv:
+    --model: 'PNA'
+    --num-samples: 100000
+    --batch-size: 4096
+    --num-workers: "auto({n_worker}, 0)"
+
 dimenet:
   inherits: _geo_gnn
-  tags:
-    - monogpu
   argv:
     --model: 'DimeNet'
-    --num-samples: 10000
+    --num-samples: 100000
     --use3d: True
+    --batch-size: 16
+    --num-workers: "auto({n_worker}, 0)"
 
 recursiongfn:
   inherits: _defaults
@@ -745,7 +765,7 @@ recursiongfn:
 
   argv:
     --batch_size: 128
-    --num_workers: 8
+    --num_workers: "auto({n_worker}, 8)"
     --num_steps: 100
     --layer_width: 128
     --num_layers: 4
@@ -769,6 +789,7 @@ torchatari:
     --env-id: Breakout-v5
 
 _llava:
+  max_duration: 3600
   inherits: _defaults
   definition: ../benchmarks/llava
   install_group: torch
@@ -776,19 +797,20 @@ _llava:
     method: per_gpu
   tags:
     - llm
-    - monogpu
   argv:
     --batch_size: 1
-    --num_workers: 4
+    --num_workers: "auto({n_worker}, 4)"
     --gradient_accumulation_steps: 1
 
 llava-single:
   inherits: _llava
+  tags:
+    - monogpu
   plan:
     method: per_gpu
   argv:
     --batch_size: 1
-    --num_workers: 4
+    --num_workers: "auto({n_worker}, 4)"
     --gradient_accumulation_steps: 1
 
 llava-gpus:
@@ -800,7 +822,7 @@ llava-gpus:
     n: 1
   argv:
     --batch_size: 1
-    --num_workers: 4
+    --num_workers: "auto({n_worker}, 4)"
     --gradient_accumulation_steps: 1
 
 
@@ -811,7 +833,6 @@ _rlhf:
   plan:
     method: per_gpu
   tags:
-    - monogpu
     - rl
     - rlhf
     - llm
@@ -825,6 +846,8 @@ _rlhf:
 
 rlhf-single:
   inherits: _rlhf
+  tags:
+    - monogpu
   plan:
     method: per_gpu
 
@@ -862,3 +885,26 @@ vjepa-gpus:
   plan:
     method: njobs
     n: 1
+
+cleanrljax:
+  inherits: _defaults
+  install_group: torch
+  definition: ../benchmarks/cleanrl_jax
+  tags:
+    - monogpu
+    - jax
+  plan:
+    method: per_gpu
+  
+  # args.batch_size     = int(args.num_envs * args.num_steps)
+  # args.minibatch_size = int(args.batch_size // args.num_minibatches)
+  # args.num_iterations = args.total_timesteps // args.batch_size
+  # --total_timesteps
+  # --num_steps
+  # --num_minibatches
+
+  argv:
+    --num_envs: auto({cpu_per_gpu}, 128)
+    --num_steps: 128
+    --num_minibatches: 4
+    --total_timesteps: 10000000
\ No newline at end of file
diff --git a/config/examples/system.yaml b/config/examples/system.yaml
index 7b84c48d1..78cf39571 100644
--- a/config/examples/system.yaml
+++ b/config/examples/system.yaml
@@ -26,3 +26,33 @@ system:
       ip: 192.168.11.13
       main: false 
       user: username
+
+
+
+
+multirun:
+  runs:
+    # Force batch size to populate the sizing model
+    - name: "bs{sizer.batch_size}"
+      matrix:
+        sizer.auto: 1
+        sizer.batch_size: [1, 2, 4, 8, 16, 32, 64, 128]
+        sizer.save: ["scaling.yaml"]
+    
+    # Matrix run
+    - name: "c{sizer.capacity}_m{sizer.multiple}_w{cpu.n_workers}"
+      matrix:
+        cpu.auto: 1
+        cpu.n_workers: [2, 4, 8, 16, 32]
+        sizer.auto: 1
+        sizer.capacity: [4Go, 8Go, 16Go, 32Go, 64Go, All]
+        sizer.multiple: 8
+        sizer.save: ["scaling.yaml"]
+
+    # Auto run
+    - name: "auto"
+      matrix:
+        cpu.auto: 1
+        sizer.auto: 1
+        sizer.multiple: 8
+        sizer.save: ["scaling.yaml"]
diff --git a/config/scaling.yaml b/config/scaling.yaml
index 09f3f9ae5..00a37bd8e 100644
--- a/config/scaling.yaml
+++ b/config/scaling.yaml
@@ -55,7 +55,13 @@ bert-tf32-fp16:
     112: 81140.75 MiB
   optimized: 128
 bf16: {}
-brax: {}
+brax:
+  arg: --batch-size
+  model:
+    1024: 4912.25 MiB
+cleanrljax:
+  arg: --num_steps
+  optimized: 128
 convnext_large-fp16:
   arg: --batch-size
   model:
@@ -178,21 +184,59 @@ diffusion-nodes:
     1: 21686.75 MiB
     2: 21930.75 MiB
     4: 23510.75 MiB
+    16: 40054.25 MiB
+    32: 61512.25 MiB
 diffusion-single:
   arg: --batch_size
   model:
     1: 21654.75 MiB
     2: 21818.75 MiB
     4: 23478.75 MiB
-dimenet: {}
+    16: 33850.25 MiB
+    32: 55354.25 MiB
+dimenet:
+  arg: --batch-size
+  model:
+    2: 452.6875 MiB
+    4: 1604.25 MiB
+    24: 4776.25 MiB
+    56: 6330.25 MiB
+    64: 12274.25 MiB
+    112: 15294.25 MiB
+    128: 13002.25 MiB
+    240: 67506.25 MiB
+    280: 56556.25 MiB
+    488: 80406.25 MiB
 dinov2-giant-gpus:
   arg: train.batch_size_per_gpu={batch_size}
   model:
-    32: 69614 MiB
+    1: 32240.25 MiB
+    2: 32252.25 MiB
+    4: 32404.25 MiB
+    16: 38350.25 MiB
+    24: 48856.25 MiB
+    32: 72102.25 MiB
   optimized: 32
+dinov2-giant-nodes:
+  arg: train.batch_size_per_gpu={batch_size}
 dinov2-giant-single:
   arg: train.batch_size_per_gpu={batch_size}
+  model:
+    1: 20682.25 MiB
+    2: 20682.25 MiB
+    4: 20682.25 MiB
+    16: 52748.25 MiB
+    24: 60792.25 MiB
+    32: 74544.25 MiB
 dlrm: {}
+dqn:
+  arg: --buffer_batch_size
+  model:
+    1024: 81.81005859375 MiB
+    2048: 83.40380859375 MiB
+    32768: 131.21630859375 MiB
+    65536: 182.21630859375 MiB
+  optimized: 128
 focalnet:
   arg: --batch-size
   model:
@@ -216,6 +260,20 @@ fp16: {}
 fp32: {}
 lightning:
   arg: --batch-size
+  model:
+    1: 1054.25 MiB
+    2: 1054.25 MiB
+    4: 1856.25 MiB
+    16: 4728.25 MiB
+    24: 5482.25 MiB
+    32: 6352.25 MiB
+    56: 1054.25 MiB
+    64: 1856.25 MiB
+    120: 14522.25 MiB
+    128: 14818.25 MiB
+    240: 25480.25 MiB
+    488: 49042.25 MiB
+    664: 65914.25 MiB
 lightning-gpus:
   arg: --batch-size
   model:
@@ -224,21 +282,61 @@ lightning-gpus:
     4: 1156.75 MiB
     8: 1260.75 MiB
     16: 4150.75 MiB
+    48: 11056.25 MiB
+    112: 16776.25 MiB
     128: 15858 MiB
+    240: 28942.25 MiB
+    256: 77822 MiB
+    504: 54100.25 MiB
+    616: 93571 MiB
+    624: 65386.25 MiB
   optimized: 16
 llama: {}
+llava-gpus:
+  arg: --batch_size
+  optimized: 1
+llava-single:
+  arg: --batch_size
+  model:
+    1: 72614.25 MiB
+    2: 15168.25 MiB
+    4: 72362.25 MiB
+  optimized: 1
 llm-full-mp-gpus:
   arg: batch_size={batch_size}
+  model:
+    1: 48964.25 MiB
+    2: 49214.25 MiB
+    4: 51310.25 MiB
+    16: 81536.25 MiB
 llm-full-mp-nodes:
   arg: batch_size={batch_size}
+  model:
+    1: 37340.25 MiB
+    2: 38112.25 MiB
+    4: 39110.25 MiB
+    16: 80638.25 MiB
 llm-lora-ddp-gpus:
   arg: batch_size={batch_size}
   model:
     1: 12418.75 MiB
+    2: 19026.25 MiB
+    4: 25464.25 MiB
+    16: 55834.25 MiB
+    32: 80268.25 MiB
 llm-lora-ddp-nodes:
   arg: batch_size={batch_size}
+  model:
+    2: 17202.25 MiB
+    4: 23956.25 MiB
+    16: 59730.25 MiB
+    32: 68932.25 MiB
 llm-lora-mp-gpus:
   arg: batch_size={batch_size}
+  model:
+    2: 38166.25 MiB
+    4: 43464.25 MiB
+    16: 77116.25 MiB
 llm-lora-single:
   arg: batch_size={batch_size}
   model:
@@ -262,11 +360,32 @@ opt-6_7b-multinode:
   model:
     1: 55380 MiB
   optimized: 1
+pna:
+  arg: --batch-size
+  model:
+    4096: 39554.25 MiB
+ppo:
+  arg: --num_steps
+  model:
+    8: 80.791748046875 MiB
+    16: 80.916748046875 MiB
+    32: 81.166748046875 MiB
+    64: 81.666748046875 MiB
+    128: 82.666748046875 MiB
+    1024: 96.666748046875 MiB
+    2048: 132.484619140625 MiB
+    4096: 205.328369140625 MiB
+    2517448: 62094.25 MiB
+  optimized: 32
 recursiongfn:
   arg: --batch_size
   model:
     2: 1134.75 MiB
     4: 1140.75 MiB
+    16: 1830.25 MiB
+    32: 1342.25 MiB
+    64: 4410.25 MiB
+    128: 9160.25 MiB
 reformer:
   arg: --batch-size
   model:
@@ -376,6 +495,46 @@ resnet50:
   optimized: 64
 resnet50-noio:
   arg: --batch-size
+  model:
+    1: 1594.25 MiB
+    2: 1652.25 MiB
+    4: 1854.25 MiB
+    16: 3052.25 MiB
+    32: 4690.25 MiB
+    56: 7114.25 MiB
+    136: 15194.25 MiB
+    288: 30632.25 MiB
+    592: 64483.8125 MiB
+    736: 76050.25 MiB
+rlhf-gpus:
+  arg: --per_device_train_batch_size
+  model:
+    1: 13448.25 MiB
+    2: 13594.25 MiB
+    4: 13686.25 MiB
+    16: 14606.25 MiB
+    32: 17918.25 MiB
+    64: 24374.25 MiB
+    128: 25830.25 MiB
+    136: 29442.25 MiB
+    392: 15372.25 MiB
+    520: 15808.25 MiB
+  optimized: 64
+rlhf-single:
+  arg: --per_device_train_batch_size
+  model:
+    1: 8590.25 MiB
+    2: 8650.25 MiB
+    4: 8822.25 MiB
+    16: 9694.25 MiB
+    32: 12952.25 MiB
+    40: 14638.25 MiB
+    64: 19422.25 MiB
+    120: 31048.25 MiB
+    128: 32442.25 MiB
+    280: 63262.25 MiB
+    352: 77536.25 MiB
+  optimized: 64
 rwkv:
   arg: --micro_bsz
   model:
@@ -424,8 +583,29 @@ torchatari:
   arg: --num-steps
   model:
     1: 1124.75 MiB
-    2: 1138.75 MiB
-    4: 1166.75 MiB
+    1024: 20176.25 MiB
+    2048: 39020.25 MiB
+    4096: 76708.25 MiB
+vjepa-gpus:
+  arg: --batch_size
+  model:
+    1: 27196.25 MiB
+    2: 28896.25 MiB
+    4: 30784.25 MiB
+    16: 52722.25 MiB
+    32: 77124.25 MiB
+  optimized: 24
+vjepa-single:
+  arg: --batch_size
+  model:
+    1: 6644.25 MiB
+    2: 18984.25 MiB
+    4: 11860.25 MiB
+    8: 30764.25 MiB
+    16: 45516.25 MiB
+    24: 57574.25 MiB
+    32: 67122.25 MiB
+  optimized: 24
 whisper:
   arg: --batch-size
   model:
@@ -442,36 +622,3 @@ whisper:
     128: 71634.375 MiB
     144: 80412.75 MiB
   optimized: 128
-
-
-llava-single:
-  arg: --batch_size
-  optimized: 1
-
-llava-gpus:
-  arg: --batch_size
-  optimized: 1
-
-rlhf-single:
-  arg: --per_device_train_batch_size
-  optimized: 64
-
-rlhf-gpus:
-  arg: --per_device_train_batch_size
-  optimized: 64
-
-vjepa-single:
-  arg: --batch_size
-  optimized: 24
-
-vjepa-gpus:
-  arg: --batch_size
-  optimized: 24
-
-ppo:
-  arg: --num_minibatches
-  optimized: 32
-
-dqn:
-  arg: --buffer_batch_size
-  optimized: 128
\ No newline at end of file
diff --git a/config/standard.yaml b/config/standard.yaml
index 588e35e9a..f32685cbc 100644
--- a/config/standard.yaml
+++ b/config/standard.yaml
@@ -161,12 +161,20 @@ dqn:
 ppo:
   enabled: true
   weight: 1.0
-  
+
+cleanrljax:
+  enabled: false
+  weight: 1.0
+
 # Geo
 dimenet:
   enabled: true
   weight: 1.0
 
+pna: 
+  enabled: true
+  weight: 1.0
+
 recursiongfn:
   enabled: true
   weight: 1.0
diff --git a/constraints/cuda.txt b/constraints/cuda.txt
index eb6bbcedf..49675b577 100644
--- a/constraints/cuda.txt
+++ b/constraints/cuda.txt
@@ -5,3 +5,14 @@
 voir >= 0.2.19
 torchcompat >= 1.0.0
 gymnax >= 0.0.8
+trl<0.11.0
+
+# latest torchtune is slower than before and cause failures
+# next version of pytorch seems to work better
+# so pending a new version of pytorch this is what we get
+torchtune<0.3.0
+
+# transformers added torchao support recently
+# but only the most recent version we do not support
+transformers<4.45.0
+torchao
\ No newline at end of file
diff --git a/constraints/extra/torch.hpu.txt b/constraints/extra/torch.hpu.txt
index 1d21c1779..e69de29bb 100644
--- a/constraints/extra/torch.hpu.txt
+++ b/constraints/extra/torch.hpu.txt
@@ -1,5 +0,0 @@
-
-#
-#
-voir >= 0.2.15
-torchcompat >= 1.0.0
diff --git a/constraints/hpu.txt b/constraints/hpu.txt
index 23a110bd2..9f6fe957d 100644
--- a/constraints/hpu.txt
+++ b/constraints/hpu.txt
@@ -1,8 +1,16 @@
-# FIXME
-# Add
-
 #
 #
 voir >= 0.2.19
 torchcompat >= 1.0.0
-gymnax >= 0.0.8
\ No newline at end of file
+gymnax >= 0.0.8
+trl<0.11.0
+
+# latest torchtune is slower than before and cause failures
+# next version of pytorch seems to work better
+# so pending a new version of pytorch this is what we get
+torchtune<0.3.0
+
+# transformers added torchao support recently
+# but only the most recent version we do not support
+transformers<4.45.0
+torchvision
\ No newline at end of file
diff --git a/constraints/rocm.txt b/constraints/rocm.txt
index b86ce00d3..cc1585575 100644
--- a/constraints/rocm.txt
+++ b/constraints/rocm.txt
@@ -1,7 +1,20 @@
---extra-index-url https://download.pytorch.org/whl/rocm6.0
+--extra-index-url https://download.pytorch.org/whl/rocm6.1
 
 #
 #
 voir >= 0.2.19
 torchcompat >= 1.0.0
 gymnax >= 0.0.8
+
+
+trl<0.11.0
+
+# latest torchtune is slower than before and cause failures
+# next version of pytorch seems to work better
+# so pending a new version of pytorch this is what we get
+torchtune<0.3.0
+
+# transformers added torchao support recently
+# but only the most recent version we do not support
+transformers<4.45.0
+torchao
\ No newline at end of file
diff --git a/docker/Dockerfile-hpu b/docker/Dockerfile-hpu
new file mode 100644
index 000000000..932959cd6
--- /dev/null
+++ b/docker/Dockerfile-hpu
@@ -0,0 +1,42 @@
+# FROM artifactory-kfs.habana-labs.com/docker-local/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:1.17.0-462
+
+FROM vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest
+
+ENV MILABENCH_GPU_ARCH=hpu
+
+WORKDIR /workspace
+
+ENV MILABENCH_CONFIG="/workspace/milabench/config/standard.yaml"
+
+ENV MILABENCH_WORDIR="/workspace/${MILABENCH_GPU_ARCH}"
+ENV MILABENCH_BASE="${MILABENCH_WORDIR}/results"
+ENV MILABENCH_VENV="${MILABENCH_WORDIR}/env"
+ENV BENCHMARK_VENV="${MILABENCH_WORDIR}/results/venv/torch"
+
+ARG BENCH=lightning
+
+RUN mkdir -p ${MILABENCH_WORDIR}
+RUN pip install virtualenv
+RUN virtualenv --system-site-packages $MILABENCH_VENV
+
+ARG CACHEBUST=1 
+RUN echo "$CACHEBUST"
+RUN git clone https://github.com/mila-iqia/milabench.git -b $MILABENCH_GPU_ARCH
+RUN $MILABENCH_VENV/bin/pip install -e milabench
+
+RUN . $MILABENCH_VENV/bin/activate && milabench install --use-current-env --select "${BENCH}"
+
+RUN $MILABENCH_VENV/bin/pip uninstall torch torchvision torchaudio -y 
+RUN sed -i 's/pic.numpy(force=True)/pic.numpy()/' /usr/local/lib/python3.10/dist-packages/torchvision/transforms/functional.py
+
+# This does not work
+# RUN . $MILABENCH_VENV/bin/activate && milabench prepare --use-current-env --select "${BENCH}"
+
+
+
+# RUN . $MILABENCH_VENV/bin/activate && milabench run --use-current-env --select $BENCH
+# RUN huggingface-cli login --token $MILABENCH_HF_TOKEN
+
+# docker build --build-arg CACHEBUST=`git rev-parse hpu` -f Dockerfile-hpu -t dockerfile-hpu . 
+# docker run   -it   --runtime=habana   -e HABANA_VISIBLE_DEVICES=all   -e OMPI_MCA_btl_vader_single_copy_mechanism=none   --shm-size 50G   --cap-add=sys_nice   --net=host   dockerfile-hpu:latest   bash
+# . $MILABENCH_VENV/bin/activate && milabench prepare --use-current-env --select lightning && milabench run --use-current-env --select lightning
diff --git a/docker/Makefile b/docker/Makefile
new file mode 100644
index 000000000..93a402704
--- /dev/null
+++ b/docker/Makefile
@@ -0,0 +1,17 @@
+
+
+
+bench = rlhf-gpus
+# bench = "lightning"
+lazy = 0
+
+hpu:
+	git add --all
+	git commit -m "-" | true
+	git push origin hpu
+	docker rmi -f $(docker images --filter "dangling=true" -q --no-trunc) | true
+	# docker system prune -a -f
+	# docker image prune -a -f
+	docker build --build-arg BENCH=$(bench) --build-arg CACHEBUST=`git rev-parse hpu` -f Dockerfile-hpu -t dockerfile-hpu . 
+	docker run --rm -it   --runtime=habana  -e PT_HPU_LAZY_MODE=$(lazy) -e HABANA_VISIBLE_DEVICES=all   -e OMPI_MCA_btl_vader_single_copy_mechanism=none --shm-size 50G   --cap-add=sys_nice   --net=host   dockerfile-hpu:latest bash -c '. $$MILABENCH_VENV/bin/activate && milabench install --use-current-env --select $(bench) && pip uninstall torch torchvision torchaudio -y  &&  milabench prepare --use-current-env --select $(bench) && milabench run --use-current-env $(args) --select $(bench)'
+
diff --git a/docs/recipes.rst b/docs/recipes.rst
index 4adfac1cc..f647ab452 100644
--- a/docs/recipes.rst
+++ b/docs/recipes.rst
@@ -117,6 +117,45 @@ It holds all the benchmark specific logs and metrics gathered by milabench.
 
   zip -r results.zip results
 
+
+Containers
+----------
+
+When using containers where some dependencies are already installed, we need to use a dummy virtualenv 
+so make milabench install its dependencies there, then the duplicate dependencies can be removed.
+
+.. code-block:: bash
+
+    podman run --rm --device nvidia.com/gpu=all --storage-opt ignore_chown_errors=true --security-opt=label=disable --ipc=host -it -e HOME=$HOME -e USER=$USER -v $HOME:$HOME  nvcr.io/nvidia/pytorch:24.02-py3
+
+    cd $HOME 
+    rm -rf env
+    pip install virtualenv
+    
+    # Create a virtual env with system packages to get the container's pytorch
+    virtualenv --system-site-packages env
+    source ./env/bin/activate
+    git clone https://github.com/mila-iqia/milabench.git
+    pip install -e milabench
+    
+    export MILABENCH_BASE="$HOME/results"
+    export MILABENCH_CONFIG="$HOME/milabench/config/standard.yaml"
+    export MILABENCH_GPU_ARCH=cuda
+
+    # This updates the requirements for cuda
+    # milabench pin --from-scratch --variant cuda -c constraints/cuda.txt
+    
+    # Install the new requirements (note: this will still install a new pytorch)
+    milabench install --use-current-env
+    
+    # uninstall pytorch that was installed in the venv
+    # so we use the system packages instead
+    pip uninstall torch torchvision torchaudio
+    
+    milabench prepare --use-current-env
+    milabench run --use-current-env
+
+
 Example Reports
 ---------------
 
diff --git a/milabench/_version.py b/milabench/_version.py
index 4b49d0506..281e1d0af 100644
--- a/milabench/_version.py
+++ b/milabench/_version.py
@@ -1,5 +1,5 @@
 """This file is generated, do not modify"""
 
-__tag__ = "v0.1.0-113-g9a5dfe3e"
-__commit__ = "9a5dfe3ef36e6baab6584faa3fa939e63ba2aed5"
-__date__ = "2024-09-16 09:08:28 -0400"
+__tag__ = "v0.1.0-146-ga8415d3"
+__commit__ = "a8415d3da9f91aa1ac23d932dff2c70fe580e556"
+__date__ = "2024-11-21 14:35:55 -0500"
diff --git a/milabench/alt_async.py b/milabench/alt_async.py
index 8608196d3..6fc9f64c8 100644
--- a/milabench/alt_async.py
+++ b/milabench/alt_async.py
@@ -190,6 +190,8 @@ def run(argv, setsid=None, process_accumulator=None, info={}, **kwargs):
             destroy(*mx.processes)
         yield entry
         
+    # mx.close()
+        
 
 def proceed(coro):
     loop = FeedbackEventLoop()
diff --git a/milabench/cli/compare.py b/milabench/cli/compare.py
index b2992857c..83f0c59ce 100644
--- a/milabench/cli/compare.py
+++ b/milabench/cli/compare.py
@@ -15,6 +15,7 @@ class Arguments:
     last    : int = None
     metric  : str = "train_rate"
     stat    : str = "median"
+    filter  : str = None
 # fmt: on
 
 
@@ -23,13 +24,15 @@ def arguments():
     # [positional: ?]
     folder: Option = None
 
+    filter: Option & str = None
+
     last: Option & int = None
 
     metric: Option & str = "train_rate"
 
     stat: Option & str = "median"
 
-    return Arguments(folder, last, metric, stat)
+    return Arguments(folder, last, metric, stat, filter)
 
 
 @tooled
@@ -66,7 +69,7 @@ def cli_compare(args=None):
         if base is not None:
             args.folder = os.path.join(base, "runs")
 
-    runs = fetch_runs(args.folder)
+    runs = fetch_runs(args.folder, args.filter)
 
     for run in runs:
         all_data = _read_reports(run.path)
diff --git a/milabench/cli/gather.py b/milabench/cli/gather.py
index d3058d65c..316b6bfb4 100644
--- a/milabench/cli/gather.py
+++ b/milabench/cli/gather.py
@@ -39,6 +39,7 @@ def arguments():
         "--tags",
         type=str,
         help="Tags defined in run names",
+        nargs="+",
         default=default_tags(),
     )
     return parser.parse_args()  # Arguments()
diff --git a/milabench/cli/list.py b/milabench/cli/list.py
new file mode 100644
index 000000000..fda73bdf5
--- /dev/null
+++ b/milabench/cli/list.py
@@ -0,0 +1,56 @@
+import os
+import yaml
+
+from milabench.config import build_config
+
+
+this = os.path.dirname(__file__)
+config = os.path.join(this, "..", "..", "config")
+
+
+def list_missing_batch_resizer():
+    standard = os.path.join(config, "standard.yaml")
+    scaling = os.path.join(config, "scaling.yaml")
+
+    base_conf = build_config(standard)
+
+    with open(scaling, "r") as fp:
+        scaling = yaml.safe_load(fp)
+
+    missing_benches = []
+    def add_bench(k, tags):
+        print(k, tags)
+        missing_benches.append(k)
+
+    for k, v in base_conf.items():
+        if k[0] == "_":
+            continue
+
+        if not v.get("enabled", False):
+            continue 
+
+        tags = set(v.get("tags", []))
+
+        if "nobatch" in tags:
+            continue
+
+        if k in scaling:
+            s = scaling[k].get("model", {})
+
+            if len(s) <= 1:
+                add_bench(k, tags)
+        else:
+            add_bench(k, tags)
+
+
+
+    b = [f"\"{b}\"" for b in missing_benches]
+
+    
+
+
+    print(" ".join(b))
+
+
+if __name__ == "__main__":
+    list_missing_batch_resizer()
diff --git a/milabench/cli/run.py b/milabench/cli/run.py
index f5e75b702..f04427af1 100644
--- a/milabench/cli/run.py
+++ b/milabench/cli/run.py
@@ -23,6 +23,7 @@
 from ..report import make_report
 from ..sizer import MemoryUsageExtractor
 from ..summary import make_summary
+from ..system import multirun, apply_system, SizerOptions, option
 
 
 # fmt: off
@@ -72,12 +73,7 @@ def _fetch_arch(mp):
         return None
     
 
-@tooled
-def cli_run(args=None):
-    """Run the benchmarks."""
-    if args is None:
-        args = arguments()
-
+def run(mp, args, name):
     layers = validation_names(args.validations)
 
     dash_class = {
@@ -85,13 +81,7 @@ def cli_run(args=None):
         "long": LongDashFormatter,
         "no": None,
     }.get(args.dash, None)
-
-    mp = get_multipack(run_name=args.run_name)
-    arch = _fetch_arch(mp)
-
-    # Initialize the backend here so we can retrieve GPU stats
-    init_arch(arch)
-
+        
     success = run_with_loggers(
         mp.do_run(repeat=args.repeat),
         loggers=[
@@ -136,3 +126,29 @@ def cli_run(args=None):
             )
 
     return success
+
+
+@tooled
+def cli_run(args=None):
+    """Run the benchmarks."""
+    if args is None:
+        args = arguments()
+
+    # Load the configuration and system
+    mp = get_multipack(run_name=args.run_name)
+    arch = _fetch_arch(mp)
+
+    # Initialize the backend here so we can retrieve GPU stats
+    init_arch(arch)
+    
+    success = 0
+    for name, conf in multirun():
+        run_name = name or args.run_name
+        
+        # Note that this function overrides the system config
+        mp = get_multipack(run_name=run_name)
+        
+        with apply_system(conf):
+            success += run(mp, args, run_name)
+    
+    return success
diff --git a/milabench/commands/__init__.py b/milabench/commands/__init__.py
index e97ac4e58..4a8f1e90a 100644
--- a/milabench/commands/__init__.py
+++ b/milabench/commands/__init__.py
@@ -451,6 +451,11 @@ def _find_node_config(self) -> Dict:
         return {}
 
     def is_local(self):
+        local = self._is_local()
+        print("is_local", self.host, local)
+        return local
+
+    def _is_local(self):
         localnode = self.pack.config["system"]["self"]
 
         if localnode is not None:
@@ -581,7 +586,7 @@ def node_address(node):
     """Favour Hostname as it is the most consistent name across machines"""
     host = node.get("hostname")
     ip = node.get("ip")
-    return host or ip
+    return ip or hostname
 
 
 class ForeachNode(ListCommand):
@@ -637,6 +642,7 @@ def executors(self):
                     **self.options
                 )
 
+            print(rank, node, node_address(node))
             worker = SSHCommand(
                 host=node_address(node),
                 user=node["user"],
diff --git a/milabench/commands/executors.py b/milabench/commands/executors.py
index f0402d29b..807a261e2 100644
--- a/milabench/commands/executors.py
+++ b/milabench/commands/executors.py
@@ -32,6 +32,9 @@ async def execute(pack, *args, cwd=None, env={}, external=False, use_stdout=Fals
     sized_args = scale_argv(pack, args)
     final_args = resolve_argv(pack, sized_args)
 
+    if use_stdout:
+        exec_env["MILABENCH_USE_STDOUT"] = "1"
+
     return await run(
         final_args,
         **kwargs,
diff --git a/milabench/compare.py b/milabench/compare.py
index e3b88b10c..32f95c64c 100644
--- a/milabench/compare.py
+++ b/milabench/compare.py
@@ -21,17 +21,30 @@ def retrieve_datetime_from_name(date):
             pass
 
 
-def fetch_runs(folder):
+def fetch_runs(folder, filter):
+    import fnmatch
+
     runs = []
+    ignored = 0
+    
     for run in os.listdir(folder):
+        if run.startswith("install") or run.startswith("prepare"):
+            continue
+    
+        if filter is not None and (not fnmatch.fnmatch(run, filter)):
+            ignored += 1
+            continue
+
         pth = os.path.join(folder, run)
         if not os.path.isdir(pth):
             continue
         if "." in run:
-            name, date = run.split(".", maxsplit=1)
+            name, fractional_seconds = run.rsplit(".", maxsplit=1)
+            name, date = name.rsplit(".", maxsplit=1)
             date = retrieve_datetime_from_name(date)
         else:
             name = run
+            date = None
 
         if date is None:
             date = datetime.fromtimestamp(os.path.getmtime(pth))
@@ -39,6 +52,8 @@ def fetch_runs(folder):
         out = _Output(pth, name, date)
         runs.append(out)
 
+    if ignored > 0:
+        print(f"Ignoring run {ignored} runs because of filter {filter}")
     runs.sort(key=lambda out: out.date)
     return runs
 
diff --git a/milabench/config.py b/milabench/config.py
index ebc041060..9a2d519c9 100644
--- a/milabench/config.py
+++ b/milabench/config.py
@@ -11,6 +11,8 @@
 config_global = contextvars.ContextVar("config", default=None)
 execution_count = (0, 0)
 
+_MONITOR_TAGS = {"monogpu", "multigpu", "multinode"}
+
 
 def set_run_count(total_run, total_bench):
     global execution_count
@@ -80,6 +82,13 @@ def finalize_config(name, bench_config):
             pack = (XPath(bench_config["config_base"]) / pack).resolve()
             bench_config["definition"] = str(pack)
 
+    if not name.startswith("_") and name != "*":
+        _tags = set(bench_config["tags"])
+        _monitor_tags = _tags & _MONITOR_TAGS
+        assert len(_monitor_tags) == 1, (
+            f"Bench {name} should have exactly one monitor tag. Found {_monitor_tags}"
+        )
+
     bench_config["tag"] = [bench_config["name"]]
 
     bench_config = OmegaConf.to_object(OmegaConf.create(bench_config))
@@ -91,11 +100,15 @@ def combine_args(args, kwargs):
         yield kwargs
     else:
         key, values = args.popitem()
-        for value in values:
-            kwargs[key] = value
+        
+        try:
+            for value in values:
+                kwargs[key] = value
+                yield from combine_args(deepcopy(args), kwargs)
+        except:
+            kwargs[key] = values
             yield from combine_args(deepcopy(args), kwargs)
 
-
 def expand_matrix(name, bench_config):
     if "matrix" not in bench_config:
         return [(name, bench_config)]
diff --git a/milabench/remote.py b/milabench/remote.py
index 7e1eef85c..27660f75a 100644
--- a/milabench/remote.py
+++ b/milabench/remote.py
@@ -100,7 +100,7 @@ def worker_commands(pack, worker_plan, setup_for="worker"):
 def sshnode(node, cmd):
     host = node["ip"]
     user = node["user"]
-    port = node["sshport"]
+    port = node.get("sshport", 22)
     return SSHCommand(cmd, user=user, host=host, port=port)
 
 
@@ -124,7 +124,6 @@ def milabench_remote_setup_plan(pack, setup_for="worker") -> SequenceCommand:
 
     nodes = pack.config["system"]["nodes"]
     copy = []
-    node_packs = []
 
     copy_source = copy_folder(pack, INSTALL_FOLDER, setup_for)
 
@@ -132,7 +131,8 @@ def milabench_remote_setup_plan(pack, setup_for="worker") -> SequenceCommand:
 
     for i, node in enumerate(nodes):
         if should_run_for(node, setup_for):
-            install.append(pip_install_milabench(node_packs[i], node, INSTALL_FOLDER))
+            node_pack = worker_pack(pack, node)
+            install.append(pip_install_milabench(node_pack, node, INSTALL_FOLDER))
 
     return SequenceCommand(
         copy_source,
@@ -192,7 +192,7 @@ def is_remote(pack):
 def is_main_local(pack):
     """Only the local main can send remote commands to remote"""
     self = pack.config["system"]["self"]
-    return self is not None and self["local"] and self.get("main", False)
+    return self is not None and self.get("local", True) and self.get("main", False)
 
 
 def is_worker(pack):
diff --git a/milabench/report.py b/milabench/report.py
index aebcaf093..491cf8bab 100644
--- a/milabench/report.py
+++ b/milabench/report.py
@@ -342,6 +342,35 @@ def short_meta(out, meta):
     out.print(Table(stats))
 
 
+def to_latex(df):
+    from dataclasses import dataclass
+    from .system import option
+
+    default_columns = [
+        "ngpu",
+        "perf",
+        "sem%",
+        "std%"
+    ]
+
+    @dataclass
+    class LatexTable:
+        output: str = option("latex.output", str, None)
+        columns: str = option("latex.columns", str, ",".join(default_columns))
+    
+    options = LatexTable()
+
+    columns = options.columns.split(",")
+
+    df = df[columns]
+
+    if options.output is not None:
+        with open(options.output, "w") as fp:
+            txt = df.to_latex(formatters=_formatters, escape=False)
+            txt = txt.replace("%", "\%").replace("_", "\_")
+            fp.write(txt)
+
+
 @error_guard({})
 def make_report(
     summary: dict[str, Summary],
@@ -376,7 +405,10 @@ def make_report(
     out.section("Breakdown")
 
     # Reorder columns
-    out.print(normalize_dataframe(df))
+    normalized = normalize_dataframe(df)
+    out.print(normalized)
+
+    to_latex(normalized)
 
     out.section("Scores")
 
@@ -385,12 +417,17 @@ def _score(column):
             # This computes a weighted geometric mean
 
             # perf can be object np.float64 !?
-            perf = df[column].astype(float)
+            # success_ratio = 1 - row["fail"] / max(row["n"], 1)
+            
+            # score = (acc if acc > 0 else row["perf"]) * success_ratio
+            score = df[column].astype(float)
 
             weights = df["weight"] * df["enabled"].astype(int)
-            weight_total = np.sum(weights)
+            # if total weight is 0 ?
+            weight_total = np.sum(weights) 
 
-            logscore = np.sum(np.log(perf) * weights) / weight_total
+            # score cannot be 0
+            logscore = np.sum(np.log(score + 1) * weights) / weight_total
             return np.exp(logscore)
         except ZeroDivisionError:
             return 0
@@ -493,12 +530,12 @@ def pandas_to_string(df, formatters=_formatters):
     # Compute column size
     col_size = defaultdict(int)
     for index, row in df.iterrows():
-        col_size["bench"] = max(col_size["bench"], len(index))
+        col_size["bench"] = max(col_size["bench"], len(index), len("bench"))
         for col, val in zip(columns, row):
             fmt = formatters.get(col)
             if fmt is not None:
                 val = fmt(val)
-                col_size[col] = max(col_size[col], len(val))
+                col_size[col] = max(col_size[col], len(val), len(col))
 
     # Generate report
     sep = " | "
diff --git a/milabench/sizer.py b/milabench/sizer.py
index b1f717247..4bd62bc7f 100644
--- a/milabench/sizer.py
+++ b/milabench/sizer.py
@@ -53,15 +53,21 @@ def to_octet(value: str) -> float:
 class Sizer:
     """Automatically scale the batch size to match GPU spec"""
 
-    def __init__(self, options=SizerOptions(), scaling_config=None):
-        self.options = options
+    def __init__(self, sizer=None, scaling_config=option("sizer.config", etype=str)):
         self.path = scaling_config
-
+        self.sizer_override = sizer
+        
         if scaling_config is None:
             scaling_config = default_scaling_config
 
         with open(scaling_config, "r") as sconf:
             self.scaling_config = yaml.safe_load(sconf)
+            
+    @property
+    def options(self):
+        if self.sizer_override:
+            return self.sizer_override
+        return SizerOptions()
 
     def benchscaling(self, benchmark):
         # key
@@ -165,6 +171,10 @@ def find_batch_size(self, benchmark, event):
         return -1
 
     def argv(self, benchmark, capacity, argv):
+        newargv = self._argv(benchmark, capacity, argv)
+        return newargv
+        
+    def _argv(self, benchmark, capacity, argv):
         """Find the batch size and override it with a new value"""
 
         config = self.benchscaling(benchmark)
@@ -214,11 +224,12 @@ def argv(self, benchmark, capacity, argv):
 
 
 def batch_sizer() -> Sizer:
-    sizer = sizer_global.get()
-    if sizer is None:
-        sizer_global.set(Sizer())
-        return batch_sizer()
-    return sizer
+    return Sizer()
+    # sizer = sizer_global.get()
+    # if sizer is None:
+    #     sizer_global.set(Sizer())
+    #     return batch_sizer()
+    # return sizer
 
 
 def get_batch_size(config, start_event):
@@ -242,13 +253,15 @@ class MemoryUsageExtractor(ValidationLayer):
     """Extract max memory usage per benchmark to populate the memory model"""
 
     def __init__(self):
-        sizer = batch_sizer()
-        self.filepath = sizer.options.save
+        
+        self.filepath = option("sizer.save", str, None)
+        sizer = Sizer()
         self.memory = deepcopy(sizer.scaling_config)
         self.scaling = None
         self.benchname = None
         self.batch_size = 0
-        self.max_usage = float("-inf")
+        self.max_usage = float("-inf")  # Usage from the gpu monitor
+        self.peak_usage = float("-inf") # Usage provided by the bench itself (for jax)
         self.early_stopped = False
 
     def on_start(self, entry):
@@ -259,6 +272,7 @@ def on_start(self, entry):
         self.benchname = entry.pack.config["name"]
         self.batch_size = None
         self.max_usage = float("-inf")
+        self.peak_usage = float("-inf")
 
         config = self.memory.setdefault(self.benchname, dict())
         template = config.get("arg", None)
@@ -300,6 +314,11 @@ def on_data(self, entry):
         if entry.data is None:
             return
 
+        memorypeak = entry.data.get("memory_peak")
+        if memorypeak is not None:
+            self.peak_usage = max(memorypeak, self.peak_usage)
+            return
+
         gpudata = entry.data.get("gpudata")
         if gpudata is not None:
             current_usage = []
@@ -312,6 +331,11 @@ def on_data(self, entry):
     def on_stop(self, entry):
         self.early_stopped = True
 
+    def max_memory_usage(self):
+        if self.peak_usage != float("-inf"):
+            return self.peak_usage
+        return self.max_usage
+
     def on_end(self, entry):
         if self.filepath is None:
             return
@@ -319,7 +343,7 @@ def on_end(self, entry):
         if (
             self.benchname is None
             or self.batch_size is None
-            or self.max_usage == float("-inf")
+            or self.max_memory_usage() == float("-inf")
         ):
             return
 
@@ -328,12 +352,13 @@ def on_end(self, entry):
         if rc == 0 or self.early_stopped:
             config = self.memory.setdefault(self.benchname, dict())
             model = config.setdefault("model", dict())
-            model[self.batch_size] = f"{self.max_usage} MiB"
+            model[self.batch_size] = f"{self.max_memory_usage()} MiB"
             config["model"] = dict(sorted(model.items(), key=lambda x: x[0]))
 
         self.benchname = None
         self.batch_size = None
         self.max_usage = float("-inf")
+        self.peak_usage = float("-inf")
 
     def report(self, *args):
         if self.filepath is not None:
diff --git a/milabench/system.py b/milabench/system.py
index c237baf2c..691d06bd9 100644
--- a/milabench/system.py
+++ b/milabench/system.py
@@ -1,4 +1,5 @@
 import contextvars
+from copy import deepcopy
 import ipaddress
 import os
 import socket
@@ -15,7 +16,7 @@
 from .merge import merge
 
 system_global = contextvars.ContextVar("system", default=None)
-
+multirun_global = contextvars.ContextVar("multirun", default=None)
 
 def get_gpu_capacity(strict=False):
     try:
@@ -79,6 +80,65 @@ def as_environment_variable(name):
     return "MILABENCH_" + "_".join(map(str.upper, frags))
 
 
+def multirun():
+    multirun = multirun_global.get()
+    
+    if multirun is None or len(multirun) == 0:
+        yield None, dict()
+        
+    runs = multirun.get("runs", dict())
+    
+    from .config import combine_args
+    import time
+    from types import SimpleNamespace
+    
+    def unflatten(dct):
+        result = {}
+        for k, v in dct.items():
+            l = result
+            frags = k.split(".")
+            for frag in frags[:-1]:
+                l = l.setdefault(frag, SimpleNamespace())
+            setattr(l, frags[-1], v)
+            
+        return result
+                
+    for run_matrix in runs:
+        arguments = run_matrix["matrix"]
+
+        for run in combine_args(arguments, dict()):
+            template_name = run_matrix["name"]
+            
+            ctx = unflatten(run)
+            ctx['time'] = int(time.time())
+            run_name = template_name.format(**ctx)
+            
+            yield run_name, run
+
+
+@contextmanager
+def apply_system(config: dict):
+    system = system_global.get()
+    old = deepcopy(system)
+    
+    if system is None:
+        system = dict()
+        system_global.set(system)
+        system = system_global.get()
+    
+    for k, v in config.items():
+        frags = k.split(".")
+        
+        lookup = system.setdefault("options", {})
+        for f in frags[:-1]:
+            lookup = lookup.setdefault(f, {})
+        lookup[frags[-1]] = v
+        
+
+    yield    
+    system_global.set(old)
+
+
 def option(name, etype, default=None):
     options = dict()
     system = system_global.get()
@@ -268,6 +328,7 @@ def get_remote_ip():
 
     for interface, address_list in addresses.items():
         for address in address_list:
+            # if address.family in (socket.AF_INET, socket.AF_INET6):
             if interface in stats and getattr(stats[interface], "isup"):
                 result.append(address.address)
 
@@ -286,49 +347,9 @@ def is_loopback(address: str) -> bool:
 
 
 
-def _resolve_ip(ip):
-    hostname = ip
-    aliaslist = []
-    ipaddrlist = [ip]
-    lazy_raise = None
-
-    if not offline:
-        # Resolve the IP
-        try:
-            hostname, aliaslist, ipaddrlist = socket.gethostbyaddr(ip)
-            lazy_raise = None
-        
-        except socket.herror as err:
-            lazy_raise = err
-
-        except socket.gaierror as err:
-            # Get Addr Info (GAI) Error
-            #
-            # When we are connecting to a node through a ssh proxy jump
-            # the node IPs/Hostnames are not available until we reach
-            # the first node inside the cluster
-            #
-            lazy_raise = err
-
-    return hostname, aliaslist, ipaddrlist, lazy_raise
-
-
-def _fix_weird(hostname):
-    if hostname.endswith(".server.mila.quebec.server.mila.quebec"):
-        print()
-        print("Hostname was extra long for no reason")
-        print(hostname, socket.gethostname())
-        print()
-
-        # why is this happening
-        hostname = hostname[: -len(".server.mila.quebec")]
-    
-    return hostname
-
-
 # If true that means we cannot resolve the ip addresses
 # so we ignore errors
-offline = False
+offline = True
 
 
 @contextmanager
@@ -351,29 +372,21 @@ def _resolve_addresses(nodes):
     ip_list = get_remote_ip()
 
     for node in nodes:
-        hostname, aliaslist, ipaddrlist, lazy_raise = _resolve_ip(node["ip"])
-
-        hostname = _fix_weird(hostname)
-
-        node["hostname"] = hostname
-        node["aliaslist"] = aliaslist
-        node["ipaddrlist"] = ipaddrlist
-
-        is_local = (
-            ("127.0.0.1" in ipaddrlist)
-            or (hostname in ("localhost", socket.gethostname(), "127.0.0.1"))
-            or (socket.gethostname().startswith(hostname))
-            or len(ip_list.intersection(ipaddrlist)) > 0
-            or any([is_loopback(ip) for ip in ipaddrlist])
-        )
-
-        # cn-g005 cn-g005.server.mila.quebec
-        # print(hostname, socket.gethostname())
+        ip = node["ip"]
+        
+        is_local = is_loopback(ip)
+        
+        if ip in ip_list:
+            is_local = True            
+        
         node["local"] = is_local
+        
+        if is_local:
+            node["hostname"] = socket.gethostname()
 
         if is_local and self is None:
             self = node
-            node["ipaddrlist"] = list(set(list(ip_list) + list(ipaddrlist)))
+            node["ipaddrlist"] = list(set(list(ip_list)))
 
     # if self is node we might be outisde the cluster
     # which explains why we could not resolve the IP of the nodes
@@ -401,11 +414,13 @@ def gethostname(host):
 def resolve_hostname(ip):
     try:
         hostname, _, iplist = socket.gethostbyaddr(ip)
-
+        
         for ip in iplist:
             if is_loopback(ip):
                 return hostname, True
 
+        # FIXME
+        return socket.gethostname(), hostname.startswith(socket.gethostname())
         return hostname, hostname == socket.gethostname()
 
     except:
@@ -464,6 +479,9 @@ def build_system_config(config_file, defaults=None, gpu=True):
         config = merge(defaults, config)
 
     system = config.get("system", {})
+    multirun = config.get("multirun", {})
+    
+    multirun_global.set(multirun)
     system_global.set(system)
 
     # capacity is only required if batch resizer is enabled
diff --git a/scripts/article/run_cuda.sh b/scripts/article/run_cuda.sh
index b7b31eed3..9ef13b7d3 100644
--- a/scripts/article/run_cuda.sh
+++ b/scripts/article/run_cuda.sh
@@ -9,7 +9,7 @@ export MILABENCH_BASE="$MILABENCH_WORDIR/results"
 export MILABENCH_VENV="$MILABENCH_WORDIR/env"
 export BENCHMARK_VENV="$MILABENCH_WORDIR/results/venv/torch"
 export MILABENCH_SIZER_SAVE="$MILABENCH_WORDIR/scaling.yaml"
-
+    
 
 if [ -z "${MILABENCH_PREPARE}" ]; then
     export MILABENCH_PREPARE=0
@@ -70,9 +70,9 @@ install_prepare() {
     milabench prepare --system $MILABENCH_WORDIR/system.yaml $ARGS
 }
 
-module load cuda/12.3.2
+# module load cuda/12.3.2
 
-if [ ! -d "$MILABENCH_WORDIR/results" ]; then
+if [ ! -d "$MILABENCH_WORDIR/env" ]; then
     install_prepare 
 else
     echo "Reusing previous install"
@@ -84,16 +84,28 @@ if [ "$MILABENCH_PREPARE" -eq 0 ]; then
 
     . $MILABENCH_WORDIR/env/bin/activate
 
-
     # pip install torch
-    # milabench pin --variant cuda --from-scratch $ARGS 
-    # milabench install --system $MILABENCH_WORDIR/system.yaml --force $ARGS
+    # milabench pin --variant cuda --from-scratch 
+    # rm -rf $MILABENCH_WORDIR/results/venv/
+    # rm -rf $MILABENCH_WORDIR/results/extra
+    # milabench install --system $MILABENCH_WORDIR/system.yaml
+    # milabench prepare --system $MILABENCH_WORDIR/system.yaml $ARGS
+
+    (
+        . $BENCHMARK_VENV/bin/activate
+        which pip
+        # pip uninstall torchao -y
+        # pip install torchao --no-input
+    )
 
-    #
-    #   Run the benchmakrs
     milabench run --system $MILABENCH_WORDIR/system.yaml $ARGS
 
     #
     #   Display report
     milabench report --runs $MILABENCH_WORDIR/results/runs
-fi
\ No newline at end of file
+fi
+
+
+# rsync -av mila@172.29.171.42:~/rocm/results/cache ~/cuda/results/cache
+# rsync -av mila@172.29.171.42:~/rocm/results/data ~/cuda/results/data
+# rsync -av mila@172.29.171.42:~/rocm/results/cache ~/cuda/results/cache
\ No newline at end of file
diff --git a/scripts/article/run_hpu.sh b/scripts/article/run_hpu.sh
index 5d875ca14..8f6126d29 100644
--- a/scripts/article/run_hpu.sh
+++ b/scripts/article/run_hpu.sh
@@ -9,68 +9,84 @@ set -ex
 export MILABENCH_GPU_ARCH=hpu
 export MILABENCH_WORDIR="$(pwd)/$MILABENCH_GPU_ARCH"
 export MILABENCH_BASE="$MILABENCH_WORDIR/results"
-export MILABENCH_CONFIG="$MILABENCH_WORDIR/milabench/config/standard.yaml"
 export MILABENCH_VENV="$MILABENCH_WORDIR/env"
 export BENCHMARK_VENV="$MILABENCH_WORDIR/results/venv/torch"
+export PT_HPU_LAZY_MODE=0
+
+if [ -z "${MILABENCH_SOURCE}" ]; then
+    export MILABENCH_CONFIG="$MILABENCH_WORDIR/milabench/config/standard.yaml"
+else
+    export MILABENCH_CONFIG="$MILABENCH_SOURCE/config/standard.yaml"
+fi
 
 if [ -z "${MILABENCH_PREPARE}" ]; then
     export MILABENCH_PREPARE=0
 fi
 
+ARGS="$@"
+
 install_prepare() {
     mkdir -p $MILABENCH_WORDIR
     cd $MILABENCH_WORDIR
 
     virtualenv $MILABENCH_WORDIR/env
 
-    git clone https://github.com/mila-iqia/milabench.git
-    git clone https://github.com/huggingface/optimum-habana.git
+    if [ -z "${MILABENCH_SOURCE}" ]; then
+        if [ ! -d "$MILABENCH_WORDIR/milabench" ]; then
+            git clone https://github.com/mila-iqia/milabench.git
+        fi
+        export MILABENCH_SOURCE="$MILABENCH_WORDIR/milabench"
+    fi
+
+    git clone https://github.com/huggingface/optimum-habana.git -b v1.13.2
 
     # wget -nv https://vault.habana.ai/artifactory/gaudi-installer/1.15.1/habanalabs-installer.sh
-    wget -nv https://vault.habana.ai/artifactory/gaudi-installer/1.16.1/habanalabs-installer.sh
+    # wget -nv https://vault.habana.ai/artifactory/gaudi-installer/1.16.1/habanalabs-installer.sh
+    wget -nv https://vault.habana.ai/artifactory/gaudi-installer/1.17.1/habanalabs-installer.sh
     chmod +x habanalabs-installer.sh
 
     . $MILABENCH_WORDIR/env/bin/activate
-    pip install -e $MILABENCH_WORDIR/milabench
-
-
-    #
-    # Install milabench's benchmarks in their venv
-    #
-    milabench install
+    pip install -e $MILABENCH_SOURCE
 
     which pip
 
     # Override dependencies for HPU
     # milabench needs pyhlml
     export HABANALABS_VIRTUAL_DIR=$MILABENCH_VENV
-    ./habanalabs-installer.sh install -t dependencies --venv -y
-    ./habanalabs-installer.sh install -t pytorch --venv -y
+    ./habanalabs-installer.sh install -t dependencies --venv -y | true
+    ./habanalabs-installer.sh install -t pytorch --venv -y | true
+
+    #
+    # Install milabench's benchmarks in their venv
+    #
+    # milabench pin --variant hpu --from-scratch $ARGS 
+    milabench install $ARGS
 
     (
         . $BENCHMARK_VENV/bin/activate
         which pip
-        pip install -e $MILABENCH_WORDIR/optimum-habana
-
-        (
-            cd $MILABENCH_WORDIR/milabench/benchmarks/dlrm/dlrm;
-            git remote add me https://github.com/Delaunay/dlrm.git
-            git fetch me
-            git checkout me/main
-        )
+        pip install --no-deps -e $MILABENCH_WORDIR/optimum-habana 
 
         # Override dependencies for HPU
         # benchmarks need pytorch
-        pip uninstall torch torchvision torchaudio
+        pip uninstall torch torchvision torchaudio -y
         export HABANALABS_VIRTUAL_DIR=$BENCHMARK_VENV
-        ./habanalabs-installer.sh install -t dependencies --venv -y
-        ./habanalabs-installer.sh install -t pytorch --venv -y
+        ./habanalabs-installer.sh install -t dependencies --venv -y | true
+        ./habanalabs-installer.sh install -t pytorch --venv -y | true
+
+        if [ -z "${MILABENCH_HF_TOKEN}" ]; then
+            echo "Missing token"
+        else
+            huggingface-cli login --token $MILABENCH_HF_TOKEN
+        fi
     )
 
     #
     #   Generate/download datasets, download models etc...
     #
-    milabench prepare
+    # sed -i 's/pic.numpy(force=True)/pic.numpy()/' $BENCHMARK_VENV/lib/python3.10/dist-packages/torchvision/transforms/functional.py
+    # sed -i 's/range(hpu.device_count())/range(len(available_modules))/' $BENCHMARK_VENV/lib/site-packages/habana_frameworks/torch/hpu/_utils.py
+    milabench prepare $ARGS
 }
 
 if [ ! -d "$MILABENCH_WORDIR" ]; then
@@ -81,12 +97,28 @@ else
 fi
 
 
+(
+    . $BENCHMARK_VENV/bin/activate
+    pip install lightning-habana
+    pip install habana-media-loader
+    # git clone https://github.com/Delaunay/torchcompat.git
+    # git clone https://github.com/Delaunay/voir.git -b hpu
+    pip uninstall torchcompat voir -y
+    pip install -e $MILABENCH_WORDIR/torchcompat
+    pip install -e $MILABENCH_WORDIR/voir
+    pip install -e $MILABENCH_WORDIR/optimum-habana
+    # pip install habana_dataloader
+)
+
 if [ "$MILABENCH_PREPARE" -eq 0 ]; then
     cd $MILABENCH_WORDIR
 
+    # python -c "import torch; print(torch.__version__)"
+    milabench prepare $ARGS --system $MILABENCH_WORDIR/system.yaml
+
     #
     #   Run the benchmakrs
-    milabench run "$@"
+    milabench run $ARGS --system $MILABENCH_WORDIR/system.yaml
 
     #
     #   Display report
diff --git a/scripts/article/run_rocm.sh b/scripts/article/run_rocm.sh
index b8a15fb76..0fc2bf16d 100644
--- a/scripts/article/run_rocm.sh
+++ b/scripts/article/run_rocm.sh
@@ -2,30 +2,66 @@
 
 set -ex
 
+# sudo usermod -a -G render,video $LOGNAME
+# sudo chmod u+s /opt/rocm-6.2.2/lib/llvm/bin/amdgpu-arch
+
 export MILABENCH_GPU_ARCH=rocm
 export MILABENCH_WORDIR="$(pwd)/$MILABENCH_GPU_ARCH"
-
+export ROCM_PATH="/opt/rocm"
 export MILABENCH_BASE="$MILABENCH_WORDIR/results"
-export MILABENCH_CONFIG="$MILABENCH_WORDIR/milabench/config/standard.yaml"
 export MILABENCH_VENV="$MILABENCH_WORDIR/env"
 export BENCHMARK_VENV="$MILABENCH_WORDIR/results/venv/torch"
+export MILABENCH_SIZER_SAVE="$MILABENCH_WORDIR/scaling.yaml"
+
+if [ -z "${MILABENCH_SOURCE}" ]; then
+    export MILABENCH_CONFIG="$MILABENCH_WORDIR/milabench/config/standard.yaml"
+else
+    export MILABENCH_CONFIG="$MILABENCH_SOURCE/config/standard.yaml"
+fi
+
+
+export GPU="$(/opt/rocm/lib/llvm/bin/amdgpu-arch | head -n 1)"
+export TORCH_ROCM_ARCH_LIST="$GPU"
+export ROCM_TARGETS="$GPU"
+export PYTORCH_ROCM_ARCH="$GPU"
+
+if [ -z "${MILABENCH_SOURCE}" ]; then
+    export MILABENCH_CONFIG="$MILABENCH_WORDIR/milabench/config/standard.yaml"
+else
+    export MILABENCH_CONFIG="$MILABENCH_SOURCE/config/standard.yaml"
+fi
 
 
+export GPU="$(/opt/rocm/lib/llvm/bin/amdgpu-arch | head -n 1)"
+export TORCH_ROCM_ARCH_LIST="$GPU"
+export ROCM_TARGETS="$GPU"
+export PYTORCH_ROCM_ARCH="$GPU"
+
+ARGS="$@"
+
 install_prepare() {
     mkdir -p $MILABENCH_WORDIR
     cd $MILABENCH_WORDIR
 
     virtualenv $MILABENCH_WORDIR/env
 
-    git clone https://github.com/mila-iqia/milabench.git
+    if [ -z "${MILABENCH_SOURCE}" ]; then
+        if [ ! -d "$MILABENCH_WORDIR/milabench" ]; then
+            git clone https://github.com/mila-iqia/milabench.git -b rocm
+        fi
+        export MILABENCH_SOURCE="$MILABENCH_WORDIR/milabench"
+    fi
 
     . $MILABENCH_WORDIR/env/bin/activate
-    pip install -e $MILABENCH_WORDIR/milabench
+    pip install -e $MILABENCH_SOURCE
 
+    
     #
     # Install milabench's benchmarks in their venv
     #
-    milabench install
+    # pip install torch --index-url https://download.pytorch.org/whl/rocm6.1
+    # milabench pin --variant rocm --from-scratch $ARGS 
+    milabench install $ARGS 
 
     #
     # Override/add package to milabench venv here
@@ -36,35 +72,48 @@ install_prepare() {
     (
         . $BENCHMARK_VENV/bin/activate
 
+        pip install ninja
+
+        if [ -z "${MILABENCH_HF_TOKEN}" ]; then
+            echo "Missing token"
+        else
+            huggingface-cli login --token $MILABENCH_HF_TOKEN
+        fi
+
         #
         # Override/add package to the benchmark venv here
         #
         which pip
-        pip uninstall torch torchvision torchaudio
-        pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1
-        pip uninstall pynvml
 
-        # sudo apt-get install lld
         # https://github.com/ROCm/jax/releases/tag/rocm-jaxlib-v0.4.30
-        # does not really work
         pip install https://github.com/ROCm/jax/releases/download/rocm-jaxlib-v0.4.30/jaxlib-0.4.30+rocm611-cp310-cp310-manylinux2014_x86_64.whl
-        pip install https://github.com/ROCm/jax/archive/refs/tags/rocm-jaxlib-v0.4.30.tar.gz
+        pip install https://github.com/ROCm/jax/archive/refs/tags/rocm-jaxlib-v0.4.30.tar.g
 
-        # 
-        FORCE_CUDA=1 pip install -U -v --no-build-isolation git+https://github.com/rusty1s/pytorch_cluster.git
-        FORCE_CUDA=1 pip install -U -v --no-build-isolation git+https://github.com/rusty1s/pytorch_scatter.git
-        FORCE_CUDA=1 pip install -U -v --no-build-isolation git+https://github.com/rusty1s/pytorch_sparse.git
+        pip uninstall torch_cluster torch_scatter torch_sparse -y
+        FORCE_ONLY_CUDA=1 pip install -U -v --use-pep517 --no-build-isolation git+https://github.com/rusty1s/pytorch_cluster.git
+        FORCE_ONLY_CUDA=1 pip install -U -v --use-pep517 --no-build-isolation git+https://github.com/rusty1s/pytorch_scatter.git
+        FORCE_ONLY_CUDA=1 pip install -U -v --use-pep517 --no-build-isolation git+https://github.com/rusty1s/pytorch_sparse.git
 
         # takes forever to compile
         # https://github.com/ROCm/xformers
-        pip install -v -U --no-build-isolation --no-deps git+https://github.com/ROCm/xformers.git@develop#egg=xformers
-        pip install -v -U --no-build-isolation --no-deps git+https://github.com/ROCm/flash-attention.git
+        pip uninstall xformers
+        pip install xformers --index-url https://download.pytorch.org/whl/rocm6.1
+        # pip install -v -U --no-build-isolation --no-deps git+https://github.com/ROCm/xformers.git@develop#egg=xformers
+        # pip install -v -U --no-build-isolation --no-deps git+https://github.com/facebookresearch/xformers.git
+        # pip install xformers -U --index-url https://download.pytorch.org/whl/rocm6.1
+
+        pip uninstall flash-attention
+        pip install -v -U --no-build-isolation --use-pep517 --no-deps git+https://github.com/ROCm/flash-attention.git 
+        pip uninstall pynvml nvidia-ml-py -y
+
+        pip install einops
     )
 
+    pip uninstall pynvml nvidia-ml-py -y
     #
     #   Generate/download datasets, download models etc...
     #
-    milabench prepare
+    milabench prepare $ARGS 
 }
 
 if [ ! -d "$MILABENCH_WORDIR" ]; then
@@ -74,11 +123,19 @@ else
     . $MILABENCH_WORDIR/env/bin/activate
 fi
 
-cd $MILABENCH_WORDIR
+(
+    . $BENCHMARK_VENV/bin/activate
+    pip install xformers --index-url https://download.pytorch.org/whl/rocm6.1
+)
+
+# milabench install $ARGS --system $MILABENCH_WORDIR/system.yaml
+
+# milabench prepare $ARGS --system $MILABENCH_WORDIR/system.yaml
 
 #
 #   Run the benchmakrs
-milabench run "$@"
+milabench run $ARGS --system $MILABENCH_WORDIR/system.yaml
+
 
 #
 #   Display report
diff --git a/tests/config/argerror.yaml b/tests/config/argerror.yaml
index 49ad733cc..59041b72f 100644
--- a/tests/config/argerror.yaml
+++ b/tests/config/argerror.yaml
@@ -9,3 +9,5 @@ benchio:
     n: 1
   argv:
     --start: 0
+  tags:
+    - monogpu
\ No newline at end of file
diff --git a/tests/config/benchio.yaml b/tests/config/benchio.yaml
index f2c694e22..50c352ca8 100644
--- a/tests/config/benchio.yaml
+++ b/tests/config/benchio.yaml
@@ -4,4 +4,6 @@ benchio:
   weight: 2
   plan:
     method: njobs
-    n: 2
\ No newline at end of file
+    n: 2
+  tags:
+    - monogpu
\ No newline at end of file
diff --git a/tests/config/benchio_bad.yaml b/tests/config/benchio_bad.yaml
index ac0b2f820..51b15ac4b 100644
--- a/tests/config/benchio_bad.yaml
+++ b/tests/config/benchio_bad.yaml
@@ -8,4 +8,5 @@ benchio:
 
   argv:
     --bad: true
-
+  tags:
+    - monogpu
\ No newline at end of file
diff --git a/tests/config/scaling.yaml b/tests/config/scaling.yaml
index 664996f79..3f3b032e9 100644
--- a/tests/config/scaling.yaml
+++ b/tests/config/scaling.yaml
@@ -5,3 +5,5 @@ benchio:
     64: 12Go
     128: 24Go
     256: 48Go
+  tags:
+    - monogpu
\ No newline at end of file
diff --git a/tests/test_command_reg/test_command_reg_one_node.txt b/tests/test_command_reg/test_command_reg_one_node.txt
index 3a511bb65..af21f4cde 100644
--- a/tests/test_command_reg/test_command_reg_one_node.txt
+++ b/tests/test_command_reg/test_command_reg_one_node.txt
@@ -16,7 +16,7 @@ export MILABENCH_DIR_RUNS=$BASE/runs
 export MILABENCH_DIR_EXTRA=$BASE/extra/llm
 export MILABENCH_DIR_CACHE=$BASE/cache
 export OMP_NUM_THREADS=0
-export MILABENCH_CONFIG='{"system": {"arch": "cuda", "sshkey": null, "nodes": [{"ip": "127.0.0.1", "main": true, "name": "0", "sshport": 22, "user": "username", "hostname": "127.0.0.1"}], "self": {"ip": "127.0.0.1", "main": true, "name": "0", "sshport": 22, "user": "username", "hostname": "127.0.0.1"}}, "dirs": {"base": "$BASE", "venv": "$BASE/venv/torch", "data": "$BASE/data", "runs": "$BASE/runs", "extra": "$BASE/extra/llm", "cache": "$BASE/cache"}, "group": "llm", "install_group": "torch", "install_variant": "cuda", "run_name": "dev", "enabled": true, "capabilities": {"nodes": 1}, "max_duration": 800, "voir": {"options": {"stop": 30, "interval": "1s"}}, "validation": {"usage": {"gpu_load_threshold": 0.5, "gpu_mem_threshold": 0.5}}, "config_base": "$SRC/milabench/config", "config_file": "$SRC/milabench/config/standard.yaml", "definition": "$SRC/milabench/benchmarks/llama", "tags": ["inference", "llm", "monogpu", "nlp", "nobatch"], "plan": {"method": "per_gpu"}, "weight": 1.0, "name": "llama", "tag": ["llama"]}'
+export MILABENCH_CONFIG='{"system": {"arch": "cuda", "sshkey": null, "nodes": [{"ip": "127.0.0.1", "main": true, "name": "0", "sshport": 22, "user": "username", "hostname": "127.0.0.1"}], "self": {"ip": "127.0.0.1", "main": true, "name": "0", "sshport": 22, "user": "username", "hostname": "127.0.0.1"}}, "dirs": {"base": "$BASE", "venv": "$BASE/venv/torch", "data": "$BASE/data", "runs": "$BASE/runs", "extra": "$BASE/extra/llm", "cache": "$BASE/cache"}, "group": "llm", "install_group": "torch", "install_variant": "cuda", "run_name": "dev", "enabled": true, "capabilities": {"nodes": 1}, "max_duration": 3600, "voir": {"options": {"stop": 30, "interval": "1s"}}, "validation": {"usage": {"gpu_load_threshold": 0.5, "gpu_mem_threshold": 0.5}}, "config_base": "$SRC/milabench/config", "config_file": "$SRC/milabench/config/standard.yaml", "definition": "$SRC/milabench/benchmarks/llama", "tags": ["inference", "llm", "monogpu", "nlp", "nobatch"], "plan": {"method": "per_gpu"}, "weight": 1.0, "name": "llama", "tag": ["llama"]}'
 
 echo "---"
 echo "llama"
@@ -37,14 +37,14 @@ echo "---"
 echo "fp16"
 echo "===="
 time (
-  CUDA_VISIBLE_DEVICES=0 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
-  CUDA_VISIBLE_DEVICES=1 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
-  CUDA_VISIBLE_DEVICES=2 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
-  CUDA_VISIBLE_DEVICES=3 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
-  CUDA_VISIBLE_DEVICES=4 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
-  CUDA_VISIBLE_DEVICES=5 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
-  CUDA_VISIBLE_DEVICES=6 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
-  CUDA_VISIBLE_DEVICES=7 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
+  CUDA_VISIBLE_DEVICES=0 $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
+  CUDA_VISIBLE_DEVICES=1 $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
+  CUDA_VISIBLE_DEVICES=2 $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
+  CUDA_VISIBLE_DEVICES=3 $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
+  CUDA_VISIBLE_DEVICES=4 $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
+  CUDA_VISIBLE_DEVICES=5 $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
+  CUDA_VISIBLE_DEVICES=6 $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
+  CUDA_VISIBLE_DEVICES=7 $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
   wait
 )
 
@@ -52,14 +52,14 @@ echo "---"
 echo "bf16"
 echo "===="
 time (
-  CUDA_VISIBLE_DEVICES=0 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
-  CUDA_VISIBLE_DEVICES=1 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
-  CUDA_VISIBLE_DEVICES=2 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
-  CUDA_VISIBLE_DEVICES=3 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
-  CUDA_VISIBLE_DEVICES=4 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
-  CUDA_VISIBLE_DEVICES=5 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
-  CUDA_VISIBLE_DEVICES=6 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
-  CUDA_VISIBLE_DEVICES=7 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
+  CUDA_VISIBLE_DEVICES=0 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
+  CUDA_VISIBLE_DEVICES=1 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
+  CUDA_VISIBLE_DEVICES=2 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
+  CUDA_VISIBLE_DEVICES=3 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
+  CUDA_VISIBLE_DEVICES=4 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
+  CUDA_VISIBLE_DEVICES=5 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
+  CUDA_VISIBLE_DEVICES=6 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
+  CUDA_VISIBLE_DEVICES=7 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
   wait
 )
 
@@ -67,14 +67,14 @@ echo "---"
 echo "tf32"
 echo "===="
 time (
-  CUDA_VISIBLE_DEVICES=0 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
-  CUDA_VISIBLE_DEVICES=1 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
-  CUDA_VISIBLE_DEVICES=2 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
-  CUDA_VISIBLE_DEVICES=3 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
-  CUDA_VISIBLE_DEVICES=4 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
-  CUDA_VISIBLE_DEVICES=5 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
-  CUDA_VISIBLE_DEVICES=6 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
-  CUDA_VISIBLE_DEVICES=7 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
+  CUDA_VISIBLE_DEVICES=0 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
+  CUDA_VISIBLE_DEVICES=1 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
+  CUDA_VISIBLE_DEVICES=2 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
+  CUDA_VISIBLE_DEVICES=3 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
+  CUDA_VISIBLE_DEVICES=4 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
+  CUDA_VISIBLE_DEVICES=5 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
+  CUDA_VISIBLE_DEVICES=6 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
+  CUDA_VISIBLE_DEVICES=7 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
   wait
 )
 
@@ -82,14 +82,14 @@ echo "---"
 echo "fp32"
 echo "===="
 time (
-  CUDA_VISIBLE_DEVICES=0 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
-  CUDA_VISIBLE_DEVICES=1 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
-  CUDA_VISIBLE_DEVICES=2 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
-  CUDA_VISIBLE_DEVICES=3 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
-  CUDA_VISIBLE_DEVICES=4 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
-  CUDA_VISIBLE_DEVICES=5 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
-  CUDA_VISIBLE_DEVICES=6 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
-  CUDA_VISIBLE_DEVICES=7 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
+  CUDA_VISIBLE_DEVICES=0 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
+  CUDA_VISIBLE_DEVICES=1 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
+  CUDA_VISIBLE_DEVICES=2 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
+  CUDA_VISIBLE_DEVICES=3 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
+  CUDA_VISIBLE_DEVICES=4 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
+  CUDA_VISIBLE_DEVICES=5 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
+  CUDA_VISIBLE_DEVICES=6 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
+  CUDA_VISIBLE_DEVICES=7 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
   wait
 )
 
@@ -285,14 +285,14 @@ echo "---"
 echo "reformer"
 echo "========"
 time (
-  CUDA_VISIBLE_DEVICES=0 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 64 &
-  CUDA_VISIBLE_DEVICES=1 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 64 &
-  CUDA_VISIBLE_DEVICES=2 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 64 &
-  CUDA_VISIBLE_DEVICES=3 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 64 &
-  CUDA_VISIBLE_DEVICES=4 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 64 &
-  CUDA_VISIBLE_DEVICES=5 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 64 &
-  CUDA_VISIBLE_DEVICES=6 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 64 &
-  CUDA_VISIBLE_DEVICES=7 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 64 &
+  CUDA_VISIBLE_DEVICES=0 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 32 &
+  CUDA_VISIBLE_DEVICES=1 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 32 &
+  CUDA_VISIBLE_DEVICES=2 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 32 &
+  CUDA_VISIBLE_DEVICES=3 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 32 &
+  CUDA_VISIBLE_DEVICES=4 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 32 &
+  CUDA_VISIBLE_DEVICES=5 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 32 &
+  CUDA_VISIBLE_DEVICES=6 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 32 &
+  CUDA_VISIBLE_DEVICES=7 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 32 &
   wait
 )
 
@@ -415,14 +415,14 @@ echo "---"
 echo "llm-lora-single"
 echo "==============="
 time (
-  CUDA_VISIBLE_DEVICES=0 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 &
-  CUDA_VISIBLE_DEVICES=1 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 &
-  CUDA_VISIBLE_DEVICES=2 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 &
-  CUDA_VISIBLE_DEVICES=3 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 &
-  CUDA_VISIBLE_DEVICES=4 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 &
-  CUDA_VISIBLE_DEVICES=5 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 &
-  CUDA_VISIBLE_DEVICES=6 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 &
-  CUDA_VISIBLE_DEVICES=7 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 &
+  CUDA_VISIBLE_DEVICES=0 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 device=cuda &
+  CUDA_VISIBLE_DEVICES=1 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 device=cuda &
+  CUDA_VISIBLE_DEVICES=2 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 device=cuda &
+  CUDA_VISIBLE_DEVICES=3 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 device=cuda &
+  CUDA_VISIBLE_DEVICES=4 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 device=cuda &
+  CUDA_VISIBLE_DEVICES=5 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 device=cuda &
+  CUDA_VISIBLE_DEVICES=6 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 device=cuda &
+  CUDA_VISIBLE_DEVICES=7 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 device=cuda &
   wait
 )
 
@@ -430,7 +430,7 @@ echo "---"
 echo "llm-lora-ddp-gpus"
 echo "================="
 time (
-  $BASE/venv/torch/bin/tune run --nnodes=1 --rdzv-backend=static --rdzv-endpoint=127.0.0.1:29400 --master-addr=127.0.0.1 --master-port=29400 --local-ranks-filter=0 --nproc-per-node=8 -- $SRC/milabench/benchmarks/llm/recipes/lora_finetune_distributed.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-ddp-gpus/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-ddp-gpus/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 &
+  $BASE/venv/torch/bin/tune run --nnodes=1 --rdzv-backend=static --rdzv-endpoint=127.0.0.1:29400 --master-addr=127.0.0.1 --master-port=29400 --local-ranks-filter=0 --nproc-per-node=8 -- $SRC/milabench/benchmarks/llm/recipes/lora_finetune_distributed.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-ddp-gpus/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-ddp-gpus/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 device=cuda &
   wait
 )
 
@@ -438,7 +438,7 @@ echo "---"
 echo "llm-lora-ddp-nodes"
 echo "=================="
 time (
-  $BASE/venv/torch/bin/tune run --nnodes=1 --rdzv-backend=static --rdzv-endpoint=127.0.0.1:29400 --master-addr=127.0.0.1 --master-port=29400 --local-ranks-filter=0 --nproc-per-node=8 -- $SRC/milabench/benchmarks/llm/recipes/lora_finetune_distributed.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-ddp-nodes/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-ddp-nodes/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 &
+  $BASE/venv/torch/bin/tune run --nnodes=1 --rdzv-backend=static --rdzv-endpoint=127.0.0.1:29400 --master-addr=127.0.0.1 --master-port=29400 --local-ranks-filter=0 --nproc-per-node=8 -- $SRC/milabench/benchmarks/llm/recipes/lora_finetune_distributed.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-ddp-nodes/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-ddp-nodes/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 device=cuda &
   wait
 )
 
@@ -446,7 +446,7 @@ echo "---"
 echo "llm-lora-mp-gpus"
 echo "================"
 time (
-  $BASE/venv/torch/bin/tune run --nnodes=1 --rdzv-backend=static --rdzv-endpoint=127.0.0.1:29400 --master-addr=127.0.0.1 --master-port=29400 --local-ranks-filter=0 --nproc-per-node=8 -- $SRC/milabench/benchmarks/llm/recipes/lora_finetune_distributed.py --config $SRC/milabench/benchmarks/llm/configs/llama3_70B_lora.yaml epochs=1 output_dir=$BASE/extra/llm-lora-mp-gpus/output tokenizer.path=$BASE/data/llama3_70B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_70B checkpointer.output_dir=$BASE/data/llama3_70B/ safetensors=true metric_logger.log_dir=$BASE/extra/llm-lora-mp-gpus/metrics repo_id="meta-llama/Meta-Llama-3.1-70B" batch_size=8 gradient_accumulation_steps=1 &
+  $BASE/venv/torch/bin/tune run --nnodes=1 --rdzv-backend=static --rdzv-endpoint=127.0.0.1:29400 --master-addr=127.0.0.1 --master-port=29400 --local-ranks-filter=0 --nproc-per-node=8 -- $SRC/milabench/benchmarks/llm/recipes/lora_finetune_distributed.py --config $SRC/milabench/benchmarks/llm/configs/llama3_70B_lora.yaml epochs=1 output_dir=$BASE/extra/llm-lora-mp-gpus/output tokenizer.path=$BASE/data/llama3_70B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_70B checkpointer.output_dir=$BASE/data/llama3_70B/ safetensors=true metric_logger.log_dir=$BASE/extra/llm-lora-mp-gpus/metrics repo_id="meta-llama/Meta-Llama-3.1-70B" batch_size=8 gradient_accumulation_steps=1 device=cuda &
   wait
 )
 
@@ -454,7 +454,7 @@ echo "---"
 echo "llm-full-mp-gpus"
 echo "================"
 time (
-  $BASE/venv/torch/bin/tune run --nnodes=1 --rdzv-backend=static --rdzv-endpoint=127.0.0.1:29400 --master-addr=127.0.0.1 --master-port=29400 --local-ranks-filter=0 --nproc-per-node=8 -- $SRC/milabench/benchmarks/llm/recipes/full_finetune_distributed.py --config $SRC/milabench/benchmarks/llm/configs/llama3_70B_full.yaml epochs=1 output_dir=$BASE/extra/llm-full-mp-gpus/output tokenizer.path=$BASE/data/llama3_70B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_70B checkpointer.output_dir=$BASE/data/llama3_70B/ metric_logger.log_dir=$BASE/extra/llm-full-mp-gpus/metrics repo_id="meta-llama/Meta-Llama-3.1-70B" safetensors=true batch_size=2 gradient_accumulation_steps=1 &
+  $BASE/venv/torch/bin/tune run --nnodes=1 --rdzv-backend=static --rdzv-endpoint=127.0.0.1:29400 --master-addr=127.0.0.1 --master-port=29400 --local-ranks-filter=0 --nproc-per-node=8 -- $SRC/milabench/benchmarks/llm/recipes/full_finetune_distributed.py --config $SRC/milabench/benchmarks/llm/configs/llama3_70B_full.yaml epochs=1 output_dir=$BASE/extra/llm-full-mp-gpus/output tokenizer.path=$BASE/data/llama3_70B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_70B checkpointer.output_dir=$BASE/data/llama3_70B/ metric_logger.log_dir=$BASE/extra/llm-full-mp-gpus/metrics repo_id="meta-llama/Meta-Llama-3.1-70B" safetensors=true batch_size=2 gradient_accumulation_steps=1 device=cuda &
   wait
 )
 
@@ -462,7 +462,7 @@ echo "---"
 echo "llm-full-mp-nodes"
 echo "================="
 time (
-  $BASE/venv/torch/bin/tune run --nnodes=1 --rdzv-backend=static --rdzv-endpoint=127.0.0.1:29400 --master-addr=127.0.0.1 --master-port=29400 --local-ranks-filter=0 --nproc-per-node=8 -- $SRC/milabench/benchmarks/llm/recipes/full_finetune_distributed.py --config $SRC/milabench/benchmarks/llm/configs/llama3_70B_full.yaml epochs=1 output_dir=$BASE/extra/llm-full-mp-nodes/output tokenizer.path=$BASE/data/llama3_70B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_70B checkpointer.output_dir=$BASE/data/llama3_70B/ metric_logger.log_dir=$BASE/extra/llm-full-mp-nodes/metrics repo_id="meta-llama/Meta-Llama-3.1-70B" safetensors=true batch_size=2 gradient_accumulation_steps=1 &
+  $BASE/venv/torch/bin/tune run --nnodes=1 --rdzv-backend=static --rdzv-endpoint=127.0.0.1:29400 --master-addr=127.0.0.1 --master-port=29400 --local-ranks-filter=0 --nproc-per-node=8 -- $SRC/milabench/benchmarks/llm/recipes/full_finetune_distributed.py --config $SRC/milabench/benchmarks/llm/configs/llama3_70B_full.yaml epochs=1 output_dir=$BASE/extra/llm-full-mp-nodes/output tokenizer.path=$BASE/data/llama3_70B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_70B checkpointer.output_dir=$BASE/data/llama3_70B/ metric_logger.log_dir=$BASE/extra/llm-full-mp-nodes/metrics repo_id="meta-llama/Meta-Llama-3.1-70B" safetensors=true batch_size=2 gradient_accumulation_steps=1 device=cuda &
   wait
 )
 
@@ -470,14 +470,14 @@ echo "---"
 echo "dqn"
 echo "==="
 time (
-  CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 &
-  CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 &
-  CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 &
-  CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 &
-  CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 &
-  CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 &
-  CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 &
-  CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 &
+  CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_size 131072 --buffer_batch_size 65536 --env_name CartPole-v1 --training_interval 10 &
+  CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_size 131072 --buffer_batch_size 65536 --env_name CartPole-v1 --training_interval 10 &
+  CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_size 131072 --buffer_batch_size 65536 --env_name CartPole-v1 --training_interval 10 &
+  CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_size 131072 --buffer_batch_size 65536 --env_name CartPole-v1 --training_interval 10 &
+  CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_size 131072 --buffer_batch_size 65536 --env_name CartPole-v1 --training_interval 10 &
+  CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_size 131072 --buffer_batch_size 65536 --env_name CartPole-v1 --training_interval 10 &
+  CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_size 131072 --buffer_batch_size 65536 --env_name CartPole-v1 --training_interval 10 &
+  CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_size 131072 --buffer_batch_size 65536 --env_name CartPole-v1 --training_interval 10 &
   wait
 )
 
@@ -485,14 +485,29 @@ echo "---"
 echo "ppo"
 echo "==="
 time (
-  CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 &
-  CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 &
-  CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 &
-  CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 &
-  CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 &
-  CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 &
-  CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 &
-  CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 &
+  CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 2000000 &
+  CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 2000000 &
+  CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 2000000 &
+  CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 2000000 &
+  CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 2000000 &
+  CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 2000000 &
+  CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 2000000 &
+  CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 2000000 &
+  wait
+)
+
+echo "---"
+echo "pna"
+echo "==="
+time (
+  CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/geo_gnn/main.py --model PNA --num-samples 100000 --batch-size 4096 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/geo_gnn/main.py --model PNA --num-samples 100000 --batch-size 4096 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/geo_gnn/main.py --model PNA --num-samples 100000 --batch-size 4096 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/geo_gnn/main.py --model PNA --num-samples 100000 --batch-size 4096 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/geo_gnn/main.py --model PNA --num-samples 100000 --batch-size 4096 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/geo_gnn/main.py --model PNA --num-samples 100000 --batch-size 4096 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/geo_gnn/main.py --model PNA --num-samples 100000 --batch-size 4096 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/geo_gnn/main.py --model PNA --num-samples 100000 --batch-size 4096 --num-workers 0 &
   wait
 )
 
@@ -500,14 +515,14 @@ echo "---"
 echo "dimenet"
 echo "======="
 time (
-  CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 10000 --use3d &
-  CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 10000 --use3d &
-  CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 10000 --use3d &
-  CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 10000 --use3d &
-  CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 10000 --use3d &
-  CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 10000 --use3d &
-  CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 10000 --use3d &
-  CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 10000 --use3d &
+  CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 100000 --use3d --batch-size 16 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 100000 --use3d --batch-size 16 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 100000 --use3d --batch-size 16 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 100000 --use3d --batch-size 16 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 100000 --use3d --batch-size 16 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 100000 --use3d --batch-size 16 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 100000 --use3d --batch-size 16 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 100000 --use3d --batch-size 16 --num-workers 0 &
   wait
 )
 
diff --git a/tests/test_command_reg/test_command_reg_two_nodes.txt b/tests/test_command_reg/test_command_reg_two_nodes.txt
index 3004505de..5e516e3f9 100644
--- a/tests/test_command_reg/test_command_reg_two_nodes.txt
+++ b/tests/test_command_reg/test_command_reg_two_nodes.txt
@@ -16,7 +16,7 @@ export MILABENCH_DIR_RUNS=$BASE/runs
 export MILABENCH_DIR_EXTRA=$BASE/extra/llm
 export MILABENCH_DIR_CACHE=$BASE/cache
 export OMP_NUM_THREADS=0
-export MILABENCH_CONFIG='{"system": {"arch": "cuda", "sshkey": null, "nodes": [{"ip": "127.0.0.1", "main": true, "name": "0", "sshport": 22, "user": "username", "hostname": "127.0.0.1"}, {"ip": "192.168.0.11", "main": false, "name": "1", "sshport": 22, "user": "username", "hostname": "192.168.0.11"}], "self": {"ip": "127.0.0.1", "main": true, "name": "0", "sshport": 22, "user": "username", "hostname": "127.0.0.1"}}, "dirs": {"base": "$BASE", "venv": "$BASE/venv/torch", "data": "$BASE/data", "runs": "$BASE/runs", "extra": "$BASE/extra/llm", "cache": "$BASE/cache"}, "group": "llm", "install_group": "torch", "install_variant": "cuda", "run_name": "dev", "enabled": true, "capabilities": {"nodes": 1}, "max_duration": 800, "voir": {"options": {"stop": 30, "interval": "1s"}}, "validation": {"usage": {"gpu_load_threshold": 0.5, "gpu_mem_threshold": 0.5}}, "config_base": "$SRC/milabench/config", "config_file": "$SRC/milabench/config/standard.yaml", "definition": "$SRC/milabench/benchmarks/llama", "tags": ["inference", "llm", "monogpu", "nlp", "nobatch"], "plan": {"method": "per_gpu"}, "weight": 1.0, "name": "llama", "tag": ["llama"]}'
+export MILABENCH_CONFIG='{"system": {"arch": "cuda", "sshkey": null, "nodes": [{"ip": "127.0.0.1", "main": true, "name": "0", "sshport": 22, "user": "username", "hostname": "127.0.0.1"}, {"ip": "192.168.0.11", "main": false, "name": "1", "sshport": 22, "user": "username", "hostname": "192.168.0.11"}], "self": {"ip": "127.0.0.1", "main": true, "name": "0", "sshport": 22, "user": "username", "hostname": "127.0.0.1"}}, "dirs": {"base": "$BASE", "venv": "$BASE/venv/torch", "data": "$BASE/data", "runs": "$BASE/runs", "extra": "$BASE/extra/llm", "cache": "$BASE/cache"}, "group": "llm", "install_group": "torch", "install_variant": "cuda", "run_name": "dev", "enabled": true, "capabilities": {"nodes": 1}, "max_duration": 3600, "voir": {"options": {"stop": 30, "interval": "1s"}}, "validation": {"usage": {"gpu_load_threshold": 0.5, "gpu_mem_threshold": 0.5}}, "config_base": "$SRC/milabench/config", "config_file": "$SRC/milabench/config/standard.yaml", "definition": "$SRC/milabench/benchmarks/llama", "tags": ["inference", "llm", "monogpu", "nlp", "nobatch"], "plan": {"method": "per_gpu"}, "weight": 1.0, "name": "llama", "tag": ["llama"]}'
 
 echo "---"
 echo "llama"
@@ -37,14 +37,14 @@ echo "---"
 echo "fp16"
 echo "===="
 time (
-  CUDA_VISIBLE_DEVICES=0 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
-  CUDA_VISIBLE_DEVICES=1 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
-  CUDA_VISIBLE_DEVICES=2 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
-  CUDA_VISIBLE_DEVICES=3 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
-  CUDA_VISIBLE_DEVICES=4 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
-  CUDA_VISIBLE_DEVICES=5 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
-  CUDA_VISIBLE_DEVICES=6 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
-  CUDA_VISIBLE_DEVICES=7 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
+  CUDA_VISIBLE_DEVICES=0 $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
+  CUDA_VISIBLE_DEVICES=1 $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
+  CUDA_VISIBLE_DEVICES=2 $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
+  CUDA_VISIBLE_DEVICES=3 $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
+  CUDA_VISIBLE_DEVICES=4 $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
+  CUDA_VISIBLE_DEVICES=5 $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
+  CUDA_VISIBLE_DEVICES=6 $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
+  CUDA_VISIBLE_DEVICES=7 $SRC/milabench/benchmarks/flops/main.py --number 30 --repeat 90 --m 8192 --n 8192 --dtype fp16 &
   wait
 )
 
@@ -52,14 +52,14 @@ echo "---"
 echo "bf16"
 echo "===="
 time (
-  CUDA_VISIBLE_DEVICES=0 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
-  CUDA_VISIBLE_DEVICES=1 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
-  CUDA_VISIBLE_DEVICES=2 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
-  CUDA_VISIBLE_DEVICES=3 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
-  CUDA_VISIBLE_DEVICES=4 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
-  CUDA_VISIBLE_DEVICES=5 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
-  CUDA_VISIBLE_DEVICES=6 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
-  CUDA_VISIBLE_DEVICES=7 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
+  CUDA_VISIBLE_DEVICES=0 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
+  CUDA_VISIBLE_DEVICES=1 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
+  CUDA_VISIBLE_DEVICES=2 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
+  CUDA_VISIBLE_DEVICES=3 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
+  CUDA_VISIBLE_DEVICES=4 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
+  CUDA_VISIBLE_DEVICES=5 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
+  CUDA_VISIBLE_DEVICES=6 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
+  CUDA_VISIBLE_DEVICES=7 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype bf16 &
   wait
 )
 
@@ -67,14 +67,14 @@ echo "---"
 echo "tf32"
 echo "===="
 time (
-  CUDA_VISIBLE_DEVICES=0 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
-  CUDA_VISIBLE_DEVICES=1 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
-  CUDA_VISIBLE_DEVICES=2 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
-  CUDA_VISIBLE_DEVICES=3 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
-  CUDA_VISIBLE_DEVICES=4 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
-  CUDA_VISIBLE_DEVICES=5 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
-  CUDA_VISIBLE_DEVICES=6 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
-  CUDA_VISIBLE_DEVICES=7 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
+  CUDA_VISIBLE_DEVICES=0 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
+  CUDA_VISIBLE_DEVICES=1 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
+  CUDA_VISIBLE_DEVICES=2 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
+  CUDA_VISIBLE_DEVICES=3 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
+  CUDA_VISIBLE_DEVICES=4 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
+  CUDA_VISIBLE_DEVICES=5 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
+  CUDA_VISIBLE_DEVICES=6 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
+  CUDA_VISIBLE_DEVICES=7 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 --tf32 &
   wait
 )
 
@@ -82,14 +82,14 @@ echo "---"
 echo "fp32"
 echo "===="
 time (
-  CUDA_VISIBLE_DEVICES=0 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
-  CUDA_VISIBLE_DEVICES=1 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
-  CUDA_VISIBLE_DEVICES=2 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
-  CUDA_VISIBLE_DEVICES=3 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
-  CUDA_VISIBLE_DEVICES=4 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
-  CUDA_VISIBLE_DEVICES=5 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
-  CUDA_VISIBLE_DEVICES=6 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
-  CUDA_VISIBLE_DEVICES=7 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
+  CUDA_VISIBLE_DEVICES=0 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
+  CUDA_VISIBLE_DEVICES=1 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
+  CUDA_VISIBLE_DEVICES=2 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
+  CUDA_VISIBLE_DEVICES=3 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
+  CUDA_VISIBLE_DEVICES=4 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
+  CUDA_VISIBLE_DEVICES=5 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
+  CUDA_VISIBLE_DEVICES=6 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
+  CUDA_VISIBLE_DEVICES=7 $SRC/milabench/benchmarks/flops/main.py --number 10 --repeat 90 --m 8192 --n 8192 --dtype fp32 &
   wait
 )
 
@@ -285,14 +285,14 @@ echo "---"
 echo "reformer"
 echo "========"
 time (
-  CUDA_VISIBLE_DEVICES=0 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 64 &
-  CUDA_VISIBLE_DEVICES=1 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 64 &
-  CUDA_VISIBLE_DEVICES=2 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 64 &
-  CUDA_VISIBLE_DEVICES=3 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 64 &
-  CUDA_VISIBLE_DEVICES=4 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 64 &
-  CUDA_VISIBLE_DEVICES=5 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 64 &
-  CUDA_VISIBLE_DEVICES=6 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 64 &
-  CUDA_VISIBLE_DEVICES=7 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 64 &
+  CUDA_VISIBLE_DEVICES=0 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 32 &
+  CUDA_VISIBLE_DEVICES=1 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 32 &
+  CUDA_VISIBLE_DEVICES=2 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 32 &
+  CUDA_VISIBLE_DEVICES=3 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 32 &
+  CUDA_VISIBLE_DEVICES=4 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 32 &
+  CUDA_VISIBLE_DEVICES=5 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 32 &
+  CUDA_VISIBLE_DEVICES=6 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 32 &
+  CUDA_VISIBLE_DEVICES=7 python -m bench --precision tf32-fp16 --num-workers 8 --model Reformer --batch-size 32 &
   wait
 )
 
@@ -357,11 +357,29 @@ time (
   wait
 )
 
+0 {'ip': '127.0.0.1', 'main': True, 'name': '0', 'sshport': 22, 'user': 'username', 'hostname': '127.0.0.1'} 127.0.0.1
+is_local 127.0.0.1 True
+1 {'ip': '192.168.0.11', 'main': False, 'name': '1', 'sshport': 22, 'user': 'username', 'hostname': '192.168.0.11'} 192.168.0.11
+is_local 192.168.0.11 False
+0 {'ip': '127.0.0.1', 'main': True, 'name': '0', 'sshport': 22, 'user': 'username', 'hostname': '127.0.0.1'} 127.0.0.1
+is_local 127.0.0.1 True
+1 {'ip': '192.168.0.11', 'main': False, 'name': '1', 'sshport': 22, 'user': 'username', 'hostname': '192.168.0.11'} 192.168.0.11
+is_local 192.168.0.11 False
+0 {'ip': '127.0.0.1', 'main': True, 'name': '0', 'sshport': 22, 'user': 'username', 'hostname': '127.0.0.1'} 127.0.0.1
+is_local 127.0.0.1 True
+1 {'ip': '192.168.0.11', 'main': False, 'name': '1', 'sshport': 22, 'user': 'username', 'hostname': '192.168.0.11'} 192.168.0.11
+is_local 192.168.0.11 False
 echo "---"
 echo "diffusion-nodes"
 echo "==============="
 time (
+0 {'ip': '127.0.0.1', 'main': True, 'name': '0', 'sshport': 22, 'user': 'username', 'hostname': '127.0.0.1'} 127.0.0.1
+is_local 127.0.0.1 True
+1 {'ip': '192.168.0.11', 'main': False, 'name': '1', 'sshport': 22, 'user': 'username', 'hostname': '192.168.0.11'} 192.168.0.11
+is_local 192.168.0.11 False
+is_local 127.0.0.1 True
   $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache accelerate launch --mixed_precision=bf16 --dynamo_backend=no --machine_rank=0 --num_machines=2 --multi_gpu --gradient_accumulation_steps=1 --num_cpu_threads_per_process=4 --main_process_ip=127.0.0.1 --main_process_port=29400 --num_processes=16 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache &
+is_local 192.168.0.11 False
   ssh -oCheckHostIP=no -oStrictHostKeyChecking=no -oPasswordAuthentication=no -oPasswordAuthentication=no -p 22 username@192.168.0.11 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache accelerate launch --mixed_precision=bf16 --dynamo_backend=no --machine_rank=1 --num_machines=2 --multi_gpu --gradient_accumulation_steps=1 --num_cpu_threads_per_process=4 --main_process_ip=127.0.0.1 --main_process_port=29400 --num_processes=16 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache &
   wait
 )
@@ -416,14 +434,14 @@ echo "---"
 echo "llm-lora-single"
 echo "==============="
 time (
-  CUDA_VISIBLE_DEVICES=0 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 &
-  CUDA_VISIBLE_DEVICES=1 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 &
-  CUDA_VISIBLE_DEVICES=2 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 &
-  CUDA_VISIBLE_DEVICES=3 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 &
-  CUDA_VISIBLE_DEVICES=4 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 &
-  CUDA_VISIBLE_DEVICES=5 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 &
-  CUDA_VISIBLE_DEVICES=6 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 &
-  CUDA_VISIBLE_DEVICES=7 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 &
+  CUDA_VISIBLE_DEVICES=0 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 device=cuda &
+  CUDA_VISIBLE_DEVICES=1 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 device=cuda &
+  CUDA_VISIBLE_DEVICES=2 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 device=cuda &
+  CUDA_VISIBLE_DEVICES=3 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 device=cuda &
+  CUDA_VISIBLE_DEVICES=4 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 device=cuda &
+  CUDA_VISIBLE_DEVICES=5 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 device=cuda &
+  CUDA_VISIBLE_DEVICES=6 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 device=cuda &
+  CUDA_VISIBLE_DEVICES=7 $SRC/milabench/benchmarks/llm/recipes/lora_finetune_single_device.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-single/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-single/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 device=cuda &
   wait
 )
 
@@ -431,16 +449,34 @@ echo "---"
 echo "llm-lora-ddp-gpus"
 echo "================="
 time (
-  $BASE/venv/torch/bin/tune run --nnodes=1 --rdzv-backend=static --rdzv-endpoint=127.0.0.1:29400 --master-addr=127.0.0.1 --master-port=29400 --local-ranks-filter=0 --nproc-per-node=8 -- $SRC/milabench/benchmarks/llm/recipes/lora_finetune_distributed.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-ddp-gpus/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-ddp-gpus/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 &
+  $BASE/venv/torch/bin/tune run --nnodes=1 --rdzv-backend=static --rdzv-endpoint=127.0.0.1:29400 --master-addr=127.0.0.1 --master-port=29400 --local-ranks-filter=0 --nproc-per-node=8 -- $SRC/milabench/benchmarks/llm/recipes/lora_finetune_distributed.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-ddp-gpus/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-ddp-gpus/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 device=cuda &
   wait
 )
 
+0 {'ip': '127.0.0.1', 'main': True, 'name': '0', 'sshport': 22, 'user': 'username', 'hostname': '127.0.0.1'} 127.0.0.1
+is_local 127.0.0.1 True
+1 {'ip': '192.168.0.11', 'main': False, 'name': '1', 'sshport': 22, 'user': 'username', 'hostname': '192.168.0.11'} 192.168.0.11
+is_local 192.168.0.11 False
+0 {'ip': '127.0.0.1', 'main': True, 'name': '0', 'sshport': 22, 'user': 'username', 'hostname': '127.0.0.1'} 127.0.0.1
+is_local 127.0.0.1 True
+1 {'ip': '192.168.0.11', 'main': False, 'name': '1', 'sshport': 22, 'user': 'username', 'hostname': '192.168.0.11'} 192.168.0.11
+is_local 192.168.0.11 False
+0 {'ip': '127.0.0.1', 'main': True, 'name': '0', 'sshport': 22, 'user': 'username', 'hostname': '127.0.0.1'} 127.0.0.1
+is_local 127.0.0.1 True
+1 {'ip': '192.168.0.11', 'main': False, 'name': '1', 'sshport': 22, 'user': 'username', 'hostname': '192.168.0.11'} 192.168.0.11
+is_local 192.168.0.11 False
 echo "---"
 echo "llm-lora-ddp-nodes"
 echo "=================="
 time (
-  $BASE/venv/torch/bin/tune run --nnodes=2 --rdzv-backend=static --rdzv-endpoint=127.0.0.1:29400 --master-addr=127.0.0.1 --master-port=29400 --local-ranks-filter=0 --node-rank=0 --local-addr=127.0.0.1 --rdzv-conf=rank=0 --nproc-per-node=8 -- $SRC/milabench/benchmarks/llm/recipes/lora_finetune_distributed.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-ddp-nodes/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-ddp-nodes/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 &
-  ssh -oCheckHostIP=no -oStrictHostKeyChecking=no -oPasswordAuthentication=no -oPasswordAuthentication=no -p 22 username@192.168.0.11 $BASE/venv/torch/bin/tune run --nnodes=2 --rdzv-backend=static --rdzv-endpoint=127.0.0.1:29400 --master-addr=127.0.0.1 --master-port=29400 --local-ranks-filter=0 --node-rank=1 --local-addr=192.168.0.11 --rdzv-conf=rank=1 --nproc-per-node=8 -- $SRC/milabench/benchmarks/llm/recipes/lora_finetune_distributed.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-ddp-nodes/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-ddp-nodes/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 &
+0 {'ip': '127.0.0.1', 'main': True, 'name': '0', 'sshport': 22, 'user': 'username', 'hostname': '127.0.0.1'} 127.0.0.1
+is_local 127.0.0.1 True
+1 {'ip': '192.168.0.11', 'main': False, 'name': '1', 'sshport': 22, 'user': 'username', 'hostname': '192.168.0.11'} 192.168.0.11
+is_local 192.168.0.11 False
+is_local 127.0.0.1 True
+  $BASE/venv/torch/bin/tune run --nnodes=2 --rdzv-backend=static --rdzv-endpoint=127.0.0.1:29400 --master-addr=127.0.0.1 --master-port=29400 --local-ranks-filter=0 --node-rank=0 --local-addr=127.0.0.1 --rdzv-conf=rank=0 --nproc-per-node=8 -- $SRC/milabench/benchmarks/llm/recipes/lora_finetune_distributed.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-ddp-nodes/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-ddp-nodes/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 device=cuda &
+is_local 192.168.0.11 False
+  ssh -oCheckHostIP=no -oStrictHostKeyChecking=no -oPasswordAuthentication=no -oPasswordAuthentication=no -p 22 username@192.168.0.11 $BASE/venv/torch/bin/tune run --nnodes=2 --rdzv-backend=static --rdzv-endpoint=127.0.0.1:29400 --master-addr=127.0.0.1 --master-port=29400 --local-ranks-filter=0 --node-rank=1 --local-addr=192.168.0.11 --rdzv-conf=rank=1 --nproc-per-node=8 -- $SRC/milabench/benchmarks/llm/recipes/lora_finetune_distributed.py --config $SRC/milabench/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml epochs=1 output_dir=$BASE/extra/llm-lora-ddp-nodes/output tokenizer.path=$BASE/data/llama3_8B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_8B/original checkpointer.output_dir=$BASE/data/llama3_8B/ metric_logger.log_dir=$BASE/extra/llm-lora-ddp-nodes/metrics repo_id="meta-llama/Meta-Llama-3.1-8B" batch_size=8 gradient_accumulation_steps=8 device=cuda &
   wait
 )
 
@@ -448,7 +484,7 @@ echo "---"
 echo "llm-lora-mp-gpus"
 echo "================"
 time (
-  $BASE/venv/torch/bin/tune run --nnodes=1 --rdzv-backend=static --rdzv-endpoint=127.0.0.1:29400 --master-addr=127.0.0.1 --master-port=29400 --local-ranks-filter=0 --nproc-per-node=8 -- $SRC/milabench/benchmarks/llm/recipes/lora_finetune_distributed.py --config $SRC/milabench/benchmarks/llm/configs/llama3_70B_lora.yaml epochs=1 output_dir=$BASE/extra/llm-lora-mp-gpus/output tokenizer.path=$BASE/data/llama3_70B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_70B checkpointer.output_dir=$BASE/data/llama3_70B/ safetensors=true metric_logger.log_dir=$BASE/extra/llm-lora-mp-gpus/metrics repo_id="meta-llama/Meta-Llama-3.1-70B" batch_size=8 gradient_accumulation_steps=1 &
+  $BASE/venv/torch/bin/tune run --nnodes=1 --rdzv-backend=static --rdzv-endpoint=127.0.0.1:29400 --master-addr=127.0.0.1 --master-port=29400 --local-ranks-filter=0 --nproc-per-node=8 -- $SRC/milabench/benchmarks/llm/recipes/lora_finetune_distributed.py --config $SRC/milabench/benchmarks/llm/configs/llama3_70B_lora.yaml epochs=1 output_dir=$BASE/extra/llm-lora-mp-gpus/output tokenizer.path=$BASE/data/llama3_70B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_70B checkpointer.output_dir=$BASE/data/llama3_70B/ safetensors=true metric_logger.log_dir=$BASE/extra/llm-lora-mp-gpus/metrics repo_id="meta-llama/Meta-Llama-3.1-70B" batch_size=8 gradient_accumulation_steps=1 device=cuda &
   wait
 )
 
@@ -456,16 +492,34 @@ echo "---"
 echo "llm-full-mp-gpus"
 echo "================"
 time (
-  $BASE/venv/torch/bin/tune run --nnodes=1 --rdzv-backend=static --rdzv-endpoint=127.0.0.1:29400 --master-addr=127.0.0.1 --master-port=29400 --local-ranks-filter=0 --nproc-per-node=8 -- $SRC/milabench/benchmarks/llm/recipes/full_finetune_distributed.py --config $SRC/milabench/benchmarks/llm/configs/llama3_70B_full.yaml epochs=1 output_dir=$BASE/extra/llm-full-mp-gpus/output tokenizer.path=$BASE/data/llama3_70B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_70B checkpointer.output_dir=$BASE/data/llama3_70B/ metric_logger.log_dir=$BASE/extra/llm-full-mp-gpus/metrics repo_id="meta-llama/Meta-Llama-3.1-70B" safetensors=true batch_size=2 gradient_accumulation_steps=1 &
+  $BASE/venv/torch/bin/tune run --nnodes=1 --rdzv-backend=static --rdzv-endpoint=127.0.0.1:29400 --master-addr=127.0.0.1 --master-port=29400 --local-ranks-filter=0 --nproc-per-node=8 -- $SRC/milabench/benchmarks/llm/recipes/full_finetune_distributed.py --config $SRC/milabench/benchmarks/llm/configs/llama3_70B_full.yaml epochs=1 output_dir=$BASE/extra/llm-full-mp-gpus/output tokenizer.path=$BASE/data/llama3_70B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_70B checkpointer.output_dir=$BASE/data/llama3_70B/ metric_logger.log_dir=$BASE/extra/llm-full-mp-gpus/metrics repo_id="meta-llama/Meta-Llama-3.1-70B" safetensors=true batch_size=2 gradient_accumulation_steps=1 device=cuda &
   wait
 )
 
+0 {'ip': '127.0.0.1', 'main': True, 'name': '0', 'sshport': 22, 'user': 'username', 'hostname': '127.0.0.1'} 127.0.0.1
+is_local 127.0.0.1 True
+1 {'ip': '192.168.0.11', 'main': False, 'name': '1', 'sshport': 22, 'user': 'username', 'hostname': '192.168.0.11'} 192.168.0.11
+is_local 192.168.0.11 False
+0 {'ip': '127.0.0.1', 'main': True, 'name': '0', 'sshport': 22, 'user': 'username', 'hostname': '127.0.0.1'} 127.0.0.1
+is_local 127.0.0.1 True
+1 {'ip': '192.168.0.11', 'main': False, 'name': '1', 'sshport': 22, 'user': 'username', 'hostname': '192.168.0.11'} 192.168.0.11
+is_local 192.168.0.11 False
+0 {'ip': '127.0.0.1', 'main': True, 'name': '0', 'sshport': 22, 'user': 'username', 'hostname': '127.0.0.1'} 127.0.0.1
+is_local 127.0.0.1 True
+1 {'ip': '192.168.0.11', 'main': False, 'name': '1', 'sshport': 22, 'user': 'username', 'hostname': '192.168.0.11'} 192.168.0.11
+is_local 192.168.0.11 False
 echo "---"
 echo "llm-full-mp-nodes"
 echo "================="
 time (
-  $BASE/venv/torch/bin/tune run --nnodes=2 --rdzv-backend=static --rdzv-endpoint=127.0.0.1:29400 --master-addr=127.0.0.1 --master-port=29400 --local-ranks-filter=0 --node-rank=0 --local-addr=127.0.0.1 --rdzv-conf=rank=0 --nproc-per-node=8 -- $SRC/milabench/benchmarks/llm/recipes/full_finetune_distributed.py --config $SRC/milabench/benchmarks/llm/configs/llama3_70B_full.yaml epochs=1 output_dir=$BASE/extra/llm-full-mp-nodes/output tokenizer.path=$BASE/data/llama3_70B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_70B checkpointer.output_dir=$BASE/data/llama3_70B/ metric_logger.log_dir=$BASE/extra/llm-full-mp-nodes/metrics repo_id="meta-llama/Meta-Llama-3.1-70B" safetensors=true batch_size=2 gradient_accumulation_steps=1 &
-  ssh -oCheckHostIP=no -oStrictHostKeyChecking=no -oPasswordAuthentication=no -oPasswordAuthentication=no -p 22 username@192.168.0.11 $BASE/venv/torch/bin/tune run --nnodes=2 --rdzv-backend=static --rdzv-endpoint=127.0.0.1:29400 --master-addr=127.0.0.1 --master-port=29400 --local-ranks-filter=0 --node-rank=1 --local-addr=192.168.0.11 --rdzv-conf=rank=1 --nproc-per-node=8 -- $SRC/milabench/benchmarks/llm/recipes/full_finetune_distributed.py --config $SRC/milabench/benchmarks/llm/configs/llama3_70B_full.yaml epochs=1 output_dir=$BASE/extra/llm-full-mp-nodes/output tokenizer.path=$BASE/data/llama3_70B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_70B checkpointer.output_dir=$BASE/data/llama3_70B/ metric_logger.log_dir=$BASE/extra/llm-full-mp-nodes/metrics repo_id="meta-llama/Meta-Llama-3.1-70B" safetensors=true batch_size=2 gradient_accumulation_steps=1 &
+0 {'ip': '127.0.0.1', 'main': True, 'name': '0', 'sshport': 22, 'user': 'username', 'hostname': '127.0.0.1'} 127.0.0.1
+is_local 127.0.0.1 True
+1 {'ip': '192.168.0.11', 'main': False, 'name': '1', 'sshport': 22, 'user': 'username', 'hostname': '192.168.0.11'} 192.168.0.11
+is_local 192.168.0.11 False
+is_local 127.0.0.1 True
+  $BASE/venv/torch/bin/tune run --nnodes=2 --rdzv-backend=static --rdzv-endpoint=127.0.0.1:29400 --master-addr=127.0.0.1 --master-port=29400 --local-ranks-filter=0 --node-rank=0 --local-addr=127.0.0.1 --rdzv-conf=rank=0 --nproc-per-node=8 -- $SRC/milabench/benchmarks/llm/recipes/full_finetune_distributed.py --config $SRC/milabench/benchmarks/llm/configs/llama3_70B_full.yaml epochs=1 output_dir=$BASE/extra/llm-full-mp-nodes/output tokenizer.path=$BASE/data/llama3_70B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_70B checkpointer.output_dir=$BASE/data/llama3_70B/ metric_logger.log_dir=$BASE/extra/llm-full-mp-nodes/metrics repo_id="meta-llama/Meta-Llama-3.1-70B" safetensors=true batch_size=2 gradient_accumulation_steps=1 device=cuda &
+is_local 192.168.0.11 False
+  ssh -oCheckHostIP=no -oStrictHostKeyChecking=no -oPasswordAuthentication=no -oPasswordAuthentication=no -p 22 username@192.168.0.11 $BASE/venv/torch/bin/tune run --nnodes=2 --rdzv-backend=static --rdzv-endpoint=127.0.0.1:29400 --master-addr=127.0.0.1 --master-port=29400 --local-ranks-filter=0 --node-rank=1 --local-addr=192.168.0.11 --rdzv-conf=rank=1 --nproc-per-node=8 -- $SRC/milabench/benchmarks/llm/recipes/full_finetune_distributed.py --config $SRC/milabench/benchmarks/llm/configs/llama3_70B_full.yaml epochs=1 output_dir=$BASE/extra/llm-full-mp-nodes/output tokenizer.path=$BASE/data/llama3_70B/original/tokenizer.model checkpointer.checkpoint_dir=$BASE/data/llama3_70B checkpointer.output_dir=$BASE/data/llama3_70B/ metric_logger.log_dir=$BASE/extra/llm-full-mp-nodes/metrics repo_id="meta-llama/Meta-Llama-3.1-70B" safetensors=true batch_size=2 gradient_accumulation_steps=1 device=cuda &
   wait
 )
 
@@ -473,14 +527,14 @@ echo "---"
 echo "dqn"
 echo "==="
 time (
-  CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 &
-  CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 &
-  CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 &
-  CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 &
-  CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 &
-  CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 &
-  CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 &
-  CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 &
+  CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_size 131072 --buffer_batch_size 65536 --env_name CartPole-v1 --training_interval 10 &
+  CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_size 131072 --buffer_batch_size 65536 --env_name CartPole-v1 --training_interval 10 &
+  CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_size 131072 --buffer_batch_size 65536 --env_name CartPole-v1 --training_interval 10 &
+  CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_size 131072 --buffer_batch_size 65536 --env_name CartPole-v1 --training_interval 10 &
+  CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_size 131072 --buffer_batch_size 65536 --env_name CartPole-v1 --training_interval 10 &
+  CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_size 131072 --buffer_batch_size 65536 --env_name CartPole-v1 --training_interval 10 &
+  CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_size 131072 --buffer_batch_size 65536 --env_name CartPole-v1 --training_interval 10 &
+  CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_size 131072 --buffer_batch_size 65536 --env_name CartPole-v1 --training_interval 10 &
   wait
 )
 
@@ -488,14 +542,29 @@ echo "---"
 echo "ppo"
 echo "==="
 time (
-  CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 &
-  CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 &
-  CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 &
-  CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 &
-  CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 &
-  CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 &
-  CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 &
-  CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 &
+  CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 2000000 &
+  CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 2000000 &
+  CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 2000000 &
+  CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 2000000 &
+  CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 2000000 &
+  CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 2000000 &
+  CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 2000000 &
+  CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 2000000 &
+  wait
+)
+
+echo "---"
+echo "pna"
+echo "==="
+time (
+  CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/geo_gnn/main.py --model PNA --num-samples 100000 --batch-size 4096 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/geo_gnn/main.py --model PNA --num-samples 100000 --batch-size 4096 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/geo_gnn/main.py --model PNA --num-samples 100000 --batch-size 4096 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/geo_gnn/main.py --model PNA --num-samples 100000 --batch-size 4096 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/geo_gnn/main.py --model PNA --num-samples 100000 --batch-size 4096 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/geo_gnn/main.py --model PNA --num-samples 100000 --batch-size 4096 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/geo_gnn/main.py --model PNA --num-samples 100000 --batch-size 4096 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/geo_gnn/main.py --model PNA --num-samples 100000 --batch-size 4096 --num-workers 0 &
   wait
 )
 
@@ -503,14 +572,14 @@ echo "---"
 echo "dimenet"
 echo "======="
 time (
-  CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 10000 --use3d &
-  CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 10000 --use3d &
-  CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 10000 --use3d &
-  CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 10000 --use3d &
-  CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 10000 --use3d &
-  CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 10000 --use3d &
-  CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 10000 --use3d &
-  CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 10000 --use3d &
+  CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 100000 --use3d --batch-size 16 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 100000 --use3d --batch-size 16 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 100000 --use3d --batch-size 16 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 100000 --use3d --batch-size 16 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 100000 --use3d --batch-size 16 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 100000 --use3d --batch-size 16 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 100000 --use3d --batch-size 16 --num-workers 0 &
+  CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/geo_gnn/main.py --model DimeNet --num-samples 100000 --use3d --batch-size 16 --num-workers 0 &
   wait
 )
 
diff --git a/tests/test_scaler.py b/tests/test_scaler.py
index f00a89793..07cdb2ed2 100644
--- a/tests/test_scaler.py
+++ b/tests/test_scaler.py
@@ -76,28 +76,27 @@ def fakeexec(pack):
 
 def test_scaler_enabled(multipack, config):
     from milabench.system import system_global
-    import contextvars
-
-    ctx = contextvars.copy_context()
-
-    def update_ctx():
-        sizer = Sizer(
-            SizerOptions(
-                size=None,
-                autoscale=True,
-                multiple=8,
-            ),
-            config("scaling"),
-        )
-        sizer_global.set(sizer)
-        system = system_global.get()
-        gpu = system.setdefault("gpu", dict())
-        gpu["capacity"] = "41920 MiB"
-
-    ctx.run(update_ctx)
+    from milabench.system import apply_system
+
+    conf = {
+        "gpu": {
+            "capacity": "41920 MiB"
+        },
+        "options": {
+            "sizer": {
+                "multiple": 8
+            }
+        }
+    }
 
     for k, pack in multipack.packs.items():
-        assert ctx.run(lambda: fakeexec(pack)) == ["--batch_size", "232"]
+        # Sizer is only enabled when config is applied
+        assert fakeexec(pack) == []
+
+    with apply_system(conf):
+        for k, pack in multipack.packs.items():
+            fakeexec(pack) == ["--batch_size", "232"]
 
-        # Sizer is only enabled inside the context
+    for k, pack in multipack.packs.items():
+        # Sizer is only enabled when config is applied
         assert fakeexec(pack) == []
diff --git a/tests/test_summary/test_compare.txt b/tests/test_summary/test_compare.txt
index c4dd7f6dc..c3bb5bf0f 100644
--- a/tests/test_summary/test_compare.txt
+++ b/tests/test_summary/test_compare.txt
@@ -1,5 +1,5 @@
                                        |   rijubigo |   sedumoje
-                                       | 2023-03-24 | 2023-03-24
-bench                |          metric |   13:45:27 |   13:57:35
+                                       | 2024-08-23 | 2024-08-23
+bench                |          metric |   09:22:03 |   09:22:03
 ----------------------------------------------------------------
 benchio              |      train_rate |    8780.41 |    8286.03
diff --git a/tests/test_summary/test_report.txt b/tests/test_summary/test_report.txt
index b9f6ce02a..2f4d3fe4b 100644
--- a/tests/test_summary/test_report.txt
+++ b/tests/test_summary/test_report.txt
@@ -11,4 +11,4 @@ benchio |    0 |   4 |    0 |    7979.82 |   2.9% |  17.2% |         nan |    79
 Scores
 ------
 Failure rate:       0.00% (PASS)
-Score:            7979.82
+Score:            7980.82
diff --git a/tests/test_summary/test_report_folder_does_average.txt b/tests/test_summary/test_report_folder_does_average.txt
index 9fda7a9c2..8884a73a6 100644
--- a/tests/test_summary/test_report_folder_does_average.txt
+++ b/tests/test_summary/test_report_folder_does_average.txt
@@ -11,4 +11,4 @@ benchio |    0 |   6 |    0 |    7878.45 |   2.5% |  18.0% |       24456 |    78
 Scores
 ------
 Failure rate:       0.00% (PASS)
-Score:            7878.45
+Score:            7879.45
diff --git a/tests/test_system_matrix.py b/tests/test_system_matrix.py
new file mode 100644
index 000000000..ed5378815
--- /dev/null
+++ b/tests/test_system_matrix.py
@@ -0,0 +1,40 @@
+
+
+
+
+
+from milabench.system import multirun, build_system_config, enable_offline, option, apply_system, SizerOptions
+
+from milabench.testing import official_config
+
+
+def test_system_matrix():
+    with enable_offline(True):
+        sys = build_system_config(official_config("examples/system"))
+        
+        n = 0
+        for name, conf in multirun():
+            print(name, conf)
+            n += 1
+
+        assert n == 39
+
+
+def test_apply_system_matrix():
+    with enable_offline(True):
+        sys = build_system_config(official_config("examples/system"))
+
+        for name, conf in multirun():
+            with apply_system(conf):
+                
+                # Apply system worked and changed the config
+                for k, v in conf.items():
+                    assert option(k, lambda x: x) == v
+
+
+                assert SizerOptions().save == option("sizer.save", lambda x: x)
+
+    
+    
+if __name__ == "__main__":
+    test_apply_system_matrix()
diff --git a/tests/test_validation.py b/tests/test_validation.py
index d5f1007b8..9ed9000aa 100644
--- a/tests/test_validation.py
+++ b/tests/test_validation.py
@@ -76,39 +76,46 @@ def test_planning_layer_per_gpu_bad(replayfolder, monkeypatch):
 
 def test_memory_tracking(replayfolder, config, tmp_path):
     import contextvars
-
-    from milabench.sizer import (
-        MemoryUsageExtractor,
-        Sizer,
-        SizerOptions,
-        sizer_global,
-        system_global,
-    )
-
-    ctx = contextvars.copy_context()
-
-    def update_ctx():
-        sizer = Sizer(
-            SizerOptions(
-                size=None,
-                autoscale=True,
-                multiple=8,
-            ),
-            config("scaling"),
+    import yaml
+    from milabench.system import apply_system, option
+    
+    conf = {
+        "gpu": {
+            "capacity": "41920 MiB"
+        },
+        "options": {
+            "sizer": {
+                "multiple": 8,
+                "autoscale": 1
+            }
+        }
+    }
+    
+    with apply_system(conf):
+        from milabench.sizer import (
+            MemoryUsageExtractor,
+            Sizer,
+            SizerOptions,
+            sizer_global,
+            system_global,
         )
-        sizer_global.set(sizer)
-        system_global.set({"gpu": {"capacity": "41920 MiB"}})
-
-    ctx.run(update_ctx)
-    layer = ctx.run(lambda: MemoryUsageExtractor())
-
-    layer.filepath = f"{tmp_path}/dummy"
-
-    assert 123 not in layer.memory["benchio"]["model"]
-
-    ctx.run(lambda: replay_validation_scenario(replayfolder, layer, filename="usage"))
-
-    assert 123 in layer.memory["benchio"]["model"]
+        
+        layer = MemoryUsageExtractor()
+        with open(config("scaling"), "r") as sconf:
+            layer.memory = yaml.safe_load(sconf)
+            
+        layer.filepath = f"{tmp_path}/dummy"
+
+        print(system_global.get())
+        # print(option("sizer.multiple", etype=int))
+        # print(option("sizer.config", etype=str))
+        # print(Sizer().scaling_config)
+        assert 123 not in layer.memory["benchio"]["model"]
+
+        replay_validation_scenario(replayfolder, layer, filename="usage")
+
+        # print(layer.memory)
+        assert 123 in layer.memory["benchio"]["model"]
 
 
 def test_exception_tracking(replayfolder, file_regression, capsys):
diff --git a/tests/test_capabilities.py b/tests/test_validation/test_capabilities.py
similarity index 100%
rename from tests/test_capabilities.py
rename to tests/test_validation/test_capabilities.py