From 7cee7fbc917d353955f5f89007d328097dde3744 Mon Sep 17 00:00:00 2001 From: mloubout Date: Fri, 31 May 2024 09:58:54 -0400 Subject: [PATCH] compiler: fix internal language specific types and cast --- devito/arch/compiler.py | 3 ++- devito/passes/iet/misc.py | 25 ++++++++++++++----------- devito/symbolics/extended_sympy.py | 29 ++++++++++++++++++++++++++++- tests/test_gpu_common.py | 2 -- tests/test_operator.py | 2 -- 5 files changed, 44 insertions(+), 17 deletions(-) diff --git a/devito/arch/compiler.py b/devito/arch/compiler.py index e056ff82b7a..5df3891074d 100644 --- a/devito/arch/compiler.py +++ b/devito/arch/compiler.py @@ -248,7 +248,7 @@ def version(self): @property def _complex_ctype(self): """ - Type definition for complex numbers. THese two cases cover 99% of the cases since + Type definition for complex numbers. These two cases cover 99% of the cases since - Hip is now using std::complex https://rocm.docs.amd.com/en/docs-5.1.3/CHANGELOG.html#hip-api-deprecations-and-warnings - Sycl supports std::complex @@ -996,6 +996,7 @@ def __new_with__(self, **kwargs): 'nvc++': NvidiaCompiler, 'nvidia': NvidiaCompiler, 'cuda': CudaCompiler, + 'nvcc': CudaCompiler, 'osx': ClangCompiler, 'intel': OneapiCompiler, 'icx': OneapiCompiler, diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 146dfe456a2..b81b8f63282 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -192,9 +192,6 @@ def minimize_symbols(iet): return iet, {} -_complex_lib = {'cuda': 'thrust/complex.h'} - - @iet_pass def complex_include(iet, language, compiler): """ @@ -205,22 +202,28 @@ def complex_include(iet, language, compiler): if not np.issubdtype(max_dtype, np.complexfloating): return iet, {} - lib = (_complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h'),) + is_cuda = language == 'cuda' + if is_cuda: + lib = ('thrust/complex.h',) + else: + lib = ('complex' if compiler._cpp else 'complex.h',) headers = {} - # For openacc (cpp) need to define constant _Complex_I that isn't found otherwise + # For (cpp), need to define constant _Complex_I and missing mix-type + # std::complex arithmetic if compiler._cpp: + namespace = 'thrust' if is_cuda else 'std' c_str = dtype_to_cstr(max_dtype.type(0).real.dtype.type) # Constant I - headers = {('_Complex_I', ('std::complex<%s>(0.0, 1.0)' % c_str))} - # Mix arithmetic definitions - dest = compiler.get_jit_dir() - hfile = dest.joinpath('stdcomplex_arith.h') - if not hfile.is_file(): + headers = {('_Complex_I', ('%s::complex<%s>(0.0, 1.0)' % (namespace, c_str)))} + # Mix arithmetic definitions, only for std, thrust has it defined + if not is_cuda: + dest = compiler.get_jit_dir() + hfile = dest.joinpath('stdcomplex_arith.h') with open(str(hfile), 'w') as ff: ff.write(str(_stdcomplex_defs)) - lib += (str(hfile),) + lib += (str(hfile),) return iet, {'includes': lib, 'headers': headers} diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 7ed801d17a0..03fec7438af 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -7,6 +7,7 @@ from sympy import Expr, Function, Number, Tuple, sympify from sympy.core.decorators import call_highest_priority +from devito import configuration from devito.finite_differences.elementary import Min, Max from devito.tools import (Pickable, Bunch, as_tuple, is_integer, float2, # noqa float3, float4, double2, double3, double4, int2, int3, @@ -811,6 +812,20 @@ class VOID(Cast): _base_typ = 'void' +class CFLOAT(Cast): + + @property + def _base_typ(self): + return configuration['compiler']._complex_ctype('float') + + +class CDOUBLE(Cast): + + @property + def _base_typ(self): + return configuration['compiler']._complex_ctype('double') + + class CHARP(CastStar): base = CHAR @@ -827,6 +842,14 @@ class USHORTP(CastStar): base = USHORT +class CFLOATP(CastStar): + base = CFLOAT + + +class CDOUBLEP(CastStar): + base = CDOUBLE + + cast_mapper = { np.int8: CHAR, np.uint8: UCHAR, @@ -839,6 +862,8 @@ class USHORTP(CastStar): np.float32: FLOAT, # noqa float: DOUBLE, # noqa np.float64: DOUBLE, # noqa + np.complex64: CFLOAT, # noqa + np.complex128: CDOUBLE, # noqa (np.int8, '*'): CHARP, (np.uint8, '*'): UCHARP, @@ -849,7 +874,9 @@ class USHORTP(CastStar): (np.int64, '*'): INTP, # noqa (np.float32, '*'): FLOATP, # noqa (float, '*'): DOUBLEP, # noqa - (np.float64, '*'): DOUBLEP # noqa + (np.float64, '*'): DOUBLEP, # noqa + (np.complex64, '*'): CFLOATP, # noqa + (np.complex128, '*'): CDOUBLEP, # noqa } for base_name in ['int', 'float', 'double']: diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index e3f53828ee1..d49d94074cc 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -74,9 +74,7 @@ def test_complex(self, dtype): u = Function(name="u", grid=grid, dtype=dtype) eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) - # Currently wrong alias type op = Operator(eq) - print(op) op() # Check against numpy diff --git a/tests/test_operator.py b/tests/test_operator.py index c1a88093796..283249aac16 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -647,9 +647,7 @@ def test_complex(self, dtype): u = Function(name="u", grid=grid, dtype=dtype) eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) - # Currently wrong alias type op = Operator(eq) - # print(op) op() # Check against numpy