Skip to content

Commit

Permalink
Merge pull request #1085 from AI-Hypercomputer:mattdavidow-pp-axis-order
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705693532
  • Loading branch information
maxtext authors committed Dec 13, 2024
2 parents 8e55ab1 + d87121b commit e4e0a4f
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 27 deletions.
3 changes: 0 additions & 3 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,6 @@ mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'ex
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
# For pipeline parallelism the pre and post decoder layer tensors' batch dimension is sharded by stages.
# Microbatches are sharded by stage, so moving out of and into this sharding should be a local reshape.
# The "stage" needs to be listed first since the microbatch dimension is first before the reshape.
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_heads', ['tensor','sequence']],
['activation_kv_heads', ['tensor','sequence']],
Expand Down
25 changes: 2 additions & 23 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,34 +530,13 @@ def create_device_mesh(config, devices=None):

multi_slice_env = num_slices > 1

dcn_parallelism = [
config.dcn_data_parallelism,
config.dcn_pipeline_parallelism,
config.dcn_fsdp_parallelism,
config.dcn_fsdp_transpose_parallelism,
config.dcn_sequence_parallelism,
config.dcn_tensor_parallelism,
config.dcn_expert_parallelism,
config.dcn_autoregressive_parallelism,
]
ici_parallelism = [
config.ici_data_parallelism,
config.ici_pipeline_parallelism,
config.ici_fsdp_parallelism,
config.ici_fsdp_transpose_parallelism,
config.ici_sequence_parallelism,
config.ici_tensor_parallelism,
config.ici_expert_parallelism,
config.ici_autoregressive_parallelism,
]

# Find possible unspecified parallelisms
ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI")
ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism, num_devices_per_slice, "ICI")

allow_split_physical_axes = config.allow_split_physical_axes if config.allow_split_physical_axes else False

if multi_slice_env:
dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN")
dcn_parallelism = fill_unspecified_mesh_axes(config.dcn_parallelism, num_slices, "DCN")
if is_valid_custom_mesh(ici_parallelism, config.custom_mesh):
mesh = create_custom_device_mesh(ici_parallelism, dcn_parallelism, devices, config.custom_mesh)
else:
Expand Down
74 changes: 74 additions & 0 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ def user_init(raw_keys):

raw_keys["num_slices"] = max_utils.get_num_slices(raw_keys)
raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys)
raw_keys = create_parallelisms_list(raw_keys)
raw_keys = set_and_validate_pipeline_config(raw_keys)

if raw_keys["dataset_type"] == "c4_mlperf":
Expand Down Expand Up @@ -470,6 +471,32 @@ def update_model_vars(base_config_path, raw_keys, config_name: str):
return updated_keys


def create_parallelisms_list(raw_keys):
ici_parallelism = [
raw_keys["ici_data_parallelism"],
raw_keys["ici_pipeline_parallelism"],
raw_keys["ici_fsdp_parallelism"],
raw_keys["ici_fsdp_transpose_parallelism"],
raw_keys["ici_sequence_parallelism"],
raw_keys["ici_tensor_parallelism"],
raw_keys["ici_expert_parallelism"],
raw_keys["ici_autoregressive_parallelism"],
]
dcn_parallelism = [
raw_keys["dcn_data_parallelism"],
raw_keys["dcn_pipeline_parallelism"],
raw_keys["dcn_fsdp_parallelism"],
raw_keys["dcn_fsdp_transpose_parallelism"],
raw_keys["dcn_sequence_parallelism"],
raw_keys["dcn_tensor_parallelism"],
raw_keys["dcn_expert_parallelism"],
raw_keys["dcn_autoregressive_parallelism"],
]
raw_keys["ici_parallelism"] = ici_parallelism
raw_keys["dcn_parallelism"] = dcn_parallelism
return raw_keys


def validate_multiple_slices(raw_keys):
if (
math.fabs(
Expand All @@ -493,7 +520,54 @@ def validate_multiple_slices(raw_keys):

def set_and_validate_pipeline_config(raw_keys):
if using_pipeline_parallelism(raw_keys):

def modify_activation_embed_and_logits_batch(logical_axis_rules):
for idx, logical_rule in enumerate(logical_axis_rules):
if logical_rule[0] == "activation_embed_and_logits_batch":
# For pipeline parallelism the pre and post decoder layer tensors' batch dimension is sharded by stages.
# Microbatches are sharded by stage, so moving out of and into this sharding should be a local reshape.
# The "stage" needs to be listed first since the microbatch dimension is first before the reshape.
logical_axis_rules[idx] = [
"activation_embed_and_logits_batch",
["stage", "data", "fsdp", "fsdp_transpose", "expert"],
]
break # Exit the loop after modifying the list
return logical_axis_rules

def pipeline_first_axis(raw_keys):
# We have seen better performance when axes used for DCN are earlier in this list than ICI, see (b/339009148) for details
ici_parallelism = [
raw_keys["ici_pipeline_parallelism"],
raw_keys["ici_data_parallelism"],
raw_keys["ici_fsdp_parallelism"],
raw_keys["ici_fsdp_transpose_parallelism"],
raw_keys["ici_sequence_parallelism"],
raw_keys["ici_tensor_parallelism"],
raw_keys["ici_expert_parallelism"],
raw_keys["ici_autoregressive_parallelism"],
]
dcn_parallelism = [
raw_keys["dcn_pipeline_parallelism"],
raw_keys["dcn_data_parallelism"],
raw_keys["dcn_fsdp_parallelism"],
raw_keys["dcn_fsdp_transpose_parallelism"],
raw_keys["dcn_sequence_parallelism"],
raw_keys["dcn_tensor_parallelism"],
raw_keys["dcn_expert_parallelism"],
raw_keys["dcn_autoregressive_parallelism"],
]
mesh_axes = ["stage", "data", "fsdp", "fsdp_transpose", "sequence", "tensor", "expert", "autoregressive"]
data_sharding = [["stage", "data", "fsdp", "fsdp_transpose", "sequence", "tensor", "expert", "autoregressive"]]

raw_keys["ici_parallelism"] = ici_parallelism
raw_keys["dcn_parallelism"] = dcn_parallelism
raw_keys["mesh_axes"] = mesh_axes
raw_keys["data_sharding"] = data_sharding
return raw_keys

raw_keys["using_pipeline_parallelism"] = True
raw_keys["logical_axis_rules"] = modify_activation_embed_and_logits_batch(raw_keys["logical_axis_rules"])
raw_keys = pipeline_first_axis(raw_keys)
num_stages = int(raw_keys["ici_pipeline_parallelism"] * raw_keys["dcn_pipeline_parallelism"])
if raw_keys["num_pipeline_repeats"] == -1:
num_pipeline_repeats, remainder = divmod(
Expand Down
2 changes: 1 addition & 1 deletion MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ def main(argv: Sequence[str]) -> None:
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
# TF allocates extraneous GPU memory when using TFDS data
# this leads to CUDA OOMs. WAR for now is to hide GPUs from TF
tf.config.set_visible_devices([], 'GPU')
tf.config.set_visible_devices([], "GPU")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""):
os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
Expand Down

0 comments on commit e4e0a4f

Please sign in to comment.