Skip to content

Commit

Permalink
Updated XLA Flags for SparseCore
Browse files Browse the repository at this point in the history
  • Loading branch information
Obliviour committed Jan 30, 2025
1 parent aab4ed7 commit 438697f
Showing 1 changed file with 40 additions and 17 deletions.
57 changes: 40 additions & 17 deletions benchmarks/xla_flags_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
" --xla_enable_async_all_gather=true"
)

# Continuation Fusion (CF)for All Reduce Collectives
# Continuation Fusion (CF) for All Reduce Collectives
# Continuation Fusion is a form of parallelizing compute work with collectives.
CF_FOR_ALL_REDUCE = (
" --xla_tpu_enable_async_collective_fusion=true"
Expand All @@ -64,33 +64,56 @@
" --xla_tpu_enable_async_collective_fusion_multiple_steps=true"
)

# Enable SparseCore All Gather (SC AG).
# Either one of CF AG or SC AG can be enabled at a time.
ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_GATHER = (
" --xla_sc_disable_megacore_partitioning=true"
" --xla_tpu_enable_async_collective_fusion_fuse_all_gather=false"
" --xla_tpu_enable_all_gather_offload_tracing=true"

# Base Flags needed when enabling sparsecore offloading
ENABLE_SPARSECORE_OFFLOADING_BASE_FLAGS = (
" --xla_tpu_use_tc_device_shape_on_sc=true"
" --xla_tpu_enable_sparse_core_collective_offload_all_gather=true"
" --xla_sc_enable_instruction_fusion=false"
" --xla_sc_disjoint_spmem=false"
" --xla_sc_disable_megacore_partitioning=true"
" --2a886c8_chip_config_name=megachip_tccontrol"
" --xla_tpu_enable_sparse_core_collective_offload_all_reduce=false"
)

# Enable SparseCore All Gather (1D), Reduce Scatter (1D) and All Reduce (ND)
ENABLE_SPARSECORE_OFFLOADING_FOR_RS_AG_AR = (
" --xla_tpu_enable_async_collective_fusion_fuse_all_gather=false"
" --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=false"
" --xla_tpu_enable_async_collective_fusion_fuse_reduce_scatter=false"

" --xla_tpu_enable_sparse_core_collective_offload_all_gather=true"
" --xla_tpu_enable_sparse_core_collective_offload_reduce_scatter=true"
" --xla_tpu_enable_sparse_core_collective_offload_all_reduce=true"

" --xla_tpu_enable_all_gather_offload_tracing=true"
" --xla_tpu_enable_reduce_scatter_offload_tracing=true"
" --xla_tpu_enable_all_reduce_offload_tracing=true"
) + ENABLE_SPARSECORE_OFFLOADING_BASE_FLAGS

# Enable SparseCore Reduce Scatter (SC RS)
# Either one of CF or SC can be enabled at a time.
ENABLE_SPARSECORE_OFFLOADING_FOR_REDUCE_SCATTER = (
" --xla_tpu_enable_async_collective_fusion_fuse_reduce_scatter=false"
" --xla_tpu_enable_sparse_core_collective_offload_reduce_scatter=true"
" --xla_tpu_enable_reduce_scatter_offload_tracing=true"
) + ENABLE_SPARSECORE_OFFLOADING_BASE_FLAGS

# Enable SparseCore All Gather (SC AG).
# Either one of CF or SC can be enabled at a time.
ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_GATHER = (
" --xla_tpu_enable_async_collective_fusion_fuse_all_gather=false"
" --xla_tpu_enable_sparse_core_collective_offload_all_gather=true"
" --xla_tpu_enable_all_gather_offload_tracing=true"
) + ENABLE_SPARSECORE_OFFLOADING_BASE_FLAGS

# Enable SparseCore All Reduce (SC AR)
# Either one of CF AR or SC AR can be enabled at a time.
# Either one of CF or SC can be enabled at a time.
# This is useful for reducing the gradient reduction all-reduce time with
# overlapping with compute during that time.
ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE = (
" --xla_sc_disable_megacore_partitioning=true"
" --xla_tpu_enable_all_reduce_offload_tracing=true"
" --xla_tpu_use_tc_device_shape_on_sc=true"
" --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=false"
" --xla_tpu_enable_sparse_core_collective_offload_all_reduce=true"
" --xla_sc_enable_instruction_fusion=false"
" --xla_sc_disjoint_spmem=false"
" --2a886c8_chip_config_name=megachip_tccontrol"
)
" --xla_tpu_enable_all_reduce_offload_tracing=true"
) + ENABLE_SPARSECORE_OFFLOADING_BASE_FLAGS

# Better memory layout for all-reduce (AR).
LAYOUT_FOR_ALL_REDUCE_SCATTER = (
Expand Down

0 comments on commit 438697f

Please sign in to comment.