Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated XLA Flags for SparseCore #1210

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 40 additions & 17 deletions benchmarks/xla_flags_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
" --xla_enable_async_all_gather=true"
)

# Continuation Fusion (CF)for All Reduce Collectives
# Continuation Fusion (CF) for All Reduce Collectives
# Continuation Fusion is a form of parallelizing compute work with collectives.
CF_FOR_ALL_REDUCE = (
" --xla_tpu_enable_async_collective_fusion=true"
Expand All @@ -64,33 +64,56 @@
" --xla_tpu_enable_async_collective_fusion_multiple_steps=true"
)

# Enable SparseCore All Gather (SC AG).
# Either one of CF AG or SC AG can be enabled at a time.
ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_GATHER = (
" --xla_sc_disable_megacore_partitioning=true"
" --xla_tpu_enable_async_collective_fusion_fuse_all_gather=false"
" --xla_tpu_enable_all_gather_offload_tracing=true"

# Base Flags needed when enabling sparsecore offloading
ENABLE_SPARSECORE_OFFLOADING_BASE_FLAGS = (
" --xla_tpu_use_tc_device_shape_on_sc=true"
" --xla_tpu_enable_sparse_core_collective_offload_all_gather=true"
" --xla_sc_enable_instruction_fusion=false"
" --xla_sc_disjoint_spmem=false"
" --xla_sc_disable_megacore_partitioning=true"
" --2a886c8_chip_config_name=megachip_tccontrol"
" --xla_tpu_enable_sparse_core_collective_offload_all_reduce=false"
)

# Enable SparseCore All Gather (1D), Reduce Scatter (1D) and All Reduce (ND)
ENABLE_SPARSECORE_OFFLOADING_FOR_RS_AG_AR = (
" --xla_tpu_enable_async_collective_fusion_fuse_all_gather=false"
" --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=false"
" --xla_tpu_enable_async_collective_fusion_fuse_reduce_scatter=false"

" --xla_tpu_enable_sparse_core_collective_offload_all_gather=true"
" --xla_tpu_enable_sparse_core_collective_offload_reduce_scatter=true"
" --xla_tpu_enable_sparse_core_collective_offload_all_reduce=true"

" --xla_tpu_enable_all_gather_offload_tracing=true"
" --xla_tpu_enable_reduce_scatter_offload_tracing=true"
" --xla_tpu_enable_all_reduce_offload_tracing=true"
) + ENABLE_SPARSECORE_OFFLOADING_BASE_FLAGS

# Enable SparseCore Reduce Scatter (SC RS)
# Either one of CF or SC can be enabled at a time.
ENABLE_SPARSECORE_OFFLOADING_FOR_REDUCE_SCATTER = (
" --xla_tpu_enable_async_collective_fusion_fuse_reduce_scatter=false"
" --xla_tpu_enable_sparse_core_collective_offload_reduce_scatter=true"
" --xla_tpu_enable_reduce_scatter_offload_tracing=true"
) + ENABLE_SPARSECORE_OFFLOADING_BASE_FLAGS

# Enable SparseCore All Gather (SC AG).
# Either one of CF or SC can be enabled at a time.
ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_GATHER = (
" --xla_tpu_enable_async_collective_fusion_fuse_all_gather=false"
" --xla_tpu_enable_sparse_core_collective_offload_all_gather=true"
" --xla_tpu_enable_all_gather_offload_tracing=true"
) + ENABLE_SPARSECORE_OFFLOADING_BASE_FLAGS

# Enable SparseCore All Reduce (SC AR)
# Either one of CF AR or SC AR can be enabled at a time.
# Either one of CF or SC can be enabled at a time.
# This is useful for reducing the gradient reduction all-reduce time with
# overlapping with compute during that time.
ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE = (
" --xla_sc_disable_megacore_partitioning=true"
" --xla_tpu_enable_all_reduce_offload_tracing=true"
" --xla_tpu_use_tc_device_shape_on_sc=true"
" --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=false"
" --xla_tpu_enable_sparse_core_collective_offload_all_reduce=true"
" --xla_sc_enable_instruction_fusion=false"
" --xla_sc_disjoint_spmem=false"
" --2a886c8_chip_config_name=megachip_tccontrol"
)
" --xla_tpu_enable_all_reduce_offload_tracing=true"
) + ENABLE_SPARSECORE_OFFLOADING_BASE_FLAGS

# Better memory layout for all-reduce (AR).
LAYOUT_FOR_ALL_REDUCE_SCATTER = (
Expand Down
Loading