Skip to content

Commit

Permalink
This is the first of several CLs implementing automatic mesh selectio…
Browse files Browse the repository at this point in the history
…n for the auto-sharding pass. This CL adds a method to create multiple mesh shapes to try as part of this effort.

PiperOrigin-RevId: 526161397
  • Loading branch information
tensorflower-gardener committed Apr 21, 2023
1 parent 7d6e1c3 commit 4baee46
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tensorflow/compiler/xla/hlo/experimental/auto_sharding/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ cc_library(
"//tensorflow/tsl/platform:errors",
"//tensorflow/tsl/platform:status",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
Expand Down Expand Up @@ -123,6 +124,7 @@ cc_library(
"//tensorflow/compiler/xla/hlo/utils:hlo_sharding_util",
"//tensorflow/tsl/platform:errors",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/btree_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
Expand Down Expand Up @@ -1959,5 +1960,51 @@ bool IsEntryComputationInputOrOutput(const HloModule* module,
}
return false;
}

void CreateDifferentMeshShapesToTryHelper(
int64_t num_devices, size_t num_mesh_dims,
std::vector<int64_t> current_shape,
std::vector<std::vector<int64_t>>& all_shapes) {
if (current_shape.size() == num_mesh_dims - 1) {
current_shape.push_back(num_devices);
if (spmd::VectorGreaterThanOneElementCount(current_shape) <= 2) {
all_shapes.push_back(current_shape);
}
return;
} else {
int64_t current_dim = 1;
while (current_dim <= num_devices) {
std::vector<int64_t> new_shape(current_shape);
new_shape.push_back(current_dim);
CreateDifferentMeshShapesToTryHelper(
num_devices / current_dim, num_mesh_dims, new_shape, all_shapes);
current_dim *= 2;
}
}
}

std::vector<std::vector<int64_t>> CreateDifferentMeshShapesToTry(
const int64_t num_devices, int num_mesh_dims, bool symmetrical_mesh_dims) {
std::vector<std::vector<int64_t>> result;
CreateDifferentMeshShapesToTryHelper(num_devices, num_mesh_dims, {}, result);

if (symmetrical_mesh_dims) {
absl::flat_hash_set<absl::btree_multiset<int64_t>> dedup_result;
for (const auto& mesh_shape : result) {
dedup_result.insert(
absl::btree_multiset<int64_t>(mesh_shape.begin(), mesh_shape.end()));
}

result.clear();

for (const auto& mesh_shape_set : dedup_result) {
result.push_back(
std::vector<int64_t>(mesh_shape_set.begin(), mesh_shape_set.end()));
}
}

return result;
}

} // namespace spmd
} // namespace xla
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,12 @@ bool OutputInputSameShapes(const HloInstruction* ins);

bool IsEntryComputationInputOrOutput(const HloModule* module,
const HloInstruction* ins);

// Given a number of devices (`num_devices`), create a list different mesh
// shapes of a given rank (`num_mesh_dims`) to try, if the option to try
// multiple mesh shapes is enabled.
std::vector<std::vector<int64_t>> CreateDifferentMeshShapesToTry(
int64_t num_devices, int num_mesh_dims, bool symmetrical_mesh_dims);
} // namespace spmd
} // namespace xla

Expand Down

0 comments on commit 4baee46

Please sign in to comment.