Skip to content

Commit

Permalink
[ADD] All missing workload specific batch sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
runame authored Jul 20, 2024
1 parent d7f5c5a commit dcc529c
Showing 1 changed file with 46 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,12 @@ def get_batch_size(workload_name):
on_vector_cluster = int(os.environ.get("RUNNING_ON_VECTOR_CLUSTER", default=0)) == 1

# Return the global batch size.
if workload_name == "criteo1tb":
if workload_name in {
"criteo1tb",
"criteo1tb_layernorm",
"criteo1tb_embed_init",
"criteo1tb_resnet",
}:
return 262_144
elif workload_name in {
"fastmri",
Expand All @@ -227,12 +232,27 @@ def get_batch_size(workload_name):
batch_size = 1024 // 32 if on_vector_cluster else 1024
elif workload_name in {
"librispeech_conformer",
"librispeech_conformer_attention_temperature",
"librispeech_conformer_layernorm",
"librispeech_conformer_gelu",
"librispeech_deepspeech",
"librispeech_deepspeech_tanh",
"librispeech_deepspeech_no_resnet",
"librispeech_deepspeech_norm_and_spec_aug",
}:
batch_size = 256
elif workload_name == "ogbg":
elif workload_name in {
"ogbg",
"ogbg_gelu",
"ogbg_silu",
"ogbg_model_size",
}:
batch_size = 512
elif workload_name == "wmt":
elif workload_name in {
"wmt",
"wmt_attention_temp",
"wmt_glu_tanh",
}:
batch_size = 128
elif workload_name == "mnist":
batch_size = 16
Expand Down Expand Up @@ -271,7 +291,12 @@ def get_eval_batch_size(workload_name):
on_vector_cluster = int(os.environ.get("RUNNING_ON_VECTOR_CLUSTER", default=0)) == 1

# Return the global eval batch size.
if workload_name == "criteo1tb":
if workload_name in {
"criteo1tb",
"criteo1tb_layernorm",
"criteo1tb_embed_init",
"criteo1tb_resnet",
}:
return 524288
elif workload_name in {
"fastmri",
Expand All @@ -296,12 +321,27 @@ def get_eval_batch_size(workload_name):
batch_size = 2048 // 512 if on_vector_cluster else 2048
elif workload_name in {
"librispeech_conformer",
"librispeech_conformer_attention_temperature",
"librispeech_conformer_layernorm",
"librispeech_conformer_gelu",
"librispeech_deepspeech",
"librispeech_deepspeech_tanh",
"librispeech_deepspeech_no_resnet",
"librispeech_deepspeech_norm_and_spec_aug",
}:
batch_size = 256
elif workload_name == "ogbg":
elif workload_name in {
"ogbg",
"ogbg_gelu",
"ogbg_silu",
"ogbg_model_size",
}:
batch_size = 32768
elif workload_name == "wmt":
elif workload_name in {
"wmt",
"wmt_attention_temp",
"wmt_glu_tanh",
}:
batch_size = 128
elif workload_name == "mnist":
batch_size = 10000
Expand Down

0 comments on commit dcc529c

Please sign in to comment.