Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#16364: split prefetch dependent config #17548

Merged
merged 1 commit into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions tt_metal/impl/dispatch/kernel_config/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@ void DispatchKernel::GenerateStaticConfigs() {
static_config_.my_downstream_cb_sem_id = 0; // unused

static_config_.split_dispatch_page_preamble_size = 0; // unused
static_config_.split_prefetch = false; // split_prefetcher
dependent_config_.prefetch_h_noc_xy = 0; // unused prefetch noc_xy
dependent_config_.prefetch_h_local_downstream_sem_addr = 0; // unused prefetch_local_downstream_sem_addr
static_config_.prefetch_h_max_credits = 0; // unused prefetch_downstream_buffer_pages

static_config_.packed_write_max_unicast_sub_cmds =
Expand Down Expand Up @@ -94,7 +91,6 @@ void DispatchKernel::GenerateStaticConfigs() {
static_config_.my_downstream_cb_sem_id = 0; // Unused

static_config_.split_dispatch_page_preamble_size = 0;
static_config_.split_prefetch = true;
// TODO: why is this hard-coded to 1 CQ on Galaxy?
if (tt::Cluster::instance().is_galaxy_cluster()) {
static_config_.prefetch_h_max_credits = my_dispatch_constants.mux_buffer_pages(1);
Expand Down Expand Up @@ -144,9 +140,6 @@ void DispatchKernel::GenerateStaticConfigs() {
GetCoreType()); // Apparently unused

static_config_.split_dispatch_page_preamble_size = sizeof(dispatch_packet_header_t);
static_config_.split_prefetch = true;
dependent_config_.prefetch_h_noc_xy = 0;
dependent_config_.prefetch_h_local_downstream_sem_addr = 1;
static_config_.prefetch_h_max_credits = my_dispatch_constants.mux_buffer_pages(device_->num_hw_cqs());

static_config_.packed_write_max_unicast_sub_cmds =
Expand Down Expand Up @@ -184,6 +177,19 @@ void DispatchKernel::GenerateDependentConfigs() {
dependent_config_.upstream_dispatch_cb_sem_id = prefetch_kernel->GetStaticConfig().my_downstream_cb_sem_id;
dependent_config_.upstream_sync_sem = prefetch_kernel->GetStaticConfig().downstream_sync_sem_id;

if (prefetch_kernel->GetStaticConfig().is_h_variant.value() &&
prefetch_kernel->GetStaticConfig().is_d_variant.value()) {
dependent_config_.split_prefetch = false;
dependent_config_.prefetch_h_noc_xy = 0;
dependent_config_.prefetch_h_local_downstream_sem_addr = 0;
} else {
dependent_config_.split_prefetch = true;
dependent_config_.prefetch_h_noc_xy = tt::tt_metal::hal.noc_xy_encoding(
prefetch_kernel->GetVirtualCore().x, prefetch_kernel->GetVirtualCore().y);
dependent_config_.prefetch_h_local_downstream_sem_addr =
prefetch_kernel->GetStaticConfig().my_downstream_cb_sem_id.value();
}

// Downstream
if (DispatchQueryManager::instance().dispatch_s_enabled()) {
TT_ASSERT(downstream_kernels_.size() == 1);
Expand Down Expand Up @@ -215,9 +221,10 @@ void DispatchKernel::GenerateDependentConfigs() {
// write to when resuming sending of commands post exec_buf stall.
TT_ASSERT(downstream_kernels_.size() == 1);
auto prefetch_h_kernel = dynamic_cast<PrefetchKernel*>(downstream_kernels_[0]);
TT_ASSERT(prefetch_h_kernel);
TT_ASSERT(prefetch_h_kernel && prefetch_h_kernel->GetStaticConfig().is_h_variant.value());
dependent_config_.downstream_logical_core = UNUSED_LOGICAL_CORE;
dependent_config_.downstream_s_logical_core = UNUSED_LOGICAL_CORE;
dependent_config_.split_prefetch = true;
dependent_config_.prefetch_h_noc_xy = tt::tt_metal::hal.noc_xy_encoding(
prefetch_h_kernel->GetVirtualCore().x, prefetch_h_kernel->GetVirtualCore().y);
dependent_config_.prefetch_h_local_downstream_sem_addr =
Expand All @@ -233,6 +240,20 @@ void DispatchKernel::GenerateDependentConfigs() {
dependent_config_.upstream_logical_core = prefetch_kernel->GetLogicalCore();
dependent_config_.upstream_dispatch_cb_sem_id = prefetch_kernel->GetStaticConfig().my_downstream_cb_sem_id;
dependent_config_.upstream_sync_sem = prefetch_kernel->GetStaticConfig().downstream_sync_sem_id;

if (prefetch_kernel->GetStaticConfig().is_h_variant.value() &&
prefetch_kernel->GetStaticConfig().is_d_variant.value()) {
dependent_config_.split_prefetch = false;
dependent_config_.prefetch_h_noc_xy = 0;
dependent_config_.prefetch_h_local_downstream_sem_addr = 0;
} else {
dependent_config_.split_prefetch = true;
dependent_config_.prefetch_h_noc_xy = tt::tt_metal::hal.noc_xy_encoding(
prefetch_kernel->GetVirtualCore().x, prefetch_kernel->GetVirtualCore().y);
dependent_config_.prefetch_h_local_downstream_sem_addr =
prefetch_kernel->GetStaticConfig().my_downstream_cb_sem_id.value();
}

// Downstream, expect a MUX_D and a DISPATCH_S if enabled
auto dispatch_s_kernel = dynamic_cast<DispatchSKernel*>(downstream_kernels_[0]);
auto mux_kernel = dynamic_cast<MuxKernel*>(downstream_kernels_[0]);
Expand Down Expand Up @@ -285,7 +306,7 @@ void DispatchKernel::CreateKernel() {
dependent_config_.downstream_cb_sem_id.value(),

static_config_.split_dispatch_page_preamble_size.value(),
static_config_.split_prefetch.value(),
dependent_config_.split_prefetch.value(),
dependent_config_.prefetch_h_noc_xy.value(),
dependent_config_.prefetch_h_local_downstream_sem_addr.value(),
static_config_.prefetch_h_max_credits.value(),
Expand Down
8 changes: 4 additions & 4 deletions tt_metal/impl/dispatch/kernel_config/dispatch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ typedef struct dispatch_static_config {
std::optional<uint32_t> my_downstream_cb_sem_id;

std::optional<uint32_t> split_dispatch_page_preamble_size; // 14
std::optional<uint32_t> split_prefetch;
std::optional<uint32_t> prefetch_h_max_credits;
std::optional<uint32_t> prefetch_h_max_credits; // Used if split_prefetch is true

std::optional<uint32_t> packed_write_max_unicast_sub_cmds; // 19
std::optional<uint32_t> dispatch_s_sync_sem_base_addr;
Expand Down Expand Up @@ -50,8 +49,9 @@ typedef struct dispatch_dependent_config {
std::optional<uint32_t> downstream_cb_size; // Dependent
std::optional<uint32_t> downstream_cb_sem_id; // Dependant

std::optional<uint32_t> prefetch_h_noc_xy; // Dependent
std::optional<uint32_t> prefetch_h_local_downstream_sem_addr; // Dependent
std::optional<uint32_t> split_prefetch; // If upstream is NOT a prefetch_HD
std::optional<uint32_t> prefetch_h_noc_xy; // Dependent. Used if split_prefetch is true
std::optional<uint32_t> prefetch_h_local_downstream_sem_addr; // Dependent. Used if split_prefetch is true
} dispatch_dependent_config_t;

class DispatchKernel : public FDKernel {
Expand Down
Loading