Skip to content

Commit

Permalink
compiler: fix dtype for mpi routines
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jan 18, 2025
1 parent 9f62b65 commit aef841b
Show file tree
Hide file tree
Showing 10 changed files with 35 additions and 28 deletions.
6 changes: 3 additions & 3 deletions devito/mpi/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from devito.mpi import MPI
from devito.symbolics import (Byref, CondNe, FieldFromPointer, FieldFromComposite,
IndexedPointer, Macro, cast_mapper, subs_op_args)
from devito.tools import (as_mapper, dtype_to_mpitype, dtype_len, dtype_to_ctype,
from devito.tools import (as_mapper, dtype_to_mpitype, dtype_len, dtype_alloc_ctype,
flatten, generator, is_integer, split)
from devito.types import (Array, Bag, Dimension, Eq, Symbol, LocalObject,
CompositeObject, CustomDimension)
Expand Down Expand Up @@ -1204,8 +1204,8 @@ def _arg_defaults(self, allocator, alias, args=None):
entry.sizes = (c_int*len(shape))(*shape)

# Allocate the send/recv buffers
size = reduce(mul, shape)*dtype_len(self.target.dtype)
ctype = dtype_to_ctype(f.dtype)
ctype, c_scale = dtype_alloc_ctype(f.dtype)
size = int(reduce(mul, shape) * c_scale) * dtype_len(self.target.dtype)
entry.bufg, bufg_memfree_args = allocator._alloc_C_libcall(size, ctype)
entry.bufs, bufs_memfree_args = allocator._alloc_C_libcall(size, ctype)

Expand Down
4 changes: 2 additions & 2 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,7 +1118,7 @@ def __setstate__(self, state):
self._lib.name = soname

self._allocator = default_allocator(
'%s.%s.%s' % (self._compiler.name, self._language, self._platform)
'%s.%s.%s' % (self._compiler.__class__.name, self._language, self._platform)
)


Expand Down Expand Up @@ -1404,7 +1404,7 @@ def parse_kwargs(**kwargs):

# `allocator`
kwargs['allocator'] = default_allocator(
'%s.%s.%s' % (kwargs['compiler'].name,
'%s.%s.%s' % (kwargs['compiler'].__class__.__name__,
kwargs['language'],
kwargs['platform'])
)
Expand Down
3 changes: 2 additions & 1 deletion devito/passes/clusters/derivatives.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import singledispatch

from sympy import S
import numpy as np

from devito.finite_differences import IndexDerivative
from devito.ir import Backward, Forward, Interval, IterationSpace, Queue
Expand Down Expand Up @@ -157,7 +158,7 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
# NOTE: created before recurring so that we ultimately get a sound ordering
try:
s = reusables.pop()
assert s.dtype is dtype
assert np.can_cast(s.dtype, dtype)
except KeyError:
name = sregistry.make_name(prefix='r')
s = Symbol(name=name, dtype=dtype)
Expand Down
6 changes: 3 additions & 3 deletions devito/symbolics/extended_dtypes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ctypes
import numpy as np

from devito.symbolics.extended_sympy import ReservedWord, Cast, CastStar, ValueLimit
from devito.symbolics.extended_sympy import ReservedWord, Cast, ValueLimit
from devito.tools import (Bunch, float2, float3, float4, double2, double3, double4, # noqa
int2, int3, int4, ctypes_vector_mapper)

Expand Down Expand Up @@ -64,7 +64,7 @@ class CustomType(ReservedWord):
cls = type(v.upper(), (Cast,), {'_base_typ': v})
globals()[cls.__name__] = cls

clsp = type('%sP' % v.upper(), (CastStar,), {'base': cls})
clsp = type('%sP' % v.upper(), (Cast,), {'base': cls})
globals()[clsp.__name__] = clsp


Expand All @@ -75,7 +75,7 @@ def no_dtype(kwargs):
def cast_mapper(arg):
try:
assert len(arg) == 2 and arg[1] == '*'
return lambda v, **kw: CastStar(v, dtype=arg[0], **no_dtype(kw))
return lambda v, **kw: Cast(v, dtype=arg[0], stars=arg[1], **no_dtype(kw))
except TypeError:
return lambda v, **kw: Cast(v, dtype=arg, **no_dtype(kw))

Expand Down
23 changes: 10 additions & 13 deletions devito/symbolics/extended_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,12 +410,17 @@ def __process_dtype__(cls, dtype):
if isinstance(dtype, str):
return dtype
dtype = ctypes_vector_mapper.get(dtype, dtype)

try:
dtype = ctypes_to_cstr(dtype_to_ctype(dtype))
except:
pass
ctype = dtype_to_ctype(dtype)
except TypeError:
ctype = dtype

return dtype
try:
cstr = ctypes_to_cstr(ctype)
except TypeError:
cstr = ctype
return cstr

def _hashable_content(self):
return super()._hashable_content() + (self._stars,)
Expand Down Expand Up @@ -768,14 +773,6 @@ def __str__(self):
__repr__ = __str__


# *** Casting

class CastStar:

def __new__(cls, base, dtype=None, ase=''):
return Cast(base, dtype=dtype, stars='*')


# Some other utility objects
Null = Macro('NULL')

Expand All @@ -789,7 +786,7 @@ def __new__(cls, intype, stars=None, **kwargs):
stars = stars or ''
argument = Keyword(f'{intype}{stars}')
newobj = super().__new__(cls, 'sizeof', arguments=(argument,), **kwargs)
newobj.intype = intype
newobj.intype = Cast.__process_dtype__(intype)
newobj.stars = stars

return newobj
Expand Down
2 changes: 1 addition & 1 deletion devito/symbolics/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def normalize_args(args):
for k, v in args.items():
try:
retval[k] = sympify(v, strict=True)
except SympifyError:
except (TypeError, SympifyError):
continue

return retval
Expand Down
3 changes: 2 additions & 1 deletion devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ def _print_Abs(self, expr):
# Unary function, single argument
arg = expr.args[0]
# AOMPCC errors with abs, always use fabs
if isinstance(self.compiler, AOMPCompiler):
if isinstance(self.compiler, AOMPCompiler) and \
not np.issubdtype(self._prec(expr), np.integer):
return "fabs(%s)" % self._print(arg)
func = f'{self.func_prefix(arg, abs=True)}abs{self.func_literal(arg)}'
return f"{self._ns}{func}({self._print(arg)})"
Expand Down
6 changes: 5 additions & 1 deletion devito/tools/dtypes_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# NOTE: the following is inspired by pyopencl.cltypes

mapper = {
"half": np.float16,
"int": np.int32,
"float": np.float32,
"double": np.float64
Expand Down Expand Up @@ -189,7 +190,8 @@ def dtype_to_mpitype(dtype):
np.int32: 'MPI_INT',
np.float32: 'MPI_FLOAT',
np.int64: 'MPI_LONG',
np.float64: 'MPI_DOUBLE'
np.float64: 'MPI_DOUBLE',
np.float16: 'MPI_UNSIGNED_SHORT'
}[dtype]


Expand Down Expand Up @@ -222,6 +224,8 @@ class c_restrict_void_p(ctypes.c_void_p):

ctypes_vector_mapper = {}
for base_name, base_dtype in mapper.items():
if base_dtype is np.float16:
continue
base_ctype = dtype_to_ctype(base_dtype)

for count in counts:
Expand Down
6 changes: 5 additions & 1 deletion devito/types/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,17 +792,21 @@ def _halo_exchange(self):
# Gather send data
data = self._data_in_region(OWNED, d, i)
sendbuf = np.ascontiguousarray(data)
if self.dtype == np.float16:
sendbuf = sendbuf.view(np.uint16)

# Setup recv buffer
shape = self._data_in_region(HALO, d, i.flip()).shape
recvbuf = np.ndarray(shape=shape, dtype=self.dtype)
if self.dtype == np.float16:
recvbuf = recvbuf.view(np.uint16)

# Communication
comm.Sendrecv(sendbuf, dest=dest, recvbuf=recvbuf, source=source)

# Scatter received data
if recvbuf is not None and source != MPI.PROC_NULL:
self._data_in_region(HALO, d, i.flip())[:] = recvbuf
self._data_in_region(HALO, d, i.flip())[:] = recvbuf.view(self.dtype)

self._is_halo_dirty = False

Expand Down
4 changes: 2 additions & 2 deletions tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,8 +567,8 @@ def test_equation(self, pickle):

eq = Eq(f, f+1, implicit_dims=xs)

pkl_eq = pickle0.dumps(eq)
new_eq = pickle0.loads(pkl_eq)
pkl_eq = pickle.dumps(eq)
new_eq = pickle.loads(pkl_eq)

assert new_eq.lhs.name == f.name
assert str(new_eq.rhs) == 'f(x) + 1'
Expand Down

0 comments on commit aef841b

Please sign in to comment.