Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/mrhs misc #1515

Merged
merged 32 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
8292b17
Initial support for register for staggered dslash kernel
maddyscientist Oct 11, 2024
e816fb1
Sanity check for QUDA_MAX_MULTI_RHS_TILE
maddyscientist Oct 11, 2024
eeea3b0
Possible WAR for ROCm performance issue with MRHS staggered kernels
maddyscientist Oct 11, 2024
467a493
FieldTmp now supports creating temporaries using a T::param_type
maddyscientist Oct 11, 2024
36fcf3d
DslashCoarse now uses getFieldTmp for it mma temporaries to avoid all…
maddyscientist Oct 11, 2024
db865e4
Small cleanup of staggered dslash
maddyscientist Oct 14, 2024
8f72560
Add default initialization to array::data
maddyscientist Oct 15, 2024
b2cebb0
Fix multi-gpu bug
maddyscientist Oct 15, 2024
704f6b2
Fixed a logic bug in the MR convergence check
weinbe2 Oct 14, 2024
57c631c
Revert "Add default initialization to array::data"
maddyscientist Oct 15, 2024
9ddd25b
n_src should be a member of DslashArg
maddyscientist Oct 16, 2024
01725e6
Fix fused pack + dslash kernels with n_src_tile > 1
maddyscientist Oct 17, 2024
bbd07a7
Merge branch 'feature/trlm_3d' of github.com:lattice/quda into featur…
maddyscientist Nov 18, 2024
77b6b59
Fix bug in dslash_test_utils.h
maddyscientist Nov 18, 2024
27339e4
nvc++ no longer needs to use constant memory args for dslash
maddyscientist Nov 18, 2024
936c2c4
Fix for nvc++ and remove unneeded target specific thread_array.h files
maddyscientist Nov 18, 2024
fd00d08
Remove WAR for nvc++ in reduce_helper which is no longer needed
maddyscientist Nov 19, 2024
fc5ac97
Add versioned CPM files to .gitignore
maddyscientist Nov 20, 2024
a10d6f3
Fix complex_quda.h to be C++20 compliant
maddyscientist Nov 20, 2024
c47e1f0
Add new variant of heterogeneous reductions: with some compilers, opt…
maddyscientist Nov 20, 2024
8b61db2
Update to use latest release of Eigen 3.4: this fixes some bugs with …
maddyscientist Nov 20, 2024
490930b
Apply ROCm perf WAR to Laplace operator
maddyscientist Nov 20, 2024
26c3e59
Fix nvc++ compiler warning
maddyscientist Nov 20, 2024
1af8c52
Merge branch 'develop' of github.com:lattice/quda into feature/mrhs-misc
maddyscientist Nov 22, 2024
c2547fc
Merge branch 'develop' of github.com:lattice/quda into feature/mrhs-misc
maddyscientist Nov 27, 2024
2b81309
Fix compiler warning introduced with 3debb29
maddyscientist Nov 27, 2024
0a3f608
Fix process divergence issues (could hang when autotuning) in generic…
maddyscientist Nov 27, 2024
22393c5
Add pragma unroll
maddyscientist Dec 3, 2024
8ffcdd5
Do not run BiCGStab(l) with staggered half precision since it is unst…
maddyscientist Dec 4, 2024
a860b1c
sentinal -> sentinel
maddyscientist Dec 5, 2024
c3fbf50
Apply review comments
maddyscientist Dec 5, 2024
b6f5edb
Apply clang format
maddyscientist Dec 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ include/jitify_options.hpp
.tags*
autom4te.cache/*
.vscode
cmake/CPM_*.cmake
19 changes: 8 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,11 @@ if(QUDA_MAX_MULTI_BLAS_N GREATER 32)
message(SEND_ERROR "Maximum QUDA_MAX_MULTI_BLAS_N is 32.")
endif()

set(QUDA_MAX_MULTI_RHS_TILE "1" CACHE STRING "maximum tile size for MRHS kernels")
if(QUDA_MAX_MULTI_RHS_TILE GREATER QUDA_MAX_MULTI_RHS)
message(SEND_ERROR "QUDA_MAX_MULTI_RHS_TILE is greater than QUDA_MAX_MULTI_RHS")
endif()

set(QUDA_PRECISION
"14"
CACHE STRING "which precisions to instantiate in QUDA (4-bit number - double, single, half, quarter)")
Expand Down Expand Up @@ -275,6 +280,7 @@ mark_as_advanced(QUDA_ALTERNATIVE_I_TO_F)

mark_as_advanced(QUDA_MAX_MULTI_BLAS_N)
mark_as_advanced(QUDA_MAX_MULTI_RHS)
mark_as_advanced(QUDA_MAX_MULTI_RHS_TILE)
mark_as_advanced(QUDA_PRECISION)
mark_as_advanced(QUDA_RECONSTRUCT)
mark_as_advanced(QUDA_CLOVER_CHOLESKY_PROMOTE)
Expand Down Expand Up @@ -420,21 +426,12 @@ if(QUDA_DOWNLOAD_EIGEN)
CPMAddPackage(
NAME Eigen
VERSION ${QUDA_EIGEN_VERSION}
URL https://gitlab.com/libeigen/eigen/-/archive/${QUDA_EIGEN_VERSION}/eigen-${QUDA_EIGEN_VERSION}.tar.bz2
URL_HASH SHA256=B4C198460EBA6F28D34894E3A5710998818515104D6E74E5CC331CE31E46E626
URL https://gitlab.com/libeigen/eigen/-/archive/e67c494cba7180066e73b9f6234d0b2129f1cdf5.tar.bz2
URL_HASH SHA256=98d244932291506b75c4ae7459af29b1112ea3d2f04660686a925d9ef6634583
DOWNLOAD_ONLY YES
SYSTEM YES)
target_include_directories(Eigen SYSTEM INTERFACE ${Eigen_SOURCE_DIR})
install(DIRECTORY ${Eigen_SOURCE_DIR}/Eigen TYPE INCLUDE)

# Eigen 3.4 needs to be patched on Neon with nvc++
if (${CMAKE_CXX_COMPILER_ID} MATCHES "NVHPC")
set(CMAKE_PATCH_EIGEN OFF CACHE BOOL "Internal use only; do not modify")
if (NOT CMAKE_PATCH_EIGEN)
execute_process(COMMAND patch -N "${Eigen_SOURCE_DIR}/Eigen/src/Core/arch/NEON/Complex.h" "${CMAKE_SOURCE_DIR}/cmake/eigen34_neon.diff")
set(CMAKE_PATCH_EIGEN ON CACHE BOOL "Internal use only; do not modify" FORCE)
endif()
endif()
else()
# fall back to using find_package
find_package(Eigen QUIET)
Expand Down
8 changes: 0 additions & 8 deletions cmake/eigen34_neon.diff

This file was deleted.

38 changes: 16 additions & 22 deletions include/complex_quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,23 +360,19 @@ struct complex
typedef ValueType value_type;

// Constructors
__host__ __device__ inline complex<ValueType>(const ValueType &re = ValueType(), const ValueType &im = ValueType())
__host__ __device__ inline complex(const ValueType &re = ValueType(), const ValueType &im = ValueType())
{
real(re);
imag(im);
}

template <class X>
__host__ __device__
inline complex<ValueType>(const complex<X> & z)
template <class X> __host__ __device__ inline complex(const complex<X> &z)
{
real(z.real());
imag(z.imag());
}

template <class X>
__host__ __device__
inline complex<ValueType>(const std::complex<X> & z)
template <class X> __host__ __device__ inline complex(const std::complex<X> &z)
{
real(z.real());
imag(z.imag());
Expand Down Expand Up @@ -436,12 +432,11 @@ struct complex
template <> struct complex<float> : public float2 {
public:
typedef float value_type;
complex<float>() = default;
constexpr complex<float>(const float &re, const float &im = float()) : float2 {re, im} { }
complex() = default;
constexpr complex(const float &re, const float &im = float()) : float2 {re, im} { }

template <typename X>
constexpr complex<float>(const std::complex<X> &z) :
float2 {static_cast<float>(z.real()), static_cast<float>(z.imag())}
constexpr complex(const std::complex<X> &z) : float2 {static_cast<float>(z.real()), static_cast<float>(z.imag())}
{
}

Expand Down Expand Up @@ -500,16 +495,15 @@ template <> struct complex<float> : public float2 {
template <> struct complex<double> : public double2 {
public:
typedef double value_type;
complex<double>() = default;
constexpr complex<double>(const double &re, const double &im = double()) : double2 {re, im} { }
complex() = default;
constexpr complex(const double &re, const double &im = double()) : double2 {re, im} { }

template <typename X>
constexpr complex<double>(const std::complex<X> &z) :
double2 {static_cast<double>(z.real()), static_cast<double>(z.imag())}
constexpr complex(const std::complex<X> &z) : double2 {static_cast<double>(z.real()), static_cast<double>(z.imag())}
{
}

template <typename T> __host__ __device__ inline complex<double> &operator=(const complex<T> &z)
template <typename T> __host__ __device__ inline complex &operator=(const complex<T> &z)
{
real(z.real());
imag(z.imag());
Expand Down Expand Up @@ -572,9 +566,9 @@ template <> struct complex<int8_t> : public char2 {
public:
typedef int8_t value_type;

complex<int8_t>() = default;
complex() = default;

constexpr complex<int8_t>(const int8_t &re, const int8_t &im = int8_t()) : char2 {re, im} { }
constexpr complex(const int8_t &re, const int8_t &im = int8_t()) : char2 {re, im} { }

__host__ __device__ inline complex<int8_t> &operator+=(const complex<int8_t> &z)
{
Expand Down Expand Up @@ -608,9 +602,9 @@ struct complex <short> : public short2
public:
typedef short value_type;

complex<short>() = default;
complex() = default;

constexpr complex<short>(const short &re, const short &im = short()) : short2 {re, im} { }
constexpr complex(const short &re, const short &im = short()) : short2 {re, im} { }

__host__ __device__ inline complex<short> &operator+=(const complex<short> &z)
{
Expand Down Expand Up @@ -644,9 +638,9 @@ struct complex <int> : public int2
public:
typedef int value_type;

complex<int>() = default;
complex() = default;

constexpr complex<int>(const int &re, const int &im = int()) : int2 {re, im} { }
constexpr complex(const int &re, const int &im = int()) : int2 {re, im} { }

__host__ __device__ inline complex<int> &operator+=(const complex<int> &z)
{
Expand Down
12 changes: 11 additions & 1 deletion include/dslash.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ namespace quda
if (arg.xpay) strcat(aux_base, ",xpay");
if (arg.dagger) strcat(aux_base, ",dagger");
setRHSstring(aux_base, in.size());
strcat(aux_base, ",n_rhs_tile=");
char tile_str[16];
i32toa(tile_str, Arg::n_src_tile);
strcat(aux_base, tile_str);
}

/**
Expand Down Expand Up @@ -329,7 +333,13 @@ namespace quda

Dslash(Arg &arg, cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in,
const ColorSpinorField &halo, const std::string &app_base = "") :
TunableKernel3D(in[0], halo.X(4), arg.nParity), arg(arg), out(out), in(in), halo(halo), nDimComms(4), dslashParam(arg)
TunableKernel3D(in[0], (halo.X(4) + Arg::n_src_tile - 1) / Arg::n_src_tile, arg.nParity),
arg(arg),
out(out),
in(in),
halo(halo),
nDimComms(4),
dslashParam(arg)
{
if (checkLocation(out, in) == QUDA_CPU_FIELD_LOCATION)
errorQuda("CPU Fields not supported in Dslash framework yet");
Expand Down
15 changes: 7 additions & 8 deletions include/dslash_helper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@
#include <kernel_helper.h>
#include <tune_quda.h>

#if defined(_NVHPC_CUDA)
#include <constant_kernel_arg.h>
constexpr quda::use_kernel_arg_p use_kernel_arg = quda::use_kernel_arg_p::FALSE;
#else
constexpr quda::use_kernel_arg_p use_kernel_arg = quda::use_kernel_arg_p::TRUE;
#endif

#include <kernel.h>

Expand Down Expand Up @@ -241,11 +236,12 @@ namespace quda
return true;
}

template <typename Float_, int nDim_> struct DslashArg {
template <typename Float_, int nDim_, int n_src_tile_ = 1> struct DslashArg {

using Float = Float_;
using real = typename mapper<Float>::type;
static constexpr int nDim = nDim_;
static constexpr int n_src_tile = n_src_tile_; // how many RHS per thread

const int parity; // only use this for single parity fields
const int nParity; // number of parities we're working on
Expand All @@ -269,6 +265,7 @@ namespace quda
int threadDimMapLower[4];
int threadDimMapUpper[4];

int_fastdiv n_src;
int_fastdiv Ls;

// these are set with symmetric preconditioned twisted-mass dagger
Expand Down Expand Up @@ -327,6 +324,7 @@ namespace quda
exterior_threads(0),
threadDimMapLower {},
threadDimMapUpper {},
n_src(in.size()),
Ls(halo.X(4) / in.size()),
twist_a(0.0),
twist_b(0.0),
Expand Down Expand Up @@ -650,8 +648,9 @@ namespace quda
Arg arg;

dslash_functor_arg(const Arg &arg, unsigned int threads_x) :
kernel_param(dim3(threads_x, arg.dc.Ls, arg.nParity)),
arg(arg) { }
kernel_param(dim3(threads_x, (arg.dc.Ls + Arg::n_src_tile - 1) / Arg::n_src_tile, arg.nParity)), arg(arg)
{
}
};

/**
Expand Down
4 changes: 0 additions & 4 deletions include/eigen_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@
#define EIGEN_USE_BLAS
#endif

#if defined(__NVCOMPILER) // WAR for nvc++ until we update to latest Eigen
#define EIGEN_DONT_VECTORIZE
#endif

#include <math.h>

// hide annoying warning
Expand Down
27 changes: 26 additions & 1 deletion include/field_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace quda {
*/
template <typename T>
struct FieldKey {
std::string volume; /** volume kstring */
std::string volume; /** volume string */
std::string aux; /** auxiliary string */

FieldKey() = default;
Expand Down Expand Up @@ -78,6 +78,18 @@ namespace quda {
*/
FieldTmp(const FieldKey<T> &key, const typename T::param_type &param);

/**
@brief Create a field temporary that corresponds to the field
constructed from the param struct. If a matching field is
present in the cache, it will be popped from the cache. If no
such temporary exists a temporary will be allocated.
@param[in] key Key corresponding to the field instance we
require
@param[in] param Parameter structure used to allocated
the temporary
*/
FieldTmp(typename T::param_type param);

/**
@brief Copy constructor is deleted to prevent accidental cache
bloat
Expand Down Expand Up @@ -111,6 +123,18 @@ namespace quda {
*/
template <typename T> auto getFieldTmp(const T &a) { return FieldTmp<T>(a); }

/**
@brief Get a field temporary that is identical to the field
instance argument. If a matching field is present in the cache,
it will be popped from the cache. If no such temporary exists, a
temporary will be allocated. When the destructor for the
FieldTmp is called, e.g., the returned object goes out of scope,
the temporary will be pushed onto the cache.

@param[in] a Field we wish to create a matching temporary for
*/
template <typename T> auto getFieldTmp(const typename T::param_type &param) { return FieldTmp<T>(param); }

/**
@brief Get a vector of field temporaries that are identical to
the vector instance argument. If enough matching fields are
Expand All @@ -130,4 +154,5 @@ namespace quda {
for (auto i = 0u; i < a.size(); i++) tmp.push_back(std::move(getFieldTmp(a[i])));
return tmp;
}

}
Loading