diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index 099613fcf..98193f01f 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -307,6 +307,10 @@ def get_batch_size(workload_name): return 32 elif workload_name == 'imagenet_resnet': return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 elif workload_name == 'imagenet_vit': return 1024 elif workload_name == 'librispeech_conformer': diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index ef0c11c0d..66fdc4ebb 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -307,6 +307,10 @@ def get_batch_size(workload_name): return 32 elif workload_name == 'imagenet_resnet': return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 elif workload_name == 'imagenet_vit': return 1024 elif workload_name == 'librispeech_conformer': diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py index 01cffc52e..ebc49d428 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -309,6 +309,10 @@ def get_batch_size(workload_name): return 32 elif workload_name == 'imagenet_resnet': return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 elif workload_name == 'imagenet_vit': return 1024 elif workload_name == 'librispeech_conformer': diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py index 530dd3acf..524bc20af 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py @@ -309,6 +309,10 @@ def get_batch_size(workload_name): return 32 elif workload_name == 'imagenet_resnet': return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 elif workload_name == 'imagenet_vit': return 1024 elif workload_name == 'librispeech_conformer': diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index be8b2f7e5..4f53afb56 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -322,6 +322,10 @@ def get_batch_size(workload_name): return 32 elif workload_name == 'imagenet_resnet': return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 elif workload_name == 'imagenet_vit': return 1024 elif workload_name == 'librispeech_conformer': diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 9ed09a615..60a1f784d 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -322,6 +322,10 @@ def get_batch_size(workload_name): return 32 elif workload_name == 'imagenet_resnet': return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 elif workload_name == 'imagenet_vit': return 1024 elif workload_name == 'librispeech_conformer': diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py index 57da48167..f8e87ec2a 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -324,6 +324,10 @@ def get_batch_size(workload_name): return 32 elif workload_name == 'imagenet_resnet': return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 elif workload_name == 'imagenet_vit': return 1024 elif workload_name == 'librispeech_conformer': diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py index ef6e84c94..1de26417f 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -324,6 +324,10 @@ def get_batch_size(workload_name): return 32 elif workload_name == 'imagenet_resnet': return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 elif workload_name == 'imagenet_vit': return 1024 elif workload_name == 'librispeech_conformer': diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 11212c1a0..80a963600 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -165,6 +165,10 @@ def get_batch_size(workload_name): return 32 elif workload_name == 'imagenet_resnet': return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 elif workload_name == 'imagenet_vit': return 1024 elif workload_name == 'librispeech_conformer': diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index 75a4abbef..32353e5b4 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -133,6 +133,10 @@ def get_batch_size(workload_name): return 32 elif workload_name == 'imagenet_resnet': return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 elif workload_name == 'imagenet_vit': return 1024 elif workload_name == 'librispeech_conformer': diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index 4139ebcf6..cccb3c1b5 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -199,6 +199,10 @@ def get_batch_size(workload_name): return 32 elif workload_name == 'imagenet_resnet': return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 elif workload_name == 'imagenet_vit': return 1024 elif workload_name == 'librispeech_conformer': diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py index b7d87924d..ec5c0b31c 100644 --- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py @@ -152,6 +152,10 @@ def get_batch_size(workload_name): return 32 elif workload_name == 'imagenet_resnet': return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 elif workload_name == 'imagenet_vit': return 1024 elif workload_name == 'librispeech_conformer': diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index 099613fcf..98193f01f 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -307,6 +307,10 @@ def get_batch_size(workload_name): return 32 elif workload_name == 'imagenet_resnet': return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 elif workload_name == 'imagenet_vit': return 1024 elif workload_name == 'librispeech_conformer': diff --git a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py index 01cffc52e..ebc49d428 100644 --- a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py @@ -309,6 +309,10 @@ def get_batch_size(workload_name): return 32 elif workload_name == 'imagenet_resnet': return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 elif workload_name == 'imagenet_vit': return 1024 elif workload_name == 'librispeech_conformer': diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 35cebba1f..f3b0aeed4 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -199,6 +199,10 @@ def get_batch_size(workload_name): return 32 elif workload_name == 'imagenet_resnet': return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 elif workload_name == 'imagenet_vit': return 1024 elif workload_name == 'librispeech_conformer': diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py index 45feb8645..fe9154934 100644 --- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py @@ -152,6 +152,10 @@ def get_batch_size(workload_name): return 32 elif workload_name == 'imagenet_resnet': return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 elif workload_name == 'imagenet_vit': return 1024 elif workload_name == 'librispeech_conformer': diff --git a/submissions/template/submission.py b/submissions/template/submission.py index 0448a46ed..5ef195db5 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -49,7 +49,8 @@ def get_batch_size(workload_name): Args: workload_name (str): Valid workload_name values are: "wmt", "ogbg", "criteo1tb", "fastmri", "imagenet_resnet", "imagenet_vit", - "librispeech_deepspeech", "librispeech_conformer". + "librispeech_deepspeech", "librispeech_conformer" or any of the + variants. Returns: int: batch_size Raises: