Skip to content

Commit

Permalink
LightMetal - Use as_*() instead of static_cast directly for some unio…
Browse files Browse the repository at this point in the history
…n type handling in from_flatbuffer()

 - CoreSpec and KernelConfig are unions, do need to use static_cast
   but for now will prefer to hide that detail by using
   cmd->core_spec_as_CoreCoord() instead.
 - Downside here is that field names core_spec and kernel_config must be
   consistent across all users of these from_flatbuffer() funcs now.
  • Loading branch information
kmabeeTT committed Feb 2, 2025
1 parent 152b88b commit 36df59d
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 51 deletions.
47 changes: 3 additions & 44 deletions tt_metal/impl/flatbuffer/program_types_from_flatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,8 @@

namespace tt::tt_metal {

std::variant<CoreCoord, CoreRange, CoreRangeSet> from_flatbuffer(
const flatbuffer::CoreSpec core_spec, const void* flatbuffer_union) {
switch (core_spec) {
case flatbuffer::CoreSpec::CoreCoord: {
auto core_coord = static_cast<const flatbuffer::CoreCoord*>(flatbuffer_union);
TT_FATAL(core_coord, "Invalid CoreCoord data");
return CoreCoord{core_coord->x(), core_coord->y()};
}
case flatbuffer::CoreSpec::CoreRange: {
auto core_range = static_cast<const flatbuffer::CoreRange*>(flatbuffer_union);
TT_FATAL(core_range, "Invalid CoreRange data");
return CoreRange{
{core_range->start()->x(), core_range->start()->y()}, {core_range->end()->x(), core_range->end()->y()}};
}
case flatbuffer::CoreSpec::CoreRangeSet: {
auto core_range_set = static_cast<const flatbuffer::CoreRangeSet*>(flatbuffer_union);
TT_FATAL(core_range_set, "Invalid CoreRangeSet data");
std::vector<CoreRange> ranges;
for (const auto range : *core_range_set->ranges()) {
ranges.emplace_back(
CoreCoord{range->start()->x(), range->start()->y()},
CoreCoord{range->end()->x(), range->end()->y()});
}
return CoreRangeSet{ranges};
}
default: throw std::runtime_error("Unhandled CoreSpec type in from_flatbuffer");
}
}

DataMovementConfig from_flatbuffer(const flatbuffer::DataMovementConfig* fb_config) {
TT_FATAL(fb_config, "Invalid DataMovementConfig data from flatbuffer.");
DataMovementConfig config;

// Extract processor, noc, and noc_mode
Expand All @@ -58,6 +30,7 @@ DataMovementConfig from_flatbuffer(const flatbuffer::DataMovementConfig* fb_conf
}

ComputeConfig from_flatbuffer(const flatbuffer::ComputeConfig* fb_config) {
TT_FATAL(fb_config, "Invalid ComputeConfig data from flatbuffer.");
ComputeConfig config;

// Extract math_fidelity and boolean flags
Expand Down Expand Up @@ -88,6 +61,7 @@ ComputeConfig from_flatbuffer(const flatbuffer::ComputeConfig* fb_config) {
}

EthernetConfig from_flatbuffer(const flatbuffer::EthernetConfig* fb_config) {
TT_FATAL(fb_config, "Invalid EthernetConfig data from flatbuffer.");
EthernetConfig config;

// Extract eth_mode, noc, and processor
Expand All @@ -108,21 +82,6 @@ EthernetConfig from_flatbuffer(const flatbuffer::EthernetConfig* fb_config) {
return config;
}

std::variant<DataMovementConfig, ComputeConfig, EthernetConfig> from_flatbuffer(
const flatbuffer::KernelConfig config_type, const void* flatbuffer_union) {
switch (config_type) {
case flatbuffer::KernelConfig::DataMovementConfig:
return from_flatbuffer(static_cast<const flatbuffer::DataMovementConfig*>(flatbuffer_union));
case flatbuffer::KernelConfig::ComputeConfig:
return from_flatbuffer(static_cast<const flatbuffer::ComputeConfig*>(flatbuffer_union));
case flatbuffer::KernelConfig::EthernetConfig:
return from_flatbuffer(static_cast<const flatbuffer::EthernetConfig*>(flatbuffer_union));
case flatbuffer::KernelConfig::NONE:
throw std::runtime_error("Unhandled KernelConfig type in from_flatbuffer.");
}
TT_THROW("Unhandled KernelConfig type in from_flatbuffer.");
}

std::vector<SubDeviceId> from_flatbuffer(const flatbuffers::Vector<uint8_t>* fb_sub_device_ids) {
std::vector<SubDeviceId> sub_device_ids(fb_sub_device_ids ? fb_sub_device_ids->size() : 0);

Expand Down
51 changes: 44 additions & 7 deletions tt_metal/impl/flatbuffer/program_types_from_flatbuffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,53 @@

namespace tt::tt_metal {

std::variant<CoreCoord, CoreRange, CoreRangeSet> from_flatbuffer(
const flatbuffer::CoreSpec core_spec, const void* flatbuffer_union);

DataMovementConfig from_flatbuffer(const flatbuffer::DataMovementConfig* fb_config);
ComputeConfig from_flatbuffer(const flatbuffer::ComputeConfig* fb_config);
EthernetConfig from_flatbuffer(const flatbuffer::EthernetConfig* fb_config);

std::variant<DataMovementConfig, ComputeConfig, EthernetConfig> from_flatbuffer(
const flatbuffer::KernelConfig config_type, const void* flatbuffer_union);

std::vector<SubDeviceId> from_flatbuffer(const flatbuffers::Vector<uint8_t>* fb_sub_device_ids);

template <typename CommandType>
std::variant<CoreCoord, CoreRange, CoreRangeSet> core_spec_from_flatbuffer(const CommandType* cmd) {
switch (cmd->core_spec_type()) {
case flatbuffer::CoreSpec::CoreCoord: {
const auto* core_coord = cmd->core_spec_as_CoreCoord();
TT_FATAL(core_coord, "Invalid CoreCoord data from flatbuffer.");
return CoreCoord{core_coord->x(), core_coord->y()};
}
case flatbuffer::CoreSpec::CoreRange: {
const auto* core_range = cmd->core_spec_as_CoreRange();
TT_FATAL(core_range, "Invalid CoreRange data from flatbuffer.");
return CoreRange{
{core_range->start()->x(), core_range->start()->y()}, {core_range->end()->x(), core_range->end()->y()}};
}
case flatbuffer::CoreSpec::CoreRangeSet: {
const auto* core_range_set = cmd->core_spec_as_CoreRangeSet();
TT_FATAL(core_range_set, "Invalid CoreRangeSet data from flatbuffer.");

std::vector<CoreRange> ranges;
for (const auto* range : *core_range_set->ranges()) {
ranges.emplace_back(
CoreCoord{range->start()->x(), range->start()->y()},
CoreCoord{range->end()->x(), range->end()->y()});
}
return CoreRangeSet{ranges};
}
case flatbuffer::CoreSpec::NONE: TT_THROW("Invalid CoreSpec type. NONE cannot be processed.");
}
TT_THROW("Unhandled CoreSpec type in from_flatbuffer.");
}

template <typename CommandType>
std::variant<DataMovementConfig, ComputeConfig, EthernetConfig> kernel_config_from_flatbuffer(const CommandType* cmd) {
switch (cmd->kernel_config_type()) {
case flatbuffer::KernelConfig::DataMovementConfig:
return from_flatbuffer(cmd->kernel_config_as_DataMovementConfig());
case flatbuffer::KernelConfig::ComputeConfig: return from_flatbuffer(cmd->kernel_config_as_ComputeConfig());
case flatbuffer::KernelConfig::EthernetConfig: return from_flatbuffer(cmd->kernel_config_as_EthernetConfig());
case flatbuffer::KernelConfig::NONE:
throw std::runtime_error("Unhandled KernelConfig type in from_flatbuffer.");
}
TT_THROW("Unhandled KernelConfig type in from_flatbuffer.");
}

} // namespace tt::tt_metal

0 comments on commit 36df59d

Please sign in to comment.