From 4baee46d43b1dcdbf61cedbef8cd220ac8f8d3ff Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 21 Apr 2023 16:12:43 -0700 Subject: [PATCH] This is the first of several CLs implementing automatic mesh selection for the auto-sharding pass. This CL adds a method to create multiple mesh shapes to try as part of this effort. PiperOrigin-RevId: 526161397 --- .../xla/hlo/experimental/auto_sharding/BUILD | 2 + .../auto_sharding/auto_sharding_util.cc | 47 +++++++++++++++++++ .../auto_sharding/auto_sharding_util.h | 6 +++ 3 files changed, 55 insertions(+) diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/BUILD b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/BUILD index 4bb03af31a7202..1fbde7d6569ddc 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/BUILD +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/BUILD @@ -43,6 +43,7 @@ cc_library( "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:status", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -123,6 +124,7 @@ cc_library( "//tensorflow/compiler/xla/hlo/utils:hlo_sharding_util", "//tensorflow/tsl/platform:errors", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index 072b6f0b8cc404..a26cb911606f14 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -29,6 +29,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/btree_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" @@ -1959,5 +1960,51 @@ bool IsEntryComputationInputOrOutput(const HloModule* module, } return false; } + +void CreateDifferentMeshShapesToTryHelper( + int64_t num_devices, size_t num_mesh_dims, + std::vector current_shape, + std::vector>& all_shapes) { + if (current_shape.size() == num_mesh_dims - 1) { + current_shape.push_back(num_devices); + if (spmd::VectorGreaterThanOneElementCount(current_shape) <= 2) { + all_shapes.push_back(current_shape); + } + return; + } else { + int64_t current_dim = 1; + while (current_dim <= num_devices) { + std::vector new_shape(current_shape); + new_shape.push_back(current_dim); + CreateDifferentMeshShapesToTryHelper( + num_devices / current_dim, num_mesh_dims, new_shape, all_shapes); + current_dim *= 2; + } + } +} + +std::vector> CreateDifferentMeshShapesToTry( + const int64_t num_devices, int num_mesh_dims, bool symmetrical_mesh_dims) { + std::vector> result; + CreateDifferentMeshShapesToTryHelper(num_devices, num_mesh_dims, {}, result); + + if (symmetrical_mesh_dims) { + absl::flat_hash_set> dedup_result; + for (const auto& mesh_shape : result) { + dedup_result.insert( + absl::btree_multiset(mesh_shape.begin(), mesh_shape.end())); + } + + result.clear(); + + for (const auto& mesh_shape_set : dedup_result) { + result.push_back( + std::vector(mesh_shape_set.begin(), mesh_shape_set.end())); + } + } + + return result; +} + } // namespace spmd } // namespace xla diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_util.h b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_util.h index ae430d24a41144..20a225ebe92247 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_util.h +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_util.h @@ -582,6 +582,12 @@ bool OutputInputSameShapes(const HloInstruction* ins); bool IsEntryComputationInputOrOutput(const HloModule* module, const HloInstruction* ins); + +// Given a number of devices (`num_devices`), create a list different mesh +// shapes of a given rank (`num_mesh_dims`) to try, if the option to try +// multiple mesh shapes is enabled. +std::vector> CreateDifferentMeshShapesToTry( + int64_t num_devices, int num_mesh_dims, bool symmetrical_mesh_dims); } // namespace spmd } // namespace xla