diff --git a/tt_metal/impl/dispatch/kernel_config/dispatch.cpp b/tt_metal/impl/dispatch/kernel_config/dispatch.cpp index 74251b35841..6235957f70e 100644 --- a/tt_metal/impl/dispatch/kernel_config/dispatch.cpp +++ b/tt_metal/impl/dispatch/kernel_config/dispatch.cpp @@ -44,9 +44,6 @@ void DispatchKernel::GenerateStaticConfigs() { static_config_.my_downstream_cb_sem_id = 0; // unused static_config_.split_dispatch_page_preamble_size = 0; // unused - static_config_.split_prefetch = false; // split_prefetcher - dependent_config_.prefetch_h_noc_xy = 0; // unused prefetch noc_xy - dependent_config_.prefetch_h_local_downstream_sem_addr = 0; // unused prefetch_local_downstream_sem_addr static_config_.prefetch_h_max_credits = 0; // unused prefetch_downstream_buffer_pages static_config_.packed_write_max_unicast_sub_cmds = @@ -94,7 +91,6 @@ void DispatchKernel::GenerateStaticConfigs() { static_config_.my_downstream_cb_sem_id = 0; // Unused static_config_.split_dispatch_page_preamble_size = 0; - static_config_.split_prefetch = true; // TODO: why is this hard-coded to 1 CQ on Galaxy? if (tt::Cluster::instance().is_galaxy_cluster()) { static_config_.prefetch_h_max_credits = my_dispatch_constants.mux_buffer_pages(1); @@ -144,9 +140,6 @@ void DispatchKernel::GenerateStaticConfigs() { GetCoreType()); // Apparently unused static_config_.split_dispatch_page_preamble_size = sizeof(dispatch_packet_header_t); - static_config_.split_prefetch = true; - dependent_config_.prefetch_h_noc_xy = 0; - dependent_config_.prefetch_h_local_downstream_sem_addr = 1; static_config_.prefetch_h_max_credits = my_dispatch_constants.mux_buffer_pages(device_->num_hw_cqs()); static_config_.packed_write_max_unicast_sub_cmds = @@ -184,6 +177,19 @@ void DispatchKernel::GenerateDependentConfigs() { dependent_config_.upstream_dispatch_cb_sem_id = prefetch_kernel->GetStaticConfig().my_downstream_cb_sem_id; dependent_config_.upstream_sync_sem = prefetch_kernel->GetStaticConfig().downstream_sync_sem_id; + if (prefetch_kernel->GetStaticConfig().is_h_variant.value() && + prefetch_kernel->GetStaticConfig().is_d_variant.value()) { + dependent_config_.split_prefetch = false; + dependent_config_.prefetch_h_noc_xy = 0; + dependent_config_.prefetch_h_local_downstream_sem_addr = 0; + } else { + dependent_config_.split_prefetch = true; + dependent_config_.prefetch_h_noc_xy = tt::tt_metal::hal.noc_xy_encoding( + prefetch_kernel->GetVirtualCore().x, prefetch_kernel->GetVirtualCore().y); + dependent_config_.prefetch_h_local_downstream_sem_addr = + prefetch_kernel->GetStaticConfig().my_downstream_cb_sem_id.value(); + } + // Downstream if (DispatchQueryManager::instance().dispatch_s_enabled()) { TT_ASSERT(downstream_kernels_.size() == 1); @@ -215,9 +221,10 @@ void DispatchKernel::GenerateDependentConfigs() { // write to when resuming sending of commands post exec_buf stall. TT_ASSERT(downstream_kernels_.size() == 1); auto prefetch_h_kernel = dynamic_cast(downstream_kernels_[0]); - TT_ASSERT(prefetch_h_kernel); + TT_ASSERT(prefetch_h_kernel && prefetch_h_kernel->GetStaticConfig().is_h_variant.value()); dependent_config_.downstream_logical_core = UNUSED_LOGICAL_CORE; dependent_config_.downstream_s_logical_core = UNUSED_LOGICAL_CORE; + dependent_config_.split_prefetch = true; dependent_config_.prefetch_h_noc_xy = tt::tt_metal::hal.noc_xy_encoding( prefetch_h_kernel->GetVirtualCore().x, prefetch_h_kernel->GetVirtualCore().y); dependent_config_.prefetch_h_local_downstream_sem_addr = @@ -233,6 +240,20 @@ void DispatchKernel::GenerateDependentConfigs() { dependent_config_.upstream_logical_core = prefetch_kernel->GetLogicalCore(); dependent_config_.upstream_dispatch_cb_sem_id = prefetch_kernel->GetStaticConfig().my_downstream_cb_sem_id; dependent_config_.upstream_sync_sem = prefetch_kernel->GetStaticConfig().downstream_sync_sem_id; + + if (prefetch_kernel->GetStaticConfig().is_h_variant.value() && + prefetch_kernel->GetStaticConfig().is_d_variant.value()) { + dependent_config_.split_prefetch = false; + dependent_config_.prefetch_h_noc_xy = 0; + dependent_config_.prefetch_h_local_downstream_sem_addr = 0; + } else { + dependent_config_.split_prefetch = true; + dependent_config_.prefetch_h_noc_xy = tt::tt_metal::hal.noc_xy_encoding( + prefetch_kernel->GetVirtualCore().x, prefetch_kernel->GetVirtualCore().y); + dependent_config_.prefetch_h_local_downstream_sem_addr = + prefetch_kernel->GetStaticConfig().my_downstream_cb_sem_id.value(); + } + // Downstream, expect a MUX_D and a DISPATCH_S if enabled auto dispatch_s_kernel = dynamic_cast(downstream_kernels_[0]); auto mux_kernel = dynamic_cast(downstream_kernels_[0]); @@ -285,7 +306,7 @@ void DispatchKernel::CreateKernel() { dependent_config_.downstream_cb_sem_id.value(), static_config_.split_dispatch_page_preamble_size.value(), - static_config_.split_prefetch.value(), + dependent_config_.split_prefetch.value(), dependent_config_.prefetch_h_noc_xy.value(), dependent_config_.prefetch_h_local_downstream_sem_addr.value(), static_config_.prefetch_h_max_credits.value(), diff --git a/tt_metal/impl/dispatch/kernel_config/dispatch.hpp b/tt_metal/impl/dispatch/kernel_config/dispatch.hpp index c2a18bdfa3e..00195dae6e8 100644 --- a/tt_metal/impl/dispatch/kernel_config/dispatch.hpp +++ b/tt_metal/impl/dispatch/kernel_config/dispatch.hpp @@ -18,8 +18,7 @@ typedef struct dispatch_static_config { std::optional my_downstream_cb_sem_id; std::optional split_dispatch_page_preamble_size; // 14 - std::optional split_prefetch; - std::optional prefetch_h_max_credits; + std::optional prefetch_h_max_credits; // Used if split_prefetch is true std::optional packed_write_max_unicast_sub_cmds; // 19 std::optional dispatch_s_sync_sem_base_addr; @@ -50,8 +49,9 @@ typedef struct dispatch_dependent_config { std::optional downstream_cb_size; // Dependent std::optional downstream_cb_sem_id; // Dependant - std::optional prefetch_h_noc_xy; // Dependent - std::optional prefetch_h_local_downstream_sem_addr; // Dependent + std::optional split_prefetch; // If upstream is NOT a prefetch_HD + std::optional prefetch_h_noc_xy; // Dependent. Used if split_prefetch is true + std::optional prefetch_h_local_downstream_sem_addr; // Dependent. Used if split_prefetch is true } dispatch_dependent_config_t; class DispatchKernel : public FDKernel {