From a7fe5bd57249d2fc0fc8533daac131af296cdd38 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 5 Nov 2024 17:05:35 -0800 Subject: [PATCH] Fix non-improved staggered dslash tests - long link should not be created --- lib/gauge_field.cpp | 2 ++ tests/host_reference/staggered_dslash_reference.cpp | 3 ++- tests/staggered_dslash_test_utils.h | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/lib/gauge_field.cpp b/lib/gauge_field.cpp index 20cf7621ac..ed414239cd 100644 --- a/lib/gauge_field.cpp +++ b/lib/gauge_field.cpp @@ -81,6 +81,8 @@ namespace quda { errorQuda("Cannot request a 12/8 reconstruct type without SU(3) link type"); if (param.reconstruct == QUDA_RECONSTRUCT_10 && param.link_type != QUDA_ASQTAD_MOM_LINKS) errorQuda("10-reconstruction only supported with momentum links"); + if (param.nFace > x[0] || param.nFace > x[1] || param.nFace > x[2] || param.nFace > x[3]) + errorQuda("Halo depth %d is greater than local lattice x = {%d %d %d %d}", param.nFace, x[0], x[1], x[2], x[3]); nColor = param.nColor; nFace = param.nFace; diff --git a/tests/host_reference/staggered_dslash_reference.cpp b/tests/host_reference/staggered_dslash_reference.cpp index 212549dc38..c5846d1ca7 100644 --- a/tests/host_reference/staggered_dslash_reference.cpp +++ b/tests/host_reference/staggered_dslash_reference.cpp @@ -107,7 +107,7 @@ void staggeredDslashReference(real_t *res, const real_t *const *fatlink, const r } // 4-d volume } -void stag_dslash(ColorSpinorField &out, const GaugeField &fat_link, const GaugeField &long_link, +void stag_dslash(ColorSpinorField &out, const GaugeField &fat_link, const GaugeField &long_link_, const ColorSpinorField &in, int oddBit, int daggerBit, QudaDslashType dslash_type) { // assert sPrecision and gPrecision must be the same @@ -131,6 +131,7 @@ void stag_dslash(ColorSpinorField &out, const GaugeField &fat_link, const GaugeF in.exchangeGhost(otherparity, nFace, daggerBit); + auto &long_link = dslash_type == QUDA_ASQTAD_DSLASH ? long_link_ : fat_link; void *qdp_fatlink[] = {fat_link.data(0), fat_link.data(1), fat_link.data(2), fat_link.data(3)}; void *qdp_longlink[] = {long_link.data(0), long_link.data(1), long_link.data(2), long_link.data(3)}; void *ghost_fatlink[] diff --git a/tests/staggered_dslash_test_utils.h b/tests/staggered_dslash_test_utils.h index d7e98e01d4..5280184613 100644 --- a/tests/staggered_dslash_test_utils.h +++ b/tests/staggered_dslash_test_utils.h @@ -275,7 +275,7 @@ struct StaggeredDslashTestWrapper { GaugeFieldParam cpuLongParam(gauge_param, qdp_longlink); cpuLongParam.order = QUDA_QDP_GAUGE_ORDER; cpuLongParam.ghostExchange = QUDA_GHOST_EXCHANGE_PAD; - cpuLong = GaugeField(cpuLongParam); + if (dslash_type == QUDA_ASQTAD_DSLASH) cpuLong = GaugeField(cpuLongParam); // Override link reconstruct as appropriate for staggered or asqtad if (is_staggered(dslash_type)) {