Skip to content

Commit

Permalink
compiler: fix internal language specific types and cast
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed May 31, 2024
1 parent abcbd6b commit 7cee7fb
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 17 deletions.
3 changes: 2 additions & 1 deletion devito/arch/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def version(self):
@property
def _complex_ctype(self):
"""
Type definition for complex numbers. THese two cases cover 99% of the cases since
Type definition for complex numbers. These two cases cover 99% of the cases since
- Hip is now using std::complex
https://rocm.docs.amd.com/en/docs-5.1.3/CHANGELOG.html#hip-api-deprecations-and-warnings
- Sycl supports std::complex
Expand Down Expand Up @@ -996,6 +996,7 @@ def __new_with__(self, **kwargs):
'nvc++': NvidiaCompiler,
'nvidia': NvidiaCompiler,
'cuda': CudaCompiler,
'nvcc': CudaCompiler,
'osx': ClangCompiler,
'intel': OneapiCompiler,
'icx': OneapiCompiler,
Expand Down
25 changes: 14 additions & 11 deletions devito/passes/iet/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,6 @@ def minimize_symbols(iet):
return iet, {}


_complex_lib = {'cuda': 'thrust/complex.h'}


@iet_pass
def complex_include(iet, language, compiler):
"""
Expand All @@ -205,22 +202,28 @@ def complex_include(iet, language, compiler):
if not np.issubdtype(max_dtype, np.complexfloating):
return iet, {}

lib = (_complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h'),)
is_cuda = language == 'cuda'
if is_cuda:
lib = ('thrust/complex.h',)
else:
lib = ('complex' if compiler._cpp else 'complex.h',)

headers = {}

# For openacc (cpp) need to define constant _Complex_I that isn't found otherwise
# For (cpp), need to define constant _Complex_I and missing mix-type
# std::complex arithmetic
if compiler._cpp:
namespace = 'thrust' if is_cuda else 'std'
c_str = dtype_to_cstr(max_dtype.type(0).real.dtype.type)
# Constant I
headers = {('_Complex_I', ('std::complex<%s>(0.0, 1.0)' % c_str))}
# Mix arithmetic definitions
dest = compiler.get_jit_dir()
hfile = dest.joinpath('stdcomplex_arith.h')
if not hfile.is_file():
headers = {('_Complex_I', ('%s::complex<%s>(0.0, 1.0)' % (namespace, c_str)))}
# Mix arithmetic definitions, only for std, thrust has it defined
if not is_cuda:
dest = compiler.get_jit_dir()
hfile = dest.joinpath('stdcomplex_arith.h')
with open(str(hfile), 'w') as ff:
ff.write(str(_stdcomplex_defs))
lib += (str(hfile),)
lib += (str(hfile),)

return iet, {'includes': lib, 'headers': headers}

Expand Down
29 changes: 28 additions & 1 deletion devito/symbolics/extended_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sympy import Expr, Function, Number, Tuple, sympify
from sympy.core.decorators import call_highest_priority

from devito import configuration
from devito.finite_differences.elementary import Min, Max
from devito.tools import (Pickable, Bunch, as_tuple, is_integer, float2, # noqa
float3, float4, double2, double3, double4, int2, int3,
Expand Down Expand Up @@ -811,6 +812,20 @@ class VOID(Cast):
_base_typ = 'void'


class CFLOAT(Cast):

@property
def _base_typ(self):
return configuration['compiler']._complex_ctype('float')


class CDOUBLE(Cast):

@property
def _base_typ(self):
return configuration['compiler']._complex_ctype('double')


class CHARP(CastStar):
base = CHAR

Expand All @@ -827,6 +842,14 @@ class USHORTP(CastStar):
base = USHORT


class CFLOATP(CastStar):
base = CFLOAT


class CDOUBLEP(CastStar):
base = CDOUBLE


cast_mapper = {
np.int8: CHAR,
np.uint8: UCHAR,
Expand All @@ -839,6 +862,8 @@ class USHORTP(CastStar):
np.float32: FLOAT, # noqa
float: DOUBLE, # noqa
np.float64: DOUBLE, # noqa
np.complex64: CFLOAT, # noqa
np.complex128: CDOUBLE, # noqa

(np.int8, '*'): CHARP,
(np.uint8, '*'): UCHARP,
Expand All @@ -849,7 +874,9 @@ class USHORTP(CastStar):
(np.int64, '*'): INTP, # noqa
(np.float32, '*'): FLOATP, # noqa
(float, '*'): DOUBLEP, # noqa
(np.float64, '*'): DOUBLEP # noqa
(np.float64, '*'): DOUBLEP, # noqa
(np.complex64, '*'): CFLOATP, # noqa
(np.complex128, '*'): CDOUBLEP, # noqa
}

for base_name in ['int', 'float', 'double']:
Expand Down
2 changes: 0 additions & 2 deletions tests/test_gpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ def test_complex(self, dtype):
u = Function(name="u", grid=grid, dtype=dtype)

eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing))
# Currently wrong alias type
op = Operator(eq)
print(op)
op()

# Check against numpy
Expand Down
2 changes: 0 additions & 2 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,9 +647,7 @@ def test_complex(self, dtype):
u = Function(name="u", grid=grid, dtype=dtype)

eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing))
# Currently wrong alias type
op = Operator(eq)
# print(op)
op()

# Check against numpy
Expand Down

0 comments on commit 7cee7fb

Please sign in to comment.