Skip to content

Commit

Permalink
Merge pull request #1523 from lattice/hotfix/nface
Browse files Browse the repository at this point in the history
Hotfix/nface
  • Loading branch information
maddyscientist authored Dec 18, 2024
2 parents 2a7fa0f + ab4d968 commit 7d21433
Show file tree
Hide file tree
Showing 17 changed files with 70 additions and 46 deletions.
11 changes: 10 additions & 1 deletion include/color_spinor_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,12 @@ namespace quda
/** Used to keep local track of allocated ghost_precision in createGhostZone */
mutable QudaPrecision ghost_precision_allocated = QUDA_INVALID_PRECISION;

/** Used to keep local track of nFace in createGhostZone */
mutable int nFace_allocated = 0;

/** Used to keep local track of spin_project in createGhostZone */
mutable bool spin_project_allocated = false;

int nColor = 0;
int nSpin = 0;
int nVec = 0;
Expand Down Expand Up @@ -771,9 +777,12 @@ namespace quda
/**
@brief Create a dummy field used for batched communication
@param[in] v Vector of fields we which to batch together
@param[in] nFace The depth of the face in each dimension and direction
@param[in] nFace The depth of the face in each dimension and direction
@param[in] spin_project Whether we are spin projecting
@return Dummy (nDim+1)-dimensional field
*/
static FieldTmp<ColorSpinorField> create_comms_batch(cvector_ref<const ColorSpinorField> &v);
static FieldTmp<ColorSpinorField> create_comms_batch(cvector_ref<const ColorSpinorField> &v, int nFace = 1, bool spin_project = true);

/**
@brief Create a field that aliases this field's storage. The
Expand Down
11 changes: 11 additions & 0 deletions include/comm_quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,18 @@ namespace quda
*/
void comm_broadcast(void *data, size_t nbytes, int root = 0);

/**
@brief Multi-process barrier that applies to the present
communicator
*/
void comm_barrier(void);

/**
@brief Multi-process barrier that is global regardless of the
present communicator
*/
void comm_barrier_global(void);

void comm_abort(int status);
void comm_abort_(int status);

Expand Down
20 changes: 7 additions & 13 deletions include/dirac_quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -2347,8 +2347,7 @@ namespace quda {
void operator()(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in) const override
{
dirac->M(out, in);
for (auto i = 0u; i < in.size(); i++)
if (shift != 0.0) blas::axpy(shift, in[i], out[i]);
if (shift != 0.0) blas::axpy(shift, in, out);
}

int getStencilSteps() const override { return dirac->getStencilSteps(); }
Expand All @@ -2369,8 +2368,7 @@ namespace quda {
void operator()(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in) const override
{
dirac->MdagM(out, in);
for (auto i = 0u; i < in.size(); i++)
if (shift != 0.0) blas::axpy(shift, in[i], out[i]);
if (shift != 0.0) blas::axpy(shift, in, out);
}

int getStencilSteps() const override
Expand Down Expand Up @@ -2421,8 +2419,7 @@ namespace quda {
void operator()(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in) const override
{
dirac->MMdag(out, in);
for (auto i = 0u; i < in.size(); i++)
if (shift != 0.0) blas::axpy(shift, in[i], out[i]);
if (shift != 0.0) blas::axpy(shift, in, out);
}

int getStencilSteps() const override
Expand All @@ -2448,8 +2445,7 @@ namespace quda {
void operator()(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in) const override
{
dirac->Mdag(out, in);
for (auto i = 0u; i < in.size(); i++)
if (shift != 0.0) blas::axpy(shift, in[i], out[i]);
if (shift != 0.0) blas::axpy(shift, in, out);
}

int getStencilSteps() const override { return dirac->getStencilSteps(); }
Expand Down Expand Up @@ -2496,7 +2492,7 @@ namespace quda {
@param vec[in,out] vector to which gamma5 is applied in place
*/
void applyGamma5(ColorSpinorField &vec) const
void applyGamma5(cvector_ref<ColorSpinorField> &vec) const
{
auto dirac_type = dirac->getDiracType();
auto pc_type = dirac->getMatPCType();
Expand Down Expand Up @@ -2573,10 +2569,8 @@ namespace quda {
void operator()(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in) const override
{
dirac->M(out, in);
for (auto i = 0u; i < in.size(); i++) {
if (shift != 0.0) blas::axpy(shift, in[i], out[i]);
applyGamma5(out[i]);
}
if (shift != 0.0) blas::axpy(shift, in, out);
applyGamma5(out);
}

int getStencilSteps() const override { return dirac->getStencilSteps(); }
Expand Down
2 changes: 1 addition & 1 deletion include/dslash_helper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ namespace quda
@return checkerboard space-time index
*/
template <QudaPCType pc_type, KernelType kernel_type, typename Arg, int nface_ = 1>
__host__ __device__ inline auto getCoords(const Arg &arg, int &idx, int s, int parity, int &dim)
__host__ __device__ __forceinline__ auto getCoords(const Arg &arg, int &idx, int s, int parity, int &dim)
{
constexpr auto nDim = Arg::nDim;
Coord<nDim> coord;
Expand Down
4 changes: 2 additions & 2 deletions include/kernels/random_init.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace quda {
struct rngArg : kernel_param<> {
int commCoord[QUDA_MAX_DIM];
int X[QUDA_MAX_DIM];
int X_global[QUDA_MAX_DIM];
uint64_t X_global[QUDA_MAX_DIM];
RNGState *state;
unsigned long long seed;
rngArg(RNGState *state, unsigned long long seed, const LatticeField &meta) :
Expand Down Expand Up @@ -46,7 +46,7 @@ namespace quda {
int x[4];
getCoords(x, id, arg.X, parity);
for (int i = 0; i < 4; i++) x[i] += arg.commCoord[i] * arg.X[i];
int idd = (((x[3] * arg.X_global[2] + x[2]) * arg.X_global[1]) + x[1]) * arg.X_global[0] + x[0];
auto idd = (((x[3] * arg.X_global[2] + x[2]) * arg.X_global[1]) + x[1]) * arg.X_global[0] + x[0];
random_init(arg.seed, idd, 0, arg.state[parity * arg.threads.x + id]);
}
};
Expand Down
19 changes: 14 additions & 5 deletions lib/color_spinor_field.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ namespace quda
alloc = std::exchange(src.alloc, false);
reference = std::exchange(src.reference, false);
ghost_precision_allocated = std::exchange(src.ghost_precision_allocated, QUDA_INVALID_PRECISION);
nFace_allocated = std::exchange(src.nFace_allocated, 0);
nColor = std::exchange(src.nColor, 0);
nSpin = std::exchange(src.nSpin, 0);
nVec = std::exchange(src.nVec, 0);
Expand Down Expand Up @@ -307,7 +308,8 @@ namespace quda
void ColorSpinorField::createGhostZone(int nFace, bool spin_project) const
{
if (ghost_precision == QUDA_INVALID_PRECISION) errorQuda("Invalid requested ghost precision");
if (ghost_precision_allocated == ghost_precision) return;
if (ghost_precision_allocated == ghost_precision && nFace_allocated == nFace &&
spin_project_allocated == spin_project) return;

bool is_fixed = (ghost_precision == QUDA_HALF_PRECISION || ghost_precision == QUDA_QUARTER_PRECISION);
int nSpinGhost = (nSpin == 4 && spin_project) ? 2 : nSpin;
Expand Down Expand Up @@ -400,7 +402,10 @@ namespace quda
dc.dims[3][1] = X[1];
dc.dims[3][2] = X[2];
}

spin_project_allocated = spin_project;
ghost_precision_allocated = ghost_precision;
nFace_allocated = nFace;
} // createGhostZone

void ColorSpinorField::zero() { qudaMemsetAsync(v, 0, bytes, device::get_default_stream()); }
Expand Down Expand Up @@ -819,7 +824,8 @@ namespace quda
if (siteSubset == QUDA_FULL_SITE_SUBSET) y[0] = savey0;
}

FieldTmp<ColorSpinorField> ColorSpinorField::create_comms_batch(cvector_ref<const ColorSpinorField> &v)
FieldTmp<ColorSpinorField> ColorSpinorField::create_comms_batch(cvector_ref<const ColorSpinorField> &v, int nFace,
bool spin_project)
{
// first create a dummy batched field
ColorSpinorParam param(v[0]);
Expand All @@ -837,9 +843,12 @@ namespace quda
FieldKey<ColorSpinorField> key;
key.volume = v.VolString();
key.aux = v.AuxString();
char aux[32];
strcpy(aux, ",ghost_batch=");
u32toa(aux + 13, v.size());
char aux[48];
strcpy(aux, ",nFace=");
u32toa(aux + 7, nFace);
strcpy(aux + 8, ",ghost_batch=");
u32toa(aux + 21, v.size());
if (spin_project && v.Nspin() > 1) strcat(aux, ",spin_project");
key.aux += aux;

return FieldTmp<ColorSpinorField>(key, param);
Expand Down
9 changes: 7 additions & 2 deletions lib/communicator_stack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,17 @@ namespace quda
// used to store the size of the tunecache at the point of splitting
static size_t tune_cache_size = 0;

// destroy any message handles associate with the prior communicator
LatticeField::freeGhostBuffer();
ColorSpinorField::freeGhostBuffer();
FieldTmp<ColorSpinorField>::destroy();

auto search = communicator_stack.find(split_key);
if (search == communicator_stack.end()) {
communicator_stack.emplace(std::piecewise_construct, std::forward_as_tuple(split_key),
std::forward_as_tuple(get_default_communicator(), split_key.data()));
}

LatticeField::freeGhostBuffer(); // Destroy the (IPC) Comm buffers with the old communicator.

auto split_key_old = current_key;
current_key = split_key;

Expand Down Expand Up @@ -362,6 +365,8 @@ namespace quda

void comm_barrier(void) { get_current_communicator().comm_barrier(); }

void comm_barrier_global(void) { get_default_communicator().comm_barrier(); }

void comm_abort_(int status) { Communicator::comm_abort_(status); };

int commDim(int dim) { return get_current_communicator().commDim(dim); }
Expand Down
2 changes: 1 addition & 1 deletion lib/covariant_derivative.cu
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ namespace quda

{
constexpr int nDim = 4;
auto halo = ColorSpinorField::create_comms_batch(in);
auto halo = ColorSpinorField::create_comms_batch(in, 1, false);
if (in.Nspin() == 4) {
CovDevArg<Float, 4, nColor, recon, nDim> arg(out, in, halo, U, mu, parity, dagger, comm_override);
CovDev<decltype(arg)> covDev(arg, out, in, halo);
Expand Down
2 changes: 1 addition & 1 deletion lib/dslash_coarse.in.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace quda {

if constexpr (is_enabled_multigrid()) {
// create a halo ndim+1 field for batched comms
auto halo = ColorSpinorField::create_comms_batch(inA);
auto halo = ColorSpinorField::create_comms_batch(inA, 1, false);

// Since use_mma = false, put a dummy 1 here for nVec
DslashCoarseLaunch<D, dagger, coarseColor, use_mma, 1> Dslash(out, inA, inB, halo, Y, X, kappa, parity, dslash,
Expand Down
2 changes: 1 addition & 1 deletion lib/dslash_coarse_mma.in.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace quda
{
if constexpr (is_enabled_multigrid()) {
// create a halo ndim+1 field for batched comms
auto halo = ColorSpinorField::create_comms_batch(inA);
auto halo = ColorSpinorField::create_comms_batch(inA, 1, false);

DslashCoarseLaunch<D, dagger, coarseColor, use_mma, nVec> Dslash(out, inA, inB, halo, Y, X, kappa, parity, dslash,
clover, commDim, halo_precision);
Expand Down
2 changes: 1 addition & 1 deletion lib/dslash_improved_staggered.cu
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ namespace quda
constexpr int nDim = 4;
constexpr bool improved = true;
constexpr QudaReconstructType recon_u = QUDA_RECONSTRUCT_NO;
auto halo = ColorSpinorField::create_comms_batch(in);
auto halo = ColorSpinorField::create_comms_batch(in, 3);
StaggeredArg<Float, nColor, nDim, recon_u, recon_l, improved> arg(out, in, halo, U, L, a, x, parity, dagger,
comm_override);
Staggered<decltype(arg)> staggered(arg, out, in, halo, L);
Expand Down
2 changes: 1 addition & 1 deletion lib/laplace.cu
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ namespace quda
const int *comm_override, TimeProfile &profile)
{
constexpr int nDim = 4;
auto halo = ColorSpinorField::create_comms_batch(in);
auto halo = ColorSpinorField::create_comms_batch(in, 1, false);
if (in.Nspin() == 1) {
constexpr int nSpin = 1;
LaplaceArg<Float, nSpin, nColor, nDim, recon> arg(out, in, halo, U, dir, a, b, x, parity, comm_override);
Expand Down
2 changes: 1 addition & 1 deletion lib/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ namespace quda {
memcpy(out.true_res, true_res.data(), true_res.size() * sizeof(double));
memcpy(out.true_res_hq, true_res_hq.data(), true_res_hq.size() * sizeof(double));

out.iter = in.iter;
out.iter = split_rank == 0 ? in.iter : 0;
comm_allreduce_int(out.iter);

out.ca_lambda_min = in.ca_lambda_min;
Expand Down
2 changes: 1 addition & 1 deletion lib/staggered_quark_smearing.cu
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ namespace quda
constexpr int nDim = 4;
constexpr int nSpin = 1;

auto halo = ColorSpinorField::create_comms_batch(in);
auto halo = ColorSpinorField::create_comms_batch(in, 3);
StaggeredQSmearArg<Float, nSpin, nColor, nDim, recon> arg(out, in, halo, U, t0, is_tslice_kernel, parity, dir,
dagger, comm_override);
StaggeredQSmear<decltype(arg)> staggered_qsmear(arg, out, in, halo);
Expand Down
5 changes: 5 additions & 0 deletions lib/tune.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,7 @@ namespace quda
*/
TuneParam tuneLaunch(Tunable &tunable, bool enabled, QudaVerbosity verbosity)
{
pushVerbosity(verbosity);
#ifdef LAUNCH_TIMER
launchTimer.TPSTART(QUDA_PROFILE_TOTAL);
launchTimer.TPSTART(QUDA_PROFILE_INIT);
Expand Down Expand Up @@ -940,6 +941,7 @@ namespace quda
Tunable::flops_global(Tunable::flops_global() + tunable.flops()); // increment flops counter
Tunable::bytes_global(Tunable::bytes_global() + tunable.bytes()); // increment bytes counter
}
popVerbosity();
return param_tuned;
}

Expand All @@ -962,6 +964,7 @@ namespace quda
Tunable::flops_global(Tunable::flops_global() + tunable.flops()); // increment flops counter
Tunable::bytes_global(Tunable::bytes_global() + tunable.bytes()); // increment bytes counter
}
popVerbosity();
return param_default;
} else if (!tuning) {

Expand Down Expand Up @@ -1179,6 +1182,8 @@ namespace quda
Tunable::flops_global(Tunable::flops_global() + tunable.flops()); // increment flops counter
Tunable::bytes_global(Tunable::bytes_global() + tunable.bytes()); // increment bytes counter
}

popVerbosity();
return param;
}

Expand Down
7 changes: 1 addition & 6 deletions tests/invert_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,12 +367,7 @@ std::vector<std::array<double, 2>> solve(test_t param)
inv_param.true_res_hq[j + i] = inv_param.true_res_hq[i];
}

quda::comm_allreduce_int(inv_param.iter);
inv_param.iter /= quda::comm_size() / num_sub_partition;
quda::comm_allreduce_sum(inv_param.gflops);
inv_param.gflops /= quda::comm_size() / num_sub_partition;
quda::comm_allreduce_max(inv_param.secs);
printfQuda("Done: %d sub-partitions - %i iter / %g secs = %g Gflops, %g secs per source\n", num_sub_partition,
printfQuda("Done: %d sub-partitions - %i total iter / %g secs = %g Gflops, %g secs per source\n", num_sub_partition,
inv_param.iter, inv_param.secs, inv_param.gflops / inv_param.secs, inv_param.secs / Nsrc_tile);
if (inv_param.energy > 0) {
printfQuda("Energy = %g J (%g J per source), Mean power = %g W, mean temp = %g C, mean clock = %f\n",
Expand Down
14 changes: 5 additions & 9 deletions tests/staggered_invert_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,12 +421,7 @@ std::vector<std::array<double, 2>> solve(test_t param)
inv_param.true_res_hq[j + i] = inv_param.true_res_hq[i];
}

quda::comm_allreduce_int(inv_param.iter);
inv_param.iter /= comm_size() / num_sub_partition;
quda::comm_allreduce_sum(inv_param.gflops);
inv_param.gflops /= comm_size() / num_sub_partition;
quda::comm_allreduce_max(inv_param.secs);
printfQuda("Done: %d sub-partitions - %i iter / %g secs = %g Gflops, %g secs per source\n", num_sub_partition,
printfQuda("Done: %d sub-partitions - %i total iter / %g secs = %g Gflops, %g secs per source\n", num_sub_partition,
inv_param.iter, inv_param.secs, inv_param.gflops / inv_param.secs, inv_param.secs / Nsrc_tile);
if (inv_param.energy > 0) {
printfQuda("Energy = %g J (%g J per source), Mean power = %g W, mean temp = %g C, mean clock = %f\n\n",
Expand Down Expand Up @@ -567,9 +562,10 @@ int main(int argc, char **argv)
if (quda::comm_rank() != 0) { delete listeners.Release(listeners.default_result_printer()); }
result = RUN_ALL_TESTS();
} else {
solve(test_t {inv_type, solution_type, solve_type, prec_sloppy, multishift, solution_accumulator_pipeline,
schwarz_t {precon_schwarz_type, inv_multigrid ? QUDA_MG_INVERTER : precon_type, prec_precondition},
inv_param.residual_type});
for (int rep = 0; rep < nrepeat; rep++)
solve(test_t {inv_type, solution_type, solve_type, prec_sloppy, multishift, solution_accumulator_pipeline,
schwarz_t {precon_schwarz_type, inv_multigrid ? QUDA_MG_INVERTER : precon_type, prec_precondition},
inv_param.residual_type});
}

cleanup();
Expand Down

0 comments on commit 7d21433

Please sign in to comment.