From 6a8ce257c53d3af42f98a0919d488822639afc8b Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 2 Aug 2023 10:15:54 -0400 Subject: [PATCH 01/58] api: add support for complex dtype --- devito/data/allocators.py | 8 +++++-- devito/finite_differences/differentiable.py | 2 +- devito/operator/operator.py | 6 +++++ devito/passes/clusters/factorization.py | 4 ++-- devito/passes/iet/misc.py | 26 ++++++++++++++++++++- devito/symbolics/inspection.py | 5 ++++ devito/tools/dtypes_lowering.py | 4 +++- devito/types/basic.py | 23 +++++++++++++++--- 8 files changed, 68 insertions(+), 10 deletions(-) diff --git a/devito/data/allocators.py b/devito/data/allocators.py index aff28ef108..c887e85de0 100644 --- a/devito/data/allocators.py +++ b/devito/data/allocators.py @@ -92,8 +92,12 @@ def initialize(cls): return def alloc(self, shape, dtype, padding=0): - datasize = int(reduce(mul, shape)) - ctype = dtype_to_ctype(dtype) + # For complex number, allocate double the size of its real/imaginary part + alloc_dtype = dtype(0).real.__class__ + c_scale = 2 if np.issubdtype(dtype, np.complexfloating) else 1 + + datasize = int(reduce(mul, shape) * c_scale) + ctype = dtype_to_ctype(alloc_dtype) # Add padding, if any try: diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index 8a32df7289..234de306c7 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -73,7 +73,7 @@ def grid(self): @cached_property def dtype(self): - dtypes = {f.dtype for f in self.find(Indexed)} - {None} + dtypes = {f.dtype for f in self._functions} - {None} return infer_dtype(dtypes) @cached_property diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 1cf99625ca..631a5c0a2c 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -477,6 +477,12 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): # Lower IET to a target-specific IET graph = Graph(iet, **kwargs) + + # Complex header if needed. Needs to be done specialization + # as some specific cases requires complex to be loaded first + complex_include(graph) + + # Specialize graph = cls._specialize_iet(graph, **kwargs) # Instrument the IET for C-level profiling diff --git a/devito/passes/clusters/factorization.py b/devito/passes/clusters/factorization.py index 8007d3fee2..041a6d38c8 100644 --- a/devito/passes/clusters/factorization.py +++ b/devito/passes/clusters/factorization.py @@ -1,6 +1,7 @@ from collections import defaultdict from sympy import Add, Mul, S, collect +from sympy.core import NumberKind from devito.ir import cluster_pass from devito.symbolics import (BasicWrapperMixin, estimate_cost, reuse_if_untouched, @@ -195,8 +196,7 @@ def _collect_nested(expr, strategy): Recursion helper for `collect_nested`. """ # Return semantic (rebuilt expression, factorization candidates) - - if expr.is_Number: + if expr.kind is NumberKind: return expr, {'coeffs': expr} elif q_routine(expr): # E.g., a DefFunction diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 28e1cc4f7b..3a22890f74 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -3,7 +3,9 @@ import cgen import numpy as np import sympy +import numpy as np +from devito import configuration from devito.finite_differences import Max, Min from devito.finite_differences.differentiable import SafeInv from devito.ir import (Any, Forward, DummyExpr, Iteration, List, Prodder, @@ -18,7 +20,7 @@ from devito.types import FIndexed __all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions', - 'generate_macros', 'minimize_symbols'] + 'generate_macros', 'minimize_symbols', 'complex_include'] @iet_pass @@ -252,6 +254,28 @@ def minimize_symbols(iet): return iet, {} +@iet_pass +def complex_include(iet): + """ + Add headers for complex arithmetic + """ + if configuration['language'] == 'cuda': + lib = 'cuComplex.h' + elif configuration['language'] == 'hip': + lib = 'hip/hip_complex.h' + else: + lib = 'complex.h' + + functions = FindSymbols().visit(iet) + for f in functions: + try: + if np.issubdtype(f.dtype, np.complexfloating): + return iet, {'includes': (lib,)} + except TypeError: + pass + return iet, {} + + def remove_redundant_moddims(iet): key = lambda d: d.is_Modulo and d.origin is not None mds = [d for d in FindSymbols('dimensions').visit(iet) if key(d)] diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 18e2623764..d4273f8016 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -3,6 +3,8 @@ import numpy as np from sympy import (Function, Indexed, Integer, Mul, Number, Pow, S, Symbol, Tuple) +from sympy.core.operations import AssocOp +from sympy.core.numbers import ImaginaryUnit from devito.finite_differences import Derivative from devito.finite_differences.differentiable import IndexDerivative @@ -168,6 +170,7 @@ def _(expr, estimate, seen): return 0, True +@_estimate_cost.register(ImaginaryUnit) @_estimate_cost.register(Number) @_estimate_cost.register(ReservedWord) def _(expr, estimate, seen): @@ -190,6 +193,8 @@ def _(expr, estimate, seen): flops, flags = _estimate_cost.registry[object](expr, estimate, seen) if {S.One, S.NegativeOne}.intersection(expr.args): flops -= 1 + if ImaginaryUnit in expr.args: + flops *= 2 return flops, flags diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index b5b564a4d7..cb83147f14 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -133,6 +133,9 @@ def dtype_to_cstr(dtype): def dtype_to_ctype(dtype): """Translate numpy.dtype into a ctypes type.""" + if isinstance(dtype, CustomDtype): + return dtype + try: return ctypes_vector_mapper[dtype] except KeyError: @@ -232,7 +235,6 @@ def ctypes_to_cstr(ctype, toarray=None): retval = '%s[%d]' % (ctypes_to_cstr(ctype._type_, toarray), ctype._length_) elif ctype.__name__.startswith('c_'): name = ctype.__name__[2:] - # A primitive datatype # FIXME: Is there a better way of extracting the C typename ? # Here, we're following the ctypes convention that each basic type has diff --git a/devito/types/basic.py b/devito/types/basic.py index 4dcf1dad95..458f82ef4b 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -14,7 +14,8 @@ from devito.data import default_allocator from devito.parameters import configuration from devito.tools import (Pickable, as_tuple, ctypes_to_cstr, dtype_to_ctype, - frozendict, memoized_meth, sympy_mutex) + frozendict, memoized_meth, sympy_mutex, dtype_to_cstr, + CustomDtype) from devito.types.args import ArgProvider from devito.types.caching import Cached, Uncached from devito.types.lazy import Evaluable @@ -436,7 +437,16 @@ def _C_name(self): @property def _C_ctype(self): - return dtype_to_ctype(self.dtype) + if isinstance(self.dtype, CustomDtype): + return self.dtype + elif np.issubdtype(self.dtype, np.complexfloating): + rtype = self.dtype(0).real.__class__ + ctname = '%s _Complex' % dtype_to_cstr(rtype) + ctype = dtype_to_ctype(rtype) + r = type(ctname, (ctype,), {}) + return r + else: + return dtype_to_ctype(self.dtype) def _subs(self, old, new, **hints): """ @@ -1547,7 +1557,14 @@ def _C_name(self): @cached_property def _C_ctype(self): try: - return POINTER(dtype_to_ctype(self.dtype)) + if np.issubdtype(self.dtype, np.complexfloating): + rtype = self.dtype(0).real.__class__ + ctname = '%s _Complex' % dtype_to_cstr(rtype) + ctype = dtype_to_ctype(rtype) + r = type(ctname, (ctype,), {}) + return POINTER(r) + else: + return POINTER(dtype_to_ctype(self.dtype)) except TypeError: # `dtype` is a ctypes-derived type! return self.dtype From 6461869ea4731bdd1fd5ab3fbffc59b5b08066a5 Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 22 May 2024 08:01:27 -0400 Subject: [PATCH 02/58] api: fix printer for complex dtype --- devito/finite_differences/differentiable.py | 4 +-- devito/passes/iet/misc.py | 1 - devito/symbolics/inspection.py | 1 - devito/symbolics/printer.py | 10 ++++++++ devito/types/basic.py | 2 +- tests/test_operator.py | 28 +++++++++++++++++---- 6 files changed, 35 insertions(+), 11 deletions(-) diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index 234de306c7..922acf0f84 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -18,9 +18,7 @@ from devito.logger import warning from devito.tools import (as_tuple, filter_ordered, flatten, frozendict, infer_dtype, is_integer, split) -from devito.types import (Array, DimensionTuple, Evaluable, Indexed, - StencilDimension) -from devito.types.basic import AbstractFunction +from devito.types import Array, DimensionTuple, Evaluable, StencilDimension __all__ = ['Differentiable', 'DiffDerivative', 'IndexDerivative', 'EvalDerivative', 'Weights'] diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 3a22890f74..cfa849fdfc 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -3,7 +3,6 @@ import cgen import numpy as np import sympy -import numpy as np from devito import configuration from devito.finite_differences import Max, Min diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index d4273f8016..ac9e7850a5 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -3,7 +3,6 @@ import numpy as np from sympy import (Function, Indexed, Integer, Mul, Number, Pow, S, Symbol, Tuple) -from sympy.core.operations import AssocOp from sympy.core.numbers import ImaginaryUnit from devito.finite_differences import Derivative diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 2c366a389e..e33216ca43 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -47,6 +47,10 @@ def single_prec(self, expr=None): dtype = sympy_dtype(expr) if expr is not None else self.dtype return dtype in [np.float32, np.float16] + def complex_prec(self, expr=None): + dtype = sympy_dtype(expr) if expr is not None else self.dtype + return np.issubdtype(dtype, np.complexfloating) + def parenthesize(self, item, level, strict=False): if isinstance(item, BooleanFunction): return "(%s)" % self._print(item) @@ -115,6 +119,8 @@ def _print_math_func(self, expr, nest=False, known=None): if self.single_prec(expr): cname = '%sf' % cname + if self.complex_prec(expr): + cname = 'c%s' % cname if nest and len(expr.args) > 2: args = ', '.join([self._print(expr.args[0]), @@ -290,8 +296,12 @@ def _print_ComponentAccess(self, expr): def _print_TrigonometricFunction(self, expr): func_name = str(expr.func) + if self.single_prec(): func_name = '%sf' % func_name + if self.complex_prec(): + func_name = 'c%s' % func_name + return '%s(%s)' % (func_name, self._print(*expr.args)) def _print_DefFunction(self, expr): diff --git a/devito/types/basic.py b/devito/types/basic.py index 458f82ef4b..2f3bd9a16c 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -1559,7 +1559,7 @@ def _C_ctype(self): try: if np.issubdtype(self.dtype, np.complexfloating): rtype = self.dtype(0).real.__class__ - ctname = '%s _Complex' % dtype_to_cstr(rtype) + ctname = '%s complex' % dtype_to_cstr(rtype) ctype = dtype_to_ctype(rtype) r = type(ctname, (ctype,), {}) return POINTER(r) diff --git a/tests/test_operator.py b/tests/test_operator.py index d5759c1c92..db962b7d6b 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -9,7 +9,7 @@ SparseFunction, SparseTimeFunction, Dimension, error, SpaceDimension, NODE, CELL, dimensions, configuration, TensorFunction, TensorTimeFunction, VectorFunction, VectorTimeFunction, - div, grad, switchconfig) + div, grad, switchconfig, exp) from devito import Inc, Le, Lt, Ge, Gt # noqa from devito.exceptions import InvalidOperator from devito.finite_differences.differentiable import diff2sympy @@ -640,6 +640,24 @@ def test_tensor(self, func1): op2 = Operator([Eq(f, f.dx) for f in f1.values()]) assert str(op1.ccode) == str(op2.ccode) + def test_complex(self): + grid = Grid((5, 5)) + x, y = grid.dimensions + # Float32 complex is called complex64 in numpy + u = Function(name="u", grid=grid, dtype=np.complex64) + + eq = Eq(u, x + 1j*y + exp(1j + x.spacing)) + # Currently wrong alias type + op = Operator(eq, opt='noop') + op() + + # Check against numpy + dx = grid.spacing_map[x.spacing] + xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5)) + npres = xx + 1j*yy + np.exp(1j + dx) + + assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0) + class TestAllocation: @@ -724,10 +742,10 @@ def verify_parameters(self, parameters, expected): """ boilerplate = ['timers'] parameters = [p.name for p in parameters] - for exp in expected: - if exp not in parameters + boilerplate: - error("Missing parameter: %s" % exp) - assert exp in parameters + boilerplate + for expi in expected: + if expi not in parameters + boilerplate: + error("Missing parameter: %s" % expi) + assert expi in parameters + boilerplate extra = [p for p in parameters if p not in expected and p not in boilerplate] if len(extra) > 0: error("Redundant parameters: %s" % str(extra)) From e8d74dfea0fda367e107895b46313ddd5e0762e9 Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 22 May 2024 08:17:39 -0400 Subject: [PATCH 03/58] compiler: fix alias dtype with complex numbers --- devito/symbolics/inspection.py | 8 +++++++- devito/types/basic.py | 2 +- tests/test_gpu_common.py | 18 ++++++++++++++++++ tests/test_operator.py | 2 +- 4 files changed, 27 insertions(+), 3 deletions(-) diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index ac9e7850a5..6c006c0820 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -305,4 +305,10 @@ def sympy_dtype(expr, base=None): dtypes.add(i.dtype) except AttributeError: pass - return infer_dtype(dtypes) + dtype = infer_dtype(dtypes) + + # Promote if complex + if expr.has(ImaginaryUnit): + dtype = np.promote_types(dtype, np.complex64).type + + return dtype diff --git a/devito/types/basic.py b/devito/types/basic.py index 2f3bd9a16c..21b6ac589a 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -441,7 +441,7 @@ def _C_ctype(self): return self.dtype elif np.issubdtype(self.dtype, np.complexfloating): rtype = self.dtype(0).real.__class__ - ctname = '%s _Complex' % dtype_to_cstr(rtype) + ctname = '%s complex' % dtype_to_cstr(rtype) ctype = dtype_to_ctype(rtype) r = type(ctname, (ctype,), {}) return r diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index a22ab93df4..2ee7a38de1 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -76,6 +76,24 @@ def test_maxpar_option(self): assert trees[0][0] is trees[1][0] assert trees[0][1] is not trees[1][1] + def test_complex(self): + grid = Grid((5, 5)) + x, y = grid.dimensions + # Float32 complex is called complex64 in numpy + u = Function(name="u", grid=grid, dtype=np.complex64) + + eq = Eq(u, x + 1j*y + exp(1j + x.spacing)) + # Currently wrong alias type + op = Operator(eq) + op() + + # Check against numpy + dx = grid.spacing_map[x.spacing] + xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5)) + npres = xx + 1j*yy + np.exp(1j + dx) + + assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0) + class TestPassesOptional: diff --git a/tests/test_operator.py b/tests/test_operator.py index db962b7d6b..5d975685ce 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -648,7 +648,7 @@ def test_complex(self): eq = Eq(u, x + 1j*y + exp(1j + x.spacing)) # Currently wrong alias type - op = Operator(eq, opt='noop') + op = Operator(eq) op() # Check against numpy From f60e08bf746391abde447ed94e2ab0fc94429f36 Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 22 May 2024 08:25:51 -0400 Subject: [PATCH 04/58] api: move complex ctype to dtype lowering --- devito/operator/operator.py | 2 +- devito/passes/clusters/factorization.py | 3 +-- devito/passes/iet/misc.py | 24 +++++++++++++----------- devito/symbolics/printer.py | 3 +++ devito/tools/dtypes_lowering.py | 8 ++++++++ devito/types/basic.py | 23 +++-------------------- tests/test_gpu_common.py | 2 +- tests/test_operator.py | 1 + 8 files changed, 31 insertions(+), 35 deletions(-) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 631a5c0a2c..92cf1d6ef8 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -480,7 +480,7 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): # Complex header if needed. Needs to be done specialization # as some specific cases requires complex to be loaded first - complex_include(graph) + complex_include(graph, language=kwargs['language'], compiler=kwargs['compiler']) # Specialize graph = cls._specialize_iet(graph, **kwargs) diff --git a/devito/passes/clusters/factorization.py b/devito/passes/clusters/factorization.py index 041a6d38c8..45d140a253 100644 --- a/devito/passes/clusters/factorization.py +++ b/devito/passes/clusters/factorization.py @@ -1,7 +1,6 @@ from collections import defaultdict from sympy import Add, Mul, S, collect -from sympy.core import NumberKind from devito.ir import cluster_pass from devito.symbolics import (BasicWrapperMixin, estimate_cost, reuse_if_untouched, @@ -196,7 +195,7 @@ def _collect_nested(expr, strategy): Recursion helper for `collect_nested`. """ # Return semantic (rebuilt expression, factorization candidates) - if expr.kind is NumberKind: + if expr.is_Number: return expr, {'coeffs': expr} elif q_routine(expr): # E.g., a DefFunction diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index cfa849fdfc..bd413fe163 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -4,7 +4,6 @@ import numpy as np import sympy -from devito import configuration from devito.finite_differences import Max, Min from devito.finite_differences.differentiable import SafeInv from devito.ir import (Any, Forward, DummyExpr, Iteration, List, Prodder, @@ -253,25 +252,28 @@ def minimize_symbols(iet): return iet, {} +_complex_lib = {'cuda': 'cuComplex.h', 'hip': 'hip/hip_complex.h'} + + @iet_pass -def complex_include(iet): +def complex_include(iet, language, compiler): """ Add headers for complex arithmetic """ - if configuration['language'] == 'cuda': - lib = 'cuComplex.h' - elif configuration['language'] == 'hip': - lib = 'hip/hip_complex.h' - else: - lib = 'complex.h' + lib = _complex_lib.get(language, 'complex.h') - functions = FindSymbols().visit(iet) - for f in functions: + headers = {} + # For openacc (cpp) need to define constant _Complex_I that isn't found otherwise + if compiler._cpp: + headers = {('_Complex_I', ('1.0fi'))} + + for f in FindSymbols().visit(iet): try: if np.issubdtype(f.dtype, np.complexfloating): - return iet, {'includes': (lib,)} + return iet, {'includes': (lib,), 'headers': headers} except TypeError: pass + return iet, {} diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index e33216ca43..f191056c48 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -245,6 +245,9 @@ def _print_Float(self, expr): return rv + def _print_ImaginaryUnit(self, expr): + return '_Complex_I' + def _print_Differentiable(self, expr): return "(%s)" % self._print(expr._expr) diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index cb83147f14..43d1e02347 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -136,6 +136,14 @@ def dtype_to_ctype(dtype): if isinstance(dtype, CustomDtype): return dtype + # Complex data + if np.issubdtype(dtype, np.complexfloating): + rtype = dtype(0).real.__class__ + ctname = '%s _Complex' % dtype_to_cstr(rtype) + ctype = dtype_to_ctype(rtype) + r = type(ctname, (ctype,), {}) + return r + try: return ctypes_vector_mapper[dtype] except KeyError: diff --git a/devito/types/basic.py b/devito/types/basic.py index 21b6ac589a..4dcf1dad95 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -14,8 +14,7 @@ from devito.data import default_allocator from devito.parameters import configuration from devito.tools import (Pickable, as_tuple, ctypes_to_cstr, dtype_to_ctype, - frozendict, memoized_meth, sympy_mutex, dtype_to_cstr, - CustomDtype) + frozendict, memoized_meth, sympy_mutex) from devito.types.args import ArgProvider from devito.types.caching import Cached, Uncached from devito.types.lazy import Evaluable @@ -437,16 +436,7 @@ def _C_name(self): @property def _C_ctype(self): - if isinstance(self.dtype, CustomDtype): - return self.dtype - elif np.issubdtype(self.dtype, np.complexfloating): - rtype = self.dtype(0).real.__class__ - ctname = '%s complex' % dtype_to_cstr(rtype) - ctype = dtype_to_ctype(rtype) - r = type(ctname, (ctype,), {}) - return r - else: - return dtype_to_ctype(self.dtype) + return dtype_to_ctype(self.dtype) def _subs(self, old, new, **hints): """ @@ -1557,14 +1547,7 @@ def _C_name(self): @cached_property def _C_ctype(self): try: - if np.issubdtype(self.dtype, np.complexfloating): - rtype = self.dtype(0).real.__class__ - ctname = '%s complex' % dtype_to_cstr(rtype) - ctype = dtype_to_ctype(rtype) - r = type(ctname, (ctype,), {}) - return POINTER(r) - else: - return POINTER(dtype_to_ctype(self.dtype)) + return POINTER(dtype_to_ctype(self.dtype)) except TypeError: # `dtype` is a ctypes-derived type! return self.dtype diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index 2ee7a38de1..464ca52125 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -7,7 +7,7 @@ from conftest import assert_structure from devito import (Constant, Eq, Inc, Grid, Function, ConditionalDimension, Dimension, MatrixSparseTimeFunction, SparseTimeFunction, - SubDimension, SubDomain, SubDomainSet, TimeFunction, + SubDimension, SubDomain, SubDomainSet, TimeFunction, exp, Operator, configuration, switchconfig, TensorTimeFunction, Buffer, assign) from devito.arch import get_gpu_info, get_cpu_info, Device, Cpu64 diff --git a/tests/test_operator.py b/tests/test_operator.py index 5d975685ce..9cdf34e313 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -655,6 +655,7 @@ def test_complex(self): dx = grid.spacing_map[x.spacing] xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5)) npres = xx + 1j*yy + np.exp(1j + dx) + print(op) assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0) From 4ea805a3ce67dda7c1f1a12cb35248eabcf34e33 Mon Sep 17 00:00:00 2001 From: mloubout Date: Tue, 28 May 2024 13:00:56 -0400 Subject: [PATCH 05/58] compiler: generate std:complex for cpp compilers --- devito/ir/iet/visitors.py | 43 +++++++++++++++++++++++---------- devito/passes/iet/misc.py | 4 +-- devito/symbolics/printer.py | 8 ++++++ devito/tools/dtypes_lowering.py | 7 ++---- tests/test_gpu_common.py | 3 ++- tests/test_operator.py | 2 +- 6 files changed, 45 insertions(+), 22 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index cd405833f5..2210815c3c 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -10,6 +10,7 @@ import ctypes import cgen as c +import numpy as np from sympy import IndexedBase from sympy.core.function import Application @@ -188,6 +189,21 @@ def __init__(self, *args, compiler=None, **kwargs): } _restrict_keyword = 'restrict' + def _complex_type(self, ctypestr, dtype): + # Not complex + try: + if not np.issubdtype(dtype, np.complexfloating): + return ctypestr + except TypeError: + return ctypestr + # Complex only supported for float and double + if ctypestr not in ('float', 'double'): + return ctypestr + if self._compiler._cpp: + return 'std::complex<%s>' % ctypestr + else: + return '%s _Complex' % ctypestr + def _gen_struct_decl(self, obj, masked=()): """ Convert ctypes.Struct -> cgen.Structure. @@ -243,10 +259,10 @@ def _gen_value(self, obj, mode=1, masked=()): if getattr(obj.function, k, False) and v not in masked] if (obj._mem_stack or obj._mem_constant) and mode == 1: - strtype = obj._C_typedata + strtype = self._complex_type(obj._C_typedata, obj.dtype) strshape = ''.join('[%s]' % ccode(i) for i in obj.symbolic_shape) else: - strtype = ctypes_to_cstr(obj._C_ctype) + strtype = self._complex_type(ctypes_to_cstr(obj._C_ctype), obj.dtype) strshape = '' if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1: if not obj._mem_stack: @@ -376,10 +392,11 @@ def visit_tuple(self, o): def visit_PointerCast(self, o): f = o.function i = f.indexed + cstr = self._complex_type(i._C_typedata, i.dtype) if f.is_PointerArray: # lvalue - lvalue = c.Value(i._C_typedata, '**%s' % f.name) + lvalue = c.Value(cstr, '**%s' % f.name) # rvalue if isinstance(o.obj, ArrayObject): @@ -388,7 +405,7 @@ def visit_PointerCast(self, o): v = f._C_name else: assert False - rvalue = '(%s**) %s' % (i._C_typedata, v) + rvalue = '(%s**) %s' % (cstr, v) else: # lvalue @@ -399,10 +416,10 @@ def visit_PointerCast(self, o): if o.flat is None: shape = ''.join("[%s]" % ccode(i) for i in o.castshape) rshape = '(*)%s' % shape - lvalue = c.Value(i._C_typedata, '(*restrict %s)%s' % (v, shape)) + lvalue = c.Value(cstr, '(*restrict %s)%s' % (v, shape)) else: rshape = '*' - lvalue = c.Value(i._C_typedata, '*%s' % v) + lvalue = c.Value(cstr, '*%s' % v) if o.alignment and f._data_alignment: lvalue = c.AlignedAttribute(f._data_alignment, lvalue) @@ -415,14 +432,14 @@ def visit_PointerCast(self, o): else: assert False - rvalue = '(%s %s) %s->%s' % (i._C_typedata, rshape, f._C_name, v) + rvalue = '(%s %s) %s->%s' % (cstr, rshape, f._C_name, v) else: if isinstance(o.obj, Pointer): v = o.obj.name else: v = f._C_name - rvalue = '(%s %s) %s' % (i._C_typedata, rshape, v) + rvalue = '(%s %s) %s' % (cstr, rshape, v) return c.Initializer(lvalue, rvalue) @@ -430,15 +447,15 @@ def visit_Dereference(self, o): a0, a1 = o.functions if a1.is_PointerArray or a1.is_TempFunction: i = a1.indexed + cstr = self._complex_type(i._C_typedata, i.dtype) if o.flat is None: shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:]) - rvalue = '(%s (*)%s) %s[%s]' % (i._C_typedata, shape, a1.name, + rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name, a1.dim.name) - lvalue = c.Value(i._C_typedata, - '(*restrict %s)%s' % (a0.name, shape)) + lvalue = c.Value(cstr, '(*restrict %s)%s' % (a0.name, shape)) else: - rvalue = '(%s *) %s[%s]' % (i._C_typedata, a1.name, a1.dim.name) - lvalue = c.Value(i._C_typedata, '*restrict %s' % a0.name) + rvalue = '(%s *) %s[%s]' % (cstr, a1.name, a1.dim.name) + lvalue = c.Value(cstr, '*restrict %s' % a0.name) if a0._data_alignment: lvalue = c.AlignedAttribute(a0._data_alignment, lvalue) else: diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index bd413fe163..efaf0d1e9b 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -260,12 +260,12 @@ def complex_include(iet, language, compiler): """ Add headers for complex arithmetic """ - lib = _complex_lib.get(language, 'complex.h') + lib = _complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h') headers = {} # For openacc (cpp) need to define constant _Complex_I that isn't found otherwise if compiler._cpp: - headers = {('_Complex_I', ('1.0fi'))} + headers = {('_Complex_I', ('std::complex(0.0f, 1.0f)'))} for f in FindSymbols().visit(iet): try: diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index f191056c48..b06477bad3 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -299,8 +299,16 @@ def _print_ComponentAccess(self, expr): def _print_TrigonometricFunction(self, expr): func_name = str(expr.func) +<<<<<<< HEAD if self.single_prec(): +======= + dtype = self.dtype + if np.issubdtype(dtype, np.complexfloating): + func_name = 'c%s' % func_name + dtype = self.dtype(0).real.dtype.type + if dtype == np.float32: +>>>>>>> 75d50a431 (compiler: generate std:complex for cpp compilers) func_name = '%sf' % func_name if self.complex_prec(): func_name = 'c%s' % func_name diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index 43d1e02347..2fd3175f76 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -139,10 +139,7 @@ def dtype_to_ctype(dtype): # Complex data if np.issubdtype(dtype, np.complexfloating): rtype = dtype(0).real.__class__ - ctname = '%s _Complex' % dtype_to_cstr(rtype) - ctype = dtype_to_ctype(rtype) - r = type(ctname, (ctype,), {}) - return r + return dtype_to_ctype(rtype) try: return ctypes_vector_mapper[dtype] @@ -219,7 +216,7 @@ class c_restrict_void_p(ctypes.c_void_p): # *** ctypes lowering -def ctypes_to_cstr(ctype, toarray=None): +def ctypes_to_cstr(ctype, toarray=None, cpp=False): """Translate ctypes types into C strings.""" if ctype in ctypes_vector_mapper.values(): retval = ctype.__name__ diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index 464ca52125..0705167b8b 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -2,6 +2,7 @@ import pytest import numpy as np +import sympy import scipy.sparse from conftest import assert_structure @@ -82,7 +83,7 @@ def test_complex(self): # Float32 complex is called complex64 in numpy u = Function(name="u", grid=grid, dtype=np.complex64) - eq = Eq(u, x + 1j*y + exp(1j + x.spacing)) + eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) # Currently wrong alias type op = Operator(eq) op() diff --git a/tests/test_operator.py b/tests/test_operator.py index 9cdf34e313..61b117bcc6 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -646,7 +646,7 @@ def test_complex(self): # Float32 complex is called complex64 in numpy u = Function(name="u", grid=grid, dtype=np.complex64) - eq = Eq(u, x + 1j*y + exp(1j + x.spacing)) + eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) # Currently wrong alias type op = Operator(eq) op() From f39c23c1ec99c3583ca451fb0e6f4fe7b57f02cb Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 30 May 2024 12:33:30 -0400 Subject: [PATCH 06/58] compiler: add std::complex arithmetic defs for unsupported types --- devito/ir/iet/visitors.py | 3 ++- devito/passes/iet/misc.py | 33 +++++++++++++++++++++++++++++++-- devito/symbolics/printer.py | 10 +--------- 3 files changed, 34 insertions(+), 12 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 2210815c3c..d32b6b9eaf 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -14,6 +14,7 @@ from sympy import IndexedBase from sympy.core.function import Application +from devito.parameters import configuration from devito.exceptions import CompilationError from devito.ir.iet.nodes import (Node, Iteration, Expression, ExpressionBundle, Call, Lambda, BlankLine, Section, ListMajor) @@ -177,7 +178,7 @@ class CGen(Visitor): def __init__(self, *args, compiler=None, **kwargs): super().__init__(*args, **kwargs) - self._compiler = compiler + self._compiler = compiler or configuration['compiler'] # The following mappers may be customized by subclasses (that is, # backend-specific CGen-erators) diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index efaf0d1e9b..090c3e2d38 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -260,17 +260,26 @@ def complex_include(iet, language, compiler): """ Add headers for complex arithmetic """ - lib = _complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h') + lib = (_complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h'),) headers = {} + # For openacc (cpp) need to define constant _Complex_I that isn't found otherwise if compiler._cpp: + # Constant I headers = {('_Complex_I', ('std::complex(0.0f, 1.0f)'))} + # Mix arithmetic definitions + dest = compiler.get_jit_dir() + hfile = dest.joinpath('stdcomplex_arith.h') + if not hfile.is_file(): + with open(str(hfile), 'w') as ff: + ff.write(str(_stdcomplex_defs)) + lib += (str(hfile),) for f in FindSymbols().visit(iet): try: if np.issubdtype(f.dtype, np.complexfloating): - return iet, {'includes': (lib,), 'headers': headers} + return iet, {'includes': lib, 'headers': headers} except TypeError: pass @@ -374,3 +383,23 @@ def _rename_subdims(target, dimensions): return {d: d._rebuild(d.root.name) for d in dims if d.root not in dimensions and names.count(d.root.name) < 2} + + +_stdcomplex_defs = """ +#include + +template +std::complex<_Tp> operator * (const _Ti & a, const std::complex<_Tp> & b){ + return std::complex<_Tp>(b.real() * a, b.imag() * a); +} + +template +std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){ + return std::complex<_Tp>(b.real() / a, b.imag() / a); +} + +template +std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){ + return std::complex<_Tp>(b.real() + a, b.imag()); +} +""" diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index b06477bad3..6e7f33bba2 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -45,7 +45,7 @@ def compiler(self): def single_prec(self, expr=None): dtype = sympy_dtype(expr) if expr is not None else self.dtype - return dtype in [np.float32, np.float16] + return dtype in [np.float32, np.float16, np.complex64] def complex_prec(self, expr=None): dtype = sympy_dtype(expr) if expr is not None else self.dtype @@ -299,16 +299,8 @@ def _print_ComponentAccess(self, expr): def _print_TrigonometricFunction(self, expr): func_name = str(expr.func) -<<<<<<< HEAD if self.single_prec(): -======= - dtype = self.dtype - if np.issubdtype(dtype, np.complexfloating): - func_name = 'c%s' % func_name - dtype = self.dtype(0).real.dtype.type - if dtype == np.float32: ->>>>>>> 75d50a431 (compiler: generate std:complex for cpp compilers) func_name = '%sf' % func_name if self.complex_prec(): func_name = 'c%s' % func_name From 61651a0ed71b93f0ac0464f5469a721689cc7a00 Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 30 May 2024 14:08:12 -0400 Subject: [PATCH 07/58] compiler: fix alias dtype with complex numbers --- devito/__init__.py | 7 ++++--- devito/arch/compiler.py | 24 +++++++++++++++++++-- devito/ir/iet/visitors.py | 34 +++++++++++++----------------- devito/operator/operator.py | 7 ++++--- devito/passes/iet/misc.py | 37 +++++++++++++++++++++++---------- devito/symbolics/inspection.py | 6 ++++-- devito/tools/dtypes_lowering.py | 12 ++++++++--- tests/test_gpu_common.py | 7 ++++--- tests/test_operator.py | 8 +++---- 9 files changed, 91 insertions(+), 51 deletions(-) diff --git a/devito/__init__.py b/devito/__init__.py index b8d7621297..52918db94a 100644 --- a/devito/__init__.py +++ b/devito/__init__.py @@ -56,7 +56,8 @@ def reinit_compiler(val): """ Re-initialize the Compiler. """ - configuration['compiler'].__init__(suffix=configuration['compiler'].suffix, + configuration['compiler'].__init__(name=configuration['compiler'].name, + suffix=configuration['compiler'].suffix, mpi=configuration['mpi']) return val @@ -64,8 +65,8 @@ def reinit_compiler(val): # Setup target platform and compiler configuration.add('platform', 'cpu64', list(platform_registry), callback=lambda i: platform_registry[i]()) -configuration.add('compiler', 'custom', compiler_registry, - callback=lambda i: compiler_registry[i]()) +configuration.add('compiler', 'custom', list(compiler_registry), + callback=lambda i: compiler_registry[i](name=i)) # Setup language for shared-memory parallelism preprocessor = lambda i: {0: 'C', 1: 'openmp'}.get(i, i) # Handles DEVITO_OPENMP deprec diff --git a/devito/arch/compiler.py b/devito/arch/compiler.py index 58ee41c437..d405a4326b 100644 --- a/devito/arch/compiler.py +++ b/devito/arch/compiler.py @@ -183,6 +183,8 @@ def __init__(self): _cpp = False def __init__(self, **kwargs): + self._name = kwargs.pop('name', self.__class__.__name__) + super().__init__(**kwargs) self.__lookup_cmds__() @@ -226,13 +228,13 @@ def __new_with__(self, **kwargs): Create a new Compiler from an existing one, inherenting from it the flags that are not specified via ``kwargs``. """ - return self.__class__(suffix=kwargs.pop('suffix', self.suffix), + return self.__class__(name=self.name, suffix=kwargs.pop('suffix', self.suffix), mpi=kwargs.pop('mpi', configuration['mpi']), **kwargs) @property def name(self): - return self.__class__.__name__ + return self._name @property def version(self): @@ -248,6 +250,20 @@ def version(self): return version + @property + def _complex_ctype(self): + """ + Type definition for complex numbers. THese two cases cover 99% of the cases since + - Hip is now using std::complex +https://rocm.docs.amd.com/en/docs-5.1.3/CHANGELOG.html#hip-api-deprecations-and-warnings + - Sycl supports std::complex + - C's _Complex is part of C99 + """ + if self._cpp: + return lambda dtype: 'std::complex<%s>' % str(dtype) + else: + return lambda dtype: '%s _Complex' % str(dtype) + def get_version(self): result, stdout, stderr = call_capture_output((self.cc, "--version")) if result != 0: @@ -716,6 +732,10 @@ def __lookup_cmds__(self): self.MPICC = 'nvcc' self.MPICXX = 'nvcc' + @property + def _complex_ctype(self): + return lambda dtype: 'thrust::complex<%s>' % str(dtype) + class HipCompiler(Compiler): diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index d32b6b9eaf..a93c2e1385 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -10,11 +10,10 @@ import ctypes import cgen as c -import numpy as np from sympy import IndexedBase from sympy.core.function import Application -from devito.parameters import configuration +from devito.parameters import configuration, switchconfig from devito.exceptions import CompilationError from devito.ir.iet.nodes import (Node, Iteration, Expression, ExpressionBundle, Call, Lambda, BlankLine, Section, ListMajor) @@ -190,20 +189,15 @@ def __init__(self, *args, compiler=None, **kwargs): } _restrict_keyword = 'restrict' - def _complex_type(self, ctypestr, dtype): - # Not complex - try: - if not np.issubdtype(dtype, np.complexfloating): - return ctypestr - except TypeError: - return ctypestr - # Complex only supported for float and double - if ctypestr not in ('float', 'double'): - return ctypestr - if self._compiler._cpp: - return 'std::complex<%s>' % ctypestr - else: - return '%s _Complex' % ctypestr + @property + def compiler(self): + return self._compiler + + def visit(self, o, *args, **kwargs): + # Make sure the visitor always is within the generating compiler + # in case the configuration is accessed + with switchconfig(compiler=self.compiler.name): + return super().visit(o, *args, **kwargs) def _gen_struct_decl(self, obj, masked=()): """ @@ -260,10 +254,10 @@ def _gen_value(self, obj, mode=1, masked=()): if getattr(obj.function, k, False) and v not in masked] if (obj._mem_stack or obj._mem_constant) and mode == 1: - strtype = self._complex_type(obj._C_typedata, obj.dtype) + strtype = obj._C_typedata strshape = ''.join('[%s]' % ccode(i) for i in obj.symbolic_shape) else: - strtype = self._complex_type(ctypes_to_cstr(obj._C_ctype), obj.dtype) + strtype = ctypes_to_cstr(obj._C_ctype) strshape = '' if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1: if not obj._mem_stack: @@ -393,7 +387,7 @@ def visit_tuple(self, o): def visit_PointerCast(self, o): f = o.function i = f.indexed - cstr = self._complex_type(i._C_typedata, i.dtype) + cstr = i._C_typedata if f.is_PointerArray: # lvalue @@ -448,7 +442,7 @@ def visit_Dereference(self, o): a0, a1 = o.functions if a1.is_PointerArray or a1.is_TempFunction: i = a1.indexed - cstr = self._complex_type(i._C_typedata, i.dtype) + cstr = i._C_typedata if o.flat is None: shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:]) rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name, diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 92cf1d6ef8..a170bcd318 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -478,8 +478,8 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): # Lower IET to a target-specific IET graph = Graph(iet, **kwargs) - # Complex header if needed. Needs to be done specialization - # as some specific cases requires complex to be loaded first + # Complex header if needed. Needs to be done before specialization + # as some specific cases require complex to be loaded first complex_include(graph, language=kwargs['language'], compiler=kwargs['compiler']) # Specialize @@ -1394,7 +1394,8 @@ def parse_kwargs(**kwargs): raise InvalidOperator("Illegal `compiler=%s`" % str(compiler)) kwargs['compiler'] = compiler_registry[compiler](platform=kwargs['platform'], language=kwargs['language'], - mpi=configuration['mpi']) + mpi=configuration['mpi'], + name=compiler) elif any([platform, language]): kwargs['compiler'] =\ configuration['compiler'].__new_with__(platform=kwargs['platform'], diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 090c3e2d38..c0e1e8f2b3 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -14,7 +14,7 @@ from devito.ir.iet.efunc import DeviceFunction, EntryFunction from devito.symbolics import (ValueLimit, evalrel, has_integer_args, limits_mapper, ccode) -from devito.tools import Bunch, as_mapper, filter_ordered, split +from devito.tools import Bunch, as_mapper, filter_ordered, split, dtype_to_cstr from devito.types import FIndexed __all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions', @@ -252,7 +252,7 @@ def minimize_symbols(iet): return iet, {} -_complex_lib = {'cuda': 'cuComplex.h', 'hip': 'hip/hip_complex.h'} +_complex_lib = {'cuda': 'thrust/complex.h'} @iet_pass @@ -260,14 +260,20 @@ def complex_include(iet, language, compiler): """ Add headers for complex arithmetic """ + # Check if there is complex numbers that always take dtype precedence + max_dtype = np.result_type(*[f.dtype for f in FindSymbols().visit(iet)]) + if not np.issubdtype(max_dtype, np.complexfloating): + return iet, {} + lib = (_complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h'),) headers = {} # For openacc (cpp) need to define constant _Complex_I that isn't found otherwise if compiler._cpp: + c_str = dtype_to_cstr(max_dtype.type(0).real.dtype.type) # Constant I - headers = {('_Complex_I', ('std::complex(0.0f, 1.0f)'))} + headers = {('_Complex_I', ('std::complex<%s>(0.0, 1.0)' % c_str))} # Mix arithmetic definitions dest = compiler.get_jit_dir() hfile = dest.joinpath('stdcomplex_arith.h') @@ -276,14 +282,7 @@ def complex_include(iet, language, compiler): ff.write(str(_stdcomplex_defs)) lib += (str(hfile),) - for f in FindSymbols().visit(iet): - try: - if np.issubdtype(f.dtype, np.complexfloating): - return iet, {'includes': lib, 'headers': headers} - except TypeError: - pass - - return iet, {} + return iet, {'includes': lib, 'headers': headers} def remove_redundant_moddims(iet): @@ -393,8 +392,19 @@ def _rename_subdims(target, dimensions): return std::complex<_Tp>(b.real() * a, b.imag() * a); } +template +std::complex<_Tp> operator * (const std::complex<_Tp> & b, const _Ti & a){ + return std::complex<_Tp>(b.real() * a, b.imag() * a); +} + template std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){ + _Tp denom = b.real() * b.real () + b.imag() * b.imag() + return std::complex<_Tp>(b.real() * a / denom, - b.imag() * a / denom); +} + +template +std::complex<_Tp> operator / (const std::complex<_Tp> & b, const _Ti & a){ return std::complex<_Tp>(b.real() / a, b.imag() / a); } @@ -402,4 +412,9 @@ def _rename_subdims(target, dimensions): std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){ return std::complex<_Tp>(b.real() + a, b.imag()); } + +template +std::complex<_Tp> operator + (const std::complex<_Tp> & b, const _Ti & a){ + return std::complex<_Tp>(b.real() + a, b.imag()); +} """ diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 6c006c0820..57aead89d2 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -305,10 +305,12 @@ def sympy_dtype(expr, base=None): dtypes.add(i.dtype) except AttributeError: pass + dtype = infer_dtype(dtypes) - # Promote if complex - if expr.has(ImaginaryUnit): + # Promote if we missed complex number, i.e f + I + is_im = np.issubdtype(dtype, np.complexfloating) + if expr.has(ImaginaryUnit) and not is_im: dtype = np.promote_types(dtype, np.complex64).type return dtype diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index 2fd3175f76..117feaff0b 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -139,7 +139,12 @@ def dtype_to_ctype(dtype): # Complex data if np.issubdtype(dtype, np.complexfloating): rtype = dtype(0).real.__class__ - return dtype_to_ctype(rtype) + from devito import configuration + make = configuration['compiler']._complex_ctype + ctname = make(dtype_to_cstr(rtype)) + ctype = dtype_to_ctype(rtype) + r = type(ctname, (ctype,), {}) + return r try: return ctypes_vector_mapper[dtype] @@ -216,7 +221,7 @@ class c_restrict_void_p(ctypes.c_void_p): # *** ctypes lowering -def ctypes_to_cstr(ctype, toarray=None, cpp=False): +def ctypes_to_cstr(ctype, toarray=None): """Translate ctypes types into C strings.""" if ctype in ctypes_vector_mapper.values(): retval = ctype.__name__ @@ -310,7 +315,8 @@ def infer_dtype(dtypes): # Resolve the vector types, if any dtypes = {dtypes_vector_mapper.get_base_dtype(i, i) for i in dtypes} - fdtypes = {i for i in dtypes if np.issubdtype(i, np.floating)} + fdtypes = {i for i in dtypes if np.issubdtype(i, np.floating) or + np.issubdtype(i, np.complexfloating)} if len(fdtypes) > 1: return max(fdtypes, key=lambda i: np.dtype(i).itemsize) elif len(fdtypes) == 1: diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index 0705167b8b..45b864df05 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -77,15 +77,16 @@ def test_maxpar_option(self): assert trees[0][0] is trees[1][0] assert trees[0][1] is not trees[1][1] - def test_complex(self): + @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) + def test_complex(self, dtype): grid = Grid((5, 5)) x, y = grid.dimensions - # Float32 complex is called complex64 in numpy - u = Function(name="u", grid=grid, dtype=np.complex64) + u = Function(name="u", grid=grid, dtype=dtype) eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) # Currently wrong alias type op = Operator(eq) + print(op) op() # Check against numpy diff --git a/tests/test_operator.py b/tests/test_operator.py index 61b117bcc6..c1a8809379 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -640,22 +640,22 @@ def test_tensor(self, func1): op2 = Operator([Eq(f, f.dx) for f in f1.values()]) assert str(op1.ccode) == str(op2.ccode) - def test_complex(self): + @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) + def test_complex(self, dtype): grid = Grid((5, 5)) x, y = grid.dimensions - # Float32 complex is called complex64 in numpy - u = Function(name="u", grid=grid, dtype=np.complex64) + u = Function(name="u", grid=grid, dtype=dtype) eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) # Currently wrong alias type op = Operator(eq) + # print(op) op() # Check against numpy dx = grid.spacing_map[x.spacing] xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5)) npres = xx + 1j*yy + np.exp(1j + dx) - print(op) assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0) From 55759e2bdcbe0a0c8c14ccd5d06caa6959d8fdb7 Mon Sep 17 00:00:00 2001 From: mloubout Date: Fri, 31 May 2024 09:58:54 -0400 Subject: [PATCH 08/58] compiler: fix internal language specific types and cast wip --- devito/arch/compiler.py | 3 +- devito/operator/operator.py | 2 +- devito/passes/iet/__init__.py | 1 + devito/passes/iet/misc.py | 71 +----------------------------- devito/symbolics/extended_sympy.py | 29 +++++++++++- tests/test_gpu_common.py | 2 - tests/test_operator.py | 2 - 7 files changed, 33 insertions(+), 77 deletions(-) diff --git a/devito/arch/compiler.py b/devito/arch/compiler.py index d405a4326b..15c49270f3 100644 --- a/devito/arch/compiler.py +++ b/devito/arch/compiler.py @@ -253,7 +253,7 @@ def version(self): @property def _complex_ctype(self): """ - Type definition for complex numbers. THese two cases cover 99% of the cases since + Type definition for complex numbers. These two cases cover 99% of the cases since - Hip is now using std::complex https://rocm.docs.amd.com/en/docs-5.1.3/CHANGELOG.html#hip-api-deprecations-and-warnings - Sycl supports std::complex @@ -1033,6 +1033,7 @@ def __contains__(self, k): 'nvc++': NvidiaCompiler, 'nvidia': NvidiaCompiler, 'cuda': CudaCompiler, + 'nvcc': CudaCompiler, 'osx': ClangCompiler, 'intel': OneapiCompiler, 'icx': OneapiCompiler, diff --git a/devito/operator/operator.py b/devito/operator/operator.py index a170bcd318..a5d3cae070 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -480,7 +480,7 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): # Complex header if needed. Needs to be done before specialization # as some specific cases require complex to be loaded first - complex_include(graph, language=kwargs['language'], compiler=kwargs['compiler']) + include_complex(graph, language=kwargs['language'], compiler=kwargs['compiler']) # Specialize graph = cls._specialize_iet(graph, **kwargs) diff --git a/devito/passes/iet/__init__.py b/devito/passes/iet/__init__.py index c09db00c9b..6b4ada0b73 100644 --- a/devito/passes/iet/__init__.py +++ b/devito/passes/iet/__init__.py @@ -8,3 +8,4 @@ from .instrument import * # noqa from .languages import * # noqa from .errors import * # noqa +from .complex import * # noqa diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index c0e1e8f2b3..e49b13de72 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -18,7 +18,7 @@ from devito.types import FIndexed __all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions', - 'generate_macros', 'minimize_symbols', 'complex_include'] + 'generate_macros', 'minimize_symbols'] @iet_pass @@ -252,39 +252,6 @@ def minimize_symbols(iet): return iet, {} -_complex_lib = {'cuda': 'thrust/complex.h'} - - -@iet_pass -def complex_include(iet, language, compiler): - """ - Add headers for complex arithmetic - """ - # Check if there is complex numbers that always take dtype precedence - max_dtype = np.result_type(*[f.dtype for f in FindSymbols().visit(iet)]) - if not np.issubdtype(max_dtype, np.complexfloating): - return iet, {} - - lib = (_complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h'),) - - headers = {} - - # For openacc (cpp) need to define constant _Complex_I that isn't found otherwise - if compiler._cpp: - c_str = dtype_to_cstr(max_dtype.type(0).real.dtype.type) - # Constant I - headers = {('_Complex_I', ('std::complex<%s>(0.0, 1.0)' % c_str))} - # Mix arithmetic definitions - dest = compiler.get_jit_dir() - hfile = dest.joinpath('stdcomplex_arith.h') - if not hfile.is_file(): - with open(str(hfile), 'w') as ff: - ff.write(str(_stdcomplex_defs)) - lib += (str(hfile),) - - return iet, {'includes': lib, 'headers': headers} - - def remove_redundant_moddims(iet): key = lambda d: d.is_Modulo and d.origin is not None mds = [d for d in FindSymbols('dimensions').visit(iet) if key(d)] @@ -382,39 +349,3 @@ def _rename_subdims(target, dimensions): return {d: d._rebuild(d.root.name) for d in dims if d.root not in dimensions and names.count(d.root.name) < 2} - - -_stdcomplex_defs = """ -#include - -template -std::complex<_Tp> operator * (const _Ti & a, const std::complex<_Tp> & b){ - return std::complex<_Tp>(b.real() * a, b.imag() * a); -} - -template -std::complex<_Tp> operator * (const std::complex<_Tp> & b, const _Ti & a){ - return std::complex<_Tp>(b.real() * a, b.imag() * a); -} - -template -std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){ - _Tp denom = b.real() * b.real () + b.imag() * b.imag() - return std::complex<_Tp>(b.real() * a / denom, - b.imag() * a / denom); -} - -template -std::complex<_Tp> operator / (const std::complex<_Tp> & b, const _Ti & a){ - return std::complex<_Tp>(b.real() / a, b.imag() / a); -} - -template -std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){ - return std::complex<_Tp>(b.real() + a, b.imag()); -} - -template -std::complex<_Tp> operator + (const std::complex<_Tp> & b, const _Ti & a){ - return std::complex<_Tp>(b.real() + a, b.imag()); -} -""" diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 4087bbc72c..72483e2004 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -7,6 +7,7 @@ from sympy import Expr, Function, Number, Tuple, sympify from sympy.core.decorators import call_highest_priority +from devito import configuration from devito.finite_differences.elementary import Min, Max from devito.tools import (Pickable, Bunch, as_tuple, is_integer, float2, # noqa float3, float4, double2, double3, double4, int2, int3, @@ -819,6 +820,20 @@ class VOID(Cast): _base_typ = 'void' +class CFLOAT(Cast): + + @property + def _base_typ(self): + return configuration['compiler']._complex_ctype('float') + + +class CDOUBLE(Cast): + + @property + def _base_typ(self): + return configuration['compiler']._complex_ctype('double') + + class CHARP(CastStar): base = CHAR @@ -835,6 +850,14 @@ class USHORTP(CastStar): base = USHORT +class CFLOATP(CastStar): + base = CFLOAT + + +class CDOUBLEP(CastStar): + base = CDOUBLE + + cast_mapper = { np.int8: CHAR, np.uint8: UCHAR, @@ -847,6 +870,8 @@ class USHORTP(CastStar): np.float32: FLOAT, # noqa float: DOUBLE, # noqa np.float64: DOUBLE, # noqa + np.complex64: CFLOAT, # noqa + np.complex128: CDOUBLE, # noqa (np.int8, '*'): CHARP, (np.uint8, '*'): UCHARP, @@ -857,7 +882,9 @@ class USHORTP(CastStar): (np.int64, '*'): INTP, # noqa (np.float32, '*'): FLOATP, # noqa (float, '*'): DOUBLEP, # noqa - (np.float64, '*'): DOUBLEP # noqa + (np.float64, '*'): DOUBLEP, # noqa + (np.complex64, '*'): CFLOATP, # noqa + (np.complex128, '*'): CDOUBLEP, # noqa } for base_name in ['int', 'float', 'double']: diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index 45b864df05..bc0d02fc61 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -84,9 +84,7 @@ def test_complex(self, dtype): u = Function(name="u", grid=grid, dtype=dtype) eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) - # Currently wrong alias type op = Operator(eq) - print(op) op() # Check against numpy diff --git a/tests/test_operator.py b/tests/test_operator.py index c1a8809379..283249aac1 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -647,9 +647,7 @@ def test_complex(self, dtype): u = Function(name="u", grid=grid, dtype=dtype) eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) - # Currently wrong alias type op = Operator(eq) - # print(op) op() # Check against numpy From e8fc1bd0fd0c0c3cfeedda3fc711e41115085aa9 Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 20 Jun 2024 10:28:38 -0400 Subject: [PATCH 09/58] compiler: rework dtype lowering --- devito/arch/compiler.py | 20 +--- devito/ir/iet/visitors.py | 2 +- devito/operator/operator.py | 4 - devito/passes/iet/__init__.py | 2 +- devito/passes/iet/definitions.py | 12 ++- devito/passes/iet/dtypes.py | 58 +++++++++++ devito/passes/iet/langbase.py | 11 ++ devito/passes/iet/languages/C.py | 12 ++- devito/passes/iet/languages/CXX.py | 69 +++++++++++++ devito/passes/iet/languages/openacc.py | 5 +- devito/passes/iet/misc.py | 2 +- devito/symbolics/__init__.py | 1 + devito/symbolics/extended_dtypes.py | 123 +++++++++++++++++++++++ devito/symbolics/extended_sympy.py | 134 +------------------------ devito/symbolics/inspection.py | 3 +- devito/symbolics/printer.py | 12 ++- devito/tools/dtypes_lowering.py | 24 +++-- devito/types/basic.py | 33 ++++-- devito/types/misc.py | 2 +- 19 files changed, 344 insertions(+), 185 deletions(-) create mode 100644 devito/passes/iet/dtypes.py create mode 100644 devito/passes/iet/languages/CXX.py create mode 100644 devito/symbolics/extended_dtypes.py diff --git a/devito/arch/compiler.py b/devito/arch/compiler.py index 15c49270f3..32c563c041 100644 --- a/devito/arch/compiler.py +++ b/devito/arch/compiler.py @@ -250,20 +250,6 @@ def version(self): return version - @property - def _complex_ctype(self): - """ - Type definition for complex numbers. These two cases cover 99% of the cases since - - Hip is now using std::complex -https://rocm.docs.amd.com/en/docs-5.1.3/CHANGELOG.html#hip-api-deprecations-and-warnings - - Sycl supports std::complex - - C's _Complex is part of C99 - """ - if self._cpp: - return lambda dtype: 'std::complex<%s>' % str(dtype) - else: - return lambda dtype: '%s _Complex' % str(dtype) - def get_version(self): result, stdout, stderr = call_capture_output((self.cc, "--version")) if result != 0: @@ -618,7 +604,7 @@ def __init_finalize__(self, **kwargs): self.cflags.remove('-O3') self.cflags.remove('-Wall') - self.cflags.append('-std=c++11') + self.cflags.append('-std=c++14') language = kwargs.pop('language', configuration['language']) platform = kwargs.pop('platform', configuration['platform']) @@ -732,10 +718,6 @@ def __lookup_cmds__(self): self.MPICC = 'nvcc' self.MPICXX = 'nvcc' - @property - def _complex_ctype(self): - return lambda dtype: 'thrust::complex<%s>' % str(dtype) - class HipCompiler(Compiler): diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index a93c2e1385..a288dd065a 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -602,7 +602,7 @@ def visit_MultiTraversable(self, o): return c.Collection(body) def visit_UsingNamespace(self, o): - return c.Statement('using namespace %s' % ccode(o.namespace)) + return c.Statement('using namespace %s' % str(o.namespace)) def visit_Lambda(self, o): body = [] diff --git a/devito/operator/operator.py b/devito/operator/operator.py index a5d3cae070..ef77373f8c 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -478,10 +478,6 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): # Lower IET to a target-specific IET graph = Graph(iet, **kwargs) - # Complex header if needed. Needs to be done before specialization - # as some specific cases require complex to be loaded first - include_complex(graph, language=kwargs['language'], compiler=kwargs['compiler']) - # Specialize graph = cls._specialize_iet(graph, **kwargs) diff --git a/devito/passes/iet/__init__.py b/devito/passes/iet/__init__.py index 6b4ada0b73..1cdb97c794 100644 --- a/devito/passes/iet/__init__.py +++ b/devito/passes/iet/__init__.py @@ -8,4 +8,4 @@ from .instrument import * # noqa from .languages import * # noqa from .errors import * # noqa -from .complex import * # noqa +from .dtypes import * # noqa diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index 3532169754..b4f4aca101 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -13,6 +13,7 @@ FindNodes, FindSymbols, MapExprStmts, Transformer, make_callable) from devito.passes import is_gpu_create +from devito.passes.iet.dtypes import lower_complex from devito.passes.iet.engine import iet_pass from devito.passes.iet.langbase import LangBB from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer, @@ -75,10 +76,12 @@ class DataManager: The language used to express data allocations, deletions, and host-device transfers. """ - def __init__(self, rcompile=None, sregistry=None, platform=None, **kwargs): + def __init__(self, rcompile=None, sregistry=None, platform=None, + compiler=None, **kwargs): self.rcompile = rcompile self.sregistry = sregistry self.platform = platform + self.compiler = compiler def _alloc_object_on_low_lat_mem(self, site, obj, storage): """ @@ -462,12 +465,18 @@ def place_casts(self, iet, **kwargs): return iet, {} + @iet_pass + def make_langtypes(self, iet): + iet, metadata = lower_complex(iet, self.lang, self.compiler) + return iet, metadata + def process(self, graph): """ Apply the `place_definitions` and `place_casts` passes. """ self.place_definitions(graph, globs=set()) self.place_casts(graph) + self.make_langtypes(graph) class DeviceAwareDataManager(DataManager): @@ -609,6 +618,7 @@ def process(self, graph): self.place_devptr(graph) self.place_bundling(graph, writes_input=graph.writes_input) self.place_casts(graph) + self.make_langtypes(graph) def make_zero_init(obj, rcompile, sregistry): diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py new file mode 100644 index 0000000000..912f707afd --- /dev/null +++ b/devito/passes/iet/dtypes.py @@ -0,0 +1,58 @@ +import numpy as np +import ctypes + +from devito.ir import FindSymbols, Uxreplace + +__all__ = ['lower_complex'] + + +def lower_complex(iet, lang, compiler): + """ + Add headers for complex arithmetic + """ + # Check if there is complex numbers that always take dtype precedence + types = {f.dtype for f in FindSymbols().visit(iet) + if not issubclass(f.dtype, ctypes._Pointer)} + + if not any(np.issubdtype(d, np.complexfloating) for d in types): + return iet, {} + + lib = (lang['header-complex'],) + + metadata = {} + if lang.get('complex-namespace') is not None: + metadata['namespaces'] = lang['complex-namespace'] + + # Some languges such as c++11 need some extra arithmetic definitions + if lang.get('def-complex'): + dest = compiler.get_jit_dir() + hfile = dest.joinpath('complex_arith.h') + with open(str(hfile), 'w') as ff: + ff.write(str(lang['def-complex'])) + lib += (str(hfile),) + + iet = _complex_dtypes(iet, lang) + metadata['includes'] = lib + print(metadata) + return iet, metadata + + +def _complex_dtypes(iet, lang): + """ + Lower dtypes to language specific types + """ + mapper = {} + + for s in FindSymbols('indexeds').visit(iet): + if s.dtype in lang['types']: + mapper[s] = s._rebuild(dtype=lang['types'][s.dtype]) + + for s in FindSymbols().visit(iet): + if s.dtype in lang['types']: + mapper[s] = s._rebuild(dtype=lang['types'][s.dtype]) + + body = Uxreplace(mapper).visit(iet.body) + params = Uxreplace(mapper).visit(iet.parameters) + iet = iet._rebuild(body=body, parameters=params) + + return iet diff --git a/devito/passes/iet/langbase.py b/devito/passes/iet/langbase.py index d27674c419..e34aa2dac3 100644 --- a/devito/passes/iet/langbase.py +++ b/devito/passes/iet/langbase.py @@ -31,6 +31,9 @@ def __getitem__(self, k): raise NotImplementedError("Missing required mapping for `%s`" % k) return self.mapper[k] + def get(self, k): + return self.mapper.get(k) + class LangBB(metaclass=LangMeta): @@ -200,6 +203,14 @@ def initialize(self, iet, options=None): """ return iet, {} + @iet_pass + def make_langtypes(self, iet): + """ + An `iet_pass` which transforms an IET such that the target language + types are introduced. + """ + return iet, {} + @property def Region(self): return self.lang.Region diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index 4b3358798d..bd5e0e6413 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -1,11 +1,18 @@ +import numpy as np + from devito.ir import Call from devito.passes.iet.definitions import DataManager from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.langbase import LangBB +from devito.tools import CustomNpType __all__ = ['CBB', 'CDataManager', 'COrchestrator'] +CCFloat = CustomNpType('_Complex float', np.complex64) +CCDouble = CustomNpType('_Complex double', np.complex128) + + class CBB(LangBB): mapper = { @@ -19,7 +26,10 @@ class CBB(LangBB): 'host-free-pin': lambda i: Call('free', (i,)), 'alloc-global-symbol': lambda i, j, k: - Call('memcpy', (i, j, k)) + Call('memcpy', (i, j, k)), + # Complex + 'header-complex': 'complex.h', + 'types': {np.complex128: CCDouble, np.complex64: CCFloat}, } diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py new file mode 100644 index 0000000000..9f833d630b --- /dev/null +++ b/devito/passes/iet/languages/CXX.py @@ -0,0 +1,69 @@ +import numpy as np + +from devito.ir import Call, UsingNamespace +from devito.passes.iet.langbase import LangBB +from devito.tools import CustomNpType + +__all__ = ['CXXBB'] + + +std_arith = """ +#include + +template +std::complex<_Tp> operator * (const _Ti & a, const std::complex<_Tp> & b){ + return std::complex<_Tp>(b.real() * a, b.imag() * a); +} + +template +std::complex<_Tp> operator * (const std::complex<_Tp> & b, const _Ti & a){ + return std::complex<_Tp>(b.real() * a, b.imag() * a); +} + +template +std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){ + _Tp denom = b.real() * b.real () + b.imag() * b.imag() + return std::complex<_Tp>(b.real() * a / denom, - b.imag() * a / denom); +} + +template +std::complex<_Tp> operator / (const std::complex<_Tp> & b, const _Ti & a){ + return std::complex<_Tp>(b.real() / a, b.imag() / a); +} + +template +std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){ + return std::complex<_Tp>(b.real() + a, b.imag()); +} + +template +std::complex<_Tp> operator + (const std::complex<_Tp> & b, const _Ti & a){ + return std::complex<_Tp>(b.real() + a, b.imag()); +} + +""" + +CXXCFloat = CustomNpType('std::complex', np.complex64, template='float') +CXXCDouble = CustomNpType('std::complex', np.complex128, template='double') + + +class CXXBB(LangBB): + + mapper = { + 'header-memcpy': 'string.h', + 'host-alloc': lambda i, j, k: + Call('posix_memalign', (i, j, k)), + 'host-alloc-pin': lambda i, j, k: + Call('posix_memalign', (i, j, k)), + 'host-free': lambda i: + Call('free', (i,)), + 'host-free-pin': lambda i: + Call('free', (i,)), + 'alloc-global-symbol': lambda i, j, k: + Call('memcpy', (i, j, k)), + # Complex + 'header-complex': 'complex', + 'complex-namespace': [UsingNamespace('std:complex_literals')], + 'def-complex': std_arith, + 'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat}, + } diff --git a/devito/passes/iet/languages/openacc.py b/devito/passes/iet/languages/openacc.py index bcd2c8d006..bcf5660ac7 100644 --- a/devito/passes/iet/languages/openacc.py +++ b/devito/passes/iet/languages/openacc.py @@ -9,7 +9,7 @@ from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.parpragma import (PragmaDeviceAwareTransformer, PragmaLangBB, PragmaIteration, PragmaTransfer) -from devito.passes.iet.languages.C import CBB +from devito.passes.iet.languages.CXX import CXXBB from devito.passes.iet.languages.openmp import OmpRegion, OmpIteration from devito.symbolics import FieldFromPointer, Macro, cast_mapper from devito.tools import filter_ordered, UnboundTuple @@ -122,7 +122,8 @@ class AccBB(PragmaLangBB): 'device-free': lambda i, *a: Call('acc_free', (i,)) } - mapper.update(CBB.mapper) + + mapper.update(CXXBB.mapper) Region = OmpRegion HostIteration = OmpIteration # Host parallelism still goes via OpenMP diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index e49b13de72..28e1cc4f7b 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -14,7 +14,7 @@ from devito.ir.iet.efunc import DeviceFunction, EntryFunction from devito.symbolics import (ValueLimit, evalrel, has_integer_args, limits_mapper, ccode) -from devito.tools import Bunch, as_mapper, filter_ordered, split, dtype_to_cstr +from devito.tools import Bunch, as_mapper, filter_ordered, split from devito.types import FIndexed __all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions', diff --git a/devito/symbolics/__init__.py b/devito/symbolics/__init__.py index 0f5c261471..9d7bee01b8 100644 --- a/devito/symbolics/__init__.py +++ b/devito/symbolics/__init__.py @@ -1,4 +1,5 @@ from devito.symbolics.extended_sympy import * # noqa +from devito.symbolics.extended_dtypes import * # noqa from devito.symbolics.queries import * # noqa from devito.symbolics.search import * # noqa from devito.symbolics.printer import * # noqa diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py new file mode 100644 index 0000000000..c558eb4e18 --- /dev/null +++ b/devito/symbolics/extended_dtypes.py @@ -0,0 +1,123 @@ +import numpy as np + +from devito.symbolics.extended_sympy import ReservedWord, Cast, CastStar, ValueLimit +from devito.tools import (Bunch, float2, float3, float4, double2, double3, double4, # noqa + int2, int3, int4) + +__all__ = ['cast_mapper', 'limits_mapper', 'INT', 'FLOAT', 'DOUBLE', 'VOID'] # noqa + + +limits_mapper = { + np.int32: Bunch(min=ValueLimit('INT_MIN'), max=ValueLimit('INT_MAX')), + np.int64: Bunch(min=ValueLimit('LONG_MIN'), max=ValueLimit('LONG_MAX')), + np.float32: Bunch(min=-ValueLimit('FLT_MAX'), max=ValueLimit('FLT_MAX')), + np.float64: Bunch(min=-ValueLimit('DBL_MAX'), max=ValueLimit('DBL_MAX')), +} + + +class CustomType(ReservedWord): + pass + + +# Dynamically create INT, INT2, .... INTP, INT2P, ... FLOAT, ... +for base_name in ['int', 'float', 'double']: + for i in ['', '2', '3', '4']: + v = '%s%s' % (base_name, i) + cls = type(v.upper(), (Cast,), {'_base_typ': v}) + globals()[cls.__name__] = cls + + clsp = type('%sP' % v.upper(), (CastStar,), {'base': cls}) + globals()[clsp.__name__] = clsp + + +class CHAR(Cast): + _base_typ = 'char' + + +class SHORT(Cast): + _base_typ = 'short' + + +class USHORT(Cast): + _base_typ = 'unsigned short' + + +class UCHAR(Cast): + _base_typ = 'unsigned char' + + +class LONG(Cast): + _base_typ = 'long' + + +class ULONG(Cast): + _base_typ = 'unsigned long' + + +class CFLOAT(Cast): + _base_typ = 'float' + + +class CDOUBLE(Cast): + _base_typ = 'double' + + +class VOID(Cast): + _base_typ = 'void' + + +class CHARP(CastStar): + base = CHAR + + +class UCHARP(CastStar): + base = UCHAR + + +class SHORTP(CastStar): + base = SHORT + + +class USHORTP(CastStar): + base = USHORT + + +class CFLOATP(CastStar): + base = CFLOAT + + +class CDOUBLEP(CastStar): + base = CDOUBLE + + +cast_mapper = { + np.int8: CHAR, + np.uint8: UCHAR, + np.int16: SHORT, # noqa + np.uint16: USHORT, # noqa + int: INT, # noqa + np.int32: INT, # noqa + np.int64: LONG, + np.uint64: ULONG, + np.float32: FLOAT, # noqa + float: DOUBLE, # noqa + np.float64: DOUBLE, # noqa + + (np.int8, '*'): CHARP, + (np.uint8, '*'): UCHARP, + (int, '*'): INTP, # noqa + (np.uint16, '*'): USHORTP, # noqa + (np.int16, '*'): SHORTP, # noqa + (np.int32, '*'): INTP, # noqa + (np.int64, '*'): INTP, # noqa + (np.float32, '*'): FLOATP, # noqa + (float, '*'): DOUBLEP, # noqa + (np.float64, '*'): DOUBLEP, # noqa +} + +for base_name in ['int', 'float', 'double']: + for i in [2, 3, 4]: + v = '%s%d' % (base_name, i) + cls = locals()[v] + cast_mapper[cls] = locals()[v.upper()] + cast_mapper[(cls, '*')] = locals()['%sP' % v.upper()] diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 72483e2004..b386a68a79 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -7,7 +7,6 @@ from sympy import Expr, Function, Number, Tuple, sympify from sympy.core.decorators import call_highest_priority -from devito import configuration from devito.finite_differences.elementary import Min, Max from devito.tools import (Pickable, Bunch, as_tuple, is_integer, float2, # noqa float3, float4, double2, double3, double4, int2, int3, @@ -20,8 +19,7 @@ 'ListInitializer', 'Byref', 'IndexedPointer', 'Cast', 'DefFunction', 'MathFunction', 'InlineIf', 'ReservedWord', 'Keyword', 'String', 'Macro', 'Class', 'MacroArgument', 'CustomType', 'Deref', 'Namespace', - 'Rvalue', 'INT', 'FLOAT', 'DOUBLE', 'VOID', 'Null', 'SizeOf', 'rfunc', - 'cast_mapper', 'BasicWrapperMixin', 'ValueLimit', 'limits_mapper'] + 'Rvalue', 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin', 'ValueLimit'] class CondEq(sympy.Eq): @@ -548,14 +546,6 @@ class ValueLimit(ReservedWord, sympy.Expr): pass -limits_mapper = { - np.int32: Bunch(min=ValueLimit('INT_MIN'), max=ValueLimit('INT_MAX')), - np.int64: Bunch(min=ValueLimit('LONG_MIN'), max=ValueLimit('LONG_MAX')), - np.float32: Bunch(min=-ValueLimit('FLT_MAX'), max=ValueLimit('FLT_MAX')), - np.float64: Bunch(min=-ValueLimit('DBL_MAX'), max=ValueLimit('DBL_MAX')), -} - - class DefFunction(Function, Pickable): """ @@ -773,128 +763,6 @@ def __new__(cls, base=''): return cls.base(base, '*') -# Dynamically create INT, INT2, .... INTP, INT2P, ... FLOAT, ... -for base_name in ['int', 'float', 'double']: - for i in ['', '2', '3', '4']: - v = '%s%s' % (base_name, i) - cls = type(v.upper(), (Cast,), {'_base_typ': v}) - globals()[cls.__name__] = cls - - clsp = type('%sP' % v.upper(), (CastStar,), {'base': cls}) - globals()[clsp.__name__] = clsp - - -class CHAR(Cast): - _base_typ = 'char' - - -class SHORT(Cast): - _base_typ = 'short' - - -class USHORT(Cast): - _base_typ = 'unsigned short' - - -class UCHAR(Cast): - _base_typ = 'unsigned char' - - -class UINT(Cast): - _base_typ = 'unsigned int' - - -class UINTP(CastStar): - base = UINT - - -class LONG(Cast): - _base_typ = 'long' - - -class ULONG(Cast): - _base_typ = 'unsigned long' - - -class VOID(Cast): - _base_typ = 'void' - - -class CFLOAT(Cast): - - @property - def _base_typ(self): - return configuration['compiler']._complex_ctype('float') - - -class CDOUBLE(Cast): - - @property - def _base_typ(self): - return configuration['compiler']._complex_ctype('double') - - -class CHARP(CastStar): - base = CHAR - - -class UCHARP(CastStar): - base = UCHAR - - -class SHORTP(CastStar): - base = SHORT - - -class USHORTP(CastStar): - base = USHORT - - -class CFLOATP(CastStar): - base = CFLOAT - - -class CDOUBLEP(CastStar): - base = CDOUBLE - - -cast_mapper = { - np.int8: CHAR, - np.uint8: UCHAR, - np.int16: SHORT, # noqa - np.uint16: USHORT, # noqa - int: INT, # noqa - np.int32: INT, # noqa - np.int64: LONG, - np.uint64: ULONG, - np.float32: FLOAT, # noqa - float: DOUBLE, # noqa - np.float64: DOUBLE, # noqa - np.complex64: CFLOAT, # noqa - np.complex128: CDOUBLE, # noqa - - (np.int8, '*'): CHARP, - (np.uint8, '*'): UCHARP, - (int, '*'): INTP, # noqa - (np.uint16, '*'): USHORTP, # noqa - (np.int16, '*'): SHORTP, # noqa - (np.int32, '*'): INTP, # noqa - (np.int64, '*'): INTP, # noqa - (np.float32, '*'): FLOATP, # noqa - (float, '*'): DOUBLEP, # noqa - (np.float64, '*'): DOUBLEP, # noqa - (np.complex64, '*'): CFLOATP, # noqa - (np.complex128, '*'): CDOUBLEP, # noqa -} - -for base_name in ['int', 'float', 'double']: - for i in [2, 3, 4]: - v = '%s%d' % (base_name, i) - cls = locals()[v] - cast_mapper[cls] = locals()[v.upper()] - cast_mapper[(cls, '*')] = locals()['%sP' % v.upper()] - - # Some other utility objects Null = Macro('NULL') diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 57aead89d2..165a3209be 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -8,7 +8,8 @@ from devito.finite_differences import Derivative from devito.finite_differences.differentiable import IndexDerivative from devito.logger import warning -from devito.symbolics.extended_sympy import (INT, CallFromPointer, Cast, +from devito.symbolics.extended_dtypes import INT +from devito.symbolics.extended_sympy import (CallFromPointer, Cast, DefFunction, ReservedWord) from devito.symbolics.queries import q_routine from devito.tools import as_tuple, prod diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 6e7f33bba2..486ce9bfa3 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -15,6 +15,7 @@ from sympy.printing.precedence import PRECEDENCE_VALUES, precedence from sympy.printing.c import C99CodePrinter +from devito import configuration from devito.arch.compiler import AOMPCompiler from devito.symbolics.inspection import has_integer_args, sympy_dtype from devito.types.basic import AbstractFunction @@ -41,13 +42,17 @@ def dtype(self): @property def compiler(self): - return self._settings['compiler'] + return self._settings['compiler'] or configuration['compiler'] def single_prec(self, expr=None): + if self.compiler._cpp and expr is not None: + return False dtype = sympy_dtype(expr) if expr is not None else self.dtype return dtype in [np.float32, np.float16, np.complex64] def complex_prec(self, expr=None): + if self.compiler._cpp: + return False dtype = sympy_dtype(expr) if expr is not None else self.dtype return np.issubdtype(dtype, np.complexfloating) @@ -246,7 +251,10 @@ def _print_Float(self, expr): return rv def _print_ImaginaryUnit(self, expr): - return '_Complex_I' + if self.compiler._cpp: + return '1i' + else: + return '_Complex_I' def _print_Differentiable(self, expr): return "(%s)" % self._print(expr._expr) diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index 117feaff0b..bfc568a8d8 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -13,7 +13,7 @@ 'double3', 'double4', 'dtypes_vector_mapper', 'dtype_to_mpidtype', 'dtype_to_cstr', 'dtype_to_ctype', 'dtype_to_mpitype', 'dtype_len', 'ctypes_to_cstr', 'c_restrict_void_p', 'ctypes_vector_mapper', - 'is_external_ctype', 'infer_dtype', 'CustomDtype'] + 'is_external_ctype', 'infer_dtype', 'CustomDtype', 'CustomNpType'] # *** Custom np.dtypes @@ -123,6 +123,18 @@ def __repr__(self): __str__ = __repr__ +class CustomNpType(CustomDtype): + """ + Custom dtype for underlying numpy type. + """ + + def __init__(self, name, nptype, template=None, modifier=None): + self.nptype = nptype + super().__init__(name, template, modifier) + + def __call__(self, val): + return self.nptype(val) + # *** np.dtypes lowering @@ -136,16 +148,6 @@ def dtype_to_ctype(dtype): if isinstance(dtype, CustomDtype): return dtype - # Complex data - if np.issubdtype(dtype, np.complexfloating): - rtype = dtype(0).real.__class__ - from devito import configuration - make = configuration['compiler']._complex_ctype - ctname = make(dtype_to_cstr(rtype)) - ctype = dtype_to_ctype(rtype) - r = type(ctname, (ctype,), {}) - return r - try: return ctypes_vector_mapper[dtype] except KeyError: diff --git a/devito/types/basic.py b/devito/types/basic.py index 4dcf1dad95..11c9e5c535 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -14,7 +14,8 @@ from devito.data import default_allocator from devito.parameters import configuration from devito.tools import (Pickable, as_tuple, ctypes_to_cstr, dtype_to_ctype, - frozendict, memoized_meth, sympy_mutex) + frozendict, memoized_meth, sympy_mutex, CustomDtype, + Reconstructable) from devito.types.args import ArgProvider from devito.types.caching import Cached, Uncached from devito.types.lazy import Evaluable @@ -84,6 +85,9 @@ def _C_typedata(self): The type of the object in the generated code as a `str`. """ _type = self._C_ctype + if isinstance(_type, CustomDtype): + return _type + while issubclass(_type, _Pointer): _type = _type._type_ @@ -877,6 +881,7 @@ def __new__(cls, *args, **kwargs): name = kwargs.get('name') alias = kwargs.get('alias') function = kwargs.get('function') + dtype = kwargs.get('dtype') if alias is True or (function and function.name != name): function = kwargs['function'] = None @@ -884,7 +889,8 @@ def __new__(cls, *args, **kwargs): # definitely a reconstruction if function is not None and \ function.name == name and \ - function.indices == indices: + function.indices == indices and \ + function.dtype == dtype: # Special case: a syntactically identical alias of `function`, so # let's just return `function` itself return function @@ -1234,7 +1240,8 @@ def bound_symbols(self): @cached_property def indexed(self): """The wrapped IndexedData object.""" - return IndexedData(self.name, shape=self._shape, function=self.function) + return IndexedData(self.name, shape=self._shape, function=self.function, + dtype=self.dtype) @cached_property def dmap(self): @@ -1516,13 +1523,14 @@ class IndexedBase(sympy.IndexedBase, Basic, Pickable): __rargs__ = ('label', 'shape') __rkwargs__ = ('function',) - def __new__(cls, label, shape, function=None): + def __new__(cls, label, shape, function=None, dtype=None): # Make sure `label` is a devito.Symbol, not a sympy.Symbol if isinstance(label, str): label = Symbol(name=label, dtype=None) with sympy_mutex: obj = sympy.IndexedBase.__new__(cls, label, shape) obj.function = function + obj._dtype = dtype or function.dtype return obj func = Pickable._rebuild @@ -1562,7 +1570,7 @@ def indices(self): @property def dtype(self): - return self.function.dtype + return self._dtype @cached_property def free_symbols(self): @@ -1624,7 +1632,7 @@ def _C_ctype(self): return self.function._C_ctype -class Indexed(sympy.Indexed): +class Indexed(sympy.Indexed, Reconstructable): # The two type flags have changed in upstream sympy as of version 1.1, # but the below interpretation is used throughout the compiler to @@ -1636,6 +1644,17 @@ class Indexed(sympy.Indexed): is_Dimension = False + __rargs__ = ('base', 'indices') + __rkwargs__ = ('dtype',) + + def __new__(cls, base, *indices, dtype=None, **kwargs): + if len(indices) == 1: + indices = as_tuple(indices[0]) + newobj = sympy.Indexed.__new__(cls, base, *indices) + newobj._dtype = dtype or base.dtype + + return newobj + @memoized_meth def __str__(self): return super().__str__() @@ -1657,7 +1676,7 @@ def function(self): @property def dtype(self): - return self.function.dtype + return self._dtype @property def name(self): diff --git a/devito/types/misc.py b/devito/types/misc.py index 29514bb99a..38beeaee53 100644 --- a/devito/types/misc.py +++ b/devito/types/misc.py @@ -83,7 +83,7 @@ class FIndexed(Indexed, Pickable): __rkwargs__ = ('strides_map', 'accessor') def __new__(cls, base, *args, strides_map=None, accessor=None): - obj = super().__new__(cls, base, *args) + obj = super().__new__(cls, base, args) obj.strides_map = frozendict(strides_map or {}) obj.accessor = accessor From 3369ef610e8aaf6fe2d882548d457dc560d12d99 Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 27 Jun 2024 07:59:59 -0400 Subject: [PATCH 10/58] compiler: switch to c++14 for complex_literals --- devito/passes/iet/dtypes.py | 2 +- devito/passes/iet/languages/CXX.py | 2 +- devito/symbolics/extended_dtypes.py | 2 +- devito/symbolics/extended_sympy.py | 6 +----- devito/symbolics/printer.py | 15 +++++++++++---- tests/test_gpu_common.py | 2 +- 6 files changed, 16 insertions(+), 13 deletions(-) diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index 912f707afd..1932b60f3a 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -33,7 +33,7 @@ def lower_complex(iet, lang, compiler): iet = _complex_dtypes(iet, lang) metadata['includes'] = lib - print(metadata) + return iet, metadata diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index 9f833d630b..5f74070472 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -63,7 +63,7 @@ class CXXBB(LangBB): Call('memcpy', (i, j, k)), # Complex 'header-complex': 'complex', - 'complex-namespace': [UsingNamespace('std:complex_literals')], + 'complex-namespace': [UsingNamespace('std::complex_literals')], 'def-complex': std_arith, 'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat}, } diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index c558eb4e18..0e8ce0cc98 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -4,7 +4,7 @@ from devito.tools import (Bunch, float2, float3, float4, double2, double3, double4, # noqa int2, int3, int4) -__all__ = ['cast_mapper', 'limits_mapper', 'INT', 'FLOAT', 'DOUBLE', 'VOID'] # noqa +__all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', 'DOUBLE', 'VOID'] # noqa limits_mapper = { diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index b386a68a79..19fcd83d4e 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -18,7 +18,7 @@ 'CallFromComposite', 'FieldFromPointer', 'FieldFromComposite', 'ListInitializer', 'Byref', 'IndexedPointer', 'Cast', 'DefFunction', 'MathFunction', 'InlineIf', 'ReservedWord', 'Keyword', 'String', - 'Macro', 'Class', 'MacroArgument', 'CustomType', 'Deref', 'Namespace', + 'Macro', 'Class', 'MacroArgument', 'Deref', 'Namespace', 'Rvalue', 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin', 'ValueLimit'] @@ -508,10 +508,6 @@ class Keyword(ReservedWord): pass -class CustomType(ReservedWord): - pass - - class String(ReservedWord): pass diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 486ce9bfa3..894530a219 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -38,14 +38,18 @@ class CodePrinter(C99CodePrinter): @property def dtype(self): - return self._settings['dtype'] + try: + return self._settings['dtype'].nptype + except AttributeError: + return self._settings['dtype'] @property def compiler(self): return self._settings['compiler'] or configuration['compiler'] - def single_prec(self, expr=None): - if self.compiler._cpp and expr is not None: + def single_prec(self, expr=None, with_f=False): + no_f = self.compiler._cpp and not with_f + if no_f and expr is not None: return False dtype = sympy_dtype(expr) if expr is not None else self.dtype return dtype in [np.float32, np.float16, np.complex64] @@ -252,7 +256,10 @@ def _print_Float(self, expr): def _print_ImaginaryUnit(self, expr): if self.compiler._cpp: - return '1i' + if self.single_prec(with_f=True): + return '1if' + else: + return '1i' else: return '_Complex_I' diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index bc0d02fc61..4a11c12556 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -92,7 +92,7 @@ def test_complex(self, dtype): xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5)) npres = xx + 1j*yy + np.exp(1j + dx) - assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0) + assert np.allclose(u.data, npres.T, rtol=1e-6, atol=0) class TestPassesOptional: From 0304046f019c18f8ae30d92ba907e2b5fd9c1e62 Mon Sep 17 00:00:00 2001 From: mloubout Date: Mon, 8 Jul 2024 12:47:53 -0400 Subject: [PATCH 11/58] compiler: subdtype numpy for dtype lowering --- devito/passes/iet/dtypes.py | 6 +----- devito/passes/iet/languages/C.py | 19 ++++++++++++++++--- devito/passes/iet/languages/CXX.py | 20 +++++++++++++++++--- devito/symbolics/printer.py | 2 +- devito/tools/dtypes_lowering.py | 14 +------------- 5 files changed, 36 insertions(+), 25 deletions(-) diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index 1932b60f3a..57eb10c4d8 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -43,11 +43,7 @@ def _complex_dtypes(iet, lang): """ mapper = {} - for s in FindSymbols('indexeds').visit(iet): - if s.dtype in lang['types']: - mapper[s] = s._rebuild(dtype=lang['types'][s.dtype]) - - for s in FindSymbols().visit(iet): + for s in FindSymbols('indexeds|basics|symbolics').visit(iet): if s.dtype in lang['types']: mapper[s] = s._rebuild(dtype=lang['types'][s.dtype]) diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index bd5e0e6413..2cee279428 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -1,16 +1,29 @@ +import ctypes as ct import numpy as np from devito.ir import Call from devito.passes.iet.definitions import DataManager from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.langbase import LangBB -from devito.tools import CustomNpType +from devito.tools.dtypes_lowering import ctypes_vector_mapper + __all__ = ['CBB', 'CDataManager', 'COrchestrator'] -CCFloat = CustomNpType('_Complex float', np.complex64) -CCDouble = CustomNpType('_Complex double', np.complex128) +class CCFloat(np.complex64): + pass + + +class CCDouble(np.complex128): + pass + + +c_complex = type('_Complex float', (ct.c_double,), {}) +c_double_complex = type('_Complex double', (ct.c_longdouble,), {}) + +ctypes_vector_mapper[CCFloat] = c_complex +ctypes_vector_mapper[CCDouble] = c_double_complex class CBB(LangBB): diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index 5f74070472..fb802acb8b 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -1,8 +1,9 @@ +import ctypes as ct import numpy as np from devito.ir import Call, UsingNamespace from devito.passes.iet.langbase import LangBB -from devito.tools import CustomNpType +from devito.tools.dtypes_lowering import ctypes_vector_mapper __all__ = ['CXXBB'] @@ -43,8 +44,21 @@ """ -CXXCFloat = CustomNpType('std::complex', np.complex64, template='float') -CXXCDouble = CustomNpType('std::complex', np.complex128, template='double') + +class CXXCFloat(np.complex64): + pass + + +class CXXCDouble(np.complex128): + pass + + +cxx_complex = type('std::complex', (ct.c_double,), {}) +cxx_double_complex = type('std::complex', (ct.c_longdouble,), {}) + + +ctypes_vector_mapper[CXXCFloat] = cxx_complex +ctypes_vector_mapper[CXXCDouble] = cxx_double_complex class CXXBB(LangBB): diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 894530a219..bae6c75fe0 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -52,7 +52,7 @@ def single_prec(self, expr=None, with_f=False): if no_f and expr is not None: return False dtype = sympy_dtype(expr) if expr is not None else self.dtype - return dtype in [np.float32, np.float16, np.complex64] + return any(issubclass(dtype, d) for d in [np.float32, np.float16, np.complex64]) def complex_prec(self, expr=None): if self.compiler._cpp: diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index bfc568a8d8..c502936f6c 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -13,7 +13,7 @@ 'double3', 'double4', 'dtypes_vector_mapper', 'dtype_to_mpidtype', 'dtype_to_cstr', 'dtype_to_ctype', 'dtype_to_mpitype', 'dtype_len', 'ctypes_to_cstr', 'c_restrict_void_p', 'ctypes_vector_mapper', - 'is_external_ctype', 'infer_dtype', 'CustomDtype', 'CustomNpType'] + 'is_external_ctype', 'infer_dtype', 'CustomDtype'] # *** Custom np.dtypes @@ -123,18 +123,6 @@ def __repr__(self): __str__ = __repr__ -class CustomNpType(CustomDtype): - """ - Custom dtype for underlying numpy type. - """ - - def __init__(self, name, nptype, template=None, modifier=None): - self.nptype = nptype - super().__init__(name, template, modifier) - - def __call__(self, val): - return self.nptype(val) - # *** np.dtypes lowering From d3dce3af433dff3e6aeb0107a83b002c70f9df79 Mon Sep 17 00:00:00 2001 From: enwask Date: Tue, 9 Jul 2024 19:23:37 +0100 Subject: [PATCH 12/58] compiler: use structs to pass complex arguments --- devito/ir/iet/visitors.py | 3 ++- devito/passes/iet/languages/C.py | 10 +++++----- devito/passes/iet/languages/CXX.py | 5 +++-- devito/symbolics/extended_dtypes.py | 26 +++++++++++++++++++++++++- 4 files changed, 35 insertions(+), 9 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index a288dd065a..7578ae5287 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -20,6 +20,7 @@ from devito.ir.support.space import Backward from devito.symbolics import (FieldFromComposite, FieldFromPointer, ListInitializer, ccode, uxreplace) +from devito.symbolics.extended_dtypes import NoDeclStruct from devito.tools import (GenericVisitor, as_tuple, ctypes_to_cstr, filter_ordered, filter_sorted, flatten, is_external_ctype, c_restrict_void_p, sorted_priority) @@ -208,7 +209,7 @@ def _gen_struct_decl(self, obj, masked=()): while issubclass(ctype, ctypes._Pointer): ctype = ctype._type_ - if not issubclass(ctype, ctypes.Structure): + if not issubclass(ctype, ctypes.Structure) or issubclass(ctype, NoDeclStruct): return None except TypeError: # E.g., `ctype` is of type `dtypes_lowering.CustomDtype` diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index 2cee279428..1c233b0ff8 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -1,10 +1,10 @@ -import ctypes as ct import numpy as np from devito.ir import Call from devito.passes.iet.definitions import DataManager from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.langbase import LangBB +from devito.symbolics.extended_dtypes import c_complex, c_double_complex from devito.tools.dtypes_lowering import ctypes_vector_mapper @@ -19,11 +19,11 @@ class CCDouble(np.complex128): pass -c_complex = type('_Complex float', (ct.c_double,), {}) -c_double_complex = type('_Complex double', (ct.c_longdouble,), {}) +c99_complex = type('_Complex float', (c_complex,), {}) +c99_double_complex = type('_Complex double', (c_double_complex,), {}) -ctypes_vector_mapper[CCFloat] = c_complex -ctypes_vector_mapper[CCDouble] = c_double_complex +ctypes_vector_mapper[CCFloat] = c99_complex +ctypes_vector_mapper[CCDouble] = c99_double_complex class CBB(LangBB): diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index fb802acb8b..88ed923640 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -3,6 +3,7 @@ from devito.ir import Call, UsingNamespace from devito.passes.iet.langbase import LangBB +from devito.symbolics.extended_dtypes import c_complex, c_double_complex from devito.tools.dtypes_lowering import ctypes_vector_mapper __all__ = ['CXXBB'] @@ -53,9 +54,9 @@ class CXXCDouble(np.complex128): pass -cxx_complex = type('std::complex', (ct.c_double,), {}) -cxx_double_complex = type('std::complex', (ct.c_longdouble,), {}) +cxx_complex = type('std::complex', (c_complex,), {}) +cxx_double_complex = type('std::complex', (c_double_complex,), {}) ctypes_vector_mapper[CXXCFloat] = cxx_complex ctypes_vector_mapper[CXXCDouble] = cxx_double_complex diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index 0e8ce0cc98..d63ca92bf5 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -1,10 +1,11 @@ +import ctypes as ct import numpy as np from devito.symbolics.extended_sympy import ReservedWord, Cast, CastStar, ValueLimit from devito.tools import (Bunch, float2, float3, float4, double2, double3, double4, # noqa int2, int3, int4) -__all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', 'DOUBLE', 'VOID'] # noqa +__all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', 'DOUBLE', 'VOID', 'c_complex', 'c_double_complex'] # noqa limits_mapper = { @@ -15,6 +16,29 @@ } +class NoDeclStruct(ct.Structure): + # ctypes.Structure that does not generate a struct definition + pass + + +class c_complex(NoDeclStruct): + # Structure for passing complex float to C/C++ + _fields_ = [('real', ct.c_float), ('imag', ct.c_float)] + + @classmethod + def from_param(cls, val): + return cls(val.real, val.imag) + + +class c_double_complex(NoDeclStruct): + # Structure for passing complex double to C/C++ + _fields_ = [('real', ct.c_double), ('imag', ct.c_double)] + + @classmethod + def from_param(cls, val): + return cls(val.real, val.imag) + + class CustomType(ReservedWord): pass From f9d92dfda32cb98cc1d4339c29ca4f0c1a21103a Mon Sep 17 00:00:00 2001 From: enwask Date: Thu, 11 Jul 2024 13:05:11 +0100 Subject: [PATCH 13/58] compiler: add Dereference scalar case --- devito/ir/iet/nodes.py | 23 +++++++++++++++++------ devito/ir/iet/visitors.py | 3 +++ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index 57becab5ff..09be89de75 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -1,6 +1,7 @@ """The Iteration/Expression Tree (IET) hierarchy.""" import abc +import ctypes import inspect from functools import cached_property from collections import OrderedDict, namedtuple @@ -1038,6 +1039,9 @@ class Dereference(ExprStmt, Node): * `pointer` is a PointerArray or TempFunction, and `pointee` is an Array. * `pointer` is an ArrayObject representing a pointer to a C struct, and `pointee` is a field in `pointer`. + * `pointer` is a Symbol with its _C_ctype deriving from ct._Pointer, and + `pointee` is a Symbol representing the dereferenced value. + """ is_Dereference = True @@ -1056,13 +1060,20 @@ def functions(self): @property def expr_symbols(self): - ret = [self.pointer.indexed] - if self.pointer.is_PointerArray or self.pointer.is_TempFunction: - ret.append(self.pointee.indexed) - ret.extend(flatten(i.free_symbols for i in self.pointee.symbolic_shape[1:])) - ret.extend(self.pointer.free_symbols) - else: + ret = [] + if self.pointer.is_Symbol: + assert (issubclass(self.pointer._C_ctype, ctypes._Pointer), + "Scalar dereference must have a pointer ctype") + ret.append(self.pointer._C_symbol) ret.append(self.pointee._C_symbol) + else: + ret.append(self.pointer.indexed) + if self.pointer.is_PointerArray or self.pointer.is_TempFunction: + ret.append(self.pointee.indexed) + ret.extend(flatten(i.free_symbols for i in self.pointee.symbolic_shape[1:])) + ret.extend(self.pointer.free_symbols) + else: + ret.append(self.pointee._C_symbol) return tuple(filter_ordered(ret)) @property diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 7578ae5287..d331989c93 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -454,6 +454,9 @@ def visit_Dereference(self, o): lvalue = c.Value(cstr, '*restrict %s' % a0.name) if a0._data_alignment: lvalue = c.AlignedAttribute(a0._data_alignment, lvalue) + elif a1.is_Symbol: + rvalue = '*%s' % a1.name + lvalue = self._gen_value(a0, 0) else: rvalue = '%s->%s' % (a1.name, a0._C_name) lvalue = self._gen_value(a0, 0) From 95a0a282f5581870a683a904cbbf1f10dcf25f27 Mon Sep 17 00:00:00 2001 From: enwask Date: Thu, 11 Jul 2024 13:51:55 +0100 Subject: [PATCH 14/58] compiler: implement float16 support --- devito/ir/iet/visitors.py | 4 +- devito/passes/iet/definitions.py | 3 +- devito/passes/iet/dtypes.py | 70 +++++++++++++++++++++-------- devito/passes/iet/languages/C.py | 18 ++++++-- devito/passes/iet/languages/CXX.py | 17 +++++-- devito/symbolics/extended_dtypes.py | 31 +++++++++++-- devito/tools/dtypes_lowering.py | 3 ++ 7 files changed, 116 insertions(+), 30 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index d331989c93..4bae7af755 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -24,7 +24,7 @@ from devito.tools import (GenericVisitor, as_tuple, ctypes_to_cstr, filter_ordered, filter_sorted, flatten, is_external_ctype, c_restrict_void_p, sorted_priority) -from devito.types.basic import AbstractFunction, Basic +from devito.types.basic import AbstractFunction, AbstractSymbol, Basic from devito.types import (ArrayObject, CompositeObject, Dimension, Pointer, IndexedData, DeviceMap) @@ -961,6 +961,7 @@ def default_retval(cls): Drive the search. Accepted: - `symbolics`: Collect all AbstractFunction objects, default - `basics`: Collect all Basic objects + - `scalars`: Collect all AbstractSymbol objects - `dimensions`: Collect all Dimensions - `indexeds`: Collect all Indexed objects - `indexedbases`: Collect all IndexedBase objects @@ -981,6 +982,7 @@ def _defines_aliases(n): rules = { 'symbolics': lambda n: n.functions, 'basics': lambda n: [i for i in n.expr_symbols if isinstance(i, Basic)], + 'scalars': lambda n: [i for i in n.expr_symbols if isinstance(i, AbstractSymbol)], 'dimensions': lambda n: [i for i in n.expr_symbols if isinstance(i, Dimension)], 'indexeds': lambda n: [i for i in n.expr_symbols if i.is_Indexed], 'indexedbases': lambda n: [i for i in n.expr_symbols diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index b4f4aca101..9f79791c67 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -13,7 +13,7 @@ FindNodes, FindSymbols, MapExprStmts, Transformer, make_callable) from devito.passes import is_gpu_create -from devito.passes.iet.dtypes import lower_complex +from devito.passes.iet.dtypes import lower_complex, lower_scalar_half from devito.passes.iet.engine import iet_pass from devito.passes.iet.langbase import LangBB from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer, @@ -467,6 +467,7 @@ def place_casts(self, iet, **kwargs): @iet_pass def make_langtypes(self, iet): + iet, _ = lower_scalar_half(iet, self.lang, self.sregistry) iet, metadata = lower_complex(iet, self.lang, self.compiler) return iet, metadata diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index 57eb10c4d8..6f35883423 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -2,8 +2,44 @@ import ctypes from devito.ir import FindSymbols, Uxreplace +from devito.ir.iet.nodes import Dereference +from devito.tools.utils import as_tuple +from devito.types.basic import Symbol -__all__ = ['lower_complex'] +__all__ = ['lower_scalar_half', 'lower_complex'] + + +def lower_scalar_half(iet, lang, sregistry): + """ + Lower half float scalars to pointers (special case, since we can't + pass them directly for lack of a ctypes equivalent) + """ + if lang.get('half_types') is None: + return iet, {} + + # dtype mappings for float16 + half, half_p = lang['half_types'] + + body = [] # derefs to prepend to the body + body_mapper = {} + params_mapper = {} + + for s in FindSymbols('scalars').visit(iet): + if s.dtype != np.float16 or s not in iet.parameters: + continue + + ptr = s._rebuild(dtype=half_p) + val = Symbol(name=sregistry.make_name(prefix='hf'), dtype=half, is_const=True) + + params_mapper[s] = ptr + body_mapper[s] = val + body.append(Dereference(val, ptr)) # val = *ptr + + body.extend(as_tuple(Uxreplace(body_mapper).visit(iet.body))) + params = Uxreplace(params_mapper).visit(iet.parameters) + + iet = iet._rebuild(body=body, parameters=params) + return iet, {} def lower_complex(iet, lang, compiler): @@ -14,30 +50,28 @@ def lower_complex(iet, lang, compiler): types = {f.dtype for f in FindSymbols().visit(iet) if not issubclass(f.dtype, ctypes._Pointer)} - if not any(np.issubdtype(d, np.complexfloating) for d in types): - return iet, {} - - lib = (lang['header-complex'],) - metadata = {} - if lang.get('complex-namespace') is not None: - metadata['namespaces'] = lang['complex-namespace'] + if any(np.issubdtype(d, np.complexfloating) for d in types): + lib = (lang['header-complex'],) + + if lang.get('complex-namespace') is not None: + metadata['namespaces'] = lang['complex-namespace'] - # Some languges such as c++11 need some extra arithmetic definitions - if lang.get('def-complex'): - dest = compiler.get_jit_dir() - hfile = dest.joinpath('complex_arith.h') - with open(str(hfile), 'w') as ff: - ff.write(str(lang['def-complex'])) - lib += (str(hfile),) + # Some languges such as c++11 need some extra arithmetic definitions + if lang.get('def-complex'): + dest = compiler.get_jit_dir() + hfile = dest.joinpath('complex_arith.h') + with open(str(hfile), 'w') as ff: + ff.write(str(lang['def-complex'])) + lib += (str(hfile),) - iet = _complex_dtypes(iet, lang) - metadata['includes'] = lib + metadata['includes'] = lib + iet = _lower_dtypes(iet, lang) return iet, metadata -def _complex_dtypes(iet, lang): +def _lower_dtypes(iet, lang): """ Lower dtypes to language specific types """ diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index 1c233b0ff8..57e7864c11 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -1,10 +1,11 @@ +from ctypes import c_float import numpy as np from devito.ir import Call from devito.passes.iet.definitions import DataManager from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.langbase import LangBB -from devito.symbolics.extended_dtypes import c_complex, c_double_complex +from devito.symbolics.extended_dtypes import c_complex, c_double_complex, c_float16, c_float16_p from devito.tools.dtypes_lowering import ctypes_vector_mapper @@ -19,11 +20,21 @@ class CCDouble(np.complex128): pass +class CHalf(np.float16): + pass + + +class CHalfP(np.float16): + pass + + c99_complex = type('_Complex float', (c_complex,), {}) c99_double_complex = type('_Complex double', (c_double_complex,), {}) ctypes_vector_mapper[CCFloat] = c99_complex ctypes_vector_mapper[CCDouble] = c99_double_complex +ctypes_vector_mapper[CHalf] = c_float16 +ctypes_vector_mapper[CHalfP] = c_float16_p class CBB(LangBB): @@ -40,9 +51,10 @@ class CBB(LangBB): Call('free', (i,)), 'alloc-global-symbol': lambda i, j, k: Call('memcpy', (i, j, k)), - # Complex + # Complex and float16 'header-complex': 'complex.h', - 'types': {np.complex128: CCDouble, np.complex64: CCFloat}, + 'types': {np.complex128: CCDouble, np.complex64: CCFloat, np.float16: CHalf}, + 'half_types': (CHalf, CHalfP), } diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index 88ed923640..c207b793c2 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -1,9 +1,8 @@ -import ctypes as ct import numpy as np from devito.ir import Call, UsingNamespace from devito.passes.iet.langbase import LangBB -from devito.symbolics.extended_dtypes import c_complex, c_double_complex +from devito.symbolics.extended_dtypes import c_complex, c_double_complex, c_float16, c_float16_p from devito.tools.dtypes_lowering import ctypes_vector_mapper __all__ = ['CXXBB'] @@ -54,12 +53,21 @@ class CXXCDouble(np.complex128): pass +class CXXHalf(np.float16): + pass + + +class CXXHalfP(np.float16): + pass + cxx_complex = type('std::complex', (c_complex,), {}) cxx_double_complex = type('std::complex', (c_double_complex,), {}) ctypes_vector_mapper[CXXCFloat] = cxx_complex ctypes_vector_mapper[CXXCDouble] = cxx_double_complex +ctypes_vector_mapper[CXXHalf] = c_float16 +ctypes_vector_mapper[CXXHalfP] = c_float16_p class CXXBB(LangBB): @@ -76,9 +84,10 @@ class CXXBB(LangBB): Call('free', (i,)), 'alloc-global-symbol': lambda i, j, k: Call('memcpy', (i, j, k)), - # Complex + # Complex and float16 'header-complex': 'complex', 'complex-namespace': [UsingNamespace('std::complex_literals')], 'def-complex': std_arith, - 'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat}, + 'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat, np.float16: CXXHalf}, + 'half_types': (CXXHalf, CXXHalfP), } diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index d63ca92bf5..e7bb595ed9 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -5,7 +5,9 @@ from devito.tools import (Bunch, float2, float3, float4, double2, double3, double4, # noqa int2, int3, int4) -__all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', 'DOUBLE', 'VOID', 'c_complex', 'c_double_complex'] # noqa +__all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', + 'DOUBLE', 'VOID', 'NoDeclStruct', 'c_complex', 'c_double_complex', + 'c_float16', 'c_float16_p'] limits_mapper = { @@ -28,7 +30,7 @@ class c_complex(NoDeclStruct): @classmethod def from_param(cls, val): return cls(val.real, val.imag) - + class c_double_complex(NoDeclStruct): # Structure for passing complex double to C/C++ @@ -37,7 +39,30 @@ class c_double_complex(NoDeclStruct): @classmethod def from_param(cls, val): return cls(val.real, val.imag) - + + +class _c_half(ct.c_uint16): + # Ctype for non-scalar half floats + @classmethod + def from_param(cls, val): + return cls(np.float16(val).view(np.uint16)) + + +c_float16 = type('_Float16', (_c_half,), {}) + + +class _c_half_p(ct.POINTER(c_float16)): + # Ctype for half scalars; we can't directly pass _Float16 values so + # we use a pointer and dereference (see `passes.iet.dtypes`) + @classmethod + def from_param(cls, val): + arr = np.array(val, dtype=np.float16) + return arr.ctypes.data_as(cls) + + +# ctypes directly parses class dict; can't inherit the _type_ attribute +c_float16_p = type('_Float16 *', (_c_half_p,), {'_type_': c_float16}) + class CustomType(ReservedWord): pass diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index c502936f6c..37acaa1dca 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -147,6 +147,9 @@ def dtype_to_ctype(dtype): # Bypass np.ctypeslib's normalization rules such as # `np.ctypeslib.as_ctypes_type(ctypes.c_void_p) -> ctypes.c_ulong` return dtype + elif dtype == np.float16: + # Allocator wants a ctype before float16 gets mapped + return ctypes.c_uint16 else: return np.ctypeslib.as_ctypes_type(dtype) From 6e747dbb943d97e9843d045acac4c861c39b8c33 Mon Sep 17 00:00:00 2001 From: enwask Date: Thu, 11 Jul 2024 14:26:08 +0100 Subject: [PATCH 15/58] symbolics: fix printer for half precision --- devito/symbolics/printer.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index bae6c75fe0..d21ff780c9 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -52,7 +52,14 @@ def single_prec(self, expr=None, with_f=False): if no_f and expr is not None: return False dtype = sympy_dtype(expr) if expr is not None else self.dtype - return any(issubclass(dtype, d) for d in [np.float32, np.float16, np.complex64]) + return any(issubclass(dtype, d) for d in [np.float32, np.complex64]) + + def half_prec(self, expr=None, with_f=False): + no_f = self.compiler._cpp and not with_f + if no_f and expr is not None: + return False + dtype = sympy_dtype(expr) if expr is not None else self.dtype + return issubclass(dtype, np.float16) def complex_prec(self, expr=None): if self.compiler._cpp: @@ -126,7 +133,7 @@ def _print_math_func(self, expr, nest=False, known=None): if cname not in self._prec_funcs: return super()._print_math_func(expr, nest=nest, known=known) - if self.single_prec(expr): + if self.single_prec(expr) or self.half_prec(expr): cname = '%sf' % cname if self.complex_prec(expr): cname = 'c%s' % cname @@ -251,6 +258,8 @@ def _print_Float(self, expr): if self.single_prec(): rv = '%sF' % rv + elif self.half_prec(): + rv = '%sF16' % rv return rv @@ -258,6 +267,8 @@ def _print_ImaginaryUnit(self, expr): if self.compiler._cpp: if self.single_prec(with_f=True): return '1if' + elif self.half_prec(with_f=True): + return '1if16' else: return '1i' else: @@ -315,7 +326,7 @@ def _print_ComponentAccess(self, expr): def _print_TrigonometricFunction(self, expr): func_name = str(expr.func) - if self.single_prec(): + if self.single_prec() or self.half_prec(): func_name = '%sf' % func_name if self.complex_prec(): func_name = 'c%s' % func_name From ae7fb7584c7d33a6924a7d2313ee009a3d4787ad Mon Sep 17 00:00:00 2001 From: enwask Date: Thu, 11 Jul 2024 14:36:42 +0100 Subject: [PATCH 16/58] misc: fix formatting --- devito/ir/iet/nodes.py | 8 ++++---- devito/passes/iet/dtypes.py | 2 +- devito/passes/iet/languages/C.py | 4 ++-- devito/passes/iet/languages/CXX.py | 6 ++++-- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index 09be89de75..1270afff81 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -1041,7 +1041,6 @@ class Dereference(ExprStmt, Node): `pointee` is a field in `pointer`. * `pointer` is a Symbol with its _C_ctype deriving from ct._Pointer, and `pointee` is a Symbol representing the dereferenced value. - """ is_Dereference = True @@ -1062,15 +1061,16 @@ def functions(self): def expr_symbols(self): ret = [] if self.pointer.is_Symbol: - assert (issubclass(self.pointer._C_ctype, ctypes._Pointer), - "Scalar dereference must have a pointer ctype") + assert issubclass(self.pointer._C_ctype, ctypes._Pointer), \ + "Scalar dereference must have a pointer ctype" ret.append(self.pointer._C_symbol) ret.append(self.pointee._C_symbol) else: ret.append(self.pointer.indexed) if self.pointer.is_PointerArray or self.pointer.is_TempFunction: ret.append(self.pointee.indexed) - ret.extend(flatten(i.free_symbols for i in self.pointee.symbolic_shape[1:])) + ret.extend(flatten(i.free_symbols + for i in self.pointee.symbolic_shape[1:])) ret.extend(self.pointer.free_symbols) else: ret.append(self.pointee._C_symbol) diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index 6f35883423..1fe98edfd8 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -16,7 +16,7 @@ def lower_scalar_half(iet, lang, sregistry): """ if lang.get('half_types') is None: return iet, {} - + # dtype mappings for float16 half, half_p = lang['half_types'] diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index 57e7864c11..572d4a86cd 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -1,11 +1,11 @@ -from ctypes import c_float import numpy as np from devito.ir import Call from devito.passes.iet.definitions import DataManager from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.langbase import LangBB -from devito.symbolics.extended_dtypes import c_complex, c_double_complex, c_float16, c_float16_p +from devito.symbolics.extended_dtypes import (c_complex, c_double_complex, + c_float16, c_float16_p) from devito.tools.dtypes_lowering import ctypes_vector_mapper diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index c207b793c2..30e5ab689a 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -2,7 +2,8 @@ from devito.ir import Call, UsingNamespace from devito.passes.iet.langbase import LangBB -from devito.symbolics.extended_dtypes import c_complex, c_double_complex, c_float16, c_float16_p +from devito.symbolics.extended_dtypes import (c_complex, c_double_complex, + c_float16, c_float16_p) from devito.tools.dtypes_lowering import ctypes_vector_mapper __all__ = ['CXXBB'] @@ -88,6 +89,7 @@ class CXXBB(LangBB): 'header-complex': 'complex', 'complex-namespace': [UsingNamespace('std::complex_literals')], 'def-complex': std_arith, - 'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat, np.float16: CXXHalf}, + 'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat, + np.float16: CXXHalf}, 'half_types': (CXXHalf, CXXHalfP), } From fc102191f453b28eecdd6fd253d727ff3f18e8b5 Mon Sep 17 00:00:00 2001 From: enwask Date: Thu, 11 Jul 2024 17:57:50 +0100 Subject: [PATCH 17/58] compiler: refactor float16 and lower_dtypes --- devito/passes/iet/definitions.py | 5 +- devito/passes/iet/dtypes.py | 96 +++++++++++++---------------- devito/passes/iet/languages/C.py | 7 ++- devito/passes/iet/languages/CXX.py | 4 +- devito/symbolics/extended_dtypes.py | 13 +--- devito/symbolics/printer.py | 4 +- 6 files changed, 56 insertions(+), 73 deletions(-) diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index 9f79791c67..6dd6ec0d97 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -13,7 +13,7 @@ FindNodes, FindSymbols, MapExprStmts, Transformer, make_callable) from devito.passes import is_gpu_create -from devito.passes.iet.dtypes import lower_complex, lower_scalar_half +from devito.passes.iet.dtypes import lower_dtypes from devito.passes.iet.engine import iet_pass from devito.passes.iet.langbase import LangBB from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer, @@ -467,8 +467,7 @@ def place_casts(self, iet, **kwargs): @iet_pass def make_langtypes(self, iet): - iet, _ = lower_scalar_half(iet, self.lang, self.sregistry) - iet, metadata = lower_complex(iet, self.lang, self.compiler) + iet, metadata = lower_dtypes(iet, self.lang, self.compiler, self.sregistry) return iet, metadata def process(self, graph): diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index 1fe98edfd8..a2d0899224 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -3,46 +3,52 @@ from devito.ir import FindSymbols, Uxreplace from devito.ir.iet.nodes import Dereference -from devito.tools.utils import as_tuple +from devito.tools.utils import as_list from devito.types.basic import Symbol -__all__ = ['lower_scalar_half', 'lower_complex'] +__all__ = ['lower_dtypes'] -def lower_scalar_half(iet, lang, sregistry): +def lower_dtypes(iet, lang, compiler, sregistry): """ - Lower half float scalars to pointers (special case, since we can't - pass them directly for lack of a ctypes equivalent) + Lower language-specific dtypes and add headers for complex arithmetic """ - if lang.get('half_types') is None: - return iet, {} + # Include complex headers if needed (before we replace complex dtypes) + metadata = _complex_includes(iet, lang, compiler) - # dtype mappings for float16 - half, half_p = lang['half_types'] - - body = [] # derefs to prepend to the body + body_prefix = [] # Derefs to prepend to the body body_mapper = {} params_mapper = {} - for s in FindSymbols('scalars').visit(iet): - if s.dtype != np.float16 or s not in iet.parameters: - continue + # Lower scalar float16s to pointers and dereference them + if lang.get('half_types') is not None: + half, half_p = lang['half_types'] # dtype mappings for half float + + for s in FindSymbols('scalars').visit(iet): + if s.dtype != np.float16 or s not in iet.parameters: + continue - ptr = s._rebuild(dtype=half_p) - val = Symbol(name=sregistry.make_name(prefix='hf'), dtype=half, is_const=True) + ptr = s._rebuild(dtype=half_p, is_const=True) + val = Symbol(name=sregistry.make_name(prefix='hf'), dtype=half, + is_const=s.is_const) - params_mapper[s] = ptr - body_mapper[s] = val - body.append(Dereference(val, ptr)) # val = *ptr + params_mapper[s], body_mapper[s] = ptr, val + body_prefix.append(Dereference(val, ptr)) # val = *ptr + + # Lower remaining language-specific dtypes + for s in FindSymbols('indexeds|basics|symbolics').visit(iet): + if s.dtype in lang['types'] and s not in params_mapper: + body_mapper[s] = params_mapper[s] = s._rebuild(dtype=lang['types'][s.dtype]) - body.extend(as_tuple(Uxreplace(body_mapper).visit(iet.body))) + # Apply the dtype replacements + body = body_prefix + as_list(Uxreplace(body_mapper).visit(iet.body)) params = Uxreplace(params_mapper).visit(iet.parameters) iet = iet._rebuild(body=body, parameters=params) - return iet, {} + return iet, metadata -def lower_complex(iet, lang, compiler): +def _complex_includes(iet, lang, compiler): """ Add headers for complex arithmetic """ @@ -50,39 +56,23 @@ def lower_complex(iet, lang, compiler): types = {f.dtype for f in FindSymbols().visit(iet) if not issubclass(f.dtype, ctypes._Pointer)} - metadata = {} - if any(np.issubdtype(d, np.complexfloating) for d in types): - lib = (lang['header-complex'],) - - if lang.get('complex-namespace') is not None: - metadata['namespaces'] = lang['complex-namespace'] - - # Some languges such as c++11 need some extra arithmetic definitions - if lang.get('def-complex'): - dest = compiler.get_jit_dir() - hfile = dest.joinpath('complex_arith.h') - with open(str(hfile), 'w') as ff: - ff.write(str(lang['def-complex'])) - lib += (str(hfile),) - - metadata['includes'] = lib - - iet = _lower_dtypes(iet, lang) - return iet, metadata + if not any(np.issubdtype(d, np.complexfloating) for d in types): + return {} + metadata = {} + lib = (lang['header-complex'],) -def _lower_dtypes(iet, lang): - """ - Lower dtypes to language specific types - """ - mapper = {} + if lang.get('complex-namespace') is not None: + metadata['namespaces'] = lang['complex-namespace'] - for s in FindSymbols('indexeds|basics|symbolics').visit(iet): - if s.dtype in lang['types']: - mapper[s] = s._rebuild(dtype=lang['types'][s.dtype]) + # Some languges such as c++11 need some extra arithmetic definitions + if lang.get('def-complex'): + dest = compiler.get_jit_dir() + hfile = dest.joinpath('complex_arith.h') + with open(str(hfile), 'w') as ff: + ff.write(str(lang['def-complex'])) + lib += (str(hfile),) - body = Uxreplace(mapper).visit(iet.body) - params = Uxreplace(mapper).visit(iet.parameters) - iet = iet._rebuild(body=body, parameters=params) + metadata['includes'] = lib - return iet + return metadata diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index 572d4a86cd..6112c3e895 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -5,11 +5,11 @@ from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.langbase import LangBB from devito.symbolics.extended_dtypes import (c_complex, c_double_complex, - c_float16, c_float16_p) + c_half, c_half_p) from devito.tools.dtypes_lowering import ctypes_vector_mapper -__all__ = ['CBB', 'CDataManager', 'COrchestrator'] +__all__ = ['CBB', 'CDataManager', 'COrchestrator', 'c_float16', 'c_float16_p'] class CCFloat(np.complex64): @@ -31,6 +31,9 @@ class CHalfP(np.float16): c99_complex = type('_Complex float', (c_complex,), {}) c99_double_complex = type('_Complex double', (c_double_complex,), {}) +c_float16 = type('_Float16', (c_half,), {}) +c_float16_p = type('_Float16 *', (c_half_p,), {'_type_': c_float16}) + ctypes_vector_mapper[CCFloat] = c99_complex ctypes_vector_mapper[CCDouble] = c99_double_complex ctypes_vector_mapper[CHalf] = c_float16 diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index 30e5ab689a..48fdb4471b 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -2,8 +2,8 @@ from devito.ir import Call, UsingNamespace from devito.passes.iet.langbase import LangBB -from devito.symbolics.extended_dtypes import (c_complex, c_double_complex, - c_float16, c_float16_p) +from devito.passes.iet.languages.C import c_float16, c_float16_p +from devito.symbolics.extended_dtypes import c_complex, c_double_complex from devito.tools.dtypes_lowering import ctypes_vector_mapper __all__ = ['CXXBB'] diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index e7bb595ed9..f6265e4938 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -7,7 +7,7 @@ __all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', 'DOUBLE', 'VOID', 'NoDeclStruct', 'c_complex', 'c_double_complex', - 'c_float16', 'c_float16_p'] + 'c_half', 'c_half_p'] limits_mapper = { @@ -41,17 +41,14 @@ def from_param(cls, val): return cls(val.real, val.imag) -class _c_half(ct.c_uint16): +class c_half(ct.c_uint16): # Ctype for non-scalar half floats @classmethod def from_param(cls, val): return cls(np.float16(val).view(np.uint16)) -c_float16 = type('_Float16', (_c_half,), {}) - - -class _c_half_p(ct.POINTER(c_float16)): +class c_half_p(ct.POINTER(c_half)): # Ctype for half scalars; we can't directly pass _Float16 values so # we use a pointer and dereference (see `passes.iet.dtypes`) @classmethod @@ -60,10 +57,6 @@ def from_param(cls, val): return arr.ctypes.data_as(cls) -# ctypes directly parses class dict; can't inherit the _type_ attribute -c_float16_p = type('_Float16 *', (_c_half_p,), {'_type_': c_float16}) - - class CustomType(ReservedWord): pass diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index d21ff780c9..b5765772a7 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -265,10 +265,8 @@ def _print_Float(self, expr): def _print_ImaginaryUnit(self, expr): if self.compiler._cpp: - if self.single_prec(with_f=True): + if self.single_prec(with_f=True) or self.half_prec(with_f=True): return '1if' - elif self.half_prec(with_f=True): - return '1if16' else: return '1i' else: From a9150791b88fea9d3fb61d1e323e06a8b9938486 Mon Sep 17 00:00:00 2001 From: enwask Date: Thu, 11 Jul 2024 18:40:05 +0100 Subject: [PATCH 18/58] compiler: add dtype_alloc_ctype helper for allocation size --- devito/data/allocators.py | 9 +++------ devito/tools/dtypes_lowering.py | 33 ++++++++++++++++++++++++++++----- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/devito/data/allocators.py b/devito/data/allocators.py index c887e85de0..475fea1f32 100644 --- a/devito/data/allocators.py +++ b/devito/data/allocators.py @@ -11,7 +11,8 @@ from devito.logger import logger from devito.parameters import configuration -from devito.tools import dtype_to_ctype, is_integer +from devito.tools import is_integer +from devito.tools.dtypes_lowering import dtype_alloc_ctype __all__ = ['ALLOC_ALIGNED', 'ALLOC_NUMA_LOCAL', 'ALLOC_NUMA_ANY', 'ALLOC_KNL_MCDRAM', 'ALLOC_KNL_DRAM', 'ALLOC_GUARD', @@ -92,12 +93,8 @@ def initialize(cls): return def alloc(self, shape, dtype, padding=0): - # For complex number, allocate double the size of its real/imaginary part - alloc_dtype = dtype(0).real.__class__ - c_scale = 2 if np.issubdtype(dtype, np.complexfloating) else 1 - + ctype, c_scale = dtype_alloc_ctype(dtype) datasize = int(reduce(mul, shape) * c_scale) - ctype = dtype_to_ctype(alloc_dtype) # Add padding, if any try: diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index 37acaa1dca..a6ce289324 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -11,8 +11,8 @@ __all__ = ['int2', 'int3', 'int4', 'float2', 'float3', 'float4', 'double2', # noqa 'double3', 'double4', 'dtypes_vector_mapper', 'dtype_to_mpidtype', - 'dtype_to_cstr', 'dtype_to_ctype', 'dtype_to_mpitype', 'dtype_len', - 'ctypes_to_cstr', 'c_restrict_void_p', 'ctypes_vector_mapper', + 'dtype_to_cstr', 'dtype_to_ctype', 'dtype_alloc_ctype', 'dtype_to_mpitype', + 'dtype_len', 'ctypes_to_cstr', 'c_restrict_void_p', 'ctypes_vector_mapper', 'is_external_ctype', 'infer_dtype', 'CustomDtype'] @@ -147,13 +147,36 @@ def dtype_to_ctype(dtype): # Bypass np.ctypeslib's normalization rules such as # `np.ctypeslib.as_ctypes_type(ctypes.c_void_p) -> ctypes.c_ulong` return dtype - elif dtype == np.float16: - # Allocator wants a ctype before float16 gets mapped - return ctypes.c_uint16 else: return np.ctypeslib.as_ctypes_type(dtype) +def dtype_alloc_ctype(dtype): + """ + Translate numpy.dtype to (ctype, int): type and scale for correct C allocation size. + """ + if isinstance(dtype, CustomDtype): + return dtype, 1 + + try: + return ctypes_vector_mapper[dtype], 1 + except KeyError: + pass + + if issubclass(dtype, ctypes._SimpleCData): + return dtype, 1 + + if dtype == np.float16: + # Allocate half float as unsigned short + return ctypes.c_uint16, 1 + + if np.issubdtype(dtype, np.complexfloating): + # For complex float, allocate twice the size of real/imaginary part + return np.ctypeslib.as_ctypes_type(dtype(0).real.__class__), 2 + + return np.ctypeslib.as_ctypes_type(dtype), 1 + + def dtype_to_mpitype(dtype): """Map numpy types to MPI datatypes.""" From 8c6b4ef2ffeedf3eb23fdbc1602d5bf48819e3f9 Mon Sep 17 00:00:00 2001 From: enwask Date: Mon, 15 Jul 2024 13:53:44 +0100 Subject: [PATCH 19/58] misc: more float16 refactoring/formatting fixes --- devito/data/allocators.py | 3 +-- devito/ir/iet/nodes.py | 3 +-- devito/symbolics/extended_dtypes.py | 31 ++++++++++++++++++----------- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/devito/data/allocators.py b/devito/data/allocators.py index 475fea1f32..1e00be3f20 100644 --- a/devito/data/allocators.py +++ b/devito/data/allocators.py @@ -11,8 +11,7 @@ from devito.logger import logger from devito.parameters import configuration -from devito.tools import is_integer -from devito.tools.dtypes_lowering import dtype_alloc_ctype +from devito.tools import is_integer, dtype_alloc_ctype __all__ = ['ALLOC_ALIGNED', 'ALLOC_NUMA_LOCAL', 'ALLOC_NUMA_ANY', 'ALLOC_KNL_MCDRAM', 'ALLOC_KNL_DRAM', 'ALLOC_GUARD', diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index 1270afff81..268b539612 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -1063,8 +1063,7 @@ def expr_symbols(self): if self.pointer.is_Symbol: assert issubclass(self.pointer._C_ctype, ctypes._Pointer), \ "Scalar dereference must have a pointer ctype" - ret.append(self.pointer._C_symbol) - ret.append(self.pointee._C_symbol) + ret.extend([self.pointer._C_symbol, self.pointee._C_symbol]) else: ret.append(self.pointer.indexed) if self.pointer.is_PointerArray or self.pointer.is_TempFunction: diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index f6265e4938..85256a3f94 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -1,4 +1,4 @@ -import ctypes as ct +import ctypes import numpy as np from devito.symbolics.extended_sympy import ReservedWord, Cast, CastStar, ValueLimit @@ -18,14 +18,16 @@ } -class NoDeclStruct(ct.Structure): - # ctypes.Structure that does not generate a struct definition +class NoDeclStruct(ctypes.Structure): + """A ctypes.Structure that does not generate a struct definition""" + pass class c_complex(NoDeclStruct): - # Structure for passing complex float to C/C++ - _fields_ = [('real', ct.c_float), ('imag', ct.c_float)] + """Structure for passing complex float to C/C++""" + + _fields_ = [('real', ctypes.c_float), ('imag', ctypes.c_float)] @classmethod def from_param(cls, val): @@ -33,24 +35,29 @@ def from_param(cls, val): class c_double_complex(NoDeclStruct): - # Structure for passing complex double to C/C++ - _fields_ = [('real', ct.c_double), ('imag', ct.c_double)] + """Structure for passing complex double to C/C++""" + + _fields_ = [('real', ctypes.c_double), ('imag', ctypes.c_double)] @classmethod def from_param(cls, val): return cls(val.real, val.imag) -class c_half(ct.c_uint16): - # Ctype for non-scalar half floats +class c_half(ctypes.c_uint16): + """Ctype for non-scalar half floats""" + @classmethod def from_param(cls, val): return cls(np.float16(val).view(np.uint16)) -class c_half_p(ct.POINTER(c_half)): - # Ctype for half scalars; we can't directly pass _Float16 values so - # we use a pointer and dereference (see `passes.iet.dtypes`) +class c_half_p(ctypes.POINTER(c_half)): + """ + Ctype for half scalars; we can't directly pass _Float16 values so + we use a pointer and dereference (see `passes.iet.dtypes`) + """ + @classmethod def from_param(cls, val): arr = np.array(val, dtype=np.float16) From 41ee036f477bf4fc705afe55d695813f6fe4b0e0 Mon Sep 17 00:00:00 2001 From: enwask Date: Tue, 16 Jul 2024 13:55:28 +0100 Subject: [PATCH 20/58] Remove dtypes lowering from IET layer --- devito/passes/iet/definitions.py | 10 +++--- devito/passes/iet/dtypes.py | 54 ++++---------------------------- 2 files changed, 11 insertions(+), 53 deletions(-) diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index 6dd6ec0d97..cbe15a9985 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -13,7 +13,7 @@ FindNodes, FindSymbols, MapExprStmts, Transformer, make_callable) from devito.passes import is_gpu_create -from devito.passes.iet.dtypes import lower_dtypes +from devito.passes.iet.dtypes import include_complex from devito.passes.iet.engine import iet_pass from devito.passes.iet.langbase import LangBB from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer, @@ -466,8 +466,8 @@ def place_casts(self, iet, **kwargs): return iet, {} @iet_pass - def make_langtypes(self, iet): - iet, metadata = lower_dtypes(iet, self.lang, self.compiler, self.sregistry) + def include_complex(self, iet): + iet, metadata = include_complex(iet, self.lang, self.compiler) return iet, metadata def process(self, graph): @@ -476,7 +476,7 @@ def process(self, graph): """ self.place_definitions(graph, globs=set()) self.place_casts(graph) - self.make_langtypes(graph) + self.include_complex(graph) class DeviceAwareDataManager(DataManager): @@ -618,7 +618,7 @@ def process(self, graph): self.place_devptr(graph) self.place_bundling(graph, writes_input=graph.writes_input) self.place_casts(graph) - self.make_langtypes(graph) + self.include_complex(graph) def make_zero_init(obj, rcompile, sregistry): diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index a2d0899224..789f49d5b4 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -1,63 +1,21 @@ import numpy as np import ctypes -from devito.ir import FindSymbols, Uxreplace -from devito.ir.iet.nodes import Dereference -from devito.tools.utils import as_list -from devito.types.basic import Symbol +from devito.ir import FindSymbols -__all__ = ['lower_dtypes'] +__all__ = ['include_complex'] -def lower_dtypes(iet, lang, compiler, sregistry): +def include_complex(iet, lang, compiler): """ - Lower language-specific dtypes and add headers for complex arithmetic - """ - # Include complex headers if needed (before we replace complex dtypes) - metadata = _complex_includes(iet, lang, compiler) - - body_prefix = [] # Derefs to prepend to the body - body_mapper = {} - params_mapper = {} - - # Lower scalar float16s to pointers and dereference them - if lang.get('half_types') is not None: - half, half_p = lang['half_types'] # dtype mappings for half float - - for s in FindSymbols('scalars').visit(iet): - if s.dtype != np.float16 or s not in iet.parameters: - continue - - ptr = s._rebuild(dtype=half_p, is_const=True) - val = Symbol(name=sregistry.make_name(prefix='hf'), dtype=half, - is_const=s.is_const) - - params_mapper[s], body_mapper[s] = ptr, val - body_prefix.append(Dereference(val, ptr)) # val = *ptr - - # Lower remaining language-specific dtypes - for s in FindSymbols('indexeds|basics|symbolics').visit(iet): - if s.dtype in lang['types'] and s not in params_mapper: - body_mapper[s] = params_mapper[s] = s._rebuild(dtype=lang['types'][s.dtype]) - - # Apply the dtype replacements - body = body_prefix + as_list(Uxreplace(body_mapper).visit(iet.body)) - params = Uxreplace(params_mapper).visit(iet.parameters) - - iet = iet._rebuild(body=body, parameters=params) - return iet, metadata - - -def _complex_includes(iet, lang, compiler): - """ - Add headers for complex arithmetic + Include complex arithmetic headers for the given language, if needed. """ # Check if there is complex numbers that always take dtype precedence types = {f.dtype for f in FindSymbols().visit(iet) if not issubclass(f.dtype, ctypes._Pointer)} if not any(np.issubdtype(d, np.complexfloating) for d in types): - return {} + return iet, {} metadata = {} lib = (lang['header-complex'],) @@ -75,4 +33,4 @@ def _complex_includes(iet, lang, compiler): metadata['includes'] = lib - return metadata + return iet, metadata From 4602d0a4506c1ab569389e3931cc02ca1fa570c4 Mon Sep 17 00:00:00 2001 From: enwask Date: Fri, 26 Jul 2024 16:53:11 +0100 Subject: [PATCH 21/58] compiler: reimplement float16/complex lowering --- devito/operator/operator.py | 10 +++++++ devito/passes/iet/definitions.py | 10 +++---- devito/passes/iet/dtypes.py | 44 +++++++++++++++++++++++++++-- devito/passes/iet/languages/C.py | 30 ++++---------------- devito/passes/iet/languages/CXX.py | 31 ++++---------------- devito/symbolics/extended_dtypes.py | 12 +++++++- 6 files changed, 78 insertions(+), 59 deletions(-) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index ef77373f8c..5fbf38c05e 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -27,10 +27,12 @@ from devito.passes import (Graph, lower_index_derivatives, generate_implicit, generate_macros, minimize_symbols, unevaluate, error_mapper, is_on_device) +from devito.passes.iet.langbase import LangBB from devito.symbolics import estimate_cost, subs_op_args from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_mapper, as_tuple, flatten, filter_sorted, frozendict, is_integer, split, timed_pass, timed_region, contains_val) +from devito.tools.dtypes_lowering import ctypes_vector_mapper from devito.types import (Buffer, Grid, Evaluable, host_layer, device_layer, disk_layer) from devito.types.dimension import Thickness @@ -270,6 +272,9 @@ def _lower(cls, expressions, **kwargs): # expression for which a partial or complete lowering is desired kwargs['rcompile'] = cls._rcompile_wrapper(**kwargs) + # Load language-specific types into the global dtype->ctype mapper + cls._load_dtype_mappings(**kwargs) + # [Eq] -> [LoweredEq] expressions = cls._lower_exprs(expressions, **kwargs) @@ -291,6 +296,11 @@ def _lower(cls, expressions, **kwargs): def _rcompile_wrapper(cls, **kwargs0): raise NotImplementedError + @classmethod + def _load_dtype_mappings(cls, **kwargs): + lang: type[LangBB] = cls._Target.DataManager.lang + ctypes_vector_mapper.update(lang.mapper.get('types', {})) + @classmethod def _initialize_state(cls, **kwargs): return {} diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index cbe15a9985..9ba83cfe31 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -13,7 +13,7 @@ FindNodes, FindSymbols, MapExprStmts, Transformer, make_callable) from devito.passes import is_gpu_create -from devito.passes.iet.dtypes import include_complex +from devito.passes.iet.dtypes import lower_dtypes from devito.passes.iet.engine import iet_pass from devito.passes.iet.langbase import LangBB from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer, @@ -466,8 +466,8 @@ def place_casts(self, iet, **kwargs): return iet, {} @iet_pass - def include_complex(self, iet): - iet, metadata = include_complex(iet, self.lang, self.compiler) + def lower_dtypes(self, iet): + iet, metadata = lower_dtypes(iet, self.lang, self.compiler, self.sregistry) return iet, metadata def process(self, graph): @@ -476,7 +476,7 @@ def process(self, graph): """ self.place_definitions(graph, globs=set()) self.place_casts(graph) - self.include_complex(graph) + self.lower_dtypes(graph) class DeviceAwareDataManager(DataManager): @@ -618,7 +618,7 @@ def process(self, graph): self.place_devptr(graph) self.place_bundling(graph, writes_input=graph.writes_input) self.place_casts(graph) - self.include_complex(graph) + self.lower_dtypes(graph) def make_zero_init(obj, rcompile, sregistry): diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index 789f49d5b4..f4f73e7663 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -2,11 +2,51 @@ import ctypes from devito.ir import FindSymbols +from devito.ir.iet.nodes import Dereference +from devito.ir.iet.visitors import Uxreplace +from devito.symbolics.extended_dtypes import Float16P +from devito.tools.utils import as_list +from devito.types.basic import Symbol -__all__ = ['include_complex'] +__all__ = ['lower_dtypes'] -def include_complex(iet, lang, compiler): +def lower_dtypes(iet, lang, compiler, sregistry): + """ + Lowers float16 scalar types to pointers since we can't directly pass their + value. Also includes headers for complex arithmetic if needed. + """ + + iet, metadata = _complex_includes(iet, lang, compiler) + + # Lower float16 parameters to pointers and dereference + body_prefix = [] + body_mapper = {} + params_mapper = {} + + # Lower scalar float16s to pointers and dereference them + for s in FindSymbols('scalars').visit(iet): + if not np.issubdtype(s.dtype, np.float16) or s not in iet.parameters: + continue + + # Replace the parameter with a pointer; replace occurences in the IET + # body with a dereference (using the original symbol's dtype) + ptr = s._rebuild(dtype=Float16P, is_const=True) + val = Symbol(name=sregistry.make_name(prefix='hf'), dtype=s.dtype, + is_const=s.is_const) + + params_mapper[s], body_mapper[s] = ptr, val + body_prefix.append(Dereference(val, ptr)) # val = *ptr + + # Apply the replacements + body = body_prefix + as_list(Uxreplace(body_mapper).visit(iet.body)) + params = Uxreplace(params_mapper).visit(iet.parameters) + + iet = iet._rebuild(body=body, parameters=params) + return iet, metadata + + +def _complex_includes(iet, lang, compiler): """ Include complex arithmetic headers for the given language, if needed. """ diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index 6112c3e895..069aa10320 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -4,41 +4,19 @@ from devito.passes.iet.definitions import DataManager from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.langbase import LangBB -from devito.symbolics.extended_dtypes import (c_complex, c_double_complex, +from devito.symbolics.extended_dtypes import (Float16P, c_complex, c_double_complex, c_half, c_half_p) -from devito.tools.dtypes_lowering import ctypes_vector_mapper __all__ = ['CBB', 'CDataManager', 'COrchestrator', 'c_float16', 'c_float16_p'] -class CCFloat(np.complex64): - pass - - -class CCDouble(np.complex128): - pass - - -class CHalf(np.float16): - pass - - -class CHalfP(np.float16): - pass - - c99_complex = type('_Complex float', (c_complex,), {}) c99_double_complex = type('_Complex double', (c_double_complex,), {}) c_float16 = type('_Float16', (c_half,), {}) c_float16_p = type('_Float16 *', (c_half_p,), {'_type_': c_float16}) -ctypes_vector_mapper[CCFloat] = c99_complex -ctypes_vector_mapper[CCDouble] = c99_double_complex -ctypes_vector_mapper[CHalf] = c_float16 -ctypes_vector_mapper[CHalfP] = c_float16_p - class CBB(LangBB): @@ -56,8 +34,10 @@ class CBB(LangBB): Call('memcpy', (i, j, k)), # Complex and float16 'header-complex': 'complex.h', - 'types': {np.complex128: CCDouble, np.complex64: CCFloat, np.float16: CHalf}, - 'half_types': (CHalf, CHalfP), + 'types': {np.complex128: c99_double_complex, + np.complex64: c99_complex, + np.float16: c_float16, + Float16P: c_float16_p} } diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index 48fdb4471b..1174a27f8d 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -3,8 +3,7 @@ from devito.ir import Call, UsingNamespace from devito.passes.iet.langbase import LangBB from devito.passes.iet.languages.C import c_float16, c_float16_p -from devito.symbolics.extended_dtypes import c_complex, c_double_complex -from devito.tools.dtypes_lowering import ctypes_vector_mapper +from devito.symbolics.extended_dtypes import Float16P, c_complex, c_double_complex __all__ = ['CXXBB'] @@ -46,30 +45,9 @@ """ -class CXXCFloat(np.complex64): - pass - - -class CXXCDouble(np.complex128): - pass - - -class CXXHalf(np.float16): - pass - - -class CXXHalfP(np.float16): - pass - - cxx_complex = type('std::complex', (c_complex,), {}) cxx_double_complex = type('std::complex', (c_double_complex,), {}) -ctypes_vector_mapper[CXXCFloat] = cxx_complex -ctypes_vector_mapper[CXXCDouble] = cxx_double_complex -ctypes_vector_mapper[CXXHalf] = c_float16 -ctypes_vector_mapper[CXXHalfP] = c_float16_p - class CXXBB(LangBB): @@ -89,7 +67,8 @@ class CXXBB(LangBB): 'header-complex': 'complex', 'complex-namespace': [UsingNamespace('std::complex_literals')], 'def-complex': std_arith, - 'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat, - np.float16: CXXHalf}, - 'half_types': (CXXHalf, CXXHalfP), + "types": {np.complex128: cxx_double_complex, + np.complex64: cxx_complex, + np.float16: c_float16, + Float16P: c_float16_p} } diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index 85256a3f94..0b8b1bcad1 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -7,7 +7,7 @@ __all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', 'DOUBLE', 'VOID', 'NoDeclStruct', 'c_complex', 'c_double_complex', - 'c_half', 'c_half_p'] + 'c_half', 'c_half_p', 'Float16P'] limits_mapper = { @@ -64,6 +64,16 @@ def from_param(cls, val): return arr.ctypes.data_as(cls) +class Float16P(np.float16): + """ + Dummy dtype for a scalar float16 value that's been mapped to a pointer. + This is needed because we can't directly pass in the values; we map to + pointers and dereference in the kernel; see `passes.iet.dtypes`. + """ + + pass + + class CustomType(ReservedWord): pass From 92b80e6ecf7e9e618d31d4d8e21fcf4fb4ed2172 Mon Sep 17 00:00:00 2001 From: enwask Date: Mon, 29 Jul 2024 12:43:36 +0100 Subject: [PATCH 22/58] misc: cleanup, docs and typing for half support --- devito/ir/iet/nodes.py | 14 ++++++-------- devito/ir/iet/visitors.py | 5 +---- devito/passes/iet/dtypes.py | 29 ++++++++++++++++------------- devito/symbolics/extended_dtypes.py | 8 +++++++- 4 files changed, 30 insertions(+), 26 deletions(-) diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index 268b539612..4ffdb39773 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -1064,15 +1064,13 @@ def expr_symbols(self): assert issubclass(self.pointer._C_ctype, ctypes._Pointer), \ "Scalar dereference must have a pointer ctype" ret.extend([self.pointer._C_symbol, self.pointee._C_symbol]) + elif self.pointer.is_PointerArray or self.pointer.is_TempFunction: + ret.extend([self.pointer.indexed, self.pointee.indexed]) + ret.extend(flatten(i.free_symbols + for i in self.pointee.symbolic_shape[1:])) + ret.extend(self.pointer.free_symbols) else: - ret.append(self.pointer.indexed) - if self.pointer.is_PointerArray or self.pointer.is_TempFunction: - ret.append(self.pointee.indexed) - ret.extend(flatten(i.free_symbols - for i in self.pointee.symbolic_shape[1:])) - ret.extend(self.pointer.free_symbols) - else: - ret.append(self.pointee._C_symbol) + ret.extend([self.pointer.indexed, self.pointee._C_symbol]) return tuple(filter_ordered(ret)) @property diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 4bae7af755..0f1466d6e2 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -454,11 +454,8 @@ def visit_Dereference(self, o): lvalue = c.Value(cstr, '*restrict %s' % a0.name) if a0._data_alignment: lvalue = c.AlignedAttribute(a0._data_alignment, lvalue) - elif a1.is_Symbol: - rvalue = '*%s' % a1.name - lvalue = self._gen_value(a0, 0) else: - rvalue = '%s->%s' % (a1.name, a0._C_name) + rvalue = '*%s' % a1.name if a1.is_Symbol else '%s->%s' % (a1.name, a0._C_name) lvalue = self._gen_value(a0, 0) return c.Initializer(lvalue, rvalue) diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index f4f73e7663..03093c18a1 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -1,17 +1,18 @@ -import numpy as np import ctypes +import numpy as np -from devito.ir import FindSymbols -from devito.ir.iet.nodes import Dereference -from devito.ir.iet.visitors import Uxreplace +from devito.arch.compiler import Compiler +from devito.ir import Callable, Dereference, FindSymbols, SymbolRegistry, Uxreplace +from devito.passes.iet.langbase import LangBB from devito.symbolics.extended_dtypes import Float16P -from devito.tools.utils import as_list -from devito.types.basic import Symbol +from devito.tools import as_list +from devito.types import Symbol __all__ = ['lower_dtypes'] -def lower_dtypes(iet, lang, compiler, sregistry): +def lower_dtypes(iet: Callable, lang: type[LangBB], compiler: Compiler, + sregistry: SymbolRegistry) -> tuple[Callable, dict]: """ Lowers float16 scalar types to pointers since we can't directly pass their value. Also includes headers for complex arithmetic if needed. @@ -20,13 +21,14 @@ def lower_dtypes(iet, lang, compiler, sregistry): iet, metadata = _complex_includes(iet, lang, compiler) # Lower float16 parameters to pointers and dereference - body_prefix = [] + prefix = [] body_mapper = {} params_mapper = {} # Lower scalar float16s to pointers and dereference them + params = set(iet.parameters) for s in FindSymbols('scalars').visit(iet): - if not np.issubdtype(s.dtype, np.float16) or s not in iet.parameters: + if s.dtype != np.float16 or s not in params: continue # Replace the parameter with a pointer; replace occurences in the IET @@ -36,17 +38,18 @@ def lower_dtypes(iet, lang, compiler, sregistry): is_const=s.is_const) params_mapper[s], body_mapper[s] = ptr, val - body_prefix.append(Dereference(val, ptr)) # val = *ptr + prefix.append(Dereference(val, ptr)) # val = *ptr # Apply the replacements - body = body_prefix + as_list(Uxreplace(body_mapper).visit(iet.body)) + prefix.extend(as_list(Uxreplace(body_mapper).visit(iet.body))) params = Uxreplace(params_mapper).visit(iet.parameters) - iet = iet._rebuild(body=body, parameters=params) + iet = iet._rebuild(body=prefix, parameters=params) return iet, metadata -def _complex_includes(iet, lang, compiler): +def _complex_includes(iet: Callable, lang: type[LangBB], + compiler: Compiler) -> tuple[Callable, dict]: """ Include complex arithmetic headers for the given language, if needed. """ diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index 0b8b1bcad1..af2da5d353 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -19,7 +19,13 @@ class NoDeclStruct(ctypes.Structure): - """A ctypes.Structure that does not generate a struct definition""" + """ + A ctypes.Structure that does not generate a struct definition. + + Some foreign types (e.g. complex) need to be passed to C/C++ as a struct + that mimics an existing type, but the struct types themselves don't show + up in the kernel, so we don't need to generate their definitions. + """ pass From 6d256f97c616c80e133623fbf3642ca2574fa8bc Mon Sep 17 00:00:00 2001 From: enwask Date: Mon, 29 Jul 2024 12:46:33 +0100 Subject: [PATCH 23/58] compiler: FindSymbols 'scalars' -> 'abstractsymbols' --- devito/ir/iet/visitors.py | 5 +++-- devito/passes/iet/dtypes.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 0f1466d6e2..0432a6e77a 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -958,7 +958,7 @@ def default_retval(cls): Drive the search. Accepted: - `symbolics`: Collect all AbstractFunction objects, default - `basics`: Collect all Basic objects - - `scalars`: Collect all AbstractSymbol objects + - `abstractsymbols`: Collect all AbstractSymbol objects - `dimensions`: Collect all Dimensions - `indexeds`: Collect all Indexed objects - `indexedbases`: Collect all IndexedBase objects @@ -979,7 +979,8 @@ def _defines_aliases(n): rules = { 'symbolics': lambda n: n.functions, 'basics': lambda n: [i for i in n.expr_symbols if isinstance(i, Basic)], - 'scalars': lambda n: [i for i in n.expr_symbols if isinstance(i, AbstractSymbol)], + 'abstractsymbols': lambda n: [i for i in n.expr_symbols + if isinstance(i, AbstractSymbol)], 'dimensions': lambda n: [i for i in n.expr_symbols if isinstance(i, Dimension)], 'indexeds': lambda n: [i for i in n.expr_symbols if i.is_Indexed], 'indexedbases': lambda n: [i for i in n.expr_symbols diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index 03093c18a1..216f989ad6 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -27,7 +27,7 @@ def lower_dtypes(iet: Callable, lang: type[LangBB], compiler: Compiler, # Lower scalar float16s to pointers and dereference them params = set(iet.parameters) - for s in FindSymbols('scalars').visit(iet): + for s in FindSymbols('abstractsymbols').visit(iet): if s.dtype != np.float16 or s not in params: continue From 2588ea6887ee7c27677301e093f6de852bed1d97 Mon Sep 17 00:00:00 2001 From: enwask Date: Tue, 30 Jul 2024 11:22:57 +0100 Subject: [PATCH 24/58] test: include scalar parameters in complex tests --- tests/test_gpu_common.py | 10 ++++++---- tests/test_operator.py | 8 +++++--- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index 4a11c12556..e2321036e0 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -81,18 +81,20 @@ def test_maxpar_option(self): def test_complex(self, dtype): grid = Grid((5, 5)) x, y = grid.dimensions + + c = Constant(name='c', dtype=dtype) u = Function(name="u", grid=grid, dtype=dtype) - eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) + eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing) * c) op = Operator(eq) - op() + op(c=1.0 + 2.0j) # Check against numpy dx = grid.spacing_map[x.spacing] xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5)) - npres = xx + 1j*yy + np.exp(1j + dx) + npres = xx + 1j*yy + np.exp(1j + dx) * (1.0 + 2.0j) - assert np.allclose(u.data, npres.T, rtol=1e-6, atol=0) + assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0) class TestPassesOptional: diff --git a/tests/test_operator.py b/tests/test_operator.py index 283249aac1..4282165314 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -644,16 +644,18 @@ def test_tensor(self, func1): def test_complex(self, dtype): grid = Grid((5, 5)) x, y = grid.dimensions + + c = Constant(name='c', dtype=dtype) u = Function(name="u", grid=grid, dtype=dtype) - eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) + eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing) * c) op = Operator(eq) - op() + op(c=1.0 + 2.0j) # Check against numpy dx = grid.spacing_map[x.spacing] xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5)) - npres = xx + 1j*yy + np.exp(1j + dx) + npres = xx + 1j*yy + np.exp(1j + dx) * (1.0 + 2.0j) assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0) From 1b23e920ee3e880d9357f17b494c9f4dbd3035a9 Mon Sep 17 00:00:00 2001 From: enwask Date: Tue, 30 Jul 2024 11:23:49 +0100 Subject: [PATCH 25/58] test: add test_dtypes with initial tests for float16 + complex --- tests/test_dtypes.py | 202 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 202 insertions(+) create mode 100644 tests/test_dtypes.py diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py new file mode 100644 index 0000000000..32f87449c7 --- /dev/null +++ b/tests/test_dtypes.py @@ -0,0 +1,202 @@ +import numpy as np +import pytest +import sympy + +from devito import Constant, Eq, Function, Grid, Operator +from devito.ir.iet.nodes import Dereference +from devito.passes.iet.langbase import LangBB +from devito.passes.iet.languages.C import CBB +from devito.passes.iet.languages.openacc import AccBB +from devito.passes.iet.languages.openmp import OmpBB +from devito.symbolics.extended_dtypes import Float16P +from devito.types.basic import Basic, Scalar, Symbol +from devito.types.dimension import Dimension, Spacing + +# Mappers for language-specific types and headers +_languages: dict[str, type[LangBB]] = { + 'C': CBB, + 'openmp': OmpBB, + 'openacc': AccBB +} + + +def _get_language(language: str, **_) -> type[LangBB]: + """ + Gets the language building block type from parametrized kwargs. + """ + + return _languages[language] + + +def _config_kwargs(platform: str, language: str, compiler: str) -> dict[str, str]: + """ + Generates kwargs for Operator to test language-specific behavior. + """ + + return { + 'platform': platform, + 'language': language, + 'compiler': compiler + } + + +# List of pararmetrized operator kwargs for testing language-specific behavior +_configs: list[dict[str, str]] = [ + _config_kwargs(*cfg) for cfg in [ + ('cpu64', 'C', 'gcc'), + ('cpu64', 'openmp', 'gcc'), + ('nvidiaX', 'openmp', 'nvc'), + ('nvidiaX', 'openacc', 'nvc') + ] +] + + +@pytest.mark.parametrize('dtype', [np.float16, np.complex64, np.complex128]) +@pytest.mark.parametrize('kwargs', _configs) +def test_dtype_mapping(dtype: np.dtype, kwargs: dict[str, str]) -> None: + """ + Tests that half and complex floats' dtypes result in the correct type + strings in generated code. + """ + + # Retrieve the language-specific type mapping + lang_types: dict[np.dtype, type] = _get_language(**kwargs).get('types') + + # Set up an operator + grid = Grid(shape=(3, 3)) + x: Dimension + y: Dimension + x, y = grid.dimensions + + c = Constant(name='c', dtype=dtype) + u = Function(name='u', grid=grid, dtype=dtype) + eq = Eq(u, c * x * y) + op = Operator(eq, **kwargs) + + # Check ctypes of the mapped parameters + params: dict[str, Basic] = {p.name: p for p in op.parameters} + _u: Function = params['u'] + _c: Constant = params['c'] + assert _u.indexed._C_ctype._type_ == lang_types[_u.dtype] + assert _c._C_ctype == lang_types[_c.dtype] + + +def test_half_params() -> None: + """ + Tests float16 input parameters: scalars should be lowered to pointers + and dereferenced; other parameters should keep the original dtype. + """ + + grid = Grid(shape=(5, 5), dtype=np.float16) + x: Dimension + y: Dimension + x, y = grid.dimensions + + c = Constant(name='c', dtype=np.float16) + u = Function(name='u', grid=grid) + eq = Eq(u, x * x.spacing + c * y * y.spacing) + op = Operator(eq) + + # Check that lowered parameters have the correct dtypes + params: dict[str, Basic] = {p.name: p for p in op.parameters} + _u: Function = params['u'] + _c: Constant = params['c'] + _dx: Spacing = params['h_x'] + _dy: Spacing = params['h_y'] + + assert _u.dtype == np.float16 + assert _c.dtype == Float16P + assert _dx.dtype == Float16P + assert _dy.dtype == Float16P + + # Ensure the mapped pointer-to-half symbols are dereferenced + derefs: set[Symbol] = {n.pointer for n in op.body.body + if isinstance(n, Dereference)} + assert _c in derefs + assert _dx in derefs + assert _dy in derefs + + +@pytest.mark.parametrize('dtype', [np.float16, np.float32, + np.complex64, np.complex128]) +@pytest.mark.parametrize('kwargs', _configs) +def test_complex_headers(dtype: np.dtype, kwargs: dict[str, str]) -> None: + """ + Tests that the correct complex headers are included when complex dtypes + are present in the operator, and omitted otherwise. + """ + + # Set up an operator + grid = Grid(shape=(3, 3)) + x: Dimension + y: Dimension + x, y = grid.dimensions + + c = Constant(name='c', dtype=dtype) + u = Function(name='u', grid=grid, dtype=dtype) + eq = Eq(u, c * x * y) + op = Operator(eq, **kwargs) + + # Check that the complex header is included <=> complex dtypes are present + header: str = _get_language(**kwargs).get('header-complex') + if np.issubdtype(dtype, np.complexfloating): + assert header in op._includes + else: + assert header not in op._includes + + +@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) +@pytest.mark.parametrize('kwargs', _configs) +def test_imag_unit(dtype: np.complexfloating, kwargs: dict[str, str]) -> None: + """ + Tests that the correct literal is used for the imaginary unit. + """ + + # Determine the expected imaginary unit string + unit_str: str + if kwargs['compiler'] == 'gcc': + # In C we multiply by the _Complex_I macro constant + unit_str = '_Complex_I' + else: + # C++ provides an imaginary literal + unit_str = '1if' if dtype == np.complex64 else '1i' + + # Set up an operator + s = Symbol(name='s', dtype=dtype) + eq = Eq(s, 2.0 + 3.0j) + op = Operator(eq, **kwargs) + + # Check that the correct imaginary unit is used + assert unit_str in str(op) + + +@pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64, + np.complex64, np.complex128]) +@pytest.mark.parametrize(['sym', 'fun'], [(sympy.exp, np.exp), + (sympy.log, np.log), + (sympy.sin, np.sin)]) +def test_math_functions(dtype: np.dtype, sym: sympy.Function, fun: np.ufunc) -> None: + """ + Tests that the correct math functions are used, and their results cast + and assigned appropriately for different float precisions and for + complex floats/doubles. + """ + + # Get the expected function call string + call_str = str(sym) + if np.issubdtype(dtype, np.complexfloating): + # Complex functions have a 'c' prefix + call_str = 'c%s' % call_str + if dtype(0).real.itemsize <= 4: + # Single precision have an 'f' suffix (half is promoted to single) + call_str = '%sf' % call_str + + # Operator setup + a = Symbol(name='a', dtype=dtype) + b = Scalar(name='b', dtype=dtype) + + eq = Eq(a, sym(b)) + op = Operator(eq) + + # Ensure the generated function call has the correct form + assert call_str in str(op) From 5396432a6737348d5f24119445fb9131c1422bce Mon Sep 17 00:00:00 2001 From: enwask Date: Tue, 30 Jul 2024 11:55:41 +0100 Subject: [PATCH 26/58] misc: more lower_dtypes cleanup + type hints --- devito/passes/iet/dtypes.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index 216f989ad6..6f698a659d 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -2,11 +2,11 @@ import numpy as np from devito.arch.compiler import Compiler -from devito.ir import Callable, Dereference, FindSymbols, SymbolRegistry, Uxreplace +from devito.ir import Callable, Dereference, FindSymbols, Node, SymbolRegistry, Uxreplace from devito.passes.iet.langbase import LangBB from devito.symbolics.extended_dtypes import Float16P from devito.tools import as_list -from devito.types import Symbol +from devito.types.basic import AbstractSymbol, Basic, Symbol __all__ = ['lower_dtypes'] @@ -21,19 +21,19 @@ def lower_dtypes(iet: Callable, lang: type[LangBB], compiler: Compiler, iet, metadata = _complex_includes(iet, lang, compiler) # Lower float16 parameters to pointers and dereference - prefix = [] - body_mapper = {} - params_mapper = {} + prefix: list[Node] = [] + params_mapper: dict[AbstractSymbol, AbstractSymbol] = {} + body_mapper: dict[AbstractSymbol, Symbol] = {} - # Lower scalar float16s to pointers and dereference them - params = set(iet.parameters) + params_set = set(iet.parameters) + s: AbstractSymbol for s in FindSymbols('abstractsymbols').visit(iet): - if s.dtype != np.float16 or s not in params: + if s.dtype != np.float16 or s not in params_set: continue # Replace the parameter with a pointer; replace occurences in the IET - # body with a dereference (using the original symbol's dtype) - ptr = s._rebuild(dtype=Float16P, is_const=True) + # body with dereferenced symbol (using the original symbol's dtype) + ptr: AbstractSymbol = s._rebuild(dtype=Float16P, is_const=True) val = Symbol(name=sregistry.make_name(prefix='hf'), dtype=s.dtype, is_const=s.is_const) @@ -42,7 +42,7 @@ def lower_dtypes(iet: Callable, lang: type[LangBB], compiler: Compiler, # Apply the replacements prefix.extend(as_list(Uxreplace(body_mapper).visit(iet.body))) - params = Uxreplace(params_mapper).visit(iet.parameters) + params: tuple[Basic] = Uxreplace(params_mapper).visit(iet.parameters) iet = iet._rebuild(body=prefix, parameters=params) return iet, metadata @@ -51,9 +51,10 @@ def lower_dtypes(iet: Callable, lang: type[LangBB], compiler: Compiler, def _complex_includes(iet: Callable, lang: type[LangBB], compiler: Compiler) -> tuple[Callable, dict]: """ - Include complex arithmetic headers for the given language, if needed. + Includes complex arithmetic headers for the given language, if needed. """ - # Check if there is complex numbers that always take dtype precedence + + # Check if there are complex numbers that always take dtype precedence types = {f.dtype for f in FindSymbols().visit(iet) if not issubclass(f.dtype, ctypes._Pointer)} From 0bf20326b4422bfa6eee798f1796b65162c73afd Mon Sep 17 00:00:00 2001 From: enwask Date: Wed, 31 Jul 2024 15:33:59 +0100 Subject: [PATCH 27/58] api: use grid dtype for extent and origin, add test_grid --- devito/types/grid.py | 11 +++++++---- tests/test_grid.py | 27 +++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) create mode 100644 tests/test_grid.py diff --git a/devito/types/grid.py b/devito/types/grid.py index 6ee7a7eee3..0904a72fe9 100644 --- a/devito/types/grid.py +++ b/devito/types/grid.py @@ -69,9 +69,10 @@ class Grid(CartesianDiscretization, ArgProvider): ---------- shape : tuple of ints Shape of the computational domain in grid points. - extent : tuple of floats, default=unit box of extent 1m in all dimensions + extent : tuple of values interpretable as dtype, default=unit box of extent 1m + in all dimensions Physical extent of the domain in m. - origin : tuple of floats, default=0.0 in all dimensions + origin : tuple of values interpretable as dtype, default=0.0 in all dimensions Physical coordinate of the origin of the domain. dimensions : tuple of SpaceDimension, optional The dimensions of the computational domain encapsulated by this Grid. @@ -178,7 +179,8 @@ def __init__(self, shape, extent=None, origin=None, dimensions=None, self._distributor = Distributor(shape, dimensions, comm, self._topology) # The physical extent - self._extent = as_tuple(extent or tuple(1. for _ in self.shape)) + extent = as_tuple(extent or tuple(1. for _ in self.shape)) + self._extent = tuple(dtype(e) for e in extent) # Initialize SubDomains subdomains = tuple(i for i in (Domain(), Interior(), *as_tuple(subdomains))) @@ -186,7 +188,8 @@ def __init__(self, shape, extent=None, origin=None, dimensions=None, i.__subdomain_finalize__(self) self._subdomains = subdomains - self._origin = as_tuple(origin or tuple(0. for _ in self.shape)) + origin = as_tuple(origin or tuple(0. for _ in self.shape)) + self._origin = tuple(dtype(o) for o in origin) self._origin_symbols = tuple(Scalar(name='o_%s' % d.name, dtype=dtype, is_const=True) for d in self.dimensions) diff --git a/tests/test_grid.py b/tests/test_grid.py new file mode 100644 index 0000000000..5753e17d7b --- /dev/null +++ b/tests/test_grid.py @@ -0,0 +1,27 @@ +import numpy as np +import pytest + +from devito import Grid + + +# Unsigned ints are unreasonable but not necessarily invalid +@pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64, np.longdouble, + np.complex64, np.complex128, np.int8, np.int16, + np.int32, np.int64, np.uint8, np.uint16, np.uint32, + np.uint64]) +def test_extent_dtypes(dtype: np.dtype[np.number]) -> None: + """ + Test that grid spacings are correctly computed for different dtypes. + """ + + # Construct a grid with the dtype and retrieve the spacing values + extent = (1, 1j) if np.issubdtype(dtype, np.complexfloating) else (2, 4) + grid = Grid(shape=(5, 5), extent=extent, dtype=dtype) + dx, dy = grid.spacing_map.values() + + # Check that the spacings have the correct dtype + assert dx.dtype == dy.dtype == dtype + + # Check that the spacings have the correct values + assert dx == dtype(extent[0] / 4) + assert dy == dtype(extent[1] / 4) From 442a804af833e47fa4ec8bc5014162caad20763a Mon Sep 17 00:00:00 2001 From: enwask Date: Wed, 31 Jul 2024 15:57:07 +0100 Subject: [PATCH 28/58] test: clean up and add more half/complex tests --- devito/symbolics/extended_dtypes.py | 4 +- tests/test_dtypes.py | 164 ++++++++++++++++++++++++---- 2 files changed, 147 insertions(+), 21 deletions(-) diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index af2da5d353..8c90e48986 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -72,9 +72,9 @@ def from_param(cls, val): class Float16P(np.float16): """ - Dummy dtype for a scalar float16 value that's been mapped to a pointer. + Dummy dtype for a scalar half value that has been mapped to a pointer. This is needed because we can't directly pass in the values; we map to - pointers and dereference in the kernel; see `passes.iet.dtypes`. + pointers and dereference in the kernel. See `passes.iet.dtypes`. """ pass diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 32f87449c7..a330096b36 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -1,5 +1,6 @@ import numpy as np import pytest +import re import sympy from devito import Constant, Eq, Function, Grid, Operator @@ -9,8 +10,9 @@ from devito.passes.iet.languages.openacc import AccBB from devito.passes.iet.languages.openmp import OmpBB from devito.symbolics.extended_dtypes import Float16P +from devito.tools import ctypes_to_cstr from devito.types.basic import Basic, Scalar, Symbol -from devito.types.dimension import Dimension, Spacing +from devito.types.dense import TimeFunction # Mappers for language-specific types and headers _languages: dict[str, type[LangBB]] = { @@ -45,7 +47,6 @@ def _config_kwargs(platform: str, language: str, compiler: str) -> dict[str, str _config_kwargs(*cfg) for cfg in [ ('cpu64', 'C', 'gcc'), ('cpu64', 'openmp', 'gcc'), - ('nvidiaX', 'openmp', 'nvc'), ('nvidiaX', 'openacc', 'nvc') ] ] @@ -53,7 +54,7 @@ def _config_kwargs(platform: str, language: str, compiler: str) -> dict[str, str @pytest.mark.parametrize('dtype', [np.float16, np.complex64, np.complex128]) @pytest.mark.parametrize('kwargs', _configs) -def test_dtype_mapping(dtype: np.dtype, kwargs: dict[str, str]) -> None: +def test_dtype_mapping(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None: """ Tests that half and complex floats' dtypes result in the correct type strings in generated code. @@ -64,8 +65,6 @@ def test_dtype_mapping(dtype: np.dtype, kwargs: dict[str, str]) -> None: # Set up an operator grid = Grid(shape=(3, 3)) - x: Dimension - y: Dimension x, y = grid.dimensions c = Constant(name='c', dtype=dtype) @@ -75,12 +74,38 @@ def test_dtype_mapping(dtype: np.dtype, kwargs: dict[str, str]) -> None: # Check ctypes of the mapped parameters params: dict[str, Basic] = {p.name: p for p in op.parameters} - _u: Function = params['u'] - _c: Constant = params['c'] + _u, _c = params['u'], params['c'] assert _u.indexed._C_ctype._type_ == lang_types[_u.dtype] assert _c._C_ctype == lang_types[_c.dtype] +@pytest.mark.parametrize('dtype', [np.float16, np.complex64, np.complex128]) +@pytest.mark.parametrize('kwargs', _configs) +def test_cse_ctypes(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None: + """ + Tests that variables introduced by CSE have the correct type strings in + the generated code. + """ + + # Retrieve the language-specific type mapping + lang_types: dict[np.dtype, type] = _get_language(**kwargs).get('types') + + # Set up an operator + grid = Grid(shape=(3, 3)) + x, y = grid.dimensions + + c = Constant(name='c', dtype=dtype) + u = Function(name='u', grid=grid, dtype=dtype) + # sin(c) should be CSE'd + eq = Eq(u, x * x.spacing + y * y.spacing * sympy.sin(c)) + op = Operator(eq, **kwargs) + + # Ensure the CSE'd variable has the correct type + match = re.search(r'[^\S\n\r]*(.*\S)\sr0 = ', str(op)) + assert match is not None + assert match.group(1) == ctypes_to_cstr(lang_types[dtype]) + + def test_half_params() -> None: """ Tests float16 input parameters: scalars should be lowered to pointers @@ -88,8 +113,6 @@ def test_half_params() -> None: """ grid = Grid(shape=(5, 5), dtype=np.float16) - x: Dimension - y: Dimension x, y = grid.dimensions c = Constant(name='c', dtype=np.float16) @@ -99,10 +122,7 @@ def test_half_params() -> None: # Check that lowered parameters have the correct dtypes params: dict[str, Basic] = {p.name: p for p in op.parameters} - _u: Function = params['u'] - _c: Constant = params['c'] - _dx: Spacing = params['h_x'] - _dy: Spacing = params['h_y'] + _u, _c, _dx, _dy = params['u'], params['c'], params['h_x'], params['h_y'] assert _u.dtype == np.float16 assert _c.dtype == Float16P @@ -120,7 +140,8 @@ def test_half_params() -> None: @pytest.mark.parametrize('dtype', [np.float16, np.float32, np.complex64, np.complex128]) @pytest.mark.parametrize('kwargs', _configs) -def test_complex_headers(dtype: np.dtype, kwargs: dict[str, str]) -> None: +def test_complex_headers(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None: + np.dtype """ Tests that the correct complex headers are included when complex dtypes are present in the operator, and omitted otherwise. @@ -128,8 +149,6 @@ def test_complex_headers(dtype: np.dtype, kwargs: dict[str, str]) -> None: # Set up an operator grid = Grid(shape=(3, 3)) - x: Dimension - y: Dimension x, y = grid.dimensions c = Constant(name='c', dtype=dtype) @@ -158,8 +177,11 @@ def test_imag_unit(dtype: np.complexfloating, kwargs: dict[str, str]) -> None: # In C we multiply by the _Complex_I macro constant unit_str = '_Complex_I' else: - # C++ provides an imaginary literal - unit_str = '1if' if dtype == np.complex64 else '1i' + # C++ provides imaginary literals + if dtype == np.complex64: + unit_str = '1if' + else: + unit_str = '1i' # Set up an operator s = Symbol(name='s', dtype=dtype) @@ -175,7 +197,8 @@ def test_imag_unit(dtype: np.complexfloating, kwargs: dict[str, str]) -> None: @pytest.mark.parametrize(['sym', 'fun'], [(sympy.exp, np.exp), (sympy.log, np.log), (sympy.sin, np.sin)]) -def test_math_functions(dtype: np.dtype, sym: sympy.Function, fun: np.ufunc) -> None: +def test_math_functions(dtype: np.dtype[np.inexact], + sym: sympy.Function, fun: np.ufunc) -> None: """ Tests that the correct math functions are used, and their results cast and assigned appropriately for different float precisions and for @@ -200,3 +223,106 @@ def test_math_functions(dtype: np.dtype, sym: sympy.Function, fun: np.ufunc) -> # Ensure the generated function call has the correct form assert call_str in str(op) + + +@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) +def test_complex_override(dtype: np.dtype[np.complexfloating]) -> None: + """ + Tests overriding complex values in op.apply(). + """ + + grid = Grid(shape=(5, 5)) + x, y = grid.dimensions + + c = Constant(name='c', dtype=dtype, value=1.0 + 0.0j) + u = Function(name='u', grid=grid, dtype=dtype) + eq = Eq(u, x * x.spacing + c * y * y.spacing) + op = Operator(eq) + op.apply(c=2.0 + 1.0j) + + # Check against numpy result + dx, dy = grid.spacing_map.values() + xx, yy = np.meshgrid(np.linspace(0, 4, 5, dtype=dtype), + np.linspace(0, 4, 5, dtype=dtype)) + expected = xx * dx + yy * dy * dtype(2.0 + 1.0j) + assert np.allclose(u.data.T, expected) + + +def test_half_time_deriv() -> None: + """ + Tests taking the time derivative of a float16 function. + """ + + grid = Grid(shape=(5, 5)) + x, y = grid.dimensions + t = grid.time_dim + + f = TimeFunction(name='f', grid=grid, space_order=2, dtype=np.float16) + g = Function(name='g', grid=grid, dtype=np.float16) + eqns = [Eq(f.forward, t * x * x.spacing + + y * y.spacing), + Eq(g, f.dt)] + op = Operator(eqns) + op.apply(time=10, dt=1.0) + + # Check against expected result + dx = grid.spacing_map[x.spacing] + xx = np.repeat(np.linspace(0, 4, 5, dtype=np.float16)[np.newaxis, :], 5, axis=0) + expected = xx * np.float16(dx) + assert np.allclose(g.data.T, expected) + + +@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) +def test_complex_time_deriv(dtype: np.dtype[np.complexfloating]) -> None: + """ + Tests taking the time derivative of a complex-valued function. + """ + + grid = Grid(shape=(5, 5)) + x, y = grid.dimensions + t = grid.time_dim + + f = TimeFunction(name='f', grid=grid, space_order=2, dtype=dtype) + g = Function(name='g', grid=grid, dtype=dtype) + eqns = [Eq(f.forward, t * x * x.spacing * (1.0 + 0.0j) + + t * y * y.spacing * (0.0 + 1.0j)), + Eq(g, f.dt)] + op = Operator(eqns) + op.apply(time=10, dt=1.0) + + # Check against expected result + dx, dy = grid.spacing_map.values() + xx, yy = np.meshgrid(np.linspace(0, 4, 5, dtype=dtype), + np.linspace(0, 4, 5, dtype=dtype)) + expected = xx * dx + yy * dy * dtype(0.0 + 1.0j) + assert np.allclose(g.data.T, expected) + + +@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) +def test_complex_space_deriv(dtype: np.dtype[np.complexfloating]) -> None: + """ + Tests taking the space derivative of a complex-valued function, with + respect to the real and imaginary axes. + """ + + grid = Grid(shape=(7, 7), dtype=dtype) + x, y = grid.dimensions + + # Operator setup + f = Function(name='f', grid=grid, space_order=2) + g = Function(name='g', grid=grid) + h = Function(name='h', grid=grid) + eqns = [Eq(f, x * x.spacing + y * y.spacing), + Eq(g, f.dx, subdomain=grid.interior), + Eq(h, f.dy, subdomain=grid.interior)] + op = Operator(eqns) + + dx = 1.0 + 0.0j + dy = 0.0 + 1.0j + op.apply(h_x=dx, h_y=dy) + + # Check against expected result (1 within the interior) + dfdx = g.data.T[1:-1, 1:-1] + dfdy = h.data.T[1:-1, 1:-1] + assert np.allclose(dfdx, np.ones((5, 5), dtype=dtype)) + assert np.allclose(dfdy, np.ones((5, 5), dtype=dtype)) From 249f1cda8b5e87a66120ba37506eff5f7c63f084 Mon Sep 17 00:00:00 2001 From: enwask Date: Wed, 31 Jul 2024 18:47:55 +0100 Subject: [PATCH 29/58] test: fix test_grid_objs, add test_grid_dtypes --- tests/test_caching.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_caching.py b/tests/test_caching.py index 93200d3d73..93b046372d 100644 --- a/tests/test_caching.py +++ b/tests/test_caching.py @@ -453,8 +453,16 @@ def test_grid_objs(self): assert y0 is y1 assert x0.spacing is x1.spacing assert y0.spacing is y1.spacing - assert ox0 is ox1 - assert oy0 is oy1 + + def test_grid_dtypes(self): + """ + Test that two grids with different dtypes have different hash values. + """ + + grid0 = Grid(shape=(4, 4), dtype=np.float32) + grid1 = Grid(shape=(4, 4), dtype=np.float64) + + assert hash(grid0) != hash(grid1) def test_special_symbols(self): """ From 232ddf7f910ed905b81750e92b867e8ea0cde82d Mon Sep 17 00:00:00 2001 From: mloubout Date: Tue, 13 Aug 2024 10:43:14 -0400 Subject: [PATCH 30/58] api: allow side for cross derivatives, fixes #2442 --- devito/finite_differences/differentiable.py | 1 + devito/finite_differences/finite_difference.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index 922acf0f84..b8fd9055f5 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -19,6 +19,7 @@ from devito.tools import (as_tuple, filter_ordered, flatten, frozendict, infer_dtype, is_integer, split) from devito.types import Array, DimensionTuple, Evaluable, StencilDimension +from devito.types.basic import AbstractFunction __all__ = ['Differentiable', 'DiffDerivative', 'IndexDerivative', 'EvalDerivative', 'Weights'] diff --git a/devito/finite_differences/finite_difference.py b/devito/finite_differences/finite_difference.py index c534232093..09b380b7c0 100644 --- a/devito/finite_differences/finite_difference.py +++ b/devito/finite_differences/finite_difference.py @@ -143,6 +143,11 @@ def first_derivative(expr, dim, fd_order, **kwargs): return generic_derivative(expr, dim, fd_order, 1, **kwargs) +# Backward compatibility +def first_derivative(expr, dim, fd_order, **kwargs): + return generic_derivative(expr, dim, fd_order, 1, **kwargs) + + def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coefficients, expand, weights=None): # Always expand time derivatives to avoid issue with buffering and streaming. From 3008fd8fdeace2244d0a9dbebf14871bb09959ab Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 15 Jan 2025 16:04:01 -0500 Subject: [PATCH 31/58] compiler: process dtypes through printer --- .../finite_differences/finite_difference.py | 5 -- devito/ir/iet/visitors.py | 63 ++++++++++-------- devito/operator/operator.py | 20 +++--- devito/passes/iet/dtypes.py | 32 +-------- devito/passes/iet/languages/C.py | 29 ++++---- devito/passes/iet/languages/CXX.py | 26 ++++---- devito/passes/iet/languages/openacc.py | 7 +- devito/passes/iet/languages/targets.py | 9 ++- devito/symbolics/extended_dtypes.py | 37 ++--------- devito/symbolics/extended_sympy.py | 15 ++++- devito/symbolics/printer.py | 27 ++++++-- devito/types/basic.py | 4 +- tests/test_dtypes.py | 66 ++----------------- 13 files changed, 130 insertions(+), 210 deletions(-) diff --git a/devito/finite_differences/finite_difference.py b/devito/finite_differences/finite_difference.py index 09b380b7c0..c534232093 100644 --- a/devito/finite_differences/finite_difference.py +++ b/devito/finite_differences/finite_difference.py @@ -143,11 +143,6 @@ def first_derivative(expr, dim, fd_order, **kwargs): return generic_derivative(expr, dim, fd_order, 1, **kwargs) -# Backward compatibility -def first_derivative(expr, dim, fd_order, **kwargs): - return generic_derivative(expr, dim, fd_order, 1, **kwargs) - - def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coefficients, expand, weights=None): # Always expand time derivatives to avoid issue with buffering and streaming. diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 0432a6e77a..622b1ed95b 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -19,9 +19,10 @@ Call, Lambda, BlankLine, Section, ListMajor) from devito.ir.support.space import Backward from devito.symbolics import (FieldFromComposite, FieldFromPointer, - ListInitializer, ccode, uxreplace) + ListInitializer, uxreplace) +from devito.symbolics.printer import _DevitoPrinterBase from devito.symbolics.extended_dtypes import NoDeclStruct -from devito.tools import (GenericVisitor, as_tuple, ctypes_to_cstr, filter_ordered, +from devito.tools import (GenericVisitor, as_tuple, filter_ordered, filter_sorted, flatten, is_external_ctype, c_restrict_void_p, sorted_priority) from devito.types.basic import AbstractFunction, AbstractSymbol, Basic @@ -176,9 +177,10 @@ class CGen(Visitor): Return a representation of the Iteration/Expression tree as a :module:`cgen` tree. """ - def __init__(self, *args, compiler=None, **kwargs): + def __init__(self, *args, compiler=None, printer=None, **kwargs): super().__init__(*args, **kwargs) self._compiler = compiler or configuration['compiler'] + self._printer = printer or _DevitoPrinterBase # The following mappers may be customized by subclasses (that is, # backend-specific CGen-erators) @@ -194,6 +196,9 @@ def __init__(self, *args, compiler=None, **kwargs): def compiler(self): return self._compiler + def ccode(self, expr, **settings): + return self._printer(settings=settings).doprint(expr, None) + def visit(self, o, *args, **kwargs): # Make sure the visitor always is within the generating compiler # in case the configuration is accessed @@ -233,7 +238,7 @@ def _gen_struct_decl(self, obj, masked=()): try: entries.append(self._gen_value(i, 0, masked=('const',))) except AttributeError: - cstr = ctypes_to_cstr(ct) + cstr = self.ccode(ct) if ct is c_restrict_void_p: cstr = '%srestrict' % cstr entries.append(c.Value(cstr, n)) @@ -255,10 +260,10 @@ def _gen_value(self, obj, mode=1, masked=()): if getattr(obj.function, k, False) and v not in masked] if (obj._mem_stack or obj._mem_constant) and mode == 1: - strtype = obj._C_typedata - strshape = ''.join('[%s]' % ccode(i) for i in obj.symbolic_shape) + strtype = self.ccode(obj._C_typedata) + strshape = ''.join('[%s]' % self.ccode(i) for i in obj.symbolic_shape) else: - strtype = ctypes_to_cstr(obj._C_ctype) + strtype = self.ccode(obj._C_ctype) strshape = '' if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1: if not obj._mem_stack: @@ -272,7 +277,7 @@ def _gen_value(self, obj, mode=1, masked=()): strobj = '%s%s' % (strname, strshape) if obj.is_LocalObject and obj.cargs and mode == 1: - arguments = [ccode(i) for i in obj.cargs] + arguments = [self.ccode(i) for i in obj.cargs] strobj = MultilineCall(strobj, arguments, True) value = c.Value(strtype, strobj) @@ -286,9 +291,9 @@ def _gen_value(self, obj, mode=1, masked=()): if obj.is_Array and obj.initvalue is not None and mode == 1: init = ListInitializer(obj.initvalue) if not obj._mem_constant or init.is_numeric: - value = c.Initializer(value, ccode(init)) + value = c.Initializer(value, self.ccode(init)) elif obj.is_LocalObject and obj.initvalue is not None and mode == 1: - value = c.Initializer(value, ccode(obj.initvalue)) + value = c.Initializer(value, self.ccode(obj.initvalue)) return value @@ -322,7 +327,7 @@ def _args_call(self, args): else: ret.append(i._C_name) except AttributeError: - ret.append(ccode(i)) + ret.append(self.ccode(i)) return ret def _gen_signature(self, o, is_declaration=False): @@ -388,7 +393,7 @@ def visit_tuple(self, o): def visit_PointerCast(self, o): f = o.function i = f.indexed - cstr = i._C_typedata + cstr = self.ccode(i._C_typedata) if f.is_PointerArray: # lvalue @@ -410,7 +415,7 @@ def visit_PointerCast(self, o): else: v = f.name if o.flat is None: - shape = ''.join("[%s]" % ccode(i) for i in o.castshape) + shape = ''.join("[%s]" % self.ccode(i) for i in o.castshape) rshape = '(*)%s' % shape lvalue = c.Value(cstr, '(*restrict %s)%s' % (v, shape)) else: @@ -443,9 +448,9 @@ def visit_Dereference(self, o): a0, a1 = o.functions if a1.is_PointerArray or a1.is_TempFunction: i = a1.indexed - cstr = i._C_typedata + cstr = self.ccode(i._C_typedata) if o.flat is None: - shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:]) + shape = ''.join("[%s]" % self.ccode(i) for i in a0.symbolic_shape[1:]) rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name, a1.dim.name) lvalue = c.Value(cstr, '(*restrict %s)%s' % (a0.name, shape)) @@ -484,8 +489,8 @@ def visit_Definition(self, o): return self._gen_value(o.function) def visit_Expression(self, o): - lhs = ccode(o.expr.lhs, dtype=o.dtype, compiler=self._compiler) - rhs = ccode(o.expr.rhs, dtype=o.dtype, compiler=self._compiler) + lhs = self.ccode(o.expr.lhs, dtype=o.dtype, compiler=self._compiler) + rhs = self.ccode(o.expr.rhs, dtype=o.dtype, compiler=self._compiler) if o.init: code = c.Initializer(self._gen_value(o.expr.lhs, 0), rhs) @@ -498,8 +503,8 @@ def visit_Expression(self, o): return code def visit_AugmentedExpression(self, o): - c_lhs = ccode(o.expr.lhs, dtype=o.dtype, compiler=self._compiler) - c_rhs = ccode(o.expr.rhs, dtype=o.dtype, compiler=self._compiler) + c_lhs = self.ccode(o.expr.lhs, dtype=o.dtype, compiler=self._compiler) + c_rhs = self.ccode(o.expr.rhs, dtype=o.dtype, compiler=self._compiler) code = c.Statement("%s %s= %s" % (c_lhs, o.op, c_rhs)) if o.pragmas: code = c.Module(self._visit(o.pragmas) + (code,)) @@ -518,7 +523,7 @@ def visit_Call(self, o, nested_call=False): o.templates) if retobj.is_Indexed or \ isinstance(retobj, (FieldFromComposite, FieldFromPointer)): - return c.Assign(ccode(retobj), call) + return c.Assign(self.ccode(retobj), call) else: return c.Initializer(c.Value(rettype, retobj._C_name), call) @@ -532,9 +537,9 @@ def visit_Conditional(self, o): then_body = c.Block(self._visit(then_body)) if else_body: else_body = c.Block(self._visit(else_body)) - return c.If(ccode(o.condition), then_body, else_body) + return c.If(self.ccode(o.condition), then_body, else_body) else: - return c.If(ccode(o.condition), then_body) + return c.If(self.ccode(o.condition), then_body) def visit_Iteration(self, o): body = flatten(self._visit(i) for i in self._blankline_logic(o.children)) @@ -544,23 +549,23 @@ def visit_Iteration(self, o): # For backward direction flip loop bounds if o.direction == Backward: - loop_init = 'int %s = %s' % (o.index, ccode(_max)) - loop_cond = '%s >= %s' % (o.index, ccode(_min)) + loop_init = 'int %s = %s' % (o.index, self.ccode(_max)) + loop_cond = '%s >= %s' % (o.index, self.ccode(_min)) loop_inc = '%s -= %s' % (o.index, o.limits[2]) else: - loop_init = 'int %s = %s' % (o.index, ccode(_min)) - loop_cond = '%s <= %s' % (o.index, ccode(_max)) + loop_init = 'int %s = %s' % (o.index, self.ccode(_min)) + loop_cond = '%s <= %s' % (o.index, self.ccode(_max)) loop_inc = '%s += %s' % (o.index, o.limits[2]) # Append unbounded indices, if any if o.uindices: - uinit = ['%s = %s' % (i.name, ccode(i.symbolic_min)) for i in o.uindices] + uinit = ['%s = %s' % (i.name, self.ccode(i.symbolic_min)) for i in o.uindices] loop_init = c.Line(', '.join([loop_init] + uinit)) ustep = [] for i in o.uindices: op = '=' if i.is_Modulo else '+=' - ustep.append('%s %s %s' % (i.name, op, ccode(i.symbolic_incr))) + ustep.append('%s %s %s' % (i.name, op, self.ccode(i.symbolic_incr))) loop_inc = c.Line(', '.join([loop_inc] + ustep)) # Create For header+body @@ -577,7 +582,7 @@ def visit_Pragma(self, o): return c.Pragma(o._generate) def visit_While(self, o): - condition = ccode(o.condition) + condition = self.ccode(o.condition) if o.body: body = flatten(self._visit(i) for i in o.children) return c.While(condition, c.Block(body)) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 5fbf38c05e..3023461827 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -27,12 +27,10 @@ from devito.passes import (Graph, lower_index_derivatives, generate_implicit, generate_macros, minimize_symbols, unevaluate, error_mapper, is_on_device) -from devito.passes.iet.langbase import LangBB from devito.symbolics import estimate_cost, subs_op_args from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_mapper, as_tuple, flatten, filter_sorted, frozendict, is_integer, split, timed_pass, timed_region, contains_val) -from devito.tools.dtypes_lowering import ctypes_vector_mapper from devito.types import (Buffer, Grid, Evaluable, host_layer, device_layer, disk_layer) from devito.types.dimension import Thickness @@ -272,9 +270,6 @@ def _lower(cls, expressions, **kwargs): # expression for which a partial or complete lowering is desired kwargs['rcompile'] = cls._rcompile_wrapper(**kwargs) - # Load language-specific types into the global dtype->ctype mapper - cls._load_dtype_mappings(**kwargs) - # [Eq] -> [LoweredEq] expressions = cls._lower_exprs(expressions, **kwargs) @@ -296,11 +291,6 @@ def _lower(cls, expressions, **kwargs): def _rcompile_wrapper(cls, **kwargs0): raise NotImplementedError - @classmethod - def _load_dtype_mappings(cls, **kwargs): - lang: type[LangBB] = cls._Target.DataManager.lang - ctypes_vector_mapper.update(lang.mapper.get('types', {})) - @classmethod def _initialize_state(cls, **kwargs): return {} @@ -764,13 +754,19 @@ def _soname(self): """A unique name for the shared object resulting from JIT compilation.""" return Signer._digest(self, configuration) + @property + def printer(self): + return self._Target.Printer + @cached_property def ccode(self): try: - return self._ccode_handler(compiler=self._compiler).visit(self) + return self._ccode_handler(compiler=self._compiler, + printer=self.printer).visit(self) except (AttributeError, TypeError): from devito.ir.iet.visitors import CGen - return CGen(compiler=self._compiler).visit(self) + return CGen(compiler=self._compiler, + printer=self.printer).visit(self) def _jit_compile(self): """ diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index 6f698a659d..f1606a73ff 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -2,11 +2,8 @@ import numpy as np from devito.arch.compiler import Compiler -from devito.ir import Callable, Dereference, FindSymbols, Node, SymbolRegistry, Uxreplace +from devito.ir import Callable, FindSymbols, SymbolRegistry from devito.passes.iet.langbase import LangBB -from devito.symbolics.extended_dtypes import Float16P -from devito.tools import as_list -from devito.types.basic import AbstractSymbol, Basic, Symbol __all__ = ['lower_dtypes'] @@ -17,34 +14,9 @@ def lower_dtypes(iet: Callable, lang: type[LangBB], compiler: Compiler, Lowers float16 scalar types to pointers since we can't directly pass their value. Also includes headers for complex arithmetic if needed. """ - + # Complex numbers iet, metadata = _complex_includes(iet, lang, compiler) - # Lower float16 parameters to pointers and dereference - prefix: list[Node] = [] - params_mapper: dict[AbstractSymbol, AbstractSymbol] = {} - body_mapper: dict[AbstractSymbol, Symbol] = {} - - params_set = set(iet.parameters) - s: AbstractSymbol - for s in FindSymbols('abstractsymbols').visit(iet): - if s.dtype != np.float16 or s not in params_set: - continue - - # Replace the parameter with a pointer; replace occurences in the IET - # body with dereferenced symbol (using the original symbol's dtype) - ptr: AbstractSymbol = s._rebuild(dtype=Float16P, is_const=True) - val = Symbol(name=sregistry.make_name(prefix='hf'), dtype=s.dtype, - is_const=s.is_const) - - params_mapper[s], body_mapper[s] = ptr, val - prefix.append(Dereference(val, ptr)) # val = *ptr - - # Apply the replacements - prefix.extend(as_list(Uxreplace(body_mapper).visit(iet.body))) - params: tuple[Basic] = Uxreplace(params_mapper).visit(iet.parameters) - - iet = iet._rebuild(body=prefix, parameters=params) return iet, metadata diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index 069aa10320..7efdaa44ff 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -1,21 +1,11 @@ -import numpy as np - from devito.ir import Call from devito.passes.iet.definitions import DataManager from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.langbase import LangBB -from devito.symbolics.extended_dtypes import (Float16P, c_complex, c_double_complex, - c_half, c_half_p) - - -__all__ = ['CBB', 'CDataManager', 'COrchestrator', 'c_float16', 'c_float16_p'] - +from devito.symbolics.extended_dtypes import c_complex, c_double_complex +from devito.symbolics.printer import _DevitoPrinterBase -c99_complex = type('_Complex float', (c_complex,), {}) -c99_double_complex = type('_Complex double', (c_double_complex,), {}) - -c_float16 = type('_Float16', (c_half,), {}) -c_float16_p = type('_Float16 *', (c_half_p,), {'_type_': c_float16}) +__all__ = ['CBB', 'CDataManager', 'COrchestrator'] class CBB(LangBB): @@ -34,10 +24,6 @@ class CBB(LangBB): Call('memcpy', (i, j, k)), # Complex and float16 'header-complex': 'complex.h', - 'types': {np.complex128: c99_double_complex, - np.complex64: c99_complex, - np.float16: c_float16, - Float16P: c_float16_p} } @@ -47,3 +33,12 @@ class CDataManager(DataManager): class COrchestrator(Orchestrator): lang = CBB + + +class CDevitoPrinter(_DevitoPrinterBase): + + # These cannot go through _print_xxx because they are classes not + # instances + type_mappings = {**_DevitoPrinterBase.type_mappings, + c_complex: 'float _Complex', + c_double_complex: 'double _Complex'} diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index 1174a27f8d..aa9e9118de 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -1,9 +1,9 @@ -import numpy as np +from sympy.printing.cxx import CXX11CodePrinter from devito.ir import Call, UsingNamespace from devito.passes.iet.langbase import LangBB -from devito.passes.iet.languages.C import c_float16, c_float16_p -from devito.symbolics.extended_dtypes import Float16P, c_complex, c_double_complex +from devito.symbolics.printer import _DevitoPrinterBase +from devito.symbolics.extended_dtypes import c_complex, c_double_complex __all__ = ['CXXBB'] @@ -45,10 +45,6 @@ """ -cxx_complex = type('std::complex', (c_complex,), {}) -cxx_double_complex = type('std::complex', (c_double_complex,), {}) - - class CXXBB(LangBB): mapper = { @@ -67,8 +63,16 @@ class CXXBB(LangBB): 'header-complex': 'complex', 'complex-namespace': [UsingNamespace('std::complex_literals')], 'def-complex': std_arith, - "types": {np.complex128: cxx_double_complex, - np.complex64: cxx_complex, - np.float16: c_float16, - Float16P: c_float16_p} } + + +class CXXDevitoPrinter(_DevitoPrinterBase, CXX11CodePrinter): + + _default_settings = {**_DevitoPrinterBase._default_settings, + **CXX11CodePrinter._default_settings} + + # These cannot go through _print_xxx because they are classes not + # instances + type_mappings = {c_complex: 'std::complex', + c_double_complex: 'std::complex', + **CXX11CodePrinter.type_mappings} diff --git a/devito/passes/iet/languages/openacc.py b/devito/passes/iet/languages/openacc.py index bcf5660ac7..1718a5269a 100644 --- a/devito/passes/iet/languages/openacc.py +++ b/devito/passes/iet/languages/openacc.py @@ -9,7 +9,7 @@ from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.parpragma import (PragmaDeviceAwareTransformer, PragmaLangBB, PragmaIteration, PragmaTransfer) -from devito.passes.iet.languages.CXX import CXXBB +from devito.passes.iet.languages.CXX import CXXBB, CXXDevitoPrinter from devito.passes.iet.languages.openmp import OmpRegion, OmpIteration from devito.symbolics import FieldFromPointer, Macro, cast_mapper from devito.tools import filter_ordered, UnboundTuple @@ -263,3 +263,8 @@ def place_devptr(self, iet, **kwargs): class AccOrchestrator(Orchestrator): lang = AccBB + + +class AccDevitoPrinter(CXXDevitoPrinter): + + pass diff --git a/devito/passes/iet/languages/targets.py b/devito/passes/iet/languages/targets.py index 4ac8d94398..66137a53e7 100644 --- a/devito/passes/iet/languages/targets.py +++ b/devito/passes/iet/languages/targets.py @@ -1,9 +1,9 @@ -from devito.passes.iet.languages.C import CDataManager, COrchestrator +from devito.passes.iet.languages.C import CDataManager, COrchestrator, CDevitoPrinter from devito.passes.iet.languages.openmp import (SimdOmpizer, Ompizer, DeviceOmpizer, OmpDataManager, DeviceOmpDataManager, OmpOrchestrator, DeviceOmpOrchestrator) from devito.passes.iet.languages.openacc import (DeviceAccizer, DeviceAccDataManager, - AccOrchestrator) + AccOrchestrator, AccDevitoPrinter) from devito.passes.iet.instrument import instrument __all__ = ['CTarget', 'OmpTarget', 'DeviceOmpTarget', 'DeviceAccTarget'] @@ -13,6 +13,7 @@ class Target: Parizer = None DataManager = None Orchestrator = None + Printer = None @classmethod def lang(cls): @@ -27,21 +28,25 @@ class CTarget(Target): Parizer = SimdOmpizer DataManager = CDataManager Orchestrator = COrchestrator + Printer = CDevitoPrinter class OmpTarget(Target): Parizer = Ompizer DataManager = OmpDataManager Orchestrator = OmpOrchestrator + Printer = CDevitoPrinter class DeviceOmpTarget(Target): Parizer = DeviceOmpizer DataManager = DeviceOmpDataManager Orchestrator = DeviceOmpOrchestrator + Printer = CDevitoPrinter class DeviceAccTarget(Target): Parizer = DeviceAccizer DataManager = DeviceAccDataManager Orchestrator = AccOrchestrator + Printer = AccDevitoPrinter diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index 8c90e48986..e9772ee744 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -3,11 +3,10 @@ from devito.symbolics.extended_sympy import ReservedWord, Cast, CastStar, ValueLimit from devito.tools import (Bunch, float2, float3, float4, double2, double3, double4, # noqa - int2, int3, int4) + int2, int3, int4, ctypes_vector_mapper) -__all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', - 'DOUBLE', 'VOID', 'NoDeclStruct', 'c_complex', 'c_double_complex', - 'c_half', 'c_half_p', 'Float16P'] +__all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', # noqa + 'DOUBLE', 'VOID', 'NoDeclStruct', 'c_complex', 'c_double_complex'] limits_mapper = { @@ -50,34 +49,8 @@ def from_param(cls, val): return cls(val.real, val.imag) -class c_half(ctypes.c_uint16): - """Ctype for non-scalar half floats""" - - @classmethod - def from_param(cls, val): - return cls(np.float16(val).view(np.uint16)) - - -class c_half_p(ctypes.POINTER(c_half)): - """ - Ctype for half scalars; we can't directly pass _Float16 values so - we use a pointer and dereference (see `passes.iet.dtypes`) - """ - - @classmethod - def from_param(cls, val): - arr = np.array(val, dtype=np.float16) - return arr.ctypes.data_as(cls) - - -class Float16P(np.float16): - """ - Dummy dtype for a scalar half value that has been mapped to a pointer. - This is needed because we can't directly pass in the values; we map to - pointers and dereference in the kernel. See `passes.iet.dtypes`. - """ - - pass +ctypes_vector_mapper.update({np.complex64: c_complex, + np.complex128: c_double_complex}) class CustomType(ReservedWord): diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 19fcd83d4e..11a3700772 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -762,8 +762,21 @@ def __new__(cls, base=''): # Some other utility objects Null = Macro('NULL') + # DefFunction, unlike sympy.Function, generates e.g. `sizeof(float)`, not `sizeof(float_)` -SizeOf = lambda *args: DefFunction('sizeof', tuple(args)) +class SizeOf(DefFunction): + + __rargs__ = ('intype',) + + def __new__(cls, intype, **kwargs): + newobj = super().__new__(cls, 'sizeof', arguments=[str(intype)], **kwargs) + newobj.intype = intype + + return newobj + + @property + def arguments(self): + return (self.intype,) def rfunc(func, item, *args): diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index b5765772a7..21d120d377 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -11,19 +11,20 @@ from sympy.core import S from sympy.core.numbers import equal_valued, Float +from sympy.printing.c import C99CodePrinter from sympy.logic.boolalg import BooleanFunction from sympy.printing.precedence import PRECEDENCE_VALUES, precedence -from sympy.printing.c import C99CodePrinter from devito import configuration from devito.arch.compiler import AOMPCompiler from devito.symbolics.inspection import has_integer_args, sympy_dtype from devito.types.basic import AbstractFunction +from devito.tools import ctypes_to_cstr __all__ = ['ccode'] -class CodePrinter(C99CodePrinter): +class _DevitoPrinterBase(C99CodePrinter): """ Decorator for sympy.printing.ccode.CCodePrinter. @@ -47,6 +48,13 @@ def dtype(self): def compiler(self): return self._settings['compiler'] or configuration['compiler'] + def doprint(self, expr, assign_to=None): + """ + The sympy code printer does a lot of extra we do not need as we handle all of + it in the compiler so we directly defaults to `_print` + """ + return self._print(expr) + def single_prec(self, expr=None, with_f=False): no_f = self.compiler._cpp and not with_f if no_f and expr is not None: @@ -72,6 +80,12 @@ def parenthesize(self, item, level, strict=False): return "(%s)" % self._print(item) return super().parenthesize(item, level, strict=strict) + def _print_type(self, expr): + try: + return self.type_mappings[expr] + except KeyError: + return ctypes_to_cstr(expr) + def _print_Function(self, expr): if isinstance(expr, AbstractFunction): return str(expr) @@ -275,7 +289,7 @@ def _print_ImaginaryUnit(self, expr): def _print_Differentiable(self, expr): return "(%s)" % self._print(expr._expr) - _print_EvalDerivative = C99CodePrinter._print_Add + _print_EvalDerivative = _print_Add def _print_CallFromPointer(self, expr): indices = [self._print(i) for i in expr.params] @@ -351,11 +365,12 @@ def _print_Fallback(self, expr): _print_IndexSum = _print_Fallback _print_ReservedWord = _print_Fallback _print_Basic = _print_Fallback + _print_SizeOf = _print_DefFunction # Lifted from SymPy so that we go through our own `_print_math_func` for k in ('exp log sin cos tan ceiling floor').split(): - setattr(CodePrinter, '_print_%s' % k, CodePrinter._print_math_func) + setattr(_DevitoPrinterBase, '_print_%s' % k, _DevitoPrinterBase._print_math_func) # Always parenthesize IntDiv and InlineIf within expressions @@ -379,10 +394,10 @@ def ccode(expr, **settings): The resulting code as a C++ string. If something went south, returns the input ``expr`` itself. """ - return CodePrinter(settings=settings).doprint(expr, None) + return _DevitoPrinterBase(settings=settings).doprint(expr, None) # Sympy 1.11 has introduced a bug in `_print_Add`, so we enforce here # to always use the correct one from our printer if Version(sympy.__version__) >= Version("1.11"): - setattr(sympy.printing.str.StrPrinter, '_print_Add', CodePrinter._print_Add) + setattr(sympy.printing.str.StrPrinter, '_print_Add', _DevitoPrinterBase._print_Add) diff --git a/devito/types/basic.py b/devito/types/basic.py index 11c9e5c535..f77adbb853 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -13,7 +13,7 @@ from devito.data import default_allocator from devito.parameters import configuration -from devito.tools import (Pickable, as_tuple, ctypes_to_cstr, dtype_to_ctype, +from devito.tools import (Pickable, as_tuple, dtype_to_ctype, frozendict, memoized_meth, sympy_mutex, CustomDtype, Reconstructable) from devito.types.args import ArgProvider @@ -95,7 +95,7 @@ def _C_typedata(self): if _type is c_char_p: _type = c_char - return ctypes_to_cstr(_type) + return _type @abc.abstractproperty def _C_ctype(self): diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index a330096b36..73bf1bbfa1 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -4,12 +4,10 @@ import sympy from devito import Constant, Eq, Function, Grid, Operator -from devito.ir.iet.nodes import Dereference from devito.passes.iet.langbase import LangBB from devito.passes.iet.languages.C import CBB from devito.passes.iet.languages.openacc import AccBB from devito.passes.iet.languages.openmp import OmpBB -from devito.symbolics.extended_dtypes import Float16P from devito.tools import ctypes_to_cstr from devito.types.basic import Basic, Scalar, Symbol from devito.types.dense import TimeFunction @@ -52,7 +50,7 @@ def _config_kwargs(platform: str, language: str, compiler: str) -> dict[str, str ] -@pytest.mark.parametrize('dtype', [np.float16, np.complex64, np.complex128]) +@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('kwargs', _configs) def test_dtype_mapping(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None: """ @@ -79,7 +77,7 @@ def test_dtype_mapping(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> N assert _c._C_ctype == lang_types[_c.dtype] -@pytest.mark.parametrize('dtype', [np.float16, np.complex64, np.complex128]) +@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('kwargs', _configs) def test_cse_ctypes(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None: """ @@ -106,39 +104,7 @@ def test_cse_ctypes(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None assert match.group(1) == ctypes_to_cstr(lang_types[dtype]) -def test_half_params() -> None: - """ - Tests float16 input parameters: scalars should be lowered to pointers - and dereferenced; other parameters should keep the original dtype. - """ - - grid = Grid(shape=(5, 5), dtype=np.float16) - x, y = grid.dimensions - - c = Constant(name='c', dtype=np.float16) - u = Function(name='u', grid=grid) - eq = Eq(u, x * x.spacing + c * y * y.spacing) - op = Operator(eq) - - # Check that lowered parameters have the correct dtypes - params: dict[str, Basic] = {p.name: p for p in op.parameters} - _u, _c, _dx, _dy = params['u'], params['c'], params['h_x'], params['h_y'] - - assert _u.dtype == np.float16 - assert _c.dtype == Float16P - assert _dx.dtype == Float16P - assert _dy.dtype == Float16P - - # Ensure the mapped pointer-to-half symbols are dereferenced - derefs: set[Symbol] = {n.pointer for n in op.body.body - if isinstance(n, Dereference)} - assert _c in derefs - assert _dx in derefs - assert _dy in derefs - - -@pytest.mark.parametrize('dtype', [np.float16, np.float32, - np.complex64, np.complex128]) +@pytest.mark.parametrize('dtype', [np.float32, np.complex64, np.complex128]) @pytest.mark.parametrize('kwargs', _configs) def test_complex_headers(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None: np.dtype @@ -192,7 +158,7 @@ def test_imag_unit(dtype: np.complexfloating, kwargs: dict[str, str]) -> None: assert unit_str in str(op) -@pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64, +@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.complex64, np.complex128]) @pytest.mark.parametrize(['sym', 'fun'], [(sympy.exp, np.exp), (sympy.log, np.log), @@ -248,30 +214,6 @@ def test_complex_override(dtype: np.dtype[np.complexfloating]) -> None: assert np.allclose(u.data.T, expected) -def test_half_time_deriv() -> None: - """ - Tests taking the time derivative of a float16 function. - """ - - grid = Grid(shape=(5, 5)) - x, y = grid.dimensions - t = grid.time_dim - - f = TimeFunction(name='f', grid=grid, space_order=2, dtype=np.float16) - g = Function(name='g', grid=grid, dtype=np.float16) - eqns = [Eq(f.forward, t * x * x.spacing + - y * y.spacing), - Eq(g, f.dt)] - op = Operator(eqns) - op.apply(time=10, dt=1.0) - - # Check against expected result - dx = grid.spacing_map[x.spacing] - xx = np.repeat(np.linspace(0, 4, 5, dtype=np.float16)[np.newaxis, :], 5, axis=0) - expected = xx * np.float16(dx) - assert np.allclose(g.data.T, expected) - - @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) def test_complex_time_deriv(dtype: np.dtype[np.complexfloating]) -> None: """ From 844580d11a2cdfca904f9ccc905ff2b27287ae06 Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 15 Jan 2025 21:23:55 -0500 Subject: [PATCH 32/58] symbolics: specialize sizeof --- devito/__init__.py | 2 +- devito/arch/compiler.py | 17 ++++++++++++----- devito/passes/iet/definitions.py | 4 ++-- devito/symbolics/extended_sympy.py | 5 +++-- devito/symbolics/printer.py | 4 +++- 5 files changed, 21 insertions(+), 11 deletions(-) diff --git a/devito/__init__.py b/devito/__init__.py index 52918db94a..e6cf0092bf 100644 --- a/devito/__init__.py +++ b/devito/__init__.py @@ -65,7 +65,7 @@ def reinit_compiler(val): # Setup target platform and compiler configuration.add('platform', 'cpu64', list(platform_registry), callback=lambda i: platform_registry[i]()) -configuration.add('compiler', 'custom', list(compiler_registry), +configuration.add('compiler', 'custom', compiler_registry, callback=lambda i: compiler_registry[i](name=i)) # Setup language for shared-memory parallelism diff --git a/devito/arch/compiler.py b/devito/arch/compiler.py index 32c563c041..6ff0825ad8 100644 --- a/devito/arch/compiler.py +++ b/devito/arch/compiler.py @@ -183,7 +183,10 @@ def __init__(self): _cpp = False def __init__(self, **kwargs): - self._name = kwargs.pop('name', self.__class__.__name__) + _name = kwargs.pop('name', self.__class__.__name__) + if isinstance(_name, Compiler): + _name = _name.name + self._name = _name super().__init__(**kwargs) @@ -988,15 +991,19 @@ class CompilerRegistry(dict): """ def __getitem__(self, key): + if isinstance(key, Compiler): + key = key.name + if key.startswith('gcc-'): i = key.split('-')[1] return partial(GNUCompiler, suffix=i) + return super().__getitem__(key) - def __contains__(self, k): - if isinstance(k, Compiler): - k = k.name - return k in self.keys() or k.startswith('gcc-') + def __contains__(self, key): + if isinstance(key, Compiler): + key = key.name + return key in self.keys() or key.startswith('gcc-') _compiler_registry = { diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index 9ba83cfe31..ecfd807475 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -17,7 +17,7 @@ from devito.passes.iet.engine import iet_pass from devito.passes.iet.langbase import LangBB from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer, - SizeOf, VOID, Keyword, pow_to_mul) + SizeOf, VOID, pow_to_mul) from devito.tools import as_mapper, as_list, as_tuple, filter_sorted, flatten from devito.types import (Array, ComponentAccess, CustomDimension, DeviceMap, DeviceRM, Eq, Symbol) @@ -279,7 +279,7 @@ def _alloc_pointed_array_on_high_bw_mem(self, site, obj, storage): memptr = VOID(Byref(obj._C_symbol), '**') alignment = obj._data_alignment - nbytes = SizeOf(Keyword('%s*' % obj._C_typedata))*obj.dim.symbolic_size + nbytes = SizeOf(obj._C_typedata, stars='*')*obj.dim.symbolic_size alloc0 = self.lang['host-alloc'](memptr, alignment, nbytes) free0 = self.lang['host-free'](obj._C_symbol) diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 11a3700772..ec13355043 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -766,11 +766,12 @@ def __new__(cls, base=''): # DefFunction, unlike sympy.Function, generates e.g. `sizeof(float)`, not `sizeof(float_)` class SizeOf(DefFunction): - __rargs__ = ('intype',) + __rargs__ = ('intype', 'stars') - def __new__(cls, intype, **kwargs): + def __new__(cls, intype, stars=None, **kwargs): newobj = super().__new__(cls, 'sizeof', arguments=[str(intype)], **kwargs) newobj.intype = intype + newobj.stars = stars or '' return newobj diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 21d120d377..a3f391b907 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -353,6 +353,9 @@ def _print_DefFunction(self, expr): template = '' return "%s%s(%s)" % (expr.name, template, ','.join(arguments)) + def _print_SizeOf(self, expr): + return f'sizeof({self._print(expr.intype)}{self._print(expr.stars)})' + _print_MathFunction = _print_DefFunction def _print_Fallback(self, expr): @@ -365,7 +368,6 @@ def _print_Fallback(self, expr): _print_IndexSum = _print_Fallback _print_ReservedWord = _print_Fallback _print_Basic = _print_Fallback - _print_SizeOf = _print_DefFunction # Lifted from SymPy so that we go through our own `_print_math_func` From d111212738ca969a066a468ce765b8fc21aba215 Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 15 Jan 2025 23:20:14 -0500 Subject: [PATCH 33/58] compiler: move dtype pass to top level operator iet pass --- devito/operator/operator.py | 4 ++++ devito/passes/iet/definitions.py | 8 -------- devito/passes/iet/dtypes.py | 14 +++++++------- devito/passes/iet/langbase.py | 8 -------- 4 files changed, 11 insertions(+), 23 deletions(-) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 3023461827..b56ab8a561 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -27,6 +27,7 @@ from devito.passes import (Graph, lower_index_derivatives, generate_implicit, generate_macros, minimize_symbols, unevaluate, error_mapper, is_on_device) +from devito.passes.iet.dtypes import lower_dtypes from devito.symbolics import estimate_cost, subs_op_args from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_mapper, as_tuple, flatten, filter_sorted, frozendict, is_integer, @@ -489,6 +490,9 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): # Extract the necessary macros from the symbolic objects generate_macros(graph, **kwargs) + # Add type specific metadata + lower_dtypes(graph, lang=cls._Target.lang, **kwargs) + # Target-independent optimizations minimize_symbols(graph) diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index ecfd807475..608c9f3662 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -13,7 +13,6 @@ FindNodes, FindSymbols, MapExprStmts, Transformer, make_callable) from devito.passes import is_gpu_create -from devito.passes.iet.dtypes import lower_dtypes from devito.passes.iet.engine import iet_pass from devito.passes.iet.langbase import LangBB from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer, @@ -465,18 +464,12 @@ def place_casts(self, iet, **kwargs): return iet, {} - @iet_pass - def lower_dtypes(self, iet): - iet, metadata = lower_dtypes(iet, self.lang, self.compiler, self.sregistry) - return iet, metadata - def process(self, graph): """ Apply the `place_definitions` and `place_casts` passes. """ self.place_definitions(graph, globs=set()) self.place_casts(graph) - self.lower_dtypes(graph) class DeviceAwareDataManager(DataManager): @@ -618,7 +611,6 @@ def process(self, graph): self.place_devptr(graph) self.place_bundling(graph, writes_input=graph.writes_input) self.place_casts(graph) - self.lower_dtypes(graph) def make_zero_init(obj, rcompile, sregistry): diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index f1606a73ff..f775e20ea4 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -3,25 +3,25 @@ from devito.arch.compiler import Compiler from devito.ir import Callable, FindSymbols, SymbolRegistry +from devito.passes.iet.engine import iet_pass from devito.passes.iet.langbase import LangBB __all__ = ['lower_dtypes'] -def lower_dtypes(iet: Callable, lang: type[LangBB], compiler: Compiler, - sregistry: SymbolRegistry) -> tuple[Callable, dict]: +def lower_dtypes(graph: Callable, lang: type[LangBB] = None, compiler: Compiler = None, + sregistry: SymbolRegistry = None, **kwargs) -> tuple[Callable, dict]: """ Lowers float16 scalar types to pointers since we can't directly pass their value. Also includes headers for complex arithmetic if needed. """ # Complex numbers - iet, metadata = _complex_includes(iet, lang, compiler) - - return iet, metadata + _complex_includes(graph, lang=lang, compiler=compiler) -def _complex_includes(iet: Callable, lang: type[LangBB], - compiler: Compiler) -> tuple[Callable, dict]: +@iet_pass +def _complex_includes(iet: Callable, lang: type[LangBB] = None, + compiler: Compiler = None) -> tuple[Callable, dict]: """ Includes complex arithmetic headers for the given language, if needed. """ diff --git a/devito/passes/iet/langbase.py b/devito/passes/iet/langbase.py index e34aa2dac3..40980286ef 100644 --- a/devito/passes/iet/langbase.py +++ b/devito/passes/iet/langbase.py @@ -203,14 +203,6 @@ def initialize(self, iet, options=None): """ return iet, {} - @iet_pass - def make_langtypes(self, iet): - """ - An `iet_pass` which transforms an IET such that the target language - types are introduced. - """ - return iet, {} - @property def Region(self): return self.lang.Region From 537311f9b360cd47d77fc467ac424ecd346f56d0 Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 16 Jan 2025 09:31:31 -0500 Subject: [PATCH 34/58] symbolics: fix SizeOf rebuild --- devito/operator/operator.py | 2 +- devito/symbolics/extended_sympy.py | 4 ++++ tests/test_dtypes.py | 36 +++++++++++++++++++----------- tests/test_mpi.py | 5 +---- 4 files changed, 29 insertions(+), 18 deletions(-) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index b56ab8a561..2dd31b9c60 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -491,7 +491,7 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): generate_macros(graph, **kwargs) # Add type specific metadata - lower_dtypes(graph, lang=cls._Target.lang, **kwargs) + lower_dtypes(graph, lang=cls._Target.lang(), **kwargs) # Target-independent optimizations minimize_symbols(graph) diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index ec13355043..2908c8f12c 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -779,6 +779,10 @@ def __new__(cls, intype, stars=None, **kwargs): def arguments(self): return (self.intype,) + @property + def args(self): + return super().args[1] + def rfunc(func, item, *args): """ diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 73bf1bbfa1..ebf9e488d0 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -1,14 +1,14 @@ import numpy as np import pytest -import re import sympy from devito import Constant, Eq, Function, Grid, Operator from devito.passes.iet.langbase import LangBB -from devito.passes.iet.languages.C import CBB -from devito.passes.iet.languages.openacc import AccBB +from devito.passes.iet.languages.C import CBB, CDevitoPrinter +from devito.passes.iet.languages.openacc import AccBB, AccDevitoPrinter from devito.passes.iet.languages.openmp import OmpBB -from devito.tools import ctypes_to_cstr +from devito.symbolics.extended_dtypes import ctypes_vector_mapper +from devito.symbolics.printer import _DevitoPrinterBase from devito.types.basic import Basic, Scalar, Symbol from devito.types.dense import TimeFunction @@ -20,6 +20,13 @@ } +_printers: dict[str, type[_DevitoPrinterBase]] = { + 'C': CDevitoPrinter, + 'openmp': CDevitoPrinter, + 'openacc': AccDevitoPrinter +} + + def _get_language(language: str, **_) -> type[LangBB]: """ Gets the language building block type from parametrized kwargs. @@ -28,6 +35,14 @@ def _get_language(language: str, **_) -> type[LangBB]: return _languages[language] +def _get_printer(language: str, **_) -> type[_DevitoPrinterBase]: + """ + Gets the printer building block type from parametrized kwargs. + """ + + return _printers[language] + + def _config_kwargs(platform: str, language: str, compiler: str) -> dict[str, str]: """ Generates kwargs for Operator to test language-specific behavior. @@ -58,9 +73,6 @@ def test_dtype_mapping(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> N strings in generated code. """ - # Retrieve the language-specific type mapping - lang_types: dict[np.dtype, type] = _get_language(**kwargs).get('types') - # Set up an operator grid = Grid(shape=(3, 3)) x, y = grid.dimensions @@ -73,8 +85,8 @@ def test_dtype_mapping(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> N # Check ctypes of the mapped parameters params: dict[str, Basic] = {p.name: p for p in op.parameters} _u, _c = params['u'], params['c'] - assert _u.indexed._C_ctype._type_ == lang_types[_u.dtype] - assert _c._C_ctype == lang_types[_c.dtype] + assert type(_u.indexed._C_ctype._type_()) == ctypes_vector_mapper[dtype] + assert _c._C_ctype == ctypes_vector_mapper[dtype] @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @@ -86,7 +98,7 @@ def test_cse_ctypes(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None """ # Retrieve the language-specific type mapping - lang_types: dict[np.dtype, type] = _get_language(**kwargs).get('types') + printer: type[_DevitoPrinterBase] = _get_printer(**kwargs) # Set up an operator grid = Grid(shape=(3, 3)) @@ -99,9 +111,7 @@ def test_cse_ctypes(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None op = Operator(eq, **kwargs) # Ensure the CSE'd variable has the correct type - match = re.search(r'[^\S\n\r]*(.*\S)\sr0 = ', str(op)) - assert match is not None - assert match.group(1) == ctypes_to_cstr(lang_types[dtype]) + assert f'{printer()._print(ctypes_vector_mapper[dtype])} r0' in str(op) @pytest.mark.parametrize('dtype', [np.float32, np.complex64, np.complex128]) diff --git a/tests/test_mpi.py b/tests/test_mpi.py index d91b6c9dab..d250c72c40 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -21,10 +21,7 @@ from devito.tools import Bunch from examples.seismic.acoustic import acoustic_setup -try: - from tests.test_dse import TestTTI -except ImportError: - TestTTI = None +from test_dse import TestTTI class TestDistributor: From 3f2b6f8ff72dc0626d62ccfa569d01378ce0c9b5 Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 16 Jan 2025 13:57:35 -0500 Subject: [PATCH 35/58] symbolics: use std namespace for c++ --- devito/passes/iet/languages/C.py | 8 +++ devito/passes/iet/languages/CXX.py | 7 +- devito/symbolics/extended_sympy.py | 8 ++- devito/symbolics/inspection.py | 5 +- devito/symbolics/printer.py | 102 +++++++++++------------------ tests/test_dtypes.py | 12 ++-- tests/test_symbolics.py | 3 +- 7 files changed, 67 insertions(+), 78 deletions(-) diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index 7efdaa44ff..4285a673e1 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -1,3 +1,5 @@ +import numpy as np + from devito.ir import Call from devito.passes.iet.definitions import DataManager from devito.passes.iet.orchestration import Orchestrator @@ -42,3 +44,9 @@ class CDevitoPrinter(_DevitoPrinterBase): type_mappings = {**_DevitoPrinterBase.type_mappings, c_complex: 'float _Complex', c_double_complex: 'double _Complex'} + + _func_prefix = {**_DevitoPrinterBase._func_prefix, np.complex64: 'c', + np.complex128: 'c'} + + def _print_ImaginaryUnit(self, expr): + return '_Complex_I' diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index aa9e9118de..b261c89213 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -70,9 +70,14 @@ class CXXDevitoPrinter(_DevitoPrinterBase, CXX11CodePrinter): _default_settings = {**_DevitoPrinterBase._default_settings, **CXX11CodePrinter._default_settings} + _ns = "std::" # These cannot go through _print_xxx because they are classes not # instances - type_mappings = {c_complex: 'std::complex', + type_mappings = {**_DevitoPrinterBase.type_mappings, + c_complex: 'std::complex', c_double_complex: 'std::complex', **CXX11CodePrinter.type_mappings} + + def _print_ImaginaryUnit(self, expr): + return f'1i{self.prec_literal(expr).lower()}' diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 2908c8f12c..127995d3c0 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -769,15 +769,17 @@ class SizeOf(DefFunction): __rargs__ = ('intype', 'stars') def __new__(cls, intype, stars=None, **kwargs): - newobj = super().__new__(cls, 'sizeof', arguments=[str(intype)], **kwargs) + stars = stars or '' + argument = Keyword(f'{intype}{stars}') + newobj = super().__new__(cls, 'sizeof', arguments=(argument,), **kwargs) newobj.intype = intype - newobj.stars = stars or '' + newobj.stars = stars return newobj @property def arguments(self): - return (self.intype,) + return self.args @property def args(self): diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 165a3209be..4831673554 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -312,6 +312,9 @@ def sympy_dtype(expr, base=None): # Promote if we missed complex number, i.e f + I is_im = np.issubdtype(dtype, np.complexfloating) if expr.has(ImaginaryUnit) and not is_im: - dtype = np.promote_types(dtype, np.complex64).type + if dtype is None: + dtype = np.complex64 + else: + dtype = np.promote_types(dtype, np.complex64).type return dtype diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index a3f391b907..51a636c8cd 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -24,6 +24,10 @@ __all__ = ['ccode'] +_prec_litterals = {np.float16: 'F16', np.float32: 'F', np.complex64: 'F'} +_func_litterals = {np.float32: 'f', np.complex64: 'f', Real: 'f'} + + class _DevitoPrinterBase(C99CodePrinter): """ @@ -37,6 +41,8 @@ class _DevitoPrinterBase(C99CodePrinter): _default_settings = {'compiler': None, 'dtype': np.float32, **C99CodePrinter._default_settings} + _func_prefix = {np.float32: 'f', np.float64: 'f'} + @property def dtype(self): try: @@ -55,25 +61,27 @@ def doprint(self, expr, assign_to=None): """ return self._print(expr) - def single_prec(self, expr=None, with_f=False): - no_f = self.compiler._cpp and not with_f - if no_f and expr is not None: - return False - dtype = sympy_dtype(expr) if expr is not None else self.dtype - return any(issubclass(dtype, d) for d in [np.float32, np.complex64]) - - def half_prec(self, expr=None, with_f=False): - no_f = self.compiler._cpp and not with_f - if no_f and expr is not None: - return False - dtype = sympy_dtype(expr) if expr is not None else self.dtype - return issubclass(dtype, np.float16) - - def complex_prec(self, expr=None): - if self.compiler._cpp: - return False - dtype = sympy_dtype(expr) if expr is not None else self.dtype - return np.issubdtype(dtype, np.complexfloating) + def _prec(self, expr): + dtype = sympy_dtype(expr) if expr is not None else None + if dtype is None or np.issubdtype(dtype, np.integer): + real = any(isinstance(i, Float) for i in expr.atoms()) + stype = self.dtype if real else np.int32 + return np.result_type(dtype or stype, stype).type + else: + return dtype or self.dtype + + def prec_literal(self, expr): + return _prec_litterals.get(self._prec(expr), '') + + def func_literal(self, expr): + return _func_litterals.get(self._prec(expr), '') + + def func_prefix(self, expr, abs=False): + prefix = self._func_prefix.get(self._prec(expr), '') + if abs: + return prefix + else: + return '' if prefix == 'f' else prefix def parenthesize(self, item, level, strict=False): if isinstance(item, BooleanFunction): @@ -147,10 +155,7 @@ def _print_math_func(self, expr, nest=False, known=None): if cname not in self._prec_funcs: return super()._print_math_func(expr, nest=nest, known=known) - if self.single_prec(expr) or self.half_prec(expr): - cname = '%sf' % cname - if self.complex_prec(expr): - cname = 'c%s' % cname + cname = f'{self.func_prefix(expr)}{cname}{self.func_literal(expr)}' if nest and len(expr.args) > 2: args = ', '.join([self._print(expr.args[0]), @@ -158,7 +163,7 @@ def _print_math_func(self, expr, nest=False, known=None): else: args = ', '.join([self._print(arg) for arg in expr.args]) - return f'{cname}({args})' + return f'{self._ns}{cname}({args})' def _print_Pow(self, expr): # Completely reimplement `_print_Pow` from sympy, since it doesn't @@ -166,16 +171,17 @@ def _print_Pow(self, expr): if "Pow" in self.known_functions: return self._print_Function(expr) PREC = precedence(expr) - suffix = 'f' if self.single_prec(expr) else '' + suffix = self.func_literal(expr) + base = self._print(expr.base) if equal_valued(expr.exp, -1): return self._print_Float(Float(1.0)) + '/' + \ self.parenthesize(expr.base, PREC) elif equal_valued(expr.exp, 0.5): - return f'sqrt{suffix}({self._print(expr.base)})' + return f'{self._ns}sqrt{suffix}({base})' elif expr.exp == S.One/3 and self.standard != 'C89': - return f'cbrt{suffix}({self._print(expr.base)})' + return f'{self._ns}cbrt{suffix}({base})' else: - return f'pow{suffix}({self._print(expr.base)}, {self._print(expr.exp)})' + return f'{self._ns}pow{suffix}({base}, {self._print(expr.exp)})' def _print_SafeInv(self, expr): """Print a SafeInv as a C-like division with a check for zero.""" @@ -209,18 +215,8 @@ def _print_Abs(self, expr): # AOMPCC errors with abs, always use fabs if isinstance(self.compiler, AOMPCompiler): return "fabs(%s)" % self._print(expr.args[0]) - # Check if argument is an integer - if has_integer_args(*expr.args[0].args): - func = "abs" - elif self.single_prec(expr): - func = "fabsf" - elif any([isinstance(a, Real) for a in expr.args[0].args]): - # The previous condition isn't sufficient to detect case with - # Python `float`s in that case, fall back to the "default" - func = "fabsf" if self.single_prec() else "fabs" - else: - func = "fabs" - return f"{func}({self._print(expr.args[0])})" + func = f'{self.func_prefix(expr, abs=True)}abs{self.func_literal(expr)}' + return f"{self._ns}{func}({self._print(expr.args[0])})" def _print_Add(self, expr, order=None): """" @@ -270,21 +266,7 @@ def _print_Float(self, expr): if 'e' not in rv: rv = rv.rstrip('0') + "0" - if self.single_prec(): - rv = '%sF' % rv - elif self.half_prec(): - rv = '%sF16' % rv - - return rv - - def _print_ImaginaryUnit(self, expr): - if self.compiler._cpp: - if self.single_prec(with_f=True) or self.half_prec(with_f=True): - return '1if' - else: - return '1i' - else: - return '_Complex_I' + return f'{rv}{self.prec_literal(expr)}' def _print_Differentiable(self, expr): return "(%s)" % self._print(expr._expr) @@ -335,16 +317,6 @@ def _print_UnaryOp(self, expr): def _print_ComponentAccess(self, expr): return "%s.%s" % (self._print(expr.base), expr.sindex) - def _print_TrigonometricFunction(self, expr): - func_name = str(expr.func) - - if self.single_prec() or self.half_prec(): - func_name = '%sf' % func_name - if self.complex_prec(): - func_name = 'c%s' % func_name - - return '%s(%s)' % (func_name, self._print(*expr.args)) - def _print_DefFunction(self, expr): arguments = [self._print(i) for i in expr.arguments] if expr.template: diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index ebf9e488d0..ffc67a4d2b 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -43,7 +43,7 @@ def _get_printer(language: str, **_) -> type[_DevitoPrinterBase]: return _printers[language] -def _config_kwargs(platform: str, language: str, compiler: str) -> dict[str, str]: +def _config_kwargs(platform: str, language: str) -> dict[str, str]: """ Generates kwargs for Operator to test language-specific behavior. """ @@ -51,16 +51,15 @@ def _config_kwargs(platform: str, language: str, compiler: str) -> dict[str, str return { 'platform': platform, 'language': language, - 'compiler': compiler } # List of pararmetrized operator kwargs for testing language-specific behavior _configs: list[dict[str, str]] = [ _config_kwargs(*cfg) for cfg in [ - ('cpu64', 'C', 'gcc'), - ('cpu64', 'openmp', 'gcc'), - ('nvidiaX', 'openacc', 'nvc') + ('cpu64', 'C'), + ('cpu64', 'openmp'), + ('nvidiaX', 'openacc') ] ] @@ -146,10 +145,9 @@ def test_imag_unit(dtype: np.complexfloating, kwargs: dict[str, str]) -> None: """ Tests that the correct literal is used for the imaginary unit. """ - # Determine the expected imaginary unit string unit_str: str - if kwargs['compiler'] == 'gcc': + if kwargs['platform'] == 'cpu64': # In C we multiply by the _Complex_I macro constant unit_str = '_Complex_I' else: diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 2bae5679c8..b647df9688 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -299,7 +299,8 @@ def test_extended_sympy_arithmetic(): def test_integer_abs(): i1 = Dimension(name="i1") assert ccode(Abs(i1 - 1)) == "abs(i1 - 1)" - assert ccode(Abs(i1 - .5)) == "fabsf(i1 - 5.0e-1F)" + # .5 is a standard python Float, i.e an np.float64 + assert ccode(Abs(i1 - .5)) == "fabs(i1 - 5.0e-1)" assert ccode( Abs(i1 - Constant('half', dtype=np.float64, default_value=0.5)) ) == "fabs(i1 - half)" From 31b9141b15eae9694f64eccf073216a01b6f5307 Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 16 Jan 2025 17:07:41 -0500 Subject: [PATCH 36/58] compiler: fix std math func names --- devito/passes/iet/languages/CXX.py | 3 ++- devito/symbolics/inspection.py | 7 +++++-- devito/symbolics/printer.py | 6 +++--- tests/test_gpu_common.py | 2 +- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index b261c89213..a6e1715e33 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -71,12 +71,13 @@ class CXXDevitoPrinter(_DevitoPrinterBase, CXX11CodePrinter): _default_settings = {**_DevitoPrinterBase._default_settings, **CXX11CodePrinter._default_settings} _ns = "std::" + _func_litterals = {} # These cannot go through _print_xxx because they are classes not # instances type_mappings = {**_DevitoPrinterBase.type_mappings, c_complex: 'std::complex', - c_double_complex: 'std::complex', + c_double_complex: 'std::complex', **CXX11CodePrinter.type_mappings} def _print_ImaginaryUnit(self, expr): diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 4831673554..20ec163e74 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -296,10 +296,13 @@ def has_integer_args(*args): return res -def sympy_dtype(expr, base=None): +def sympy_dtype(expr, base=None, default=None): """ Infer the dtype of the expression. """ + if expr is None: + return default + dtypes = {base} - {None} for i in expr.free_symbols: try: @@ -313,7 +316,7 @@ def sympy_dtype(expr, base=None): is_im = np.issubdtype(dtype, np.complexfloating) if expr.has(ImaginaryUnit) and not is_im: if dtype is None: - dtype = np.complex64 + dtype = default or np.complex64 else: dtype = np.promote_types(dtype, np.complex64).type diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 51a636c8cd..f082edd43a 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -25,7 +25,6 @@ _prec_litterals = {np.float16: 'F16', np.float32: 'F', np.complex64: 'F'} -_func_litterals = {np.float32: 'f', np.complex64: 'f', Real: 'f'} class _DevitoPrinterBase(C99CodePrinter): @@ -42,6 +41,7 @@ class _DevitoPrinterBase(C99CodePrinter): **C99CodePrinter._default_settings} _func_prefix = {np.float32: 'f', np.float64: 'f'} + _func_litterals = {np.float32: 'f', np.complex64: 'f', Real: 'f'} @property def dtype(self): @@ -62,7 +62,7 @@ def doprint(self, expr, assign_to=None): return self._print(expr) def _prec(self, expr): - dtype = sympy_dtype(expr) if expr is not None else None + dtype = sympy_dtype(expr, default=self.dtype) if dtype is None or np.issubdtype(dtype, np.integer): real = any(isinstance(i, Float) for i in expr.atoms()) stype = self.dtype if real else np.int32 @@ -74,7 +74,7 @@ def prec_literal(self, expr): return _prec_litterals.get(self._prec(expr), '') def func_literal(self, expr): - return _func_litterals.get(self._prec(expr), '') + return self._func_litterals.get(self._prec(expr), '') def func_prefix(self, expr, abs=False): prefix = self._func_prefix.get(self._prec(expr), '') diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index e2321036e0..13239687bc 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -94,7 +94,7 @@ def test_complex(self, dtype): xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5)) npres = xx + 1j*yy + np.exp(1j + dx) * (1.0 + 2.0j) - assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0) + assert np.allclose(u.data, npres.T, rtol=5e-7, atol=0) class TestPassesOptional: From 4412323988bc6c680f2f6df1f08d5fe56506214a Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 16 Jan 2025 21:48:23 -0500 Subject: [PATCH 37/58] symbolics: move printers rogether through registry --- devito/ir/iet/visitors.py | 72 +++++++++------------ devito/operator/operator.py | 19 +++--- devito/passes/iet/definitions.py | 4 +- devito/passes/iet/dtypes.py | 24 ++++--- devito/passes/iet/languages/C.py | 19 ------ devito/passes/iet/languages/CXX.py | 22 ------- devito/passes/iet/languages/openacc.py | 7 +-- devito/passes/iet/languages/targets.py | 9 +-- devito/symbolics/printer.py | 87 +++++++++++++++++++++----- tests/test_dtypes.py | 15 ++--- tests/test_symbolics.py | 5 +- 11 files changed, 129 insertions(+), 154 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 622b1ed95b..7ab38b5191 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -13,14 +13,13 @@ from sympy import IndexedBase from sympy.core.function import Application -from devito.parameters import configuration, switchconfig from devito.exceptions import CompilationError from devito.ir.iet.nodes import (Node, Iteration, Expression, ExpressionBundle, Call, Lambda, BlankLine, Section, ListMajor) from devito.ir.support.space import Backward from devito.symbolics import (FieldFromComposite, FieldFromPointer, ListInitializer, uxreplace) -from devito.symbolics.printer import _DevitoPrinterBase +from devito.symbolics.printer import ccode from devito.symbolics.extended_dtypes import NoDeclStruct from devito.tools import (GenericVisitor, as_tuple, filter_ordered, filter_sorted, flatten, is_external_ctype, @@ -177,10 +176,8 @@ class CGen(Visitor): Return a representation of the Iteration/Expression tree as a :module:`cgen` tree. """ - def __init__(self, *args, compiler=None, printer=None, **kwargs): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._compiler = compiler or configuration['compiler'] - self._printer = printer or _DevitoPrinterBase # The following mappers may be customized by subclasses (that is, # backend-specific CGen-erators) @@ -192,19 +189,6 @@ def __init__(self, *args, compiler=None, printer=None, **kwargs): } _restrict_keyword = 'restrict' - @property - def compiler(self): - return self._compiler - - def ccode(self, expr, **settings): - return self._printer(settings=settings).doprint(expr, None) - - def visit(self, o, *args, **kwargs): - # Make sure the visitor always is within the generating compiler - # in case the configuration is accessed - with switchconfig(compiler=self.compiler.name): - return super().visit(o, *args, **kwargs) - def _gen_struct_decl(self, obj, masked=()): """ Convert ctypes.Struct -> cgen.Structure. @@ -238,7 +222,7 @@ def _gen_struct_decl(self, obj, masked=()): try: entries.append(self._gen_value(i, 0, masked=('const',))) except AttributeError: - cstr = self.ccode(ct) + cstr = ccode(ct) if ct is c_restrict_void_p: cstr = '%srestrict' % cstr entries.append(c.Value(cstr, n)) @@ -260,10 +244,10 @@ def _gen_value(self, obj, mode=1, masked=()): if getattr(obj.function, k, False) and v not in masked] if (obj._mem_stack or obj._mem_constant) and mode == 1: - strtype = self.ccode(obj._C_typedata) - strshape = ''.join('[%s]' % self.ccode(i) for i in obj.symbolic_shape) + strtype = ccode(obj._C_typedata) + strshape = ''.join('[%s]' % ccode(i) for i in obj.symbolic_shape) else: - strtype = self.ccode(obj._C_ctype) + strtype = ccode(obj._C_ctype) strshape = '' if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1: if not obj._mem_stack: @@ -277,7 +261,7 @@ def _gen_value(self, obj, mode=1, masked=()): strobj = '%s%s' % (strname, strshape) if obj.is_LocalObject and obj.cargs and mode == 1: - arguments = [self.ccode(i) for i in obj.cargs] + arguments = [ccode(i) for i in obj.cargs] strobj = MultilineCall(strobj, arguments, True) value = c.Value(strtype, strobj) @@ -291,9 +275,9 @@ def _gen_value(self, obj, mode=1, masked=()): if obj.is_Array and obj.initvalue is not None and mode == 1: init = ListInitializer(obj.initvalue) if not obj._mem_constant or init.is_numeric: - value = c.Initializer(value, self.ccode(init)) + value = c.Initializer(value, ccode(init)) elif obj.is_LocalObject and obj.initvalue is not None and mode == 1: - value = c.Initializer(value, self.ccode(obj.initvalue)) + value = c.Initializer(value, ccode(obj.initvalue)) return value @@ -327,7 +311,7 @@ def _args_call(self, args): else: ret.append(i._C_name) except AttributeError: - ret.append(self.ccode(i)) + ret.append(ccode(i)) return ret def _gen_signature(self, o, is_declaration=False): @@ -393,7 +377,7 @@ def visit_tuple(self, o): def visit_PointerCast(self, o): f = o.function i = f.indexed - cstr = self.ccode(i._C_typedata) + cstr = ccode(i._C_typedata) if f.is_PointerArray: # lvalue @@ -415,7 +399,7 @@ def visit_PointerCast(self, o): else: v = f.name if o.flat is None: - shape = ''.join("[%s]" % self.ccode(i) for i in o.castshape) + shape = ''.join("[%s]" % ccode(i) for i in o.castshape) rshape = '(*)%s' % shape lvalue = c.Value(cstr, '(*restrict %s)%s' % (v, shape)) else: @@ -448,9 +432,9 @@ def visit_Dereference(self, o): a0, a1 = o.functions if a1.is_PointerArray or a1.is_TempFunction: i = a1.indexed - cstr = self.ccode(i._C_typedata) + cstr = ccode(i._C_typedata) if o.flat is None: - shape = ''.join("[%s]" % self.ccode(i) for i in a0.symbolic_shape[1:]) + shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:]) rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name, a1.dim.name) lvalue = c.Value(cstr, '(*restrict %s)%s' % (a0.name, shape)) @@ -489,8 +473,8 @@ def visit_Definition(self, o): return self._gen_value(o.function) def visit_Expression(self, o): - lhs = self.ccode(o.expr.lhs, dtype=o.dtype, compiler=self._compiler) - rhs = self.ccode(o.expr.rhs, dtype=o.dtype, compiler=self._compiler) + lhs = ccode(o.expr.lhs, dtype=o.dtype) + rhs = ccode(o.expr.rhs, dtype=o.dtype) if o.init: code = c.Initializer(self._gen_value(o.expr.lhs, 0), rhs) @@ -503,8 +487,8 @@ def visit_Expression(self, o): return code def visit_AugmentedExpression(self, o): - c_lhs = self.ccode(o.expr.lhs, dtype=o.dtype, compiler=self._compiler) - c_rhs = self.ccode(o.expr.rhs, dtype=o.dtype, compiler=self._compiler) + c_lhs = ccode(o.expr.lhs, dtype=o.dtype) + c_rhs = ccode(o.expr.rhs, dtype=o.dtype) code = c.Statement("%s %s= %s" % (c_lhs, o.op, c_rhs)) if o.pragmas: code = c.Module(self._visit(o.pragmas) + (code,)) @@ -523,7 +507,7 @@ def visit_Call(self, o, nested_call=False): o.templates) if retobj.is_Indexed or \ isinstance(retobj, (FieldFromComposite, FieldFromPointer)): - return c.Assign(self.ccode(retobj), call) + return c.Assign(ccode(retobj), call) else: return c.Initializer(c.Value(rettype, retobj._C_name), call) @@ -537,9 +521,9 @@ def visit_Conditional(self, o): then_body = c.Block(self._visit(then_body)) if else_body: else_body = c.Block(self._visit(else_body)) - return c.If(self.ccode(o.condition), then_body, else_body) + return c.If(ccode(o.condition), then_body, else_body) else: - return c.If(self.ccode(o.condition), then_body) + return c.If(ccode(o.condition), then_body) def visit_Iteration(self, o): body = flatten(self._visit(i) for i in self._blankline_logic(o.children)) @@ -549,23 +533,23 @@ def visit_Iteration(self, o): # For backward direction flip loop bounds if o.direction == Backward: - loop_init = 'int %s = %s' % (o.index, self.ccode(_max)) - loop_cond = '%s >= %s' % (o.index, self.ccode(_min)) + loop_init = 'int %s = %s' % (o.index, ccode(_max)) + loop_cond = '%s >= %s' % (o.index, ccode(_min)) loop_inc = '%s -= %s' % (o.index, o.limits[2]) else: - loop_init = 'int %s = %s' % (o.index, self.ccode(_min)) - loop_cond = '%s <= %s' % (o.index, self.ccode(_max)) + loop_init = 'int %s = %s' % (o.index, ccode(_min)) + loop_cond = '%s <= %s' % (o.index, ccode(_max)) loop_inc = '%s += %s' % (o.index, o.limits[2]) # Append unbounded indices, if any if o.uindices: - uinit = ['%s = %s' % (i.name, self.ccode(i.symbolic_min)) for i in o.uindices] + uinit = ['%s = %s' % (i.name, ccode(i.symbolic_min)) for i in o.uindices] loop_init = c.Line(', '.join([loop_init] + uinit)) ustep = [] for i in o.uindices: op = '=' if i.is_Modulo else '+=' - ustep.append('%s %s %s' % (i.name, op, self.ccode(i.symbolic_incr))) + ustep.append('%s %s %s' % (i.name, op, ccode(i.symbolic_incr))) loop_inc = c.Line(', '.join([loop_inc] + ustep)) # Create For header+body @@ -582,7 +566,7 @@ def visit_Pragma(self, o): return c.Pragma(o._generate) def visit_While(self, o): - condition = self.ccode(o.condition) + condition = ccode(o.condition) if o.body: body = flatten(self._visit(i) for i in o.children) return c.While(condition, c.Block(body)) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 2dd31b9c60..338ac9faa1 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -23,7 +23,7 @@ from devito.operator.profiling import create_profile from devito.operator.registry import operator_selector from devito.mpi import MPI -from devito.parameters import configuration +from devito.parameters import configuration, switchconfig from devito.passes import (Graph, lower_index_derivatives, generate_implicit, generate_macros, minimize_symbols, unevaluate, error_mapper, is_on_device) @@ -758,19 +758,14 @@ def _soname(self): """A unique name for the shared object resulting from JIT compilation.""" return Signer._digest(self, configuration) - @property - def printer(self): - return self._Target.Printer - @cached_property def ccode(self): - try: - return self._ccode_handler(compiler=self._compiler, - printer=self.printer).visit(self) - except (AttributeError, TypeError): - from devito.ir.iet.visitors import CGen - return CGen(compiler=self._compiler, - printer=self.printer).visit(self) + with switchconfig(compiler=self._compiler, language=self._language): + try: + return self._ccode_handler().visit(self) + except (AttributeError, TypeError): + from devito.ir.iet.visitors import CGen + return CGen().visit(self) def _jit_compile(self): """ diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index 608c9f3662..67f0441ec4 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -75,12 +75,10 @@ class DataManager: The language used to express data allocations, deletions, and host-device transfers. """ - def __init__(self, rcompile=None, sregistry=None, platform=None, - compiler=None, **kwargs): + def __init__(self, rcompile=None, sregistry=None, platform=None, **kwargs): self.rcompile = rcompile self.sregistry = sregistry self.platform = platform - self.compiler = compiler def _alloc_object_on_low_lat_mem(self, site, obj, storage): """ diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index f775e20ea4..7f617b1c41 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -9,16 +9,6 @@ __all__ = ['lower_dtypes'] -def lower_dtypes(graph: Callable, lang: type[LangBB] = None, compiler: Compiler = None, - sregistry: SymbolRegistry = None, **kwargs) -> tuple[Callable, dict]: - """ - Lowers float16 scalar types to pointers since we can't directly pass their - value. Also includes headers for complex arithmetic if needed. - """ - # Complex numbers - _complex_includes(graph, lang=lang, compiler=compiler) - - @iet_pass def _complex_includes(iet: Callable, lang: type[LangBB] = None, compiler: Compiler = None) -> tuple[Callable, dict]: @@ -50,3 +40,17 @@ def _complex_includes(iet: Callable, lang: type[LangBB] = None, metadata['includes'] = lib return iet, metadata + + +dtype_passes = [_complex_includes] + + +def lower_dtypes(graph: Callable, lang: type[LangBB] = None, compiler: Compiler = None, + sregistry: SymbolRegistry = None, **kwargs) -> tuple[Callable, dict]: + """ + Lowers float16 scalar types to pointers since we can't directly pass their + value. Also includes headers for complex arithmetic if needed. + """ + + for dtype_pass in dtype_passes: + dtype_pass(graph, lang=lang, compiler=compiler) diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index 4285a673e1..ff50e54205 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -1,11 +1,7 @@ -import numpy as np - from devito.ir import Call from devito.passes.iet.definitions import DataManager from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.langbase import LangBB -from devito.symbolics.extended_dtypes import c_complex, c_double_complex -from devito.symbolics.printer import _DevitoPrinterBase __all__ = ['CBB', 'CDataManager', 'COrchestrator'] @@ -35,18 +31,3 @@ class CDataManager(DataManager): class COrchestrator(Orchestrator): lang = CBB - - -class CDevitoPrinter(_DevitoPrinterBase): - - # These cannot go through _print_xxx because they are classes not - # instances - type_mappings = {**_DevitoPrinterBase.type_mappings, - c_complex: 'float _Complex', - c_double_complex: 'double _Complex'} - - _func_prefix = {**_DevitoPrinterBase._func_prefix, np.complex64: 'c', - np.complex128: 'c'} - - def _print_ImaginaryUnit(self, expr): - return '_Complex_I' diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index a6e1715e33..17003c0d8f 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -1,9 +1,5 @@ -from sympy.printing.cxx import CXX11CodePrinter - from devito.ir import Call, UsingNamespace from devito.passes.iet.langbase import LangBB -from devito.symbolics.printer import _DevitoPrinterBase -from devito.symbolics.extended_dtypes import c_complex, c_double_complex __all__ = ['CXXBB'] @@ -64,21 +60,3 @@ class CXXBB(LangBB): 'complex-namespace': [UsingNamespace('std::complex_literals')], 'def-complex': std_arith, } - - -class CXXDevitoPrinter(_DevitoPrinterBase, CXX11CodePrinter): - - _default_settings = {**_DevitoPrinterBase._default_settings, - **CXX11CodePrinter._default_settings} - _ns = "std::" - _func_litterals = {} - - # These cannot go through _print_xxx because they are classes not - # instances - type_mappings = {**_DevitoPrinterBase.type_mappings, - c_complex: 'std::complex', - c_double_complex: 'std::complex', - **CXX11CodePrinter.type_mappings} - - def _print_ImaginaryUnit(self, expr): - return f'1i{self.prec_literal(expr).lower()}' diff --git a/devito/passes/iet/languages/openacc.py b/devito/passes/iet/languages/openacc.py index 1718a5269a..bcf5660ac7 100644 --- a/devito/passes/iet/languages/openacc.py +++ b/devito/passes/iet/languages/openacc.py @@ -9,7 +9,7 @@ from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.parpragma import (PragmaDeviceAwareTransformer, PragmaLangBB, PragmaIteration, PragmaTransfer) -from devito.passes.iet.languages.CXX import CXXBB, CXXDevitoPrinter +from devito.passes.iet.languages.CXX import CXXBB from devito.passes.iet.languages.openmp import OmpRegion, OmpIteration from devito.symbolics import FieldFromPointer, Macro, cast_mapper from devito.tools import filter_ordered, UnboundTuple @@ -263,8 +263,3 @@ def place_devptr(self, iet, **kwargs): class AccOrchestrator(Orchestrator): lang = AccBB - - -class AccDevitoPrinter(CXXDevitoPrinter): - - pass diff --git a/devito/passes/iet/languages/targets.py b/devito/passes/iet/languages/targets.py index 66137a53e7..4ac8d94398 100644 --- a/devito/passes/iet/languages/targets.py +++ b/devito/passes/iet/languages/targets.py @@ -1,9 +1,9 @@ -from devito.passes.iet.languages.C import CDataManager, COrchestrator, CDevitoPrinter +from devito.passes.iet.languages.C import CDataManager, COrchestrator from devito.passes.iet.languages.openmp import (SimdOmpizer, Ompizer, DeviceOmpizer, OmpDataManager, DeviceOmpDataManager, OmpOrchestrator, DeviceOmpOrchestrator) from devito.passes.iet.languages.openacc import (DeviceAccizer, DeviceAccDataManager, - AccOrchestrator, AccDevitoPrinter) + AccOrchestrator) from devito.passes.iet.instrument import instrument __all__ = ['CTarget', 'OmpTarget', 'DeviceOmpTarget', 'DeviceAccTarget'] @@ -13,7 +13,6 @@ class Target: Parizer = None DataManager = None Orchestrator = None - Printer = None @classmethod def lang(cls): @@ -28,25 +27,21 @@ class CTarget(Target): Parizer = SimdOmpizer DataManager = CDataManager Orchestrator = COrchestrator - Printer = CDevitoPrinter class OmpTarget(Target): Parizer = Ompizer DataManager = OmpDataManager Orchestrator = OmpOrchestrator - Printer = CDevitoPrinter class DeviceOmpTarget(Target): Parizer = DeviceOmpizer DataManager = DeviceOmpDataManager Orchestrator = DeviceOmpOrchestrator - Printer = CDevitoPrinter class DeviceAccTarget(Target): Parizer = DeviceAccizer DataManager = DeviceAccDataManager Orchestrator = AccOrchestrator - Printer = AccDevitoPrinter diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index f082edd43a..ca0d9e98e3 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -7,17 +7,19 @@ from mpmath.libmp import prec_to_dps, to_str from packaging.version import Version -from numbers import Real from sympy.core import S from sympy.core.numbers import equal_valued, Float +from sympy.printing.codeprinter import CodePrinter from sympy.printing.c import C99CodePrinter +from sympy.printing.cxx import CXX11CodePrinter from sympy.logic.boolalg import BooleanFunction from sympy.printing.precedence import PRECEDENCE_VALUES, precedence from devito import configuration from devito.arch.compiler import AOMPCompiler from devito.symbolics.inspection import has_integer_args, sympy_dtype +from devito.symbolics.extended_dtypes import c_complex, c_double_complex from devito.types.basic import AbstractFunction from devito.tools import ctypes_to_cstr @@ -27,7 +29,7 @@ _prec_litterals = {np.float16: 'F16', np.float32: 'F', np.complex64: 'F'} -class _DevitoPrinterBase(C99CodePrinter): +class _DevitoPrinterBase(CodePrinter): """ Decorator for sympy.printing.ccode.CCodePrinter. @@ -38,10 +40,10 @@ class _DevitoPrinterBase(C99CodePrinter): Options for code printing. """ _default_settings = {'compiler': None, 'dtype': np.float32, - **C99CodePrinter._default_settings} + **CodePrinter._default_settings} - _func_prefix = {np.float32: 'f', np.float64: 'f'} - _func_litterals = {np.float32: 'f', np.complex64: 'f', Real: 'f'} + _func_prefix = {} + _func_litterals = {} @property def dtype(self): @@ -65,8 +67,10 @@ def _prec(self, expr): dtype = sympy_dtype(expr, default=self.dtype) if dtype is None or np.issubdtype(dtype, np.integer): real = any(isinstance(i, Float) for i in expr.atoms()) - stype = self.dtype if real else np.int32 - return np.result_type(dtype or stype, stype).type + if real: + return self.dtype + else: + return dtype or self.dtype else: return dtype or self.dtype @@ -212,11 +216,13 @@ def _print_Max(self, expr): def _print_Abs(self, expr): """Print an absolute value. Use `abs` if can infer it is an Integer""" + # Unary function, single argument + arg = expr.args[0] # AOMPCC errors with abs, always use fabs if isinstance(self.compiler, AOMPCompiler): - return "fabs(%s)" % self._print(expr.args[0]) - func = f'{self.func_prefix(expr, abs=True)}abs{self.func_literal(expr)}' - return f"{self._ns}{func}({self._print(expr.args[0])})" + return "fabs(%s)" % self._print(arg) + func = f'{self.func_prefix(arg, abs=True)}abs{self.func_literal(arg)}' + return f"{self._ns}{func}({self._print(arg)})" def _print_Add(self, expr, order=None): """" @@ -352,6 +358,58 @@ def _print_Fallback(self, expr): PRECEDENCE_VALUES['InlineIf'] = 1 +# Sympy 1.11 has introduced a bug in `_print_Add`, so we enforce here +# to always use the correct one from our printer +if Version(sympy.__version__) >= Version("1.11"): + setattr(sympy.printing.str.StrPrinter, '_print_Add', _DevitoPrinterBase._print_Add) + + +class CDevitoPrinter(_DevitoPrinterBase, C99CodePrinter): + + _default_settings = {**_DevitoPrinterBase._default_settings, + **C99CodePrinter._default_settings} + _func_litterals = {np.float32: 'f', np.complex64: 'f'} + _func_prefix = {np.float32: 'f', np.float64: 'f', + np.complex64: 'c', np.complex128: 'c'} + + # These cannot go through _print_xxx because they are classes not + # instances + type_mappings = {**C99CodePrinter.type_mappings, + c_complex: 'float _Complex', + c_double_complex: 'double _Complex'} + + def _print_ImaginaryUnit(self, expr): + return '_Complex_I' + + +class CXXDevitoPrinter(_DevitoPrinterBase, CXX11CodePrinter): + + _default_settings = {**_DevitoPrinterBase._default_settings, + **CXX11CodePrinter._default_settings} + _ns = "std::" + _func_litterals = {} + _func_prefix = {np.float32: 'f', np.float64: 'f'} + + # These cannot go through _print_xxx because they are classes not + # instances + type_mappings = {**CXX11CodePrinter.type_mappings, + c_complex: 'std::complex', + c_double_complex: 'std::complex'} + + def _print_ImaginaryUnit(self, expr): + return f'1i{self.prec_literal(expr).lower()}' + + +class AccDevitoPrinter(CXXDevitoPrinter): + + pass + + +printer_registry: dict[str, type[_DevitoPrinterBase]] = { + 'C': CDevitoPrinter, 'openmp': CDevitoPrinter, + 'openacc': AccDevitoPrinter} + + def ccode(expr, **settings): """Generate C++ code from an expression. @@ -368,10 +426,5 @@ def ccode(expr, **settings): The resulting code as a C++ string. If something went south, returns the input ``expr`` itself. """ - return _DevitoPrinterBase(settings=settings).doprint(expr, None) - - -# Sympy 1.11 has introduced a bug in `_print_Add`, so we enforce here -# to always use the correct one from our printer -if Version(sympy.__version__) >= Version("1.11"): - setattr(sympy.printing.str.StrPrinter, '_print_Add', _DevitoPrinterBase._print_Add) + printer = printer_registry.get(configuration['language'], CDevitoPrinter) + return printer(settings=settings).doprint(expr, None) diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index ffc67a4d2b..88d9299ee5 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -4,11 +4,11 @@ from devito import Constant, Eq, Function, Grid, Operator from devito.passes.iet.langbase import LangBB -from devito.passes.iet.languages.C import CBB, CDevitoPrinter -from devito.passes.iet.languages.openacc import AccBB, AccDevitoPrinter +from devito.passes.iet.languages.C import CBB +from devito.passes.iet.languages.openacc import AccBB from devito.passes.iet.languages.openmp import OmpBB from devito.symbolics.extended_dtypes import ctypes_vector_mapper -from devito.symbolics.printer import _DevitoPrinterBase +from devito.symbolics.printer import printer_registry, _DevitoPrinterBase from devito.types.basic import Basic, Scalar, Symbol from devito.types.dense import TimeFunction @@ -20,13 +20,6 @@ } -_printers: dict[str, type[_DevitoPrinterBase]] = { - 'C': CDevitoPrinter, - 'openmp': CDevitoPrinter, - 'openacc': AccDevitoPrinter -} - - def _get_language(language: str, **_) -> type[LangBB]: """ Gets the language building block type from parametrized kwargs. @@ -40,7 +33,7 @@ def _get_printer(language: str, **_) -> type[_DevitoPrinterBase]: Gets the printer building block type from parametrized kwargs. """ - return _printers[language] + return printer_registry[language] def _config_kwargs(platform: str, language: str) -> dict[str, str]: diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index b647df9688..19cc0703ce 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -13,7 +13,7 @@ from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa CallFromPointer, Cast, DefFunction, FieldFromPointer, INT, FieldFromComposite, IntDiv, Namespace, Rvalue, - ReservedWord, ListInitializer, ccode, uxreplace, + ReservedWord, ListInitializer, uxreplace, ccode, retrieve_derivatives) from devito.tools import as_tuple from devito.types import (Array, Bundle, FIndexed, LocalObject, Object, @@ -299,8 +299,7 @@ def test_extended_sympy_arithmetic(): def test_integer_abs(): i1 = Dimension(name="i1") assert ccode(Abs(i1 - 1)) == "abs(i1 - 1)" - # .5 is a standard python Float, i.e an np.float64 - assert ccode(Abs(i1 - .5)) == "fabs(i1 - 5.0e-1)" + assert ccode(Abs(i1 - .5)) == "fabsf(i1 - 5.0e-1F)" assert ccode( Abs(i1 - Constant('half', dtype=np.float64, default_value=0.5)) ) == "fabs(i1 - half)" From b00d198fd203979e374ea3b3bf4e025d87260516 Mon Sep 17 00:00:00 2001 From: mloubout Date: Fri, 17 Jan 2025 10:11:19 -0500 Subject: [PATCH 38/58] symbolics: rework Cast --- devito/mpi/routines.py | 8 +- devito/passes/iet/dtypes.py | 9 ++- devito/passes/iet/languages/openacc.py | 4 +- devito/symbolics/extended_dtypes.py | 101 ++++++------------------- devito/symbolics/extended_sympy.py | 41 +++++++--- devito/symbolics/printer.py | 15 ++-- tests/test_dtypes.py | 8 +- tests/test_symbolics.py | 6 +- 8 files changed, 81 insertions(+), 111 deletions(-) diff --git a/devito/mpi/routines.py b/devito/mpi/routines.py index 8b4987c8bb..8da418bfde 100644 --- a/devito/mpi/routines.py +++ b/devito/mpi/routines.py @@ -605,7 +605,7 @@ def _make_msg(self, f, hse, key): return MPIMsg('msg%d' % key, f, halos) def _make_sendrecv(self, f, hse, key, msg=None): - cast = cast_mapper[(f.c0.dtype, '*')] + cast = cast_mapper((f.c0.dtype, '*')) comm = f.grid.distributor._obj_comm bufg = FieldFromPointer(msg._C_field_bufg, msg) @@ -671,7 +671,7 @@ def _call_compute(self, hs, compute, *args): return compute.make_call(dynamic_args_mapper=hs.omapper.core) def _make_wait(self, f, hse, key, msg=None): - cast = cast_mapper[(f.c0.dtype, '*')] + cast = cast_mapper((f.c0.dtype, '*')) bufs = FieldFromPointer(msg._C_field_bufs, msg) @@ -772,7 +772,7 @@ def _call_sendrecv(self, *args): return def _make_haloupdate(self, f, hse, key, *args, msg=None): - cast = cast_mapper[(f.c0.dtype, '*')] + cast = cast_mapper((f.c0.dtype, '*')) comm = f.grid.distributor._obj_comm fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices} @@ -819,7 +819,7 @@ def _call_haloupdate(self, name, f, hse, msg): return HaloUpdateCall(name, args) def _make_halowait(self, f, hse, key, *args, msg=None): - cast = cast_mapper[(f.c0.dtype, '*')] + cast = cast_mapper((f.c0.dtype, '*')) fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices} diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index 7f617b1c41..ea64434837 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -17,8 +17,13 @@ def _complex_includes(iet: Callable, lang: type[LangBB] = None, """ # Check if there are complex numbers that always take dtype precedence - types = {f.dtype for f in FindSymbols().visit(iet) - if not issubclass(f.dtype, ctypes._Pointer)} + types = set() + for f in FindSymbols().visit(iet): + try: + if not issubclass(f.dtype, ctypes._Pointer): + types.add(f.dtype) + except TypeError: + pass if not any(np.issubdtype(d, np.complexfloating) for d in types): return iet, {} diff --git a/devito/passes/iet/languages/openacc.py b/devito/passes/iet/languages/openacc.py index bcf5660ac7..e67dc7eee2 100644 --- a/devito/passes/iet/languages/openacc.py +++ b/devito/passes/iet/languages/openacc.py @@ -236,11 +236,11 @@ def place_devptr(self, iet, **kwargs): dpf = List(body=[ self.lang.mapper['map-serial-present'](hp, tdp), - Block(body=DummyExpr(tdp, cast_mapper[tdp.dtype](hp))) + Block(body=DummyExpr(tdp, cast_mapper(tdp.dtype)(hp))) ]) ffp = FieldFromPointer(f._C_field_dmap, f._C_symbol) - ctdp = cast_mapper[(hp.dtype, '*')](tdp) + ctdp = cast_mapper((hp.dtype, '*'))(tdp) cast = DummyExpr(ffp, ctdp) ret = Return(ctdp) diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index e9772ee744..23fb084ba7 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -5,7 +5,7 @@ from devito.tools import (Bunch, float2, float3, float4, double2, double3, double4, # noqa int2, int3, int4, ctypes_vector_mapper) -__all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', # noqa +__all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', 'BaseCast', # noqa 'DOUBLE', 'VOID', 'NoDeclStruct', 'c_complex', 'c_double_complex'] @@ -68,94 +68,37 @@ class CustomType(ReservedWord): globals()[clsp.__name__] = clsp -class CHAR(Cast): - _base_typ = 'char' +def no_dtype(kwargs): + return {k: v for k, v in kwargs.items() if k != 'dtype'} -class SHORT(Cast): - _base_typ = 'short' +def cast_mapper(arg): + try: + assert len(arg) == 2 and arg[1] == '*' + return lambda v, **kw: CastStar(v, dtype=arg[0], **no_dtype(kw)) + except TypeError: + return lambda v, **kw: Cast(v, dtype=arg, **no_dtype(kw)) -class USHORT(Cast): - _base_typ = 'unsigned short' +FLOAT = cast_mapper(np.float32) +DOUBLE = cast_mapper(np.float64) +ULONG = cast_mapper(np.uint64) +UINTP = cast_mapper(np.uint32) -class UCHAR(Cast): - _base_typ = 'unsigned char' +# Standard ones, needed as class for e.g. single dispatch +class BaseCast(Cast): + def __new__(cls, base, stars=None, **kwargs): + kwargs['dtype'] = cls._dtype + return super().__new__(cls, base, stars=stars, **kwargs) -class LONG(Cast): - _base_typ = 'long' +class VOID(BaseCast): -class ULONG(Cast): - _base_typ = 'unsigned long' + _dtype = type('void', (ctypes.c_int,), {}) -class CFLOAT(Cast): - _base_typ = 'float' +class INT(BaseCast): - -class CDOUBLE(Cast): - _base_typ = 'double' - - -class VOID(Cast): - _base_typ = 'void' - - -class CHARP(CastStar): - base = CHAR - - -class UCHARP(CastStar): - base = UCHAR - - -class SHORTP(CastStar): - base = SHORT - - -class USHORTP(CastStar): - base = USHORT - - -class CFLOATP(CastStar): - base = CFLOAT - - -class CDOUBLEP(CastStar): - base = CDOUBLE - - -cast_mapper = { - np.int8: CHAR, - np.uint8: UCHAR, - np.int16: SHORT, # noqa - np.uint16: USHORT, # noqa - int: INT, # noqa - np.int32: INT, # noqa - np.int64: LONG, - np.uint64: ULONG, - np.float32: FLOAT, # noqa - float: DOUBLE, # noqa - np.float64: DOUBLE, # noqa - - (np.int8, '*'): CHARP, - (np.uint8, '*'): UCHARP, - (int, '*'): INTP, # noqa - (np.uint16, '*'): USHORTP, # noqa - (np.int16, '*'): SHORTP, # noqa - (np.int32, '*'): INTP, # noqa - (np.int64, '*'): INTP, # noqa - (np.float32, '*'): FLOATP, # noqa - (float, '*'): DOUBLEP, # noqa - (np.float64, '*'): DOUBLEP, # noqa -} - -for base_name in ['int', 'float', 'double']: - for i in [2, 3, 4]: - v = '%s%d' % (base_name, i) - cls = locals()[v] - cast_mapper[cls] = locals()[v.upper()] - cast_mapper[(cls, '*')] = locals()['%sP' % v.upper()] + _dtype = np.int32 diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 127995d3c0..68ff67f894 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -10,7 +10,7 @@ from devito.finite_differences.elementary import Min, Max from devito.tools import (Pickable, Bunch, as_tuple, is_integer, float2, # noqa float3, float4, double2, double3, double4, int2, int3, - int4) + int4, dtype_to_ctype, ctypes_to_cstr, ctypes_vector_mapper) from devito.types import Symbol from devito.types.basic import Basic @@ -382,15 +382,17 @@ class Cast(UnaryOp): Symbolic representation of the C notation `(type)expr`. """ - _base_typ = '' + __rargs__ = ('base', ) + __rkwargs__ = ('stars', 'dtype') - __rkwargs__ = ('stars',) - - def __new__(cls, base, stars=None, **kwargs): + def __new__(cls, base, dtype=None, stars=None, **kwargs): # Attempt simplifcation # E.g., `FLOAT(32) -> 32.0` of type `sympy.Float` try: - return sympify(eval(cls._base_typ)(base)) + if isinstance(dtype, str): + return sympify(eval(dtype)(base)) + else: + return sympify(dtype(base)) except (NameError, SyntaxError): # E.g., `_base_typ` is "char" or "unsigned long" pass @@ -399,9 +401,22 @@ def __new__(cls, base, stars=None, **kwargs): pass obj = super().__new__(cls, base) - obj._stars = stars + obj._stars = stars or '' + obj._dtype = cls.__process_dtype__(dtype) return obj + @classmethod + def __process_dtype__(cls, dtype): + if isinstance(dtype, str): + return dtype + dtype = ctypes_vector_mapper.get(dtype, dtype) + try: + dtype = ctypes_to_cstr(dtype_to_ctype(dtype)) + except: + pass + + return dtype + def _hashable_content(self): return super()._hashable_content() + (self._stars,) @@ -411,9 +426,13 @@ def _hashable_content(self): def stars(self): return self._stars + @property + def dtype(self): + return self._dtype + @property def typ(self): - return '%s%s' % (self._base_typ, self.stars or '') + return '%s%s' % (self.dtype, self.stars or '') @property def _op(self): @@ -753,10 +772,8 @@ def __str__(self): class CastStar: - base = None - - def __new__(cls, base=''): - return cls.base(base, '*') + def __new__(cls, base, dtype=None, ase=''): + return Cast(base, dtype=dtype, stars='*') # Some other utility objects diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index ca0d9e98e3..84c0201f23 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -314,11 +314,16 @@ def _print_InlineIf(self, expr): PREC = precedence(expr) return self.parenthesize("(%s) ? %s : %s" % (cond, true_expr, false_expr), PREC) - def _print_UnaryOp(self, expr): - if expr.base.is_Symbol: - return "%s%s" % (expr._op, self._print(expr.base)) - else: - return "%s(%s)" % (expr._op, self._print(expr.base)) + def _print_UnaryOp(self, expr, op=None): + op = op or expr._op + base = self._print(expr.base) + if not expr.base.is_Symbol: + base = f'({base})' + return f'{op}{base}' + + def _print_Cast(self, expr): + cast = f'({self._print(expr.dtype)}{self._print(expr.stars)})' + return self._print_UnaryOp(expr, op=cast) def _print_ComponentAccess(self, expr): return "%s.%s" % (self._print(expr.base), expr.sindex) diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 88d9299ee5..44d94dbe6a 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -2,7 +2,7 @@ import pytest import sympy -from devito import Constant, Eq, Function, Grid, Operator +from devito import Constant, Eq, Function, Grid, Operator, exp, log, sin from devito.passes.iet.langbase import LangBB from devito.passes.iet.languages.C import CBB from devito.passes.iet.languages.openacc import AccBB @@ -161,9 +161,9 @@ def test_imag_unit(dtype: np.complexfloating, kwargs: dict[str, str]) -> None: @pytest.mark.parametrize('dtype', [np.float32, np.float64, np.complex64, np.complex128]) -@pytest.mark.parametrize(['sym', 'fun'], [(sympy.exp, np.exp), - (sympy.log, np.log), - (sympy.sin, np.sin)]) +@pytest.mark.parametrize(['sym', 'fun'], [(exp, np.exp), + (log, np.log), + (sin, np.sin)]) def test_math_functions(dtype: np.dtype[np.inexact], sym: sympy.Function, fun: np.ufunc) -> None: """ diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 19cc0703ce..a4d40a72f6 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -14,7 +14,7 @@ CallFromPointer, Cast, DefFunction, FieldFromPointer, INT, FieldFromComposite, IntDiv, Namespace, Rvalue, ReservedWord, ListInitializer, uxreplace, ccode, - retrieve_derivatives) + retrieve_derivatives, BaseCast) from devito.tools import as_tuple from devito.types import (Array, Bundle, FIndexed, LocalObject, Object, ComponentAccess, StencilDimension, Symbol as dSymbol) @@ -409,8 +409,8 @@ def test_rvalue(): def test_cast(): s = Symbol(name='s', dtype=np.float32) - class BarCast(Cast): - _base_typ = 'bar' + class BarCast(BaseCast): + _dtype = 'bar' v = BarCast(s, '**') assert ccode(v) == '(bar**)s' From fd47c1abca0d42e2ec5b94a71d191fe64904b6dd Mon Sep 17 00:00:00 2001 From: mloubout Date: Fri, 17 Jan 2025 13:10:26 -0500 Subject: [PATCH 39/58] compiler: fix complex headers --- devito/passes/iet/dtypes.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index ea64434837..2a7cfbd6c4 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -5,6 +5,7 @@ from devito.ir import Callable, FindSymbols, SymbolRegistry from devito.passes.iet.engine import iet_pass from devito.passes.iet.langbase import LangBB +from devito.tools import as_tuple __all__ = ['lower_dtypes'] @@ -29,7 +30,7 @@ def _complex_includes(iet: Callable, lang: type[LangBB] = None, return iet, {} metadata = {} - lib = (lang['header-complex'],) + lib = as_tuple(lang['header-complex']) if lang.get('complex-namespace') is not None: metadata['namespaces'] = lang['complex-namespace'] From 5940b603555691c2d1a814e64ff20fda548c258d Mon Sep 17 00:00:00 2001 From: mloubout Date: Fri, 17 Jan 2025 13:19:28 -0500 Subject: [PATCH 40/58] api: remove un-needed dtype reconstruction mode --- devito/passes/iet/dtypes.py | 12 ++++++------ devito/symbolics/extended_dtypes.py | 2 +- devito/symbolics/extended_sympy.py | 2 +- devito/types/basic.py | 30 +++++++---------------------- devito/types/misc.py | 2 +- 5 files changed, 16 insertions(+), 32 deletions(-) diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index 2a7cfbd6c4..7848e9c1dc 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -1,4 +1,3 @@ -import ctypes import numpy as np from devito.arch.compiler import Compiler @@ -18,15 +17,16 @@ def _complex_includes(iet: Callable, lang: type[LangBB] = None, """ # Check if there are complex numbers that always take dtype precedence - types = set() + is_complex = False for f in FindSymbols().visit(iet): try: - if not issubclass(f.dtype, ctypes._Pointer): - types.add(f.dtype) + if np.issubdtype(f.dtype, np.complexfloating): + is_complex = True + break except TypeError: - pass + continue - if not any(np.issubdtype(d, np.complexfloating) for d in types): + if not is_complex: return iet, {} metadata = {} diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index 23fb084ba7..0789c7b947 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -83,7 +83,7 @@ def cast_mapper(arg): FLOAT = cast_mapper(np.float32) DOUBLE = cast_mapper(np.float64) ULONG = cast_mapper(np.uint64) -UINTP = cast_mapper(np.uint32) +UINTP = cast_mapper((np.uint32, '*')) # Standard ones, needed as class for e.g. single dispatch diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 68ff67f894..1b01c6db3f 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -396,7 +396,7 @@ def __new__(cls, base, dtype=None, stars=None, **kwargs): except (NameError, SyntaxError): # E.g., `_base_typ` is "char" or "unsigned long" pass - except TypeError: + except (ValueError, TypeError): # `base` ain't a number pass diff --git a/devito/types/basic.py b/devito/types/basic.py index f77adbb853..1a0c2fc97a 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -14,8 +14,7 @@ from devito.data import default_allocator from devito.parameters import configuration from devito.tools import (Pickable, as_tuple, dtype_to_ctype, - frozendict, memoized_meth, sympy_mutex, CustomDtype, - Reconstructable) + frozendict, memoized_meth, sympy_mutex, CustomDtype) from devito.types.args import ArgProvider from devito.types.caching import Cached, Uncached from devito.types.lazy import Evaluable @@ -881,7 +880,6 @@ def __new__(cls, *args, **kwargs): name = kwargs.get('name') alias = kwargs.get('alias') function = kwargs.get('function') - dtype = kwargs.get('dtype') if alias is True or (function and function.name != name): function = kwargs['function'] = None @@ -889,8 +887,7 @@ def __new__(cls, *args, **kwargs): # definitely a reconstruction if function is not None and \ function.name == name and \ - function.indices == indices and \ - function.dtype == dtype: + function.indices == indices: # Special case: a syntactically identical alias of `function`, so # let's just return `function` itself return function @@ -1240,8 +1237,7 @@ def bound_symbols(self): @cached_property def indexed(self): """The wrapped IndexedData object.""" - return IndexedData(self.name, shape=self._shape, function=self.function, - dtype=self.dtype) + return IndexedData(self.name, shape=self._shape, function=self.function) @cached_property def dmap(self): @@ -1523,14 +1519,13 @@ class IndexedBase(sympy.IndexedBase, Basic, Pickable): __rargs__ = ('label', 'shape') __rkwargs__ = ('function',) - def __new__(cls, label, shape, function=None, dtype=None): + def __new__(cls, label, shape, function=None): # Make sure `label` is a devito.Symbol, not a sympy.Symbol if isinstance(label, str): label = Symbol(name=label, dtype=None) with sympy_mutex: obj = sympy.IndexedBase.__new__(cls, label, shape) obj.function = function - obj._dtype = dtype or function.dtype return obj func = Pickable._rebuild @@ -1570,7 +1565,7 @@ def indices(self): @property def dtype(self): - return self._dtype + return self.function.dtype @cached_property def free_symbols(self): @@ -1632,7 +1627,7 @@ def _C_ctype(self): return self.function._C_ctype -class Indexed(sympy.Indexed, Reconstructable): +class Indexed(sympy.Indexed): # The two type flags have changed in upstream sympy as of version 1.1, # but the below interpretation is used throughout the compiler to @@ -1644,17 +1639,6 @@ class Indexed(sympy.Indexed, Reconstructable): is_Dimension = False - __rargs__ = ('base', 'indices') - __rkwargs__ = ('dtype',) - - def __new__(cls, base, *indices, dtype=None, **kwargs): - if len(indices) == 1: - indices = as_tuple(indices[0]) - newobj = sympy.Indexed.__new__(cls, base, *indices) - newobj._dtype = dtype or base.dtype - - return newobj - @memoized_meth def __str__(self): return super().__str__() @@ -1676,7 +1660,7 @@ def function(self): @property def dtype(self): - return self._dtype + return self.function.dtype @property def name(self): diff --git a/devito/types/misc.py b/devito/types/misc.py index 38beeaee53..29514bb99a 100644 --- a/devito/types/misc.py +++ b/devito/types/misc.py @@ -83,7 +83,7 @@ class FIndexed(Indexed, Pickable): __rkwargs__ = ('strides_map', 'accessor') def __new__(cls, base, *args, strides_map=None, accessor=None): - obj = super().__new__(cls, base, args) + obj = super().__new__(cls, base, *args) obj.strides_map = frozendict(strides_map or {}) obj.accessor = accessor From 64704f31f4a4bd54b7e5d3defb0df369059970d7 Mon Sep 17 00:00:00 2001 From: mloubout Date: Fri, 17 Jan 2025 15:47:07 -0500 Subject: [PATCH 41/58] compiler: fix dtype for mpi routines --- devito/mpi/routines.py | 6 +++--- devito/operator/operator.py | 4 ++-- devito/passes/clusters/derivatives.py | 3 ++- devito/symbolics/extended_dtypes.py | 6 +++--- devito/symbolics/extended_sympy.py | 10 +--------- devito/symbolics/manipulation.py | 2 +- devito/symbolics/printer.py | 3 ++- devito/tools/dtypes_lowering.py | 6 +++++- devito/types/array.py | 2 ++ devito/types/dense.py | 6 +++++- tests/test_pickle.py | 4 ++-- 11 files changed, 28 insertions(+), 24 deletions(-) diff --git a/devito/mpi/routines.py b/devito/mpi/routines.py index 8da418bfde..67158a621c 100644 --- a/devito/mpi/routines.py +++ b/devito/mpi/routines.py @@ -16,7 +16,7 @@ from devito.mpi import MPI from devito.symbolics import (Byref, CondNe, FieldFromPointer, FieldFromComposite, IndexedPointer, Macro, cast_mapper, subs_op_args) -from devito.tools import (as_mapper, dtype_to_mpitype, dtype_len, dtype_to_ctype, +from devito.tools import (as_mapper, dtype_to_mpitype, dtype_len, dtype_alloc_ctype, flatten, generator, is_integer, split) from devito.types import (Array, Bag, Dimension, Eq, Symbol, LocalObject, CompositeObject, CustomDimension) @@ -1204,8 +1204,8 @@ def _arg_defaults(self, allocator, alias, args=None): entry.sizes = (c_int*len(shape))(*shape) # Allocate the send/recv buffers - size = reduce(mul, shape)*dtype_len(self.target.dtype) - ctype = dtype_to_ctype(f.dtype) + ctype, c_scale = dtype_alloc_ctype(f.dtype) + size = int(reduce(mul, shape) * c_scale) * dtype_len(self.target.dtype) entry.bufg, bufg_memfree_args = allocator._alloc_C_libcall(size, ctype) entry.bufs, bufs_memfree_args = allocator._alloc_C_libcall(size, ctype) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 338ac9faa1..3966da4378 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -1121,7 +1121,7 @@ def __setstate__(self, state): self._lib.name = soname self._allocator = default_allocator( - '%s.%s.%s' % (self._compiler.name, self._language, self._platform) + '%s.%s.%s' % (self._compiler.__class__.name, self._language, self._platform) ) @@ -1407,7 +1407,7 @@ def parse_kwargs(**kwargs): # `allocator` kwargs['allocator'] = default_allocator( - '%s.%s.%s' % (kwargs['compiler'].name, + '%s.%s.%s' % (kwargs['compiler'].__class__.__name__, kwargs['language'], kwargs['platform']) ) diff --git a/devito/passes/clusters/derivatives.py b/devito/passes/clusters/derivatives.py index f8f339aa1e..5af92a3208 100644 --- a/devito/passes/clusters/derivatives.py +++ b/devito/passes/clusters/derivatives.py @@ -1,6 +1,7 @@ from functools import singledispatch from sympy import S +import numpy as np from devito.finite_differences import IndexDerivative from devito.ir import Backward, Forward, Interval, IterationSpace, Queue @@ -157,7 +158,7 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs): # NOTE: created before recurring so that we ultimately get a sound ordering try: s = reusables.pop() - assert s.dtype is dtype + assert np.can_cast(s.dtype, dtype) except KeyError: name = sregistry.make_name(prefix='r') s = Symbol(name=name, dtype=dtype) diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index 0789c7b947..089b454b72 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -1,7 +1,7 @@ import ctypes import numpy as np -from devito.symbolics.extended_sympy import ReservedWord, Cast, CastStar, ValueLimit +from devito.symbolics.extended_sympy import ReservedWord, Cast, ValueLimit from devito.tools import (Bunch, float2, float3, float4, double2, double3, double4, # noqa int2, int3, int4, ctypes_vector_mapper) @@ -64,7 +64,7 @@ class CustomType(ReservedWord): cls = type(v.upper(), (Cast,), {'_base_typ': v}) globals()[cls.__name__] = cls - clsp = type('%sP' % v.upper(), (CastStar,), {'base': cls}) + clsp = type('%sP' % v.upper(), (Cast,), {'base': cls}) globals()[clsp.__name__] = clsp @@ -75,7 +75,7 @@ def no_dtype(kwargs): def cast_mapper(arg): try: assert len(arg) == 2 and arg[1] == '*' - return lambda v, **kw: CastStar(v, dtype=arg[0], **no_dtype(kw)) + return lambda v, **kw: Cast(v, dtype=arg[0], stars=arg[1], **no_dtype(kw)) except TypeError: return lambda v, **kw: Cast(v, dtype=arg, **no_dtype(kw)) diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 1b01c6db3f..9ea9611f63 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -768,14 +768,6 @@ def __str__(self): __repr__ = __str__ -# *** Casting - -class CastStar: - - def __new__(cls, base, dtype=None, ase=''): - return Cast(base, dtype=dtype, stars='*') - - # Some other utility objects Null = Macro('NULL') @@ -789,7 +781,7 @@ def __new__(cls, intype, stars=None, **kwargs): stars = stars or '' argument = Keyword(f'{intype}{stars}') newobj = super().__new__(cls, 'sizeof', arguments=(argument,), **kwargs) - newobj.intype = intype + newobj.intype = Cast.__process_dtype__(intype) newobj.stars = stars return newobj diff --git a/devito/symbolics/manipulation.py b/devito/symbolics/manipulation.py index f5992ac8be..80389ead08 100644 --- a/devito/symbolics/manipulation.py +++ b/devito/symbolics/manipulation.py @@ -393,7 +393,7 @@ def normalize_args(args): for k, v in args.items(): try: retval[k] = sympify(v, strict=True) - except SympifyError: + except (TypeError, SympifyError): continue return retval diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 84c0201f23..0a60b7f2cc 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -219,7 +219,8 @@ def _print_Abs(self, expr): # Unary function, single argument arg = expr.args[0] # AOMPCC errors with abs, always use fabs - if isinstance(self.compiler, AOMPCompiler): + if isinstance(self.compiler, AOMPCompiler) and \ + not np.issubdtype(self._prec(expr), np.integer): return "fabs(%s)" % self._print(arg) func = f'{self.func_prefix(arg, abs=True)}abs{self.func_literal(arg)}' return f"{self._ns}{func}({self._print(arg)})" diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index a6ce289324..d8e2b0723b 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -21,6 +21,7 @@ # NOTE: the following is inspired by pyopencl.cltypes mapper = { + "half": np.float16, "int": np.int32, "float": np.float32, "double": np.float64 @@ -189,7 +190,8 @@ def dtype_to_mpitype(dtype): np.int32: 'MPI_INT', np.float32: 'MPI_FLOAT', np.int64: 'MPI_LONG', - np.float64: 'MPI_DOUBLE' + np.float64: 'MPI_DOUBLE', + np.float16: 'MPI_UNSIGNED_SHORT' }[dtype] @@ -222,6 +224,8 @@ class c_restrict_void_p(ctypes.c_void_p): ctypes_vector_mapper = {} for base_name, base_dtype in mapper.items(): + if base_dtype is np.float16: + continue base_ctype = dtype_to_ctype(base_dtype) for count in counts: diff --git a/devito/types/array.py b/devito/types/array.py index 3ebd2fdcbb..b72bf74faa 100644 --- a/devito/types/array.py +++ b/devito/types/array.py @@ -204,6 +204,8 @@ class ArrayMapped(Array): (_C_field_dmap, c_void_p), (_C_field_size, c_uint64)]})) + _C_typedata = 'struct ' + _C_structname + class ArrayObject(ArrayBasic): diff --git a/devito/types/dense.py b/devito/types/dense.py index 29d7d24d1d..adfc64a094 100644 --- a/devito/types/dense.py +++ b/devito/types/dense.py @@ -796,17 +796,21 @@ def _halo_exchange(self): # Gather send data data = self._data_in_region(OWNED, d, i) sendbuf = np.ascontiguousarray(data) + if self.dtype == np.float16: + sendbuf = sendbuf.view(np.uint16) # Setup recv buffer shape = self._data_in_region(HALO, d, i.flip()).shape recvbuf = np.ndarray(shape=shape, dtype=self.dtype) + if self.dtype == np.float16: + recvbuf = recvbuf.view(np.uint16) # Communication comm.Sendrecv(sendbuf, dest=dest, recvbuf=recvbuf, source=source) # Scatter received data if recvbuf is not None and source != MPI.PROC_NULL: - self._data_in_region(HALO, d, i.flip())[:] = recvbuf + self._data_in_region(HALO, d, i.flip())[:] = recvbuf.view(self.dtype) self._is_halo_dirty = False diff --git a/tests/test_pickle.py b/tests/test_pickle.py index ef47e917fb..fc33e98965 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -567,8 +567,8 @@ def test_equation(self, pickle): eq = Eq(f, f+1, implicit_dims=xs) - pkl_eq = pickle0.dumps(eq) - new_eq = pickle0.loads(pkl_eq) + pkl_eq = pickle.dumps(eq) + new_eq = pickle.loads(pkl_eq) assert new_eq.lhs.name == f.name assert str(new_eq.rhs) == 'f(x) + 1' From 0a809dd16259bffb73e83288e8b4411b4d6217c3 Mon Sep 17 00:00:00 2001 From: mloubout Date: Sat, 18 Jan 2025 18:16:03 -0500 Subject: [PATCH 42/58] compiler: fix missing algorithm include for min/max --- devito/operator/operator.py | 5 +++-- devito/passes/iet/languages/CXX.py | 1 + devito/passes/iet/misc.py | 31 ++++++++++++++++-------------- tests/test_gpu_openacc.py | 13 ++++++++++++- 4 files changed, 33 insertions(+), 17 deletions(-) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 3966da4378..a5fef9b84f 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -470,6 +470,7 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): * Finalize (e.g., symbol definitions, array casts) """ name = kwargs.get("name", "Kernel") + lang = cls._Target.lang() # Wrap the IET with an EntryFunction (a special Callable representing # the entry point of the generated library) @@ -488,10 +489,10 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): cls._Target.instrument(graph, profiler=profiler, **kwargs) # Extract the necessary macros from the symbolic objects - generate_macros(graph, **kwargs) + generate_macros(graph, lang=lang, **kwargs) # Add type specific metadata - lower_dtypes(graph, lang=cls._Target.lang(), **kwargs) + lower_dtypes(graph, lang=lang, **kwargs) # Target-independent optimizations minimize_symbols(graph) diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index 17003c0d8f..bfa6bebe35 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -45,6 +45,7 @@ class CXXBB(LangBB): mapper = { 'header-memcpy': 'string.h', + 'header-algorithm': 'algorithm', 'host-alloc': lambda i, j, k: Call('posix_memalign', (i, j, k)), 'host-alloc-pin': lambda i, j, k: diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 28e1cc4f7b..d978936053 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -144,7 +144,7 @@ def generate_macros(graph, **kwargs): @iet_pass -def _generate_macros(iet, tracker=None, **kwargs): +def _generate_macros(iet, tracker=None, lang=None, **kwargs): # Derive the Macros necessary for the FIndexeds iet = _generate_macros_findexeds(iet, tracker=tracker, **kwargs) @@ -152,7 +152,8 @@ def _generate_macros(iet, tracker=None, **kwargs): headers = sorted((ccode(define), ccode(expr)) for define, expr in headers) # Generate Macros from higher-level SymPy objects - headers.extend(_generate_macros_math(iet)) + mheaders, includes = _generate_macros_math(iet, lang=lang) + headers.extend(mheaders) # Remove redundancies while preserving the order headers = filter_ordered(headers) @@ -160,7 +161,6 @@ def _generate_macros(iet, tracker=None, **kwargs): # Some special Symbols may represent Macros defined in standard libraries, # so we need to include the respective includes limits = FindApplications(ValueLimit).visit(iet) - includes = set() if limits & (set(limits_mapper[np.int32]) | set(limits_mapper[np.int64])): includes.add('limits.h') elif limits & (set(limits_mapper[np.float32]) | set(limits_mapper[np.float64])): @@ -195,35 +195,38 @@ def _generate_macros_findexeds(iet, sregistry=None, tracker=None, **kwargs): return iet -def _generate_macros_math(iet): +def _generate_macros_math(iet, lang=None): headers = [] + includes = [] for i in FindApplications().visit(iet): - headers.extend(_lower_macro_math(i)) + header, include = _lower_macro_math(i, lang) + headers.extend(header) + includes.extend(include) - return headers + return headers, set(includes) - {None} @singledispatch -def _lower_macro_math(expr): - return () +def _lower_macro_math(expr, lang): + return (), {} @_lower_macro_math.register(Min) @_lower_macro_math.register(sympy.Min) -def _(expr): +def _(expr, lang): if has_integer_args(*expr.args) and len(expr.args) == 2: - return (('MIN(a,b)', ('(((a) < (b)) ? (a) : (b))')),) + return (('MIN(a,b)', ('(((a) < (b)) ? (a) : (b))')),), {} else: - return () + return (), (lang.get('header-algorithm'),) @_lower_macro_math.register(Max) @_lower_macro_math.register(sympy.Max) -def _(expr): +def _(expr, lang): if has_integer_args(*expr.args) and len(expr.args) == 2: - return (('MAX(a,b)', ('(((a) > (b)) ? (a) : (b))')),) + return (('MAX(a,b)', ('(((a) > (b)) ? (a) : (b))')),), {} else: - return () + return (), (lang.get('header-algorithm'),) @_lower_macro_math.register(SafeInv) diff --git a/tests/test_gpu_openacc.py b/tests/test_gpu_openacc.py index bdf732a12d..8c4813db0b 100644 --- a/tests/test_gpu_openacc.py +++ b/tests/test_gpu_openacc.py @@ -2,7 +2,7 @@ import numpy as np from devito import (Grid, Function, TimeFunction, SparseTimeFunction, Eq, Operator, - norm, solve) + norm, solve, Max) from conftest import skipif, assert_blocking, opts_device_tiling from devito.data import LEFT from devito.exceptions import InvalidOperator @@ -171,6 +171,17 @@ def test_multi_tile_blocking_structure(self): assert len(iters) == len(v) assert all(i.step == j for i, j in zip(iters, v)) + def test_std_max(self): + grid = Grid(shape=(3, 3, 3)) + x, y, z = grid.dimensions + + u = Function(name='u', grid=grid) + + op = Operator(Eq(u, Max(1.2 * x / y, 2.3 * y / x)), + platform='nvidiaX', language='openacc') + + assert '' in str(op) + class TestOperator: From 7632f07ba4ef368508b64e15550755dcc0a283b1 Mon Sep 17 00:00:00 2001 From: mloubout Date: Sat, 18 Jan 2025 18:47:06 -0500 Subject: [PATCH 43/58] arch: switch sycl error to warning for no-compile codegen --- devito/arch/compiler.py | 4 ++-- devito/passes/iet/dtypes.py | 6 +++--- devito/passes/iet/misc.py | 4 ++-- devito/symbolics/extended_sympy.py | 8 ++------ 4 files changed, 9 insertions(+), 13 deletions(-) diff --git a/devito/arch/compiler.py b/devito/arch/compiler.py index 6ff0825ad8..2276ca1d6e 100644 --- a/devito/arch/compiler.py +++ b/devito/arch/compiler.py @@ -838,7 +838,7 @@ def __init_finalize__(self, **kwargs): language = kwargs.pop('language', configuration['language']) if language == 'sycl': - raise ValueError("Use SyclCompiler to jit-compile sycl") + warning("Use SyclCompiler to jit-compile sycl") elif language == 'openmp': # Earlier versions to OneAPI 2023.2.0 (clang17 underneath), have an @@ -894,7 +894,7 @@ def __init_finalize__(self, **kwargs): language = kwargs.pop('language', configuration['language']) if language != 'sycl': - raise ValueError("Expected language sycl with SyclCompiler") + warning("Expected language sycl with SyclCompiler") self.cflags.remove('-std=c99') self.cflags.append('-fsycl') diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index 7848e9c1dc..95b944c2b1 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -10,8 +10,8 @@ @iet_pass -def _complex_includes(iet: Callable, lang: type[LangBB] = None, - compiler: Compiler = None) -> tuple[Callable, dict]: +def _complex_includes(iet: Callable, lang: type[LangBB], compiler: Compiler, + sregistry: SymbolRegistry) -> tuple[Callable, dict]: """ Includes complex arithmetic headers for the given language, if needed. """ @@ -59,4 +59,4 @@ def lower_dtypes(graph: Callable, lang: type[LangBB] = None, compiler: Compiler """ for dtype_pass in dtype_passes: - dtype_pass(graph, lang=lang, compiler=compiler) + dtype_pass(graph, lang=lang, compiler=compiler, sregistry=sregistry) diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index d978936053..c30a7d6664 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -230,10 +230,10 @@ def _(expr, lang): @_lower_macro_math.register(SafeInv) -def _(expr): +def _(expr, lang): eps = np.finfo(np.float32).resolution**2 return (('SAFEINV(a, b)', - f'(((a) < {eps} || (b) < {eps}) ? (0.0F) : (1.0F / (a)))'),) + f'(((a) < {eps} || (b) < {eps}) ? (0.0F) : (1.0F / (a)))'),), {} @iet_pass diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 9ea9611f63..1213cdead5 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -386,15 +386,11 @@ class Cast(UnaryOp): __rkwargs__ = ('stars', 'dtype') def __new__(cls, base, dtype=None, stars=None, **kwargs): - # Attempt simplifcation # E.g., `FLOAT(32) -> 32.0` of type `sympy.Float` try: - if isinstance(dtype, str): - return sympify(eval(dtype)(base)) - else: - return sympify(dtype(base)) + base = sympify(dtype(base)) except (NameError, SyntaxError): - # E.g., `_base_typ` is "char" or "unsigned long" + # E.g., `dtype` is "char" or "unsigned long" pass except (ValueError, TypeError): # `base` ain't a number From b8c1122f1226de10a30836842cca15d79c0ca093 Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 22 Jan 2025 14:11:11 -0500 Subject: [PATCH 44/58] symbolics: rework cast/sizeof for pickling --- devito/symbolics/extended_dtypes.py | 10 ++--- devito/symbolics/extended_sympy.py | 57 +++++++++++++---------------- devito/symbolics/printer.py | 13 +++++-- devito/types/array.py | 2 - devito/types/basic.py | 6 ++- tests/test_pickle.py | 22 ++++++++++- 6 files changed, 65 insertions(+), 45 deletions(-) diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index 089b454b72..1a69aefec1 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -68,16 +68,12 @@ class CustomType(ReservedWord): globals()[clsp.__name__] = clsp -def no_dtype(kwargs): - return {k: v for k, v in kwargs.items() if k != 'dtype'} - - def cast_mapper(arg): try: assert len(arg) == 2 and arg[1] == '*' - return lambda v, **kw: Cast(v, dtype=arg[0], stars=arg[1], **no_dtype(kw)) + return lambda v, dtype=None, **kw: Cast(v, dtype=arg[0], stars=arg[1], **kw) except TypeError: - return lambda v, **kw: Cast(v, dtype=arg, **no_dtype(kw)) + return lambda v, dtype=None, **kw: Cast(v, dtype=arg, **kw) FLOAT = cast_mapper(np.float32) @@ -96,7 +92,7 @@ def __new__(cls, base, stars=None, **kwargs): class VOID(BaseCast): - _dtype = type('void', (ctypes.c_int,), {}) + _dtype = 'void' class INT(BaseCast): diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 1213cdead5..aba854c4cf 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -383,36 +383,21 @@ class Cast(UnaryOp): """ __rargs__ = ('base', ) - __rkwargs__ = ('stars', 'dtype') + __rkwargs__ = ('dtype', 'stars') def __new__(cls, base, dtype=None, stars=None, **kwargs): - # E.g., `FLOAT(32) -> 32.0` of type `sympy.Float` try: - base = sympify(dtype(base)) - except (NameError, SyntaxError): - # E.g., `dtype` is "char" or "unsigned long" - pass - except (ValueError, TypeError): - # `base` ain't a number + if issubclass(dtype, np.generic) and sympify(base).is_Number: + base = sympify(dtype(base)) + except TypeError: + # E.g. void pass obj = super().__new__(cls, base) obj._stars = stars or '' - obj._dtype = cls.__process_dtype__(dtype) + obj._dtype = dtype return obj - @classmethod - def __process_dtype__(cls, dtype): - if isinstance(dtype, str): - return dtype - dtype = ctypes_vector_mapper.get(dtype, dtype) - try: - dtype = ctypes_to_cstr(dtype_to_ctype(dtype)) - except: - pass - - return dtype - def _hashable_content(self): return super()._hashable_content() + (self._stars,) @@ -427,12 +412,18 @@ def dtype(self): return self._dtype @property - def typ(self): - return '%s%s' % (self.dtype, self.stars or '') + def _C_ctype(self): + ctype = ctypes_vector_mapper.get(self.dtype, self.dtype) + try: + ctype = dtype_to_ctype(ctype) + except: + pass + + return ctype @property def _op(self): - return '(%s)' % self.typ + return '(%s)' % self._C_ctype class IndexedPointer(sympy.Expr, Pickable, BasicWrapperMixin): @@ -775,17 +766,21 @@ class SizeOf(DefFunction): def __new__(cls, intype, stars=None, **kwargs): stars = stars or '' - argument = Keyword(f'{intype}{stars}') - newobj = super().__new__(cls, 'sizeof', arguments=(argument,), **kwargs) - newobj.intype = Cast.__process_dtype__(intype) + + if not isinstance(intype, (str, ReservedWord)): + intype = dtype_to_ctype(intype) + if intype in ctypes_vector_mapper.values(): + idx = list(ctypes_vector_mapper.values()).index(intype) + intype = list(ctypes_vector_mapper.keys())[idx] + else: + intype = ctypes_to_cstr(intype) + + newobj = super().__new__(cls, 'sizeof', arguments=f'{intype}{stars}', **kwargs) newobj.stars = stars + newobj.intype = intype return newobj - @property - def arguments(self): - return self.args - @property def args(self): return super().args[1] diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 0a60b7f2cc..eb6f96ec77 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -64,7 +64,10 @@ def doprint(self, expr, assign_to=None): return self._print(expr) def _prec(self, expr): - dtype = sympy_dtype(expr, default=self.dtype) + try: + dtype = sympy_dtype(expr, default=self.dtype) + except TypeError: + return self.dtype if dtype is None or np.issubdtype(dtype, np.integer): real = any(isinstance(i, Float) for i in expr.atoms()) if real: @@ -200,7 +203,11 @@ def _print_Mod(self, expr): def _print_Mul(self, expr): term = super()._print_Mul(expr) - return term.replace("(-1)*", "-") + # avoid (-1)*... + term = term.replace("(-1)*", "-") + # Avoid (-1) / ... + term = term.replace("(-1)/", f"-{self._prec(expr)(1)}/") + return term def _print_Min(self, expr): if has_integer_args(*expr.args) and len(expr.args) == 2: @@ -323,7 +330,7 @@ def _print_UnaryOp(self, expr, op=None): return f'{op}{base}' def _print_Cast(self, expr): - cast = f'({self._print(expr.dtype)}{self._print(expr.stars)})' + cast = f'({self._print(expr._C_ctype)}{self._print(expr.stars)})' return self._print_UnaryOp(expr, op=cast) def _print_ComponentAccess(self, expr): diff --git a/devito/types/array.py b/devito/types/array.py index b72bf74faa..3ebd2fdcbb 100644 --- a/devito/types/array.py +++ b/devito/types/array.py @@ -204,8 +204,6 @@ class ArrayMapped(Array): (_C_field_dmap, c_void_p), (_C_field_size, c_uint64)]})) - _C_typedata = 'struct ' + _C_structname - class ArrayObject(ArrayBasic): diff --git a/devito/types/basic.py b/devito/types/basic.py index 1a0c2fc97a..877f12b8c5 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -1,7 +1,7 @@ import abc import inspect from collections import namedtuple -from ctypes import POINTER, _Pointer, c_char_p, c_char +from ctypes import POINTER, _Pointer, c_char_p, c_char, Structure from functools import reduce, cached_property from operator import mul @@ -94,6 +94,10 @@ def _C_typedata(self): if _type is c_char_p: _type = c_char + # ctypes Structures such as dense/array use directly the struct name + if issubclass(_type, Structure): + _type = f'struct {_type.__name__}' + return _type @abc.abstractproperty diff --git a/tests/test_pickle.py b/tests/test_pickle.py index fc33e98965..e1d48060a1 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -1,3 +1,4 @@ +import ctypes import pickle as pickle0 import cloudpickle as pickle1 @@ -22,7 +23,7 @@ from devito.types.basic import BoundSymbol, AbstractSymbol from devito.tools import EnrichedTuple from devito.symbolics import (IntDiv, ListInitializer, FieldFromPointer, - CallFromPointer, DefFunction) + CallFromPointer, DefFunction, Cast, SizeOf) from examples.seismic import (demo_model, AcquisitionGeometry, TimeAxis, RickerSource, Receiver) @@ -575,6 +576,25 @@ def test_equation(self, pickle): assert new_eq.implicit_dims[0].name == 'xs' assert new_eq.implicit_dims[0].factor.data == 4 + @pytest.mark.parametrize('typ', [ctypes.c_float, 'struct truct']) + def test_Cast(self, pickle, typ): + a = Symbol('a') + un = Cast(a, dtype=typ) + + pkl_un = pickle.dumps(un) + new_un = pickle.loads(pkl_un) + + assert un == new_un + + @pytest.mark.parametrize('typ', [ctypes.c_float, 'struct truct']) + def test_SizeOf(self, pickle, typ): + un = SizeOf(typ) + + pkl_un = pickle.dumps(un) + new_un = pickle.loads(pkl_un) + + assert un == new_un + class TestAdvanced: From e029104c056df4ff215aced2aee449103d2cd36c Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 22 Jan 2025 16:11:00 -0500 Subject: [PATCH 45/58] api: fix c_datatype hack --- devito/ir/iet/visitors.py | 2 +- devito/symbolics/extended_sympy.py | 1 - devito/symbolics/printer.py | 1 - devito/tools/dtypes_lowering.py | 2 +- devito/types/basic.py | 8 ++------ 5 files changed, 4 insertions(+), 10 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 7ab38b5191..ac6fa8806c 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -466,7 +466,7 @@ def visit_Break(self, o): def visit_Return(self, o): v = 'return' if o.value is not None: - v += ' %s' % o.value + v += f' {ccode(o.value)}' return c.Statement(v) def visit_Definition(self, o): diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index aba854c4cf..f803117343 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -418,7 +418,6 @@ def _C_ctype(self): ctype = dtype_to_ctype(ctype) except: pass - return ctype @property diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index eb6f96ec77..76980ae557 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -1,7 +1,6 @@ """ Utilities to turn SymPy objects into C strings. """ - import numpy as np import sympy diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index d8e2b0723b..e1bfe481a0 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -144,7 +144,7 @@ def dtype_to_ctype(dtype): if isinstance(dtype, CustomDtype): return dtype - elif issubclass(dtype, ctypes._SimpleCData): + elif issubclass(dtype, (ctypes._Pointer, ctypes.Structure, ctypes._SimpleCData)): # Bypass np.ctypeslib's normalization rules such as # `np.ctypeslib.as_ctypes_type(ctypes.c_void_p) -> ctypes.c_ulong` return dtype diff --git a/devito/types/basic.py b/devito/types/basic.py index 877f12b8c5..f0bd493855 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -1,7 +1,7 @@ import abc import inspect from collections import namedtuple -from ctypes import POINTER, _Pointer, c_char_p, c_char, Structure +from ctypes import POINTER, _Pointer, c_char_p, c_char from functools import reduce, cached_property from operator import mul @@ -81,7 +81,7 @@ def _C_name(self): @property def _C_typedata(self): """ - The type of the object in the generated code as a `str`. + The type of the object's data in the generated code. """ _type = self._C_ctype if isinstance(_type, CustomDtype): @@ -94,10 +94,6 @@ def _C_typedata(self): if _type is c_char_p: _type = c_char - # ctypes Structures such as dense/array use directly the struct name - if issubclass(_type, Structure): - _type = f'struct {_type.__name__}' - return _type @abc.abstractproperty From 77359899c87f52e273c3a76e349e9a9731a32303 Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 23 Jan 2025 09:58:09 -0500 Subject: [PATCH 46/58] compiler: make visitor language parametric --- devito/ir/iet/nodes.py | 2 +- devito/ir/iet/visitors.py | 59 ++++++++++++++++-------------- devito/operator/operator.py | 18 ++++----- devito/passes/iet/engine.py | 7 ++-- devito/symbolics/extended_sympy.py | 9 ++--- devito/symbolics/printer.py | 11 ++++-- devito/types/basic.py | 7 +++- tests/test_dtypes.py | 5 ++- 8 files changed, 65 insertions(+), 53 deletions(-) diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index 4ffdb39773..d98a835a09 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -153,7 +153,7 @@ def writes(self): return () def _signature_items(self): - return (str(self.ccode),) + return (str(self),) class ExprStmt: diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index ac6fa8806c..e50b7dd674 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -176,8 +176,9 @@ class CGen(Visitor): Return a representation of the Iteration/Expression tree as a :module:`cgen` tree. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, language=None, **kwargs): super().__init__(*args, **kwargs) + self.language = language # The following mappers may be customized by subclasses (that is, # backend-specific CGen-erators) @@ -189,6 +190,9 @@ def __init__(self, *args, **kwargs): } _restrict_keyword = 'restrict' + def ccode(self, expr, **kwargs): + return ccode(expr, language=self.language, **kwargs) + def _gen_struct_decl(self, obj, masked=()): """ Convert ctypes.Struct -> cgen.Structure. @@ -222,7 +226,7 @@ def _gen_struct_decl(self, obj, masked=()): try: entries.append(self._gen_value(i, 0, masked=('const',))) except AttributeError: - cstr = ccode(ct) + cstr = self.ccode(ct) if ct is c_restrict_void_p: cstr = '%srestrict' % cstr entries.append(c.Value(cstr, n)) @@ -244,10 +248,10 @@ def _gen_value(self, obj, mode=1, masked=()): if getattr(obj.function, k, False) and v not in masked] if (obj._mem_stack or obj._mem_constant) and mode == 1: - strtype = ccode(obj._C_typedata) - strshape = ''.join('[%s]' % ccode(i) for i in obj.symbolic_shape) + strtype = self.ccode(obj._C_typedata) + strshape = ''.join('[%s]' % self.ccode(i) for i in obj.symbolic_shape) else: - strtype = ccode(obj._C_ctype) + strtype = self.ccode(obj._C_ctype) strshape = '' if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1: if not obj._mem_stack: @@ -261,7 +265,7 @@ def _gen_value(self, obj, mode=1, masked=()): strobj = '%s%s' % (strname, strshape) if obj.is_LocalObject and obj.cargs and mode == 1: - arguments = [ccode(i) for i in obj.cargs] + arguments = [self.ccode(i) for i in obj.cargs] strobj = MultilineCall(strobj, arguments, True) value = c.Value(strtype, strobj) @@ -275,9 +279,9 @@ def _gen_value(self, obj, mode=1, masked=()): if obj.is_Array and obj.initvalue is not None and mode == 1: init = ListInitializer(obj.initvalue) if not obj._mem_constant or init.is_numeric: - value = c.Initializer(value, ccode(init)) + value = c.Initializer(value, self.ccode(init)) elif obj.is_LocalObject and obj.initvalue is not None and mode == 1: - value = c.Initializer(value, ccode(obj.initvalue)) + value = c.Initializer(value, self.ccode(obj.initvalue)) return value @@ -311,7 +315,7 @@ def _args_call(self, args): else: ret.append(i._C_name) except AttributeError: - ret.append(ccode(i)) + ret.append(self.ccode(i)) return ret def _gen_signature(self, o, is_declaration=False): @@ -324,6 +328,7 @@ def _gen_signature(self, o, is_declaration=False): signature = TemplateDecl(tparams, signature) else: signature = c.Template(tparams, signature) + return signature def _blankline_logic(self, children): @@ -377,7 +382,7 @@ def visit_tuple(self, o): def visit_PointerCast(self, o): f = o.function i = f.indexed - cstr = ccode(i._C_typedata) + cstr = self.ccode(i._C_typedata) if f.is_PointerArray: # lvalue @@ -399,7 +404,7 @@ def visit_PointerCast(self, o): else: v = f.name if o.flat is None: - shape = ''.join("[%s]" % ccode(i) for i in o.castshape) + shape = ''.join("[%s]" % self.ccode(i) for i in o.castshape) rshape = '(*)%s' % shape lvalue = c.Value(cstr, '(*restrict %s)%s' % (v, shape)) else: @@ -432,9 +437,9 @@ def visit_Dereference(self, o): a0, a1 = o.functions if a1.is_PointerArray or a1.is_TempFunction: i = a1.indexed - cstr = ccode(i._C_typedata) + cstr = self.ccode(i._C_typedata) if o.flat is None: - shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:]) + shape = ''.join("[%s]" % self.ccode(i) for i in a0.symbolic_shape[1:]) rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name, a1.dim.name) lvalue = c.Value(cstr, '(*restrict %s)%s' % (a0.name, shape)) @@ -473,8 +478,8 @@ def visit_Definition(self, o): return self._gen_value(o.function) def visit_Expression(self, o): - lhs = ccode(o.expr.lhs, dtype=o.dtype) - rhs = ccode(o.expr.rhs, dtype=o.dtype) + lhs = self.ccode(o.expr.lhs, dtype=o.dtype) + rhs = self.ccode(o.expr.rhs, dtype=o.dtype) if o.init: code = c.Initializer(self._gen_value(o.expr.lhs, 0), rhs) @@ -487,8 +492,8 @@ def visit_Expression(self, o): return code def visit_AugmentedExpression(self, o): - c_lhs = ccode(o.expr.lhs, dtype=o.dtype) - c_rhs = ccode(o.expr.rhs, dtype=o.dtype) + c_lhs = self.ccode(o.expr.lhs, dtype=o.dtype) + c_rhs = self.ccode(o.expr.rhs, dtype=o.dtype) code = c.Statement("%s %s= %s" % (c_lhs, o.op, c_rhs)) if o.pragmas: code = c.Module(self._visit(o.pragmas) + (code,)) @@ -507,7 +512,7 @@ def visit_Call(self, o, nested_call=False): o.templates) if retobj.is_Indexed or \ isinstance(retobj, (FieldFromComposite, FieldFromPointer)): - return c.Assign(ccode(retobj), call) + return c.Assign(self.ccode(retobj), call) else: return c.Initializer(c.Value(rettype, retobj._C_name), call) @@ -521,9 +526,9 @@ def visit_Conditional(self, o): then_body = c.Block(self._visit(then_body)) if else_body: else_body = c.Block(self._visit(else_body)) - return c.If(ccode(o.condition), then_body, else_body) + return c.If(self.ccode(o.condition), then_body, else_body) else: - return c.If(ccode(o.condition), then_body) + return c.If(self.ccode(o.condition), then_body) def visit_Iteration(self, o): body = flatten(self._visit(i) for i in self._blankline_logic(o.children)) @@ -533,23 +538,23 @@ def visit_Iteration(self, o): # For backward direction flip loop bounds if o.direction == Backward: - loop_init = 'int %s = %s' % (o.index, ccode(_max)) - loop_cond = '%s >= %s' % (o.index, ccode(_min)) + loop_init = 'int %s = %s' % (o.index, self.ccode(_max)) + loop_cond = '%s >= %s' % (o.index, self.ccode(_min)) loop_inc = '%s -= %s' % (o.index, o.limits[2]) else: - loop_init = 'int %s = %s' % (o.index, ccode(_min)) - loop_cond = '%s <= %s' % (o.index, ccode(_max)) + loop_init = 'int %s = %s' % (o.index, self.ccode(_min)) + loop_cond = '%s <= %s' % (o.index, self.ccode(_max)) loop_inc = '%s += %s' % (o.index, o.limits[2]) # Append unbounded indices, if any if o.uindices: - uinit = ['%s = %s' % (i.name, ccode(i.symbolic_min)) for i in o.uindices] + uinit = ['%s = %s' % (i.name, self.ccode(i.symbolic_min)) for i in o.uindices] loop_init = c.Line(', '.join([loop_init] + uinit)) ustep = [] for i in o.uindices: op = '=' if i.is_Modulo else '+=' - ustep.append('%s %s %s' % (i.name, op, ccode(i.symbolic_incr))) + ustep.append('%s %s %s' % (i.name, op, self.ccode(i.symbolic_incr))) loop_inc = c.Line(', '.join([loop_inc] + ustep)) # Create For header+body @@ -566,7 +571,7 @@ def visit_Pragma(self, o): return c.Pragma(o._generate) def visit_While(self, o): - condition = ccode(o.condition) + condition = self.ccode(o.condition) if o.body: body = flatten(self._visit(i) for i in o.children) return c.While(condition, c.Block(body)) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index a5fef9b84f..990200f3e8 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -23,7 +23,7 @@ from devito.operator.profiling import create_profile from devito.operator.registry import operator_selector from devito.mpi import MPI -from devito.parameters import configuration, switchconfig +from devito.parameters import configuration from devito.passes import (Graph, lower_index_derivatives, generate_implicit, generate_macros, minimize_symbols, unevaluate, error_mapper, is_on_device) @@ -479,8 +479,6 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): # Lower IET to a target-specific IET graph = Graph(iet, **kwargs) - - # Specialize graph = cls._specialize_iet(graph, **kwargs) # Instrument the IET for C-level profiling @@ -761,12 +759,11 @@ def _soname(self): @cached_property def ccode(self): - with switchconfig(compiler=self._compiler, language=self._language): - try: - return self._ccode_handler().visit(self) - except (AttributeError, TypeError): - from devito.ir.iet.visitors import CGen - return CGen().visit(self) + try: + return self._ccode_handler(language=self._language).visit(self) + except (AttributeError, TypeError): + from devito.ir.iet.visitors import CGen + return CGen(language=self._language).visit(self) def _jit_compile(self): """ @@ -904,7 +901,8 @@ def apply(self, **kwargs): """ # Compile the operator before building the arguments list # to avoid out of memory with greedy compilers - cfunction = self.cfunction + with self._profiler.timer_on('jit-compile'): + cfunction = self.cfunction # Build the arguments list to invoke the kernel function with self._profiler.timer_on('arguments'): diff --git a/devito/passes/iet/engine.py b/devito/passes/iet/engine.py index b9b5bf15d4..859b55c065 100644 --- a/devito/passes/iet/engine.py +++ b/devito/passes/iet/engine.py @@ -40,10 +40,11 @@ class Graph: The `visit` method collects info about the nodes in the Graph. """ - def __init__(self, iet, options=None, sregistry=None, **kwargs): + def __init__(self, iet, options=None, sregistry=None, language=None, **kwargs): self.efuncs = OrderedDict([(iet.name, iet)]) self.sregistry = sregistry + self.language = language self.includes = [] self.headers = [] @@ -147,7 +148,7 @@ def apply(self, func, **kwargs): # Minimize code size if len(efuncs) > len(self.efuncs): efuncs = reuse_compounds(efuncs, self.sregistry) - efuncs = reuse_efuncs(self.root, efuncs, self.sregistry) + efuncs = reuse_efuncs(self.root, efuncs, self.sregistry, self.language) self.efuncs = efuncs @@ -316,7 +317,7 @@ def _(i, sregistry=None): return i._rebuild(pname=pname, cfields=cfields, ncfields=ncfields, function=None) -def reuse_efuncs(root, efuncs, sregistry=None): +def reuse_efuncs(root, efuncs, sregistry=None, language=None): """ Generalise `efuncs` so that syntactically identical Callables may be dropped, thus maximizing code reuse. diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index f803117343..1e7a41c9b9 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -765,14 +765,11 @@ class SizeOf(DefFunction): def __new__(cls, intype, stars=None, **kwargs): stars = stars or '' - if not isinstance(intype, (str, ReservedWord)): - intype = dtype_to_ctype(intype) - if intype in ctypes_vector_mapper.values(): - idx = list(ctypes_vector_mapper.values()).index(intype) + ctype = dtype_to_ctype(intype) + if ctype in ctypes_vector_mapper.values(): + idx = list(ctypes_vector_mapper.values()).index(ctype) intype = list(ctypes_vector_mapper.keys())[idx] - else: - intype = ctypes_to_cstr(intype) newobj = super().__new__(cls, 'sizeof', arguments=f'{intype}{stars}', **kwargs) newobj.stars = stars diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 76980ae557..d167b89c73 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -20,7 +20,7 @@ from devito.symbolics.inspection import has_integer_args, sympy_dtype from devito.symbolics.extended_dtypes import c_complex, c_double_complex from devito.types.basic import AbstractFunction -from devito.tools import ctypes_to_cstr +from devito.tools import ctypes_to_cstr, dtype_to_ctype __all__ = ['ccode'] @@ -95,6 +95,10 @@ def parenthesize(self, item, level, strict=False): return super().parenthesize(item, level, strict=strict) def _print_type(self, expr): + try: + expr = dtype_to_ctype(expr) + except TypeError: + pass try: return self.type_mappings[expr] except KeyError: @@ -422,7 +426,7 @@ class AccDevitoPrinter(CXXDevitoPrinter): 'openacc': AccDevitoPrinter} -def ccode(expr, **settings): +def ccode(expr, language=None, **settings): """Generate C++ code from an expression. Parameters @@ -438,5 +442,6 @@ def ccode(expr, **settings): The resulting code as a C++ string. If something went south, returns the input ``expr`` itself. """ - printer = printer_registry.get(configuration['language'], CDevitoPrinter) + lang = language or configuration['language'] + printer = printer_registry.get(lang, CDevitoPrinter) return printer(settings=settings).doprint(expr, None) diff --git a/devito/types/basic.py b/devito/types/basic.py index f0bd493855..b2bb6a71b5 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -1,7 +1,7 @@ import abc import inspect from collections import namedtuple -from ctypes import POINTER, _Pointer, c_char_p, c_char +from ctypes import POINTER, _Pointer, c_char_p, c_char, Structure from functools import reduce, cached_property from operator import mul @@ -87,13 +87,18 @@ def _C_typedata(self): if isinstance(_type, CustomDtype): return _type + _pointer = False while issubclass(_type, _Pointer): + _pointer = True _type = _type._type_ # `ctypes` treats C strings specially if _type is c_char_p: _type = c_char + if issubclass(_type, Structure) and _pointer: + _type = f'struct {_type.__name__}' + return _type @abc.abstractproperty diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 44d94dbe6a..a0dddec575 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -59,7 +59,8 @@ def _config_kwargs(platform: str, language: str) -> dict[str, str]: @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('kwargs', _configs) -def test_dtype_mapping(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None: +def test_dtype_mapping(dtype: np.dtype[np.inexact], kwargs: dict[str, str], + expected=None) -> None: """ Tests that half and complex floats' dtypes result in the correct type strings in generated code. @@ -78,7 +79,7 @@ def test_dtype_mapping(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> N params: dict[str, Basic] = {p.name: p for p in op.parameters} _u, _c = params['u'], params['c'] assert type(_u.indexed._C_ctype._type_()) == ctypes_vector_mapper[dtype] - assert _c._C_ctype == ctypes_vector_mapper[dtype] + assert _c._C_ctype == expected or ctypes_vector_mapper[dtype] @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) From 62f48623fc09708ff8ef76cc1d1a5b4340fff460 Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 23 Jan 2025 13:05:10 -0500 Subject: [PATCH 47/58] compiler: make sure complex ctype is handled properly for typedata --- devito/symbolics/extended_dtypes.py | 4 ++++ devito/types/basic.py | 11 +++++++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index 1a69aefec1..e0e0309944 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -34,6 +34,8 @@ class c_complex(NoDeclStruct): _fields_ = [('real', ctypes.c_float), ('imag', ctypes.c_float)] + _base_dtype = True + @classmethod def from_param(cls, val): return cls(val.real, val.imag) @@ -44,6 +46,8 @@ class c_double_complex(NoDeclStruct): _fields_ = [('real', ctypes.c_double), ('imag', ctypes.c_double)] + _base_dtype = True + @classmethod def from_param(cls, val): return cls(val.real, val.imag) diff --git a/devito/types/basic.py b/devito/types/basic.py index b2bb6a71b5..0a217a6dbf 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -87,17 +87,20 @@ def _C_typedata(self): if isinstance(_type, CustomDtype): return _type - _pointer = False while issubclass(_type, _Pointer): - _pointer = True _type = _type._type_ # `ctypes` treats C strings specially if _type is c_char_p: _type = c_char - if issubclass(_type, Structure) and _pointer: - _type = f'struct {_type.__name__}' + try: + # We have internal types such as c_complex that are + # Structure too but should be treated as plain c_type + _type._base_dtype + except AttributeError: + if issubclass(_type, Structure): + _type = f'struct {_type.__name__}' return _type From 4193b674fb0d6ba9bf5f410f36c6d92c7023bf7b Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 23 Jan 2025 16:09:25 -0500 Subject: [PATCH 48/58] symbolics: cleaner repr of Cast --- devito/symbolics/extended_sympy.py | 8 ++++++-- examples/userapi/06_sparse_operations.ipynb | 8 ++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 1e7a41c9b9..cb1d835c08 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -10,7 +10,8 @@ from devito.finite_differences.elementary import Min, Max from devito.tools import (Pickable, Bunch, as_tuple, is_integer, float2, # noqa float3, float4, double2, double3, double4, int2, int3, - int4, dtype_to_ctype, ctypes_to_cstr, ctypes_vector_mapper) + int4, dtype_to_ctype, ctypes_to_cstr, ctypes_vector_mapper, + ctypes_to_cstr) from devito.types import Symbol from devito.types.basic import Basic @@ -422,7 +423,10 @@ def _C_ctype(self): @property def _op(self): - return '(%s)' % self._C_ctype + return '(%s)' % ctypes_to_cstr(self._C_ctype) + + def __str__(self): + return "%s%s" % (self._op, self.base) class IndexedPointer(sympy.Expr, Pickable, BasicWrapperMixin): diff --git a/examples/userapi/06_sparse_operations.ipynb b/examples/userapi/06_sparse_operations.ipynb index 20b55e455b..3b9085dc82 100644 --- a/examples/userapi/06_sparse_operations.ipynb +++ b/examples/userapi/06_sparse_operations.ipynb @@ -277,8 +277,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "Eq(posx, (int)(floor((-o_x + s_coords(p_s, 0))/h_x)))\n", - "Eq(posy, (int)(floor((-o_y + s_coords(p_s, 1))/h_y)))\n", + "Eq(posx, (int)floor((-o_x + s_coords(p_s, 0))/h_x))\n", + "Eq(posy, (int)floor((-o_y + s_coords(p_s, 1))/h_y))\n", "Eq(px, -floor((-o_x + s_coords(p_s, 0))/h_x) + (-o_x + s_coords(p_s, 0))/h_x)\n", "Eq(py, -floor((-o_y + s_coords(p_s, 1))/h_y) + (-o_y + s_coords(p_s, 1))/h_y)\n", "Eq(sum, 0.0)\n", @@ -484,8 +484,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "Eq(posx, (int)(floor((-o_x + s_coords(p_s, 0))/h_x)))\n", - "Eq(posy, (int)(floor((-o_y + s_coords(p_s, 1))/h_y)))\n", + "Eq(posx, (int)floor((-o_x + s_coords(p_s, 0))/h_x))\n", + "Eq(posy, (int)floor((-o_y + s_coords(p_s, 1))/h_y))\n", "Eq(sum, 0.0)\n", "Inc(sum, wsincrsx(p_s, rsx + 3)*wsincrsy(p_s, rsy + 3)*f(t, rsx + posx, rsy + posy))\n", "Eq(s(time, p_s), sum)\n" From 8fa25c167ded20905878698296316c285828b2de Mon Sep 17 00:00:00 2001 From: mloubout Date: Fri, 24 Jan 2025 08:29:03 -0500 Subject: [PATCH 49/58] test: improve dtype tests log --- tests/test_dtypes.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index a0dddec575..5423056d40 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -24,7 +24,6 @@ def _get_language(language: str, **_) -> type[LangBB]: """ Gets the language building block type from parametrized kwargs. """ - return _languages[language] @@ -32,7 +31,6 @@ def _get_printer(language: str, **_) -> type[_DevitoPrinterBase]: """ Gets the printer building block type from parametrized kwargs. """ - return printer_registry[language] @@ -40,7 +38,6 @@ def _config_kwargs(platform: str, language: str) -> dict[str, str]: """ Generates kwargs for Operator to test language-specific behavior. """ - return { 'platform': platform, 'language': language, @@ -57,15 +54,19 @@ def _config_kwargs(platform: str, language: str) -> dict[str, str]: ] +def kw_id(kwargs): + # For more readable log + return "-".join(f'{k}' for k in kwargs.values()) + + @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) -@pytest.mark.parametrize('kwargs', _configs) +@pytest.mark.parametrize('kwargs', _configs, ids=kw_id) def test_dtype_mapping(dtype: np.dtype[np.inexact], kwargs: dict[str, str], expected=None) -> None: """ Tests that half and complex floats' dtypes result in the correct type strings in generated code. """ - # Set up an operator grid = Grid(shape=(3, 3)) x, y = grid.dimensions @@ -83,13 +84,12 @@ def test_dtype_mapping(dtype: np.dtype[np.inexact], kwargs: dict[str, str], @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) -@pytest.mark.parametrize('kwargs', _configs) +@pytest.mark.parametrize('kwargs', _configs, ids=kw_id) def test_cse_ctypes(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None: """ Tests that variables introduced by CSE have the correct type strings in the generated code. """ - # Retrieve the language-specific type mapping printer: type[_DevitoPrinterBase] = _get_printer(**kwargs) @@ -108,14 +108,13 @@ def test_cse_ctypes(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None @pytest.mark.parametrize('dtype', [np.float32, np.complex64, np.complex128]) -@pytest.mark.parametrize('kwargs', _configs) +@pytest.mark.parametrize('kwargs', _configs, ids=kw_id) def test_complex_headers(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None: np.dtype """ Tests that the correct complex headers are included when complex dtypes are present in the operator, and omitted otherwise. """ - # Set up an operator grid = Grid(shape=(3, 3)) x, y = grid.dimensions @@ -134,7 +133,7 @@ def test_complex_headers(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) -@pytest.mark.parametrize('kwargs', _configs) +@pytest.mark.parametrize('kwargs', _configs, ids=kw_id) def test_imag_unit(dtype: np.complexfloating, kwargs: dict[str, str]) -> None: """ Tests that the correct literal is used for the imaginary unit. @@ -172,7 +171,6 @@ def test_math_functions(dtype: np.dtype[np.inexact], and assigned appropriately for different float precisions and for complex floats/doubles. """ - # Get the expected function call string call_str = str(sym) if np.issubdtype(dtype, np.complexfloating): @@ -198,7 +196,6 @@ def test_complex_override(dtype: np.dtype[np.complexfloating]) -> None: """ Tests overriding complex values in op.apply(). """ - grid = Grid(shape=(5, 5)) x, y = grid.dimensions @@ -221,7 +218,6 @@ def test_complex_time_deriv(dtype: np.dtype[np.complexfloating]) -> None: """ Tests taking the time derivative of a complex-valued function. """ - grid = Grid(shape=(5, 5)) x, y = grid.dimensions t = grid.time_dim @@ -248,7 +244,6 @@ def test_complex_space_deriv(dtype: np.dtype[np.complexfloating]) -> None: Tests taking the space derivative of a complex-valued function, with respect to the real and imaginary axes. """ - grid = Grid(shape=(7, 7), dtype=dtype) x, y = grid.dimensions From aa908fe6088d7033367b6750b666a92eae23fcec Mon Sep 17 00:00:00 2001 From: mloubout Date: Sat, 25 Jan 2025 23:48:06 -0500 Subject: [PATCH 50/58] compiler: make sure cpp is used for c++ compilers --- devito/symbolics/printer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index d167b89c73..18fde09ac6 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -422,8 +422,8 @@ class AccDevitoPrinter(CXXDevitoPrinter): printer_registry: dict[str, type[_DevitoPrinterBase]] = { - 'C': CDevitoPrinter, 'openmp': CDevitoPrinter, - 'openacc': AccDevitoPrinter} + 'C': CDevitoPrinter, 'CXX': CXXDevitoPrinter, + 'openmp': CDevitoPrinter, 'openacc': AccDevitoPrinter} def ccode(expr, language=None, **settings): @@ -443,5 +443,8 @@ def ccode(expr, language=None, **settings): the input ``expr`` itself. """ lang = language or configuration['language'] + cpp = settings.get('compiler', configuration['compiler'])._cpp + if lang in ['C', 'openmp'] and cpp: + lang = 'CXX' printer = printer_registry.get(lang, CDevitoPrinter) return printer(settings=settings).doprint(expr, None) From 64ee30cad70a18b66c633b7c8a23f496bf816011 Mon Sep 17 00:00:00 2001 From: mloubout Date: Mon, 27 Jan 2025 09:56:11 -0500 Subject: [PATCH 51/58] compiler: make printer part of the target and differentiate C and CXX --- devito/core/cpu.py | 11 +++-- devito/ir/__init__.py | 1 + devito/ir/cgen/__init__.py | 1 + devito/{symbolics => ir/cgen}/printer.py | 61 +++--------------------- devito/operator/registry.py | 2 +- devito/passes/iet/languages/C.py | 29 ++++++++++- devito/passes/iet/languages/CXX.py | 28 ++++++++++- devito/passes/iet/languages/openacc.py | 13 ++++- devito/passes/iet/languages/targets.py | 41 ++++++++++++++-- tests/test_dtypes.py | 6 +-- 10 files changed, 121 insertions(+), 72 deletions(-) create mode 100644 devito/ir/cgen/__init__.py rename devito/{symbolics => ir/cgen}/printer.py (85%) diff --git a/devito/core/cpu.py b/devito/core/cpu.py index b9baedb237..e618ac2826 100644 --- a/devito/core/cpu.py +++ b/devito/core/cpu.py @@ -8,7 +8,8 @@ from devito.passes.clusters import (Lift, blocking, buffering, cire, cse, factorize, fission, fuse, optimize_pows, optimize_hyperplanes) -from devito.passes.iet import (CTarget, OmpTarget, avoid_denormals, linearize, +from devito.passes.iet import (CTarget, CXXTarget, COmpTarget, CXXOmpTarget, + avoid_denormals, linearize, mpiize, hoist_prodders, relax_incr_dimensions, check_stability) from devito.tools import timed_pass @@ -244,7 +245,7 @@ def _normalize_kwargs(cls, **kwargs): class Cpu64CustomOperator(Cpu64OperatorMixin, CustomOperator): - _Target = OmpTarget + _Target = COmpTarget @classmethod def _make_dsl_passes_mapper(cls, **kwargs): @@ -325,7 +326,7 @@ class Cpu64NoopCOperator(Cpu64NoopOperator): class Cpu64NoopOmpOperator(Cpu64NoopOperator): - _Target = OmpTarget + _Target = COmpTarget class Cpu64AdvCOperator(Cpu64AdvOperator): @@ -333,7 +334,7 @@ class Cpu64AdvCOperator(Cpu64AdvOperator): class Cpu64AdvOmpOperator(Cpu64AdvOperator): - _Target = OmpTarget + _Target = COmpTarget class Cpu64FsgCOperator(Cpu64FsgOperator): @@ -341,4 +342,4 @@ class Cpu64FsgCOperator(Cpu64FsgOperator): class Cpu64FsgOmpOperator(Cpu64FsgOperator): - _Target = OmpTarget + _Target = COmpTarget diff --git a/devito/ir/__init__.py b/devito/ir/__init__.py index 0cfa5730eb..d700e6573d 100644 --- a/devito/ir/__init__.py +++ b/devito/ir/__init__.py @@ -2,3 +2,4 @@ from devito.ir.equations import * # noqa from devito.ir.clusters import * # noqa from devito.ir.iet import * # noqa +from devito.ir.printer import * # noqa \ No newline at end of file diff --git a/devito/ir/cgen/__init__.py b/devito/ir/cgen/__init__.py new file mode 100644 index 0000000000..b9d0e6ed2c --- /dev/null +++ b/devito/ir/cgen/__init__.py @@ -0,0 +1 @@ +from devito.ir.cgen.printer import * # noqa \ No newline at end of file diff --git a/devito/symbolics/printer.py b/devito/ir/cgen/printer.py similarity index 85% rename from devito/symbolics/printer.py rename to devito/ir/cgen/printer.py index 18fde09ac6..51ea4809d3 100644 --- a/devito/symbolics/printer.py +++ b/devito/ir/cgen/printer.py @@ -10,25 +10,22 @@ from sympy.core import S from sympy.core.numbers import equal_valued, Float from sympy.printing.codeprinter import CodePrinter -from sympy.printing.c import C99CodePrinter -from sympy.printing.cxx import CXX11CodePrinter from sympy.logic.boolalg import BooleanFunction from sympy.printing.precedence import PRECEDENCE_VALUES, precedence from devito import configuration from devito.arch.compiler import AOMPCompiler from devito.symbolics.inspection import has_integer_args, sympy_dtype -from devito.symbolics.extended_dtypes import c_complex, c_double_complex from devito.types.basic import AbstractFunction from devito.tools import ctypes_to_cstr, dtype_to_ctype -__all__ = ['ccode'] +__all__ = ['BasePrinter', 'printer_registry', 'ccode'] _prec_litterals = {np.float16: 'F16', np.float32: 'F', np.complex64: 'F'} -class _DevitoPrinterBase(CodePrinter): +class BasePrinter(CodePrinter): """ Decorator for sympy.printing.ccode.CCodePrinter. @@ -366,7 +363,7 @@ def _print_Fallback(self, expr): # Lifted from SymPy so that we go through our own `_print_math_func` for k in ('exp log sin cos tan ceiling floor').split(): - setattr(_DevitoPrinterBase, '_print_%s' % k, _DevitoPrinterBase._print_math_func) + setattr(BasePrinter, '_print_%s' % k, BasePrinter._print_math_func) # Always parenthesize IntDiv and InlineIf within expressions @@ -377,53 +374,10 @@ def _print_Fallback(self, expr): # Sympy 1.11 has introduced a bug in `_print_Add`, so we enforce here # to always use the correct one from our printer if Version(sympy.__version__) >= Version("1.11"): - setattr(sympy.printing.str.StrPrinter, '_print_Add', _DevitoPrinterBase._print_Add) + setattr(sympy.printing.str.StrPrinter, '_print_Add', BasePrinter._print_Add) -class CDevitoPrinter(_DevitoPrinterBase, C99CodePrinter): - - _default_settings = {**_DevitoPrinterBase._default_settings, - **C99CodePrinter._default_settings} - _func_litterals = {np.float32: 'f', np.complex64: 'f'} - _func_prefix = {np.float32: 'f', np.float64: 'f', - np.complex64: 'c', np.complex128: 'c'} - - # These cannot go through _print_xxx because they are classes not - # instances - type_mappings = {**C99CodePrinter.type_mappings, - c_complex: 'float _Complex', - c_double_complex: 'double _Complex'} - - def _print_ImaginaryUnit(self, expr): - return '_Complex_I' - - -class CXXDevitoPrinter(_DevitoPrinterBase, CXX11CodePrinter): - - _default_settings = {**_DevitoPrinterBase._default_settings, - **CXX11CodePrinter._default_settings} - _ns = "std::" - _func_litterals = {} - _func_prefix = {np.float32: 'f', np.float64: 'f'} - - # These cannot go through _print_xxx because they are classes not - # instances - type_mappings = {**CXX11CodePrinter.type_mappings, - c_complex: 'std::complex', - c_double_complex: 'std::complex'} - - def _print_ImaginaryUnit(self, expr): - return f'1i{self.prec_literal(expr).lower()}' - - -class AccDevitoPrinter(CXXDevitoPrinter): - - pass - - -printer_registry: dict[str, type[_DevitoPrinterBase]] = { - 'C': CDevitoPrinter, 'CXX': CXXDevitoPrinter, - 'openmp': CDevitoPrinter, 'openacc': AccDevitoPrinter} +printer_registry: dict[str, type[BasePrinter]] = {'default': BasePrinter} def ccode(expr, language=None, **settings): @@ -443,8 +397,5 @@ def ccode(expr, language=None, **settings): the input ``expr`` itself. """ lang = language or configuration['language'] - cpp = settings.get('compiler', configuration['compiler'])._cpp - if lang in ['C', 'openmp'] and cpp: - lang = 'CXX' - printer = printer_registry.get(lang, CDevitoPrinter) + printer = printer_registry.get(lang, BasePrinter) return printer(settings=settings).doprint(expr, None) diff --git a/devito/operator/registry.py b/devito/operator/registry.py index 04c1000866..963cf792fe 100644 --- a/devito/operator/registry.py +++ b/devito/operator/registry.py @@ -26,7 +26,7 @@ class OperatorRegistry(OrderedDict, metaclass=Singleton): """ _modes = ('noop', 'advanced', 'advanced-fsg') - _languages = ('C', 'openmp', 'openacc', 'cuda', 'hip', 'sycl') + _languages = ('C', 'CXX', 'openmp', 'openacc', 'cuda', 'hip', 'sycl') _accepted = _modes + tuple(product(_modes, _languages)) def add(self, operator, platform, mode, language='C'): diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index ff50e54205..75aa1991ff 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -1,7 +1,11 @@ -from devito.ir import Call +import numpy as np +from sympy.printing.c import C99CodePrinter + +from devito.ir import Call, BasePrinter, printer_registry from devito.passes.iet.definitions import DataManager from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.langbase import LangBB +from devito.symbolics.extended_dtypes import c_complex, c_double_complex __all__ = ['CBB', 'CDataManager', 'COrchestrator'] @@ -31,3 +35,26 @@ class CDataManager(DataManager): class COrchestrator(Orchestrator): lang = CBB + + +class CPrinter(BasePrinter, C99CodePrinter): + + _default_settings = {**BasePrinter._default_settings, + **C99CodePrinter._default_settings} + _func_litterals = {np.float32: 'f', np.complex64: 'f'} + _func_prefix = {np.float32: 'f', np.float64: 'f', + np.complex64: 'c', np.complex128: 'c'} + + # These cannot go through _print_xxx because they are classes not + # instances + type_mappings = {**C99CodePrinter.type_mappings, + c_complex: 'float _Complex', + c_double_complex: 'double _Complex'} + + def _print_ImaginaryUnit(self, expr): + return '_Complex_I' + + +printer_registry['C'] = CPrinter +printer_registry['openmp'] = CPrinter +printer_registry['Copenmp'] = CPrinter diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index bfa6bebe35..71aaeba578 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -1,5 +1,9 @@ -from devito.ir import Call, UsingNamespace +import numpy as np +from sympy.printing.cxx import CXX11CodePrinter + +from devito.ir import Call, UsingNamespace, BasePrinter, printer_registry from devito.passes.iet.langbase import LangBB +from devito.symbolics.extended_dtypes import c_complex, c_double_complex __all__ = ['CXXBB'] @@ -61,3 +65,25 @@ class CXXBB(LangBB): 'complex-namespace': [UsingNamespace('std::complex_literals')], 'def-complex': std_arith, } + + +class CXXPrinter(BasePrinter, CXX11CodePrinter): + + _default_settings = {**BasePrinter._default_settings, + **CXX11CodePrinter._default_settings} + _ns = "std::" + _func_litterals = {} + _func_prefix = {np.float32: 'f', np.float64: 'f'} + + # These cannot go through _print_xxx because they are classes not + # instances + type_mappings = {**CXX11CodePrinter.type_mappings, + c_complex: 'std::complex', + c_double_complex: 'std::complex'} + + def _print_ImaginaryUnit(self, expr): + return f'1i{self.prec_literal(expr).lower()}' + + +printer_registry['CXX'] = CXXPrinter +printer_registry['CXXopenmp'] = CXXPrinter diff --git a/devito/passes/iet/languages/openacc.py b/devito/passes/iet/languages/openacc.py index e67dc7eee2..6b85ecc626 100644 --- a/devito/passes/iet/languages/openacc.py +++ b/devito/passes/iet/languages/openacc.py @@ -2,14 +2,15 @@ from devito.arch import AMDGPUX, NVIDIAX from devito.ir import (Call, DeviceCall, DummyExpr, EntryFunction, List, Block, - ParallelTree, Pragma, Return, FindSymbols, make_callable) + ParallelTree, Pragma, Return, FindSymbols, make_callable, + printer_registry) from devito.passes import needs_transfer, is_on_device from devito.passes.iet.definitions import DeviceAwareDataManager from devito.passes.iet.engine import iet_pass from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.parpragma import (PragmaDeviceAwareTransformer, PragmaLangBB, PragmaIteration, PragmaTransfer) -from devito.passes.iet.languages.CXX import CXXBB +from devito.passes.iet.languages.CXX import CXXBB, CXXPrinter from devito.passes.iet.languages.openmp import OmpRegion, OmpIteration from devito.symbolics import FieldFromPointer, Macro, cast_mapper from devito.tools import filter_ordered, UnboundTuple @@ -263,3 +264,11 @@ def place_devptr(self, iet, **kwargs): class AccOrchestrator(Orchestrator): lang = AccBB + + +class AccPrinter(CXXPrinter): + + pass + + +printer_registry['openacc'] = AccPrinter diff --git a/devito/passes/iet/languages/targets.py b/devito/passes/iet/languages/targets.py index 4ac8d94398..248b61f960 100644 --- a/devito/passes/iet/languages/targets.py +++ b/devito/passes/iet/languages/targets.py @@ -1,9 +1,10 @@ -from devito.passes.iet.languages.C import CDataManager, COrchestrator +from devito.passes.iet.languages.C import CDataManager, COrchestrator, CPrinter +from devito.passes.iet.languages.CXX import CXXPrinter from devito.passes.iet.languages.openmp import (SimdOmpizer, Ompizer, DeviceOmpizer, OmpDataManager, DeviceOmpDataManager, OmpOrchestrator, DeviceOmpOrchestrator) from devito.passes.iet.languages.openacc import (DeviceAccizer, DeviceAccDataManager, - AccOrchestrator) + AccOrchestrator, AccPrinter) from devito.passes.iet.instrument import instrument __all__ = ['CTarget', 'OmpTarget', 'DeviceOmpTarget', 'DeviceAccTarget'] @@ -13,6 +14,7 @@ class Target: Parizer = None DataManager = None Orchestrator = None + Printer = None @classmethod def lang(cls): @@ -27,21 +29,52 @@ class CTarget(Target): Parizer = SimdOmpizer DataManager = CDataManager Orchestrator = COrchestrator + Printer = CPrinter -class OmpTarget(Target): +class CXXTarget(Target): + Parizer = SimdOmpizer + DataManager = CDataManager + Orchestrator = COrchestrator + Printer = CXXPrinter + + +class COmpTarget(Target): Parizer = Ompizer DataManager = OmpDataManager Orchestrator = OmpOrchestrator + Printer = CPrinter + + +OmpTarget = COmpTarget + + +class CXXOmpTarget(Target): + Parizer = Ompizer + DataManager = OmpDataManager + Orchestrator = OmpOrchestrator + Printer = CXXPrinter + + +class DeviceCOmpTarget(Target): + Parizer = DeviceOmpizer + DataManager = DeviceOmpDataManager + Orchestrator = DeviceOmpOrchestrator + Printer = CPrinter + + +DeviceOmpTarget = DeviceCOmpTarget -class DeviceOmpTarget(Target): +class DeviceCXXOmpTarget(Target): Parizer = DeviceOmpizer DataManager = DeviceOmpDataManager Orchestrator = DeviceOmpOrchestrator + Printer = CXXPrinter class DeviceAccTarget(Target): Parizer = DeviceAccizer DataManager = DeviceAccDataManager Orchestrator = AccOrchestrator + Printer = AccPrinter diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 5423056d40..597271a172 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -8,7 +8,7 @@ from devito.passes.iet.languages.openacc import AccBB from devito.passes.iet.languages.openmp import OmpBB from devito.symbolics.extended_dtypes import ctypes_vector_mapper -from devito.symbolics.printer import printer_registry, _DevitoPrinterBase +from devito.symbolics.printer import printer_registry, BasePrinter from devito.types.basic import Basic, Scalar, Symbol from devito.types.dense import TimeFunction @@ -27,7 +27,7 @@ def _get_language(language: str, **_) -> type[LangBB]: return _languages[language] -def _get_printer(language: str, **_) -> type[_DevitoPrinterBase]: +def _get_printer(language: str, **_) -> type[BasePrinter]: """ Gets the printer building block type from parametrized kwargs. """ @@ -91,7 +91,7 @@ def test_cse_ctypes(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None the generated code. """ # Retrieve the language-specific type mapping - printer: type[_DevitoPrinterBase] = _get_printer(**kwargs) + printer: type[BasePrinter] = _get_printer(**kwargs) # Set up an operator grid = Grid(shape=(3, 3)) From 5d1fe7ab6961f1f842cbc60fb28786ae407efeba Mon Sep 17 00:00:00 2001 From: mloubout Date: Mon, 27 Jan 2025 10:41:47 -0500 Subject: [PATCH 52/58] compiler: add all cxx target to operator registry --- devito/arch/compiler.py | 45 +++++++++++---------- devito/core/__init__.py | 33 ++++++++++++++-- devito/core/arm.py | 25 +++++++----- devito/core/cpu.py | 37 +++++++++++++++++- devito/core/intel.py | 11 +++++- devito/core/operator.py | 5 +++ devito/core/power.py | 8 +++- devito/ir/__init__.py | 2 +- devito/ir/cgen/printer.py | 54 ++++++++++++++++++-------- devito/ir/iet/nodes.py | 9 +---- devito/ir/iet/visitors.py | 38 +++++++++--------- devito/operator/operator.py | 25 ++++++++---- devito/operator/registry.py | 3 +- devito/passes/iet/languages/C.py | 9 ++--- devito/passes/iet/languages/CXX.py | 8 ++-- devito/passes/iet/languages/openacc.py | 6 +-- devito/passes/iet/languages/targets.py | 3 +- devito/passes/iet/misc.py | 12 +++--- devito/passes/iet/parpragma.py | 6 +-- devito/symbolics/__init__.py | 1 - devito/symbolics/inspection.py | 4 +- tests/test_dtypes.py | 14 +++++-- tests/test_ir.py | 3 +- tests/test_symbolics.py | 4 +- 24 files changed, 240 insertions(+), 125 deletions(-) diff --git a/devito/arch/compiler.py b/devito/arch/compiler.py index 2276ca1d6e..0f74107af2 100644 --- a/devito/arch/compiler.py +++ b/devito/arch/compiler.py @@ -180,7 +180,7 @@ def __init__(self): """ fields = {'cc', 'ld'} - _cpp = False + default_cpp = False def __init__(self, **kwargs): _name = kwargs.pop('name', self.__class__.__name__) @@ -191,6 +191,7 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.__lookup_cmds__() + self._cpp = kwargs.get('cpp', self.default_cpp) self.suffix = kwargs.get('suffix') if not kwargs.get('mpi'): @@ -201,7 +202,7 @@ def __init__(self, **kwargs): self.cc = self.MPICC if self._cpp is False else self.MPICXX self.ld = self.cc # Wanted by the superclass - self.cflags = ['-O3', '-g', '-fPIC', '-Wall', '-std=c99'] + self.cflags = ['-O3', '-g', '-fPIC', '-Wall', f'-std={self.std}'] self.ldflags = ['-shared'] self.include_dirs = [] @@ -253,6 +254,10 @@ def version(self): return version + @property + def std(self): + return 'c++14' if self._cpp else 'c99' + def get_version(self): result, stdout, stderr = call_capture_output((self.cc, "--version")) if result != 0: @@ -492,8 +497,8 @@ def __init_finalize__(self, **kwargs): language = kwargs.pop('language', configuration['language']) platform = kwargs.pop('platform', configuration['platform']) - if isinstance(platform, NvidiaDevice): - self.cflags.remove('-std=c99') + if platform is NVIDIAX: + self.cflags.remove(f'-std={self.std}') # Add flags for OpenMP offloading if language in ['C', 'openmp']: cc = get_nvidia_cc() @@ -501,7 +506,7 @@ def __init_finalize__(self, **kwargs): self.cflags += ['-Xopenmp-target', '-march=sm_%s' % cc] self.ldflags += ['-fopenmp', '-fopenmp-targets=nvptx64-nvidia-cuda'] elif platform is AMDGPUX: - self.cflags.remove('-std=c99') + self.cflags.remove(f'-std={self.std}') # Add flags for OpenMP offloading if language in ['C', 'openmp']: self.ldflags += ['-target', 'x86_64-pc-linux-gnu'] @@ -560,10 +565,10 @@ def __init_finalize__(self, **kwargs): if not configuration['safe-math']: self.cflags.append('-ffast-math') - if isinstance(platform, NvidiaDevice): - self.cflags.remove('-std=c99') + if platform is NVIDIAX: + self.cflags.remove(f'-std={self.std}') elif platform is AMDGPUX: - self.cflags.remove('-std=c99') + self.cflags.remove(f'-std={self.std}') # Add flags for OpenMP offloading if language in ['C', 'openmp']: self.ldflags += ['-target', 'x86_64-pc-linux-gnu'] @@ -599,15 +604,15 @@ def __lookup_cmds__(self): class PGICompiler(Compiler): - _cpp = True + default_cpp = True def __init_finalize__(self, **kwargs): - self.cflags.remove('-std=c99') + self.cflags.remove(f'-std={self.std}') self.cflags.remove('-O3') self.cflags.remove('-Wall') - self.cflags.append('-std=c++14') + self.cflags.append(f'-std={self.std}') language = kwargs.pop('language', configuration['language']) platform = kwargs.pop('platform', configuration['platform']) @@ -650,14 +655,14 @@ def __lookup_cmds__(self): class CudaCompiler(Compiler): - _cpp = True + default_cpp = True def __init_finalize__(self, **kwargs): - self.cflags.remove('-std=c99') + self.cflags.remove(f'-std={self.std}') self.cflags.remove('-Wall') self.cflags.remove('-fPIC') - self.cflags.extend(['-std=c++14', '-Xcompiler', '-fPIC']) + self.cflags.extend([f'-std={self.std}', '-Xcompiler', '-fPIC']) if configuration['mpi']: # We rather use `nvcc` to compile MPI, but for this we have to @@ -724,14 +729,14 @@ def __lookup_cmds__(self): class HipCompiler(Compiler): - _cpp = True + default_cpp = True def __init_finalize__(self, **kwargs): - self.cflags.remove('-std=c99') + self.cflags.remove(f'-std={self.std}') self.cflags.remove('-Wall') self.cflags.remove('-fPIC') - self.cflags.extend(['-std=c++14', '-fPIC']) + self.cflags.extend([f'-std={self.std}', '-fPIC']) if configuration['mpi']: # We rather use `hipcc` to compile MPI, but for this we have to @@ -885,7 +890,7 @@ def __lookup_cmds__(self): class SyclCompiler(OneapiCompiler): - _cpp = True + default_cpp = True def __init_finalize__(self, **kwargs): IntelCompiler.__init_finalize__(self, **kwargs) @@ -896,7 +901,7 @@ def __init_finalize__(self, **kwargs): if language != 'sycl': warning("Expected language sycl with SyclCompiler") - self.cflags.remove('-std=c99') + self.cflags.remove(f'-std={self.std}') self.cflags.append('-fsycl') self.cflags.remove('-g') # -g disables some optimizations in IGC @@ -952,7 +957,7 @@ def __new__(cls, *args, **kwargs): obj = super().__new__(cls) # Keep base to initialize accordingly obj._base = kwargs.pop('base', _base) - obj._cpp = obj._base._cpp + obj.default_cpp = obj._base.default_cpp return obj diff --git a/devito/core/__init__.py b/devito/core/__init__.py index 507c5edf15..47771ce657 100644 --- a/devito/core/__init__.py +++ b/devito/core/__init__.py @@ -2,11 +2,19 @@ from devito.core.cpu import (Cpu64NoopCOperator, Cpu64NoopOmpOperator, Cpu64AdvCOperator, Cpu64AdvOmpOperator, Cpu64FsgCOperator, Cpu64FsgOmpOperator, - Cpu64CustomOperator) + Cpu64CustomOperator, Cpu64CXXCustomOperator, + Cpu64CXXNoopCOperator, Cpu64CXXNoopOmpOperator, + Cpu64AdvCXXOperator, Cpu64CXXAdvOmpOperator, + Cpu64CXXFsgCOperator, Cpu64CXXFsgOmpOperator) + from devito.core.intel import (Intel64AdvCOperator, Intel64AdvOmpOperator, - Intel64FsgCOperator, Intel64FsgOmpOperator) -from devito.core.arm import ArmAdvCOperator, ArmAdvOmpOperator -from devito.core.power import PowerAdvCOperator, PowerAdvOmpOperator + Intel64FsgCOperator, Intel64FsgOmpOperator, + Intel64CXXAdvCOperator, Intel64CXXAdvOmpOperator, + Intel64CXXFsgCOperator, Intel64CXXFsgOmpOperator) +from devito.core.arm import (ArmAdvCOperator, ArmAdvOmpOperator, + ArmCXXAdvCOperator, ArmCXXAdvOmpOperator) +from devito.core.power import (PowerAdvCOperator, PowerAdvOmpOperator, + PowerCXXAdvCOperator, PowerCXXAdvOmpOperator) from devito.core.gpu import (DeviceNoopOmpOperator, DeviceNoopAccOperator, DeviceAdvOmpOperator, DeviceAdvAccOperator, DeviceFsgOmpOperator, DeviceFsgAccOperator, @@ -16,26 +24,43 @@ # Register CPU Operators operator_registry.add(Cpu64CustomOperator, Cpu64, 'custom', 'C') operator_registry.add(Cpu64CustomOperator, Cpu64, 'custom', 'openmp') +operator_registry.add(Cpu64CXXCustomOperator, Cpu64, 'custom', 'CXX') +operator_registry.add(Cpu64CXXCustomOperator, Cpu64, 'custom', 'CXXopenmp') operator_registry.add(Cpu64NoopCOperator, Cpu64, 'noop', 'C') operator_registry.add(Cpu64NoopOmpOperator, Cpu64, 'noop', 'openmp') +operator_registry.add(Cpu64CXXNoopCOperator, Cpu64, 'noop', 'CXX') +operator_registry.add(Cpu64CXXNoopOmpOperator, Cpu64, 'noop', 'CXXopenmp') operator_registry.add(Cpu64AdvCOperator, Cpu64, 'advanced', 'C') operator_registry.add(Cpu64AdvOmpOperator, Cpu64, 'advanced', 'openmp') +operator_registry.add(Cpu64AdvCXXOperator, Cpu64, 'advanced', 'CXX') +operator_registry.add(Cpu64CXXAdvOmpOperator, Cpu64, 'advanced', 'CXXopenmp') operator_registry.add(Cpu64FsgCOperator, Cpu64, 'advanced-fsg', 'C') operator_registry.add(Cpu64FsgOmpOperator, Cpu64, 'advanced-fsg', 'openmp') +operator_registry.add(Cpu64CXXFsgCOperator, Cpu64, 'advanced-fsg', 'CXX') +operator_registry.add(Cpu64CXXFsgOmpOperator, Cpu64, 'advanced-fsg', 'CXXopenmp') operator_registry.add(Intel64AdvCOperator, Intel64, 'advanced', 'C') operator_registry.add(Intel64AdvOmpOperator, Intel64, 'advanced', 'openmp') +operator_registry.add(Intel64CXXAdvCOperator, Intel64, 'advanced', 'CXX') +operator_registry.add(Intel64CXXAdvOmpOperator, Intel64, 'advanced', 'CXXopenmp') + operator_registry.add(Intel64FsgCOperator, Intel64, 'advanced-fsg', 'C') operator_registry.add(Intel64FsgOmpOperator, Intel64, 'advanced-fsg', 'openmp') +operator_registry.add(Intel64CXXFsgCOperator, Intel64, 'advanced-fsg', 'CXX') +operator_registry.add(Intel64CXXFsgOmpOperator, Intel64, 'advanced-fsg', 'CXXopenmp') operator_registry.add(ArmAdvCOperator, Arm, 'advanced', 'C') operator_registry.add(ArmAdvOmpOperator, Arm, 'advanced', 'openmp') +operator_registry.add(ArmCXXAdvCOperator, Arm, 'advanced', 'CXX') +operator_registry.add(ArmCXXAdvOmpOperator, Arm, 'advanced', 'CXXopenmp') operator_registry.add(PowerAdvCOperator, Power, 'advanced', 'C') operator_registry.add(PowerAdvOmpOperator, Power, 'advanced', 'openmp') +operator_registry.add(PowerCXXAdvCOperator, Power, 'advanced', 'CXX') +operator_registry.add(PowerCXXAdvOmpOperator, Power, 'advanced', 'CXXopenmp') # Register Device Operators operator_registry.add(DeviceCustomOmpOperator, Device, 'custom', 'C') diff --git a/devito/core/arm.py b/devito/core/arm.py index 0b765c1b52..a0581aaf01 100644 --- a/devito/core/arm.py +++ b/devito/core/arm.py @@ -1,19 +1,24 @@ -from devito.core.cpu import Cpu64AdvOperator -from devito.passes.iet import CTarget, OmpTarget +from devito.core.cpu import (Cpu64AdvOperator, Cpu64AdvCXXOperator, + Cpu64AdvCOperator) +from devito.passes.iet import OmpTarget, CXXOmpTarget -__all__ = ['ArmAdvCOperator', 'ArmAdvOmpOperator'] +__all__ = ['ArmAdvCOperator', 'ArmAdvOmpOperator', 'ArmCXXAdvCOperator', + 'ArmCXXAdvOmpOperator'] -class ArmAdvOperator(Cpu64AdvOperator): - pass +ArmAdvOperator = Cpu64AdvOperator +ArmAdvCOperator = Cpu64AdvCOperator +ArmCXXAdvOperator = Cpu64AdvCXXOperator +ArmCXXAdvCOperator = Cpu64AdvCXXOperator -class ArmAdvCOperator(ArmAdvOperator): - _Target = CTarget - - -class ArmAdvOmpOperator(ArmAdvOperator): +class ArmAdvOmpOperator(ArmAdvCOperator): _Target = OmpTarget # Avoid nested parallelism on ThunderX2 PAR_NESTED = 4 + + +class ArmCXXAdvOmpOperator(ArmAdvOmpOperator): + _Target = CXXOmpTarget + LINEARIZE = True diff --git a/devito/core/cpu.py b/devito/core/cpu.py index e618ac2826..716f0cd902 100644 --- a/devito/core/cpu.py +++ b/devito/core/cpu.py @@ -78,7 +78,7 @@ def _normalize_kwargs(cls, **kwargs): # Misc o['opt-comms'] = oo.pop('opt-comms', True) - o['linearize'] = oo.pop('linearize', False) + o['linearize'] = oo.pop('linearize', cls.LINEARIZE) o['mapify-reduce'] = oo.pop('mapify-reduce', cls.MAPIFY_REDUCE) o['index-mode'] = oo.pop('index-mode', cls.INDEX_MODE) o['place-transfers'] = oo.pop('place-transfers', True) @@ -318,6 +318,11 @@ def _make_iet_passes_mapper(cls, **kwargs): assert not (set(_known_passes) & set(_known_passes_disabled)) +class Cpu64CXXCustomOperator(Cpu64CustomOperator): + + _Target = CXXTarget + LINEARIZE = True + # Language level @@ -325,21 +330,51 @@ class Cpu64NoopCOperator(Cpu64NoopOperator): _Target = CTarget +class Cpu64CXXNoopCOperator(Cpu64NoopOperator): + _Target = CXXTarget + LINEARIZE = True + + class Cpu64NoopOmpOperator(Cpu64NoopOperator): _Target = COmpTarget +class Cpu64CXXNoopOmpOperator(Cpu64NoopOperator): + _Target = CXXOmpTarget + LINEARIZE = True + + class Cpu64AdvCOperator(Cpu64AdvOperator): _Target = CTarget +class Cpu64AdvCXXOperator(Cpu64AdvOperator): + _Target = CXXTarget + LINEARIZE = True + + class Cpu64AdvOmpOperator(Cpu64AdvOperator): _Target = COmpTarget +class Cpu64CXXAdvOmpOperator(Cpu64AdvOperator): + _Target = CXXOmpTarget + LINEARIZE = True + + class Cpu64FsgCOperator(Cpu64FsgOperator): _Target = CTarget +class Cpu64CXXFsgCOperator(Cpu64FsgOperator): + _Target = CXXTarget + LINEARIZE = True + + class Cpu64FsgOmpOperator(Cpu64FsgOperator): _Target = COmpTarget + + +class Cpu64CXXFsgOmpOperator(Cpu64FsgOperator): + _Target = CXXOmpTarget + LINEARIZE = True diff --git a/devito/core/intel.py b/devito/core/intel.py index 3b8f8b0208..84b840f086 100644 --- a/devito/core/intel.py +++ b/devito/core/intel.py @@ -1,11 +1,18 @@ from devito.core.cpu import (Cpu64AdvCOperator, Cpu64AdvOmpOperator, - Cpu64FsgCOperator, Cpu64FsgOmpOperator) + Cpu64FsgCOperator, Cpu64FsgOmpOperator, + Cpu64AdvCXXOperator, Cpu64CXXAdvOmpOperator, + Cpu64CXXFsgCOperator, Cpu64CXXFsgOmpOperator) __all__ = ['Intel64AdvCOperator', 'Intel64AdvOmpOperator', 'Intel64FsgCOperator', - 'Intel64FsgOmpOperator'] + 'Intel64FsgOmpOperator', 'Intel64CXXAdvCOperator', 'Intel64CXXAdvOmpOperator', + 'Intel64CXXFsgCOperator', 'Intel64CXXFsgOmpOperator'] Intel64AdvCOperator = Cpu64AdvCOperator Intel64AdvOmpOperator = Cpu64AdvOmpOperator Intel64FsgCOperator = Cpu64FsgCOperator Intel64FsgOmpOperator = Cpu64FsgOmpOperator +Intel64CXXAdvCOperator = Cpu64AdvCXXOperator +Intel64CXXAdvOmpOperator = Cpu64CXXAdvOmpOperator +Intel64CXXFsgCOperator = Cpu64CXXFsgCOperator +Intel64CXXFsgOmpOperator = Cpu64CXXFsgOmpOperator diff --git a/devito/core/operator.py b/devito/core/operator.py index e6bfd18916..30adb06982 100644 --- a/devito/core/operator.py +++ b/devito/core/operator.py @@ -143,6 +143,11 @@ class BasicOperator(Operator): The target language constructor, to be specified by subclasses. """ + LINEARIZE = False + """ + Linearize n-dimensional Indexeds. + """ + @classmethod def _normalize_kwargs(cls, **kwargs): # Will be populated with dummy values; this method is actually overriden diff --git a/devito/core/power.py b/devito/core/power.py index ab651a1910..5868df0819 100644 --- a/devito/core/power.py +++ b/devito/core/power.py @@ -1,6 +1,10 @@ -from devito.core.cpu import Cpu64AdvCOperator, Cpu64AdvOmpOperator +from devito.core.cpu import (Cpu64AdvCOperator, Cpu64AdvOmpOperator, + Cpu64AdvCXXOperator, Cpu64CXXAdvOmpOperator) -__all__ = ['PowerAdvCOperator', 'PowerAdvOmpOperator'] +__all__ = ['PowerAdvCOperator', 'PowerAdvOmpOperator', + 'PowerCXXAdvCOperator', 'PowerCXXAdvOmpOperator'] PowerAdvCOperator = Cpu64AdvCOperator PowerAdvOmpOperator = Cpu64AdvOmpOperator +PowerCXXAdvCOperator = Cpu64AdvCXXOperator +PowerCXXAdvOmpOperator = Cpu64CXXAdvOmpOperator diff --git a/devito/ir/__init__.py b/devito/ir/__init__.py index d700e6573d..78a10d0a8c 100644 --- a/devito/ir/__init__.py +++ b/devito/ir/__init__.py @@ -2,4 +2,4 @@ from devito.ir.equations import * # noqa from devito.ir.clusters import * # noqa from devito.ir.iet import * # noqa -from devito.ir.printer import * # noqa \ No newline at end of file +from devito.ir.cgen import * # noqa \ No newline at end of file diff --git a/devito/ir/cgen/printer.py b/devito/ir/cgen/printer.py index 51ea4809d3..20fc3007c0 100644 --- a/devito/ir/cgen/printer.py +++ b/devito/ir/cgen/printer.py @@ -19,10 +19,7 @@ from devito.types.basic import AbstractFunction from devito.tools import ctypes_to_cstr, dtype_to_ctype -__all__ = ['BasePrinter', 'printer_registry', 'ccode'] - - -_prec_litterals = {np.float16: 'F16', np.float32: 'F', np.complex64: 'F'} +__all__ = ['BasePrinter', 'ccode'] class BasePrinter(CodePrinter): @@ -41,6 +38,19 @@ class BasePrinter(CodePrinter): _func_prefix = {} _func_litterals = {} + _qualifiers_mapper = { + 'is_const': 'const', + 'is_volatile': 'volatile', + '_mem_constant': 'static', + '_mem_shared': '', + } + + _prec_litterals = {np.float32: 'F', np.complex64: 'F'} + + _restrict_keyword = 'restrict' + + _default_includes = [] + @property def dtype(self): try: @@ -74,7 +84,7 @@ def _prec(self, expr): return dtype or self.dtype def prec_literal(self, expr): - return _prec_litterals.get(self._prec(expr), '') + return self._prec_litterals.get(self._prec(expr), '') def func_literal(self, expr): return self._func_litterals.get(self._prec(expr), '') @@ -209,17 +219,30 @@ def _print_Mul(self, expr): term = term.replace("(-1)/", f"-{self._prec(expr)(1)}/") return term + def _print_fmath_func(self, name, expr): + args = ",".join([self._print(i) for i in expr.args]) + func = f'{self.func_prefix(expr, abs=True)}{name}{self.func_literal(expr)}' + return f"{self._ns}{func}({args})" + def _print_Min(self, expr): - if has_integer_args(*expr.args) and len(expr.args) == 2: + if len(expr.args) > 2: + return self._print_Min(expr.func(expr.args[0], + expr.func(*expr.args[1:]), + evaluate=False)) + elif has_integer_args(*expr.args) and len(expr.args) == 2: return "MIN(%s)" % self._print(expr.args)[1:-1] else: - return super()._print_Min(expr) + return self._print_fmath_func('min', expr) def _print_Max(self, expr): - if has_integer_args(*expr.args) and len(expr.args) == 2: + if len(expr.args) > 2: + return self._print_Max(expr.func(expr.args[0], + expr.func(*expr.args[1:]), + evaluate=False)) + elif has_integer_args(*expr.args) and len(expr.args) == 2: return "MAX(%s)" % self._print(expr.args)[1:-1] else: - return super()._print_Max(expr) + return self._print_fmath_func('max', expr) def _print_Abs(self, expr): """Print an absolute value. Use `abs` if can infer it is an Integer""" @@ -229,8 +252,7 @@ def _print_Abs(self, expr): if isinstance(self.compiler, AOMPCompiler) and \ not np.issubdtype(self._prec(expr), np.integer): return "fabs(%s)" % self._print(arg) - func = f'{self.func_prefix(arg, abs=True)}abs{self.func_literal(arg)}' - return f"{self._ns}{func}({self._print(arg)})" + return self._print_fmath_func('abs', expr) def _print_Add(self, expr, order=None): """" @@ -377,10 +399,7 @@ def _print_Fallback(self, expr): setattr(sympy.printing.str.StrPrinter, '_print_Add', BasePrinter._print_Add) -printer_registry: dict[str, type[BasePrinter]] = {'default': BasePrinter} - - -def ccode(expr, language=None, **settings): +def ccode(expr, printer=None, **settings): """Generate C++ code from an expression. Parameters @@ -396,6 +415,7 @@ def ccode(expr, language=None, **settings): The resulting code as a C++ string. If something went south, returns the input ``expr`` itself. """ - lang = language or configuration['language'] - printer = printer_registry.get(lang, BasePrinter) + if printer is None: + from devito.passes.iet.languages.C import CPrinter + printer = CPrinter return printer(settings=settings).doprint(expr, None) diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index d98a835a09..2e5620c644 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -11,11 +11,12 @@ from sympy import IndexedBase, sympify from devito.data import FULL +from devito.ir.cgen import ccode from devito.ir.equations import DummyEq, OpInc, OpMin, OpMax from devito.ir.support import (INBOUND, SEQUENTIAL, PARALLEL, PARALLEL_IF_ATOMIC, PARALLEL_IF_PVT, VECTORIZED, AFFINE, Property, Forward, WithLock, PrefetchUpdate, detect_io) -from devito.symbolics import ListInitializer, CallFromPointer, ccode +from devito.symbolics import ListInitializer, CallFromPointer from devito.tools import (Signer, as_tuple, filter_ordered, filter_sorted, flatten, ctypes_to_cstr) from devito.types.basic import (AbstractFunction, AbstractSymbol, Basic, Indexed, @@ -63,12 +64,6 @@ class Node(Signer): appears in this list are treated as traversable fields. """ - _ccode_handler = None - """ - Customizable by subclasses, in particular Operator subclasses which define - backend-specific nodes and, as such, require node-specific handlers. - """ - def __new__(cls, *args, **kwargs): obj = super().__new__(cls) argnames, _, _, defaultvalues, _, _, _ = inspect.getfullargspec(cls.__init__) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index e50b7dd674..2f1987aca9 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -19,7 +19,6 @@ from devito.ir.support.space import Backward from devito.symbolics import (FieldFromComposite, FieldFromPointer, ListInitializer, uxreplace) -from devito.symbolics.printer import ccode from devito.symbolics.extended_dtypes import NoDeclStruct from devito.tools import (GenericVisitor, as_tuple, filter_ordered, filter_sorted, flatten, is_external_ctype, @@ -176,22 +175,23 @@ class CGen(Visitor): Return a representation of the Iteration/Expression tree as a :module:`cgen` tree. """ - def __init__(self, *args, language=None, **kwargs): + def __init__(self, *args, printer=None, **kwargs): super().__init__(*args, **kwargs) - self.language = language - - # The following mappers may be customized by subclasses (that is, - # backend-specific CGen-erators) - _qualifiers_mapper = { - 'is_const': 'const', - 'is_volatile': 'volatile', - '_mem_constant': 'static', - '_mem_shared': '', - } - _restrict_keyword = 'restrict' + if printer is None: + from devito.passes.iet.languages.C import CPrinter + printer = CPrinter + self.printer = printer def ccode(self, expr, **kwargs): - return ccode(expr, language=self.language, **kwargs) + return self.printer(settings=kwargs).doprint(expr, None) + + @property + def _qualifiers_mapper(self): + return self.printer._qualifiers_mapper + + @property + def _restrict_keyword(self): + return self.printer._restrict_keyword def _gen_struct_decl(self, obj, masked=()): """ @@ -228,7 +228,7 @@ def _gen_struct_decl(self, obj, masked=()): except AttributeError: cstr = self.ccode(ct) if ct is c_restrict_void_p: - cstr = '%srestrict' % cstr + cstr = f'{cstr}{self._restrict_keyword}' entries.append(c.Value(cstr, n)) return c.Struct(ctype.__name__, entries) @@ -406,7 +406,7 @@ def visit_PointerCast(self, o): if o.flat is None: shape = ''.join("[%s]" % self.ccode(i) for i in o.castshape) rshape = '(*)%s' % shape - lvalue = c.Value(cstr, '(*restrict %s)%s' % (v, shape)) + lvalue = c.Value(cstr, f'(*{self._restrict_keyword} {v}){shape}') else: rshape = '*' lvalue = c.Value(cstr, '*%s' % v) @@ -442,10 +442,10 @@ def visit_Dereference(self, o): shape = ''.join("[%s]" % self.ccode(i) for i in a0.symbolic_shape[1:]) rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name, a1.dim.name) - lvalue = c.Value(cstr, '(*restrict %s)%s' % (a0.name, shape)) + lvalue = c.Value(cstr, f'(*{self._restrict_keyword} {a0.name}){shape}') else: rvalue = '(%s *) %s[%s]' % (cstr, a1.name, a1.dim.name) - lvalue = c.Value(cstr, '*restrict %s' % a0.name) + lvalue = c.Value(cstr, f'*{self._restrict_keyword} {a0.name}') if a0._data_alignment: lvalue = c.AlignedAttribute(a0._data_alignment, lvalue) else: @@ -471,7 +471,7 @@ def visit_Break(self, o): def visit_Return(self, o): v = 'return' if o.value is not None: - v += f' {ccode(o.value)}' + v += f' {self.ccode(o.value)}' return c.Statement(v) def visit_Definition(self, o): diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 990200f3e8..2dd6eace02 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -145,7 +145,6 @@ class Operator(Callable): """ _default_headers = [('_POSIX_C_SOURCE', '200809L')] - _default_includes = ['stdlib.h', 'math.h', 'sys/time.h'] _default_globals = [] _default_namespaces = [] @@ -195,6 +194,11 @@ def _sanitize_exprs(cls, expressions, **kwargs): return expressions + @classmethod + @property + def _default_includes(cls): + return cls._Target.Printer._default_includes + @classmethod def _build(cls, expressions, **kwargs): # Python- (i.e., compile-) and C-level (i.e., run-time) performance @@ -487,7 +491,7 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): cls._Target.instrument(graph, profiler=profiler, **kwargs) # Extract the necessary macros from the symbolic objects - generate_macros(graph, lang=lang, **kwargs) + generate_macros(graph, lang=lang, printer=cls._Target.Printer, **kwargs) # Add type specific metadata lower_dtypes(graph, lang=lang, **kwargs) @@ -757,13 +761,14 @@ def _soname(self): """A unique name for the shared object resulting from JIT compilation.""" return Signer._digest(self, configuration) + @cached_property + def _printer(self): + return self._Target.Printer + @cached_property def ccode(self): - try: - return self._ccode_handler(language=self._language).visit(self) - except (AttributeError, TypeError): - from devito.ir.iet.visitors import CGen - return CGen(language=self._language).visit(self) + from devito.ir.iet.visitors import CGen + return CGen(printer=self._printer).visit(self) def _jit_compile(self): """ @@ -1404,6 +1409,12 @@ def parse_kwargs(**kwargs): else: kwargs['compiler'] = configuration['compiler'].__new_with__() + # Make sure compiler and language are compatible + if kwargs['compiler']._cpp and kwargs['language'] in ['C', 'openmp']: + kwargs['language'] = 'CXX' if kwargs['language'] == 'C' else 'CXXopenmp' + if 'CXX' in kwargs['language'] and not kwargs['compiler']._cpp: + kwargs['compiler'] = kwargs['compiler'].__new_with__(cpp=True) + # `allocator` kwargs['allocator'] = default_allocator( '%s.%s.%s' % (kwargs['compiler'].__class__.__name__, diff --git a/devito/operator/registry.py b/devito/operator/registry.py index 963cf792fe..c8aac315b7 100644 --- a/devito/operator/registry.py +++ b/devito/operator/registry.py @@ -26,7 +26,8 @@ class OperatorRegistry(OrderedDict, metaclass=Singleton): """ _modes = ('noop', 'advanced', 'advanced-fsg') - _languages = ('C', 'CXX', 'openmp', 'openacc', 'cuda', 'hip', 'sycl') + _languages = ('C', 'CXX', 'openmp', 'Copenmp', 'CXXopenmp', + 'openacc', 'cuda', 'hip', 'sycl') _accepted = _modes + tuple(product(_modes, _languages)) def add(self, operator, platform, mode, language='C'): diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index 75aa1991ff..13044cd8b5 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -1,7 +1,7 @@ import numpy as np from sympy.printing.c import C99CodePrinter -from devito.ir import Call, BasePrinter, printer_registry +from devito.ir import Call, BasePrinter from devito.passes.iet.definitions import DataManager from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.langbase import LangBB @@ -44,6 +44,8 @@ class CPrinter(BasePrinter, C99CodePrinter): _func_litterals = {np.float32: 'f', np.complex64: 'f'} _func_prefix = {np.float32: 'f', np.float64: 'f', np.complex64: 'c', np.complex128: 'c'} + _prec_litterals = {**BasePrinter._prec_litterals, np.float16: 'F16'} + _default_includes = ['stdlib.h', 'math.h', 'sys/time.h'] # These cannot go through _print_xxx because they are classes not # instances @@ -53,8 +55,3 @@ class CPrinter(BasePrinter, C99CodePrinter): def _print_ImaginaryUnit(self, expr): return '_Complex_I' - - -printer_registry['C'] = CPrinter -printer_registry['openmp'] = CPrinter -printer_registry['Copenmp'] = CPrinter diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index 71aaeba578..b031786bed 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -1,7 +1,7 @@ import numpy as np from sympy.printing.cxx import CXX11CodePrinter -from devito.ir import Call, UsingNamespace, BasePrinter, printer_registry +from devito.ir import Call, UsingNamespace, BasePrinter from devito.passes.iet.langbase import LangBB from devito.symbolics.extended_dtypes import c_complex, c_double_complex @@ -74,6 +74,8 @@ class CXXPrinter(BasePrinter, CXX11CodePrinter): _ns = "std::" _func_litterals = {} _func_prefix = {np.float32: 'f', np.float64: 'f'} + _restrict_keyword = '__restrict' + _default_includes = ['stdlib.h', 'cmath', 'sys/time.h'] # These cannot go through _print_xxx because they are classes not # instances @@ -83,7 +85,3 @@ class CXXPrinter(BasePrinter, CXX11CodePrinter): def _print_ImaginaryUnit(self, expr): return f'1i{self.prec_literal(expr).lower()}' - - -printer_registry['CXX'] = CXXPrinter -printer_registry['CXXopenmp'] = CXXPrinter diff --git a/devito/passes/iet/languages/openacc.py b/devito/passes/iet/languages/openacc.py index 6b85ecc626..25d8e6e478 100644 --- a/devito/passes/iet/languages/openacc.py +++ b/devito/passes/iet/languages/openacc.py @@ -2,8 +2,7 @@ from devito.arch import AMDGPUX, NVIDIAX from devito.ir import (Call, DeviceCall, DummyExpr, EntryFunction, List, Block, - ParallelTree, Pragma, Return, FindSymbols, make_callable, - printer_registry) + ParallelTree, Pragma, Return, FindSymbols, make_callable) from devito.passes import needs_transfer, is_on_device from devito.passes.iet.definitions import DeviceAwareDataManager from devito.passes.iet.engine import iet_pass @@ -269,6 +268,3 @@ class AccOrchestrator(Orchestrator): class AccPrinter(CXXPrinter): pass - - -printer_registry['openacc'] = AccPrinter diff --git a/devito/passes/iet/languages/targets.py b/devito/passes/iet/languages/targets.py index 248b61f960..d0070c8632 100644 --- a/devito/passes/iet/languages/targets.py +++ b/devito/passes/iet/languages/targets.py @@ -7,7 +7,8 @@ AccOrchestrator, AccPrinter) from devito.passes.iet.instrument import instrument -__all__ = ['CTarget', 'OmpTarget', 'DeviceOmpTarget', 'DeviceAccTarget'] +__all__ = ['CTarget', 'OmpTarget', 'COmpTarget', 'DeviceOmpTarget', 'DeviceAccTarget', + 'CXXTarget', 'CXXOmpTarget', 'DeviceCXXOmpTarget'] class Target: diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index c30a7d6664..e435205ec1 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -11,9 +11,9 @@ Uxreplace, filter_iterations, retrieve_iteration_tree, pull_dims) from devito.passes.iet.engine import iet_pass +from devito.passes.iet.languages.C import CPrinter from devito.ir.iet.efunc import DeviceFunction, EntryFunction -from devito.symbolics import (ValueLimit, evalrel, has_integer_args, limits_mapper, - ccode) +from devito.symbolics import (ValueLimit, evalrel, has_integer_args, limits_mapper) from devito.tools import Bunch, as_mapper, filter_ordered, split from devito.types import FIndexed @@ -149,7 +149,9 @@ def _generate_macros(iet, tracker=None, lang=None, **kwargs): iet = _generate_macros_findexeds(iet, tracker=tracker, **kwargs) headers = [i.header for i in tracker.values()] - headers = sorted((ccode(define), ccode(expr)) for define, expr in headers) + printer = kwargs.get('printer', CPrinter) + headers = sorted((printer()._print(define), printer()._print(expr)) + for define, expr in headers) # Generate Macros from higher-level SymPy objects mheaders, includes = _generate_macros_math(iet, lang=lang) @@ -214,7 +216,7 @@ def _lower_macro_math(expr, lang): @_lower_macro_math.register(Min) @_lower_macro_math.register(sympy.Min) def _(expr, lang): - if has_integer_args(*expr.args) and len(expr.args) == 2: + if has_integer_args(*expr.args): return (('MIN(a,b)', ('(((a) < (b)) ? (a) : (b))')),), {} else: return (), (lang.get('header-algorithm'),) @@ -223,7 +225,7 @@ def _(expr, lang): @_lower_macro_math.register(Max) @_lower_macro_math.register(sympy.Max) def _(expr, lang): - if has_integer_args(*expr.args) and len(expr.args) == 2: + if has_integer_args(*expr.args): return (('MAX(a,b)', ('(((a) > (b)) ? (a) : (b))')),), {} else: return (), (lang.get('header-algorithm'),) diff --git a/devito/passes/iet/parpragma.py b/devito/passes/iet/parpragma.py index c3ed016a94..5d756d1d93 100644 --- a/devito/passes/iet/parpragma.py +++ b/devito/passes/iet/parpragma.py @@ -9,12 +9,12 @@ from devito.ir import (Conditional, DummyEq, Dereference, Expression, ExpressionBundle, FindSymbols, FindNodes, ParallelIteration, ParallelTree, Pragma, Prodder, Transfer, List, Transformer, - IsPerfectIteration, OpInc, filter_iterations, + IsPerfectIteration, OpInc, filter_iterations, ccode, retrieve_iteration_tree, IMask, VECTORIZED) from devito.passes.iet.engine import iet_pass from devito.passes.iet.langbase import (LangBB, LangTransformer, DeviceAwareMixin, ShmTransformer, make_sections_from_imask) -from devito.symbolics import INT, ccode +from devito.symbolics import INT from devito.tools import as_tuple, flatten, is_integer, prod from devito.types import Symbol @@ -292,7 +292,7 @@ def _make_partree(self, candidates, nthreads=None): **root.args) niters = prod([root.symbolic_size] + [j.symbolic_size for j in collapsable]) - value = INT(Max(niters / (nthreads*self.chunk_nonaffine), 1)) + value = INT(Max(INT(niters / (nthreads*self.chunk_nonaffine)), 1)) prefix = [Expression(DummyEq(chunk_size, value, dtype=np.int32))] # Create a ParallelTree diff --git a/devito/symbolics/__init__.py b/devito/symbolics/__init__.py index 9d7bee01b8..3f1525297a 100644 --- a/devito/symbolics/__init__.py +++ b/devito/symbolics/__init__.py @@ -2,6 +2,5 @@ from devito.symbolics.extended_dtypes import * # noqa from devito.symbolics.queries import * # noqa from devito.symbolics.search import * # noqa -from devito.symbolics.printer import * # noqa from devito.symbolics.inspection import * # noqa from devito.symbolics.manipulation import * # noqa diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 20ec163e74..7aa02f5e02 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -287,7 +287,9 @@ def has_integer_args(*args): res = True for a in args: try: - if len(a.args) > 0: + if isinstance(a, INT): + res = res and True + elif len(a.args) > 0: res = res and has_integer_args(*a.args) else: res = res and has_integer_args(a) diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 597271a172..5d31fd2a21 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -3,12 +3,12 @@ import sympy from devito import Constant, Eq, Function, Grid, Operator, exp, log, sin +from devito.ir.cgen.printer import BasePrinter from devito.passes.iet.langbase import LangBB -from devito.passes.iet.languages.C import CBB -from devito.passes.iet.languages.openacc import AccBB +from devito.passes.iet.languages.C import CBB, CPrinter +from devito.passes.iet.languages.openacc import AccBB, AccPrinter from devito.passes.iet.languages.openmp import OmpBB from devito.symbolics.extended_dtypes import ctypes_vector_mapper -from devito.symbolics.printer import printer_registry, BasePrinter from devito.types.basic import Basic, Scalar, Symbol from devito.types.dense import TimeFunction @@ -19,6 +19,12 @@ 'openacc': AccBB } +_printers: dict[str, type[BasePrinter]] = { + 'C': CPrinter, + 'openmp': CPrinter, + 'openacc': AccPrinter +} + def _get_language(language: str, **_) -> type[LangBB]: """ @@ -31,7 +37,7 @@ def _get_printer(language: str, **_) -> type[BasePrinter]: """ Gets the printer building block type from parametrized kwargs. """ - return printer_registry[language] + return _printers[language] def _config_kwargs(platform: str, language: str) -> dict[str, str]: diff --git a/tests/test_ir.py b/tests/test_ir.py index e6ea3f89e3..350de3ae20 100644 --- a/tests/test_ir.py +++ b/tests/test_ir.py @@ -5,6 +5,7 @@ from conftest import EVAL, skipif # noqa from devito import (Eq, Inc, Grid, Constant, Function, TimeFunction, # noqa Operator, Dimension, SubDimension, switchconfig) +from devito.ir.cgen import ccode from devito.ir.equations import LoweredEq from devito.ir.equations.algorithms import dimension_sort from devito.ir.iet import Iteration, FindNodes @@ -14,7 +15,7 @@ from devito.ir.support.space import (NullInterval, Interval, Forward, Backward, IterationSpace) from devito.ir.support.guards import GuardOverflow -from devito.symbolics import DefFunction, FieldFromPointer, ccode +from devito.symbolics import DefFunction, FieldFromPointer from devito.tools import prod from devito.types import Array, CriticalRegion, Jump, Scalar, Symbol diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index a4d40a72f6..26cad82a72 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -9,11 +9,11 @@ Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos, Min, Max) from devito.finite_differences.differentiable import SafeInv, Weights -from devito.ir import Expression, FindNodes +from devito.ir import Expression, FindNodes, ccode from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa CallFromPointer, Cast, DefFunction, FieldFromPointer, INT, FieldFromComposite, IntDiv, Namespace, Rvalue, - ReservedWord, ListInitializer, uxreplace, ccode, + ReservedWord, ListInitializer, uxreplace, retrieve_derivatives, BaseCast) from devito.tools import as_tuple from devito.types import (Array, Bundle, FIndexed, LocalObject, Object, From 6b72b32743fd8818f90afaf984aa42aebfd0be1f Mon Sep 17 00:00:00 2001 From: mloubout Date: Tue, 28 Jan 2025 09:33:44 -0500 Subject: [PATCH 53/58] compiler: cleanup operator class names --- devito/core/__init__.py | 36 ++++++++++++++++---------------- devito/core/arm.py | 9 ++++---- devito/core/cpu.py | 11 +++++----- devito/core/intel.py | 14 ++++++------- devito/core/power.py | 6 +++--- devito/passes/iet/languages/C.py | 1 - 6 files changed, 38 insertions(+), 39 deletions(-) diff --git a/devito/core/__init__.py b/devito/core/__init__.py index 47771ce657..beb4a08169 100644 --- a/devito/core/__init__.py +++ b/devito/core/__init__.py @@ -2,19 +2,19 @@ from devito.core.cpu import (Cpu64NoopCOperator, Cpu64NoopOmpOperator, Cpu64AdvCOperator, Cpu64AdvOmpOperator, Cpu64FsgCOperator, Cpu64FsgOmpOperator, - Cpu64CustomOperator, Cpu64CXXCustomOperator, + Cpu64CustomOperator, Cpu64CustomCXXOperator, Cpu64CXXNoopCOperator, Cpu64CXXNoopOmpOperator, - Cpu64AdvCXXOperator, Cpu64CXXAdvOmpOperator, - Cpu64CXXFsgCOperator, Cpu64CXXFsgOmpOperator) + Cpu64AdvCXXOperator, Cpu64AdvCXXOmpOperator, + Cpu64FsgCXXOperator, Cpu64FsgCXXOmpOperator) from devito.core.intel import (Intel64AdvCOperator, Intel64AdvOmpOperator, Intel64FsgCOperator, Intel64FsgOmpOperator, - Intel64CXXAdvCOperator, Intel64CXXAdvOmpOperator, - Intel64CXXFsgCOperator, Intel64CXXFsgOmpOperator) + Intel64CXXAdvCOperator, Intel64AdvCXXOmpOperator, + Intel64FsgCXXOperator, Intel64FsgCXXOmpOperator) from devito.core.arm import (ArmAdvCOperator, ArmAdvOmpOperator, - ArmCXXAdvCOperator, ArmCXXAdvOmpOperator) + ArmAdvCXXOperator, ArmAdvCXXOmpOperator) from devito.core.power import (PowerAdvCOperator, PowerAdvOmpOperator, - PowerCXXAdvCOperator, PowerCXXAdvOmpOperator) + PowerCXXAdvCOperator, PowerAdvCXXOmpOperator) from devito.core.gpu import (DeviceNoopOmpOperator, DeviceNoopAccOperator, DeviceAdvOmpOperator, DeviceAdvAccOperator, DeviceFsgOmpOperator, DeviceFsgAccOperator, @@ -24,8 +24,8 @@ # Register CPU Operators operator_registry.add(Cpu64CustomOperator, Cpu64, 'custom', 'C') operator_registry.add(Cpu64CustomOperator, Cpu64, 'custom', 'openmp') -operator_registry.add(Cpu64CXXCustomOperator, Cpu64, 'custom', 'CXX') -operator_registry.add(Cpu64CXXCustomOperator, Cpu64, 'custom', 'CXXopenmp') +operator_registry.add(Cpu64CustomCXXOperator, Cpu64, 'custom', 'CXX') +operator_registry.add(Cpu64CustomCXXOperator, Cpu64, 'custom', 'CXXopenmp') operator_registry.add(Cpu64NoopCOperator, Cpu64, 'noop', 'C') operator_registry.add(Cpu64NoopOmpOperator, Cpu64, 'noop', 'openmp') @@ -35,32 +35,32 @@ operator_registry.add(Cpu64AdvCOperator, Cpu64, 'advanced', 'C') operator_registry.add(Cpu64AdvOmpOperator, Cpu64, 'advanced', 'openmp') operator_registry.add(Cpu64AdvCXXOperator, Cpu64, 'advanced', 'CXX') -operator_registry.add(Cpu64CXXAdvOmpOperator, Cpu64, 'advanced', 'CXXopenmp') +operator_registry.add(Cpu64AdvCXXOmpOperator, Cpu64, 'advanced', 'CXXopenmp') operator_registry.add(Cpu64FsgCOperator, Cpu64, 'advanced-fsg', 'C') operator_registry.add(Cpu64FsgOmpOperator, Cpu64, 'advanced-fsg', 'openmp') -operator_registry.add(Cpu64CXXFsgCOperator, Cpu64, 'advanced-fsg', 'CXX') -operator_registry.add(Cpu64CXXFsgOmpOperator, Cpu64, 'advanced-fsg', 'CXXopenmp') +operator_registry.add(Cpu64FsgCXXOperator, Cpu64, 'advanced-fsg', 'CXX') +operator_registry.add(Cpu64FsgCXXOmpOperator, Cpu64, 'advanced-fsg', 'CXXopenmp') operator_registry.add(Intel64AdvCOperator, Intel64, 'advanced', 'C') operator_registry.add(Intel64AdvOmpOperator, Intel64, 'advanced', 'openmp') operator_registry.add(Intel64CXXAdvCOperator, Intel64, 'advanced', 'CXX') -operator_registry.add(Intel64CXXAdvOmpOperator, Intel64, 'advanced', 'CXXopenmp') +operator_registry.add(Intel64AdvCXXOmpOperator, Intel64, 'advanced', 'CXXopenmp') operator_registry.add(Intel64FsgCOperator, Intel64, 'advanced-fsg', 'C') operator_registry.add(Intel64FsgOmpOperator, Intel64, 'advanced-fsg', 'openmp') -operator_registry.add(Intel64CXXFsgCOperator, Intel64, 'advanced-fsg', 'CXX') -operator_registry.add(Intel64CXXFsgOmpOperator, Intel64, 'advanced-fsg', 'CXXopenmp') +operator_registry.add(Intel64FsgCXXOperator, Intel64, 'advanced-fsg', 'CXX') +operator_registry.add(Intel64FsgCXXOmpOperator, Intel64, 'advanced-fsg', 'CXXopenmp') operator_registry.add(ArmAdvCOperator, Arm, 'advanced', 'C') operator_registry.add(ArmAdvOmpOperator, Arm, 'advanced', 'openmp') -operator_registry.add(ArmCXXAdvCOperator, Arm, 'advanced', 'CXX') -operator_registry.add(ArmCXXAdvOmpOperator, Arm, 'advanced', 'CXXopenmp') +operator_registry.add(ArmAdvCXXOperator, Arm, 'advanced', 'CXX') +operator_registry.add(ArmAdvCXXOmpOperator, Arm, 'advanced', 'CXXopenmp') operator_registry.add(PowerAdvCOperator, Power, 'advanced', 'C') operator_registry.add(PowerAdvOmpOperator, Power, 'advanced', 'openmp') operator_registry.add(PowerCXXAdvCOperator, Power, 'advanced', 'CXX') -operator_registry.add(PowerCXXAdvOmpOperator, Power, 'advanced', 'CXXopenmp') +operator_registry.add(PowerAdvCXXOmpOperator, Power, 'advanced', 'CXXopenmp') # Register Device Operators operator_registry.add(DeviceCustomOmpOperator, Device, 'custom', 'C') diff --git a/devito/core/arm.py b/devito/core/arm.py index a0581aaf01..f990ef31e0 100644 --- a/devito/core/arm.py +++ b/devito/core/arm.py @@ -2,14 +2,13 @@ Cpu64AdvCOperator) from devito.passes.iet import OmpTarget, CXXOmpTarget -__all__ = ['ArmAdvCOperator', 'ArmAdvOmpOperator', 'ArmCXXAdvCOperator', - 'ArmCXXAdvOmpOperator'] +__all__ = ['ArmAdvCOperator', 'ArmAdvOmpOperator', 'ArmAdvCXXOperator', + 'ArmAdvCXXOmpOperator'] ArmAdvOperator = Cpu64AdvOperator ArmAdvCOperator = Cpu64AdvCOperator -ArmCXXAdvOperator = Cpu64AdvCXXOperator -ArmCXXAdvCOperator = Cpu64AdvCXXOperator +ArmAdvCXXOperator = Cpu64AdvCXXOperator class ArmAdvOmpOperator(ArmAdvCOperator): @@ -19,6 +18,6 @@ class ArmAdvOmpOperator(ArmAdvCOperator): PAR_NESTED = 4 -class ArmCXXAdvOmpOperator(ArmAdvOmpOperator): +class ArmAdvCXXOmpOperator(ArmAdvOmpOperator): _Target = CXXOmpTarget LINEARIZE = True diff --git a/devito/core/cpu.py b/devito/core/cpu.py index 716f0cd902..bd48ce9bb1 100644 --- a/devito/core/cpu.py +++ b/devito/core/cpu.py @@ -16,7 +16,8 @@ __all__ = ['Cpu64NoopCOperator', 'Cpu64NoopOmpOperator', 'Cpu64AdvCOperator', 'Cpu64AdvOmpOperator', 'Cpu64FsgCOperator', 'Cpu64FsgOmpOperator', - 'Cpu64CustomOperator'] + 'Cpu64CustomOperator', 'Cpu64CustomCXXOperator', 'Cpu64AdvCXXOperator', + 'Cpu64AdvCXXOmpOperator', 'Cpu64FsgCXXOperator', 'Cpu64FsgCXXOmpOperator'] class Cpu64OperatorMixin: @@ -318,7 +319,7 @@ def _make_iet_passes_mapper(cls, **kwargs): assert not (set(_known_passes) & set(_known_passes_disabled)) -class Cpu64CXXCustomOperator(Cpu64CustomOperator): +class Cpu64CustomCXXOperator(Cpu64CustomOperator): _Target = CXXTarget LINEARIZE = True @@ -357,7 +358,7 @@ class Cpu64AdvOmpOperator(Cpu64AdvOperator): _Target = COmpTarget -class Cpu64CXXAdvOmpOperator(Cpu64AdvOperator): +class Cpu64AdvCXXOmpOperator(Cpu64AdvOperator): _Target = CXXOmpTarget LINEARIZE = True @@ -366,7 +367,7 @@ class Cpu64FsgCOperator(Cpu64FsgOperator): _Target = CTarget -class Cpu64CXXFsgCOperator(Cpu64FsgOperator): +class Cpu64FsgCXXOperator(Cpu64FsgOperator): _Target = CXXTarget LINEARIZE = True @@ -375,6 +376,6 @@ class Cpu64FsgOmpOperator(Cpu64FsgOperator): _Target = COmpTarget -class Cpu64CXXFsgOmpOperator(Cpu64FsgOperator): +class Cpu64FsgCXXOmpOperator(Cpu64FsgOperator): _Target = CXXOmpTarget LINEARIZE = True diff --git a/devito/core/intel.py b/devito/core/intel.py index 84b840f086..9e378ffc12 100644 --- a/devito/core/intel.py +++ b/devito/core/intel.py @@ -1,11 +1,11 @@ from devito.core.cpu import (Cpu64AdvCOperator, Cpu64AdvOmpOperator, Cpu64FsgCOperator, Cpu64FsgOmpOperator, - Cpu64AdvCXXOperator, Cpu64CXXAdvOmpOperator, - Cpu64CXXFsgCOperator, Cpu64CXXFsgOmpOperator) + Cpu64AdvCXXOperator, Cpu64AdvCXXOmpOperator, + Cpu64FsgCXXOperator, Cpu64FsgCXXOmpOperator) __all__ = ['Intel64AdvCOperator', 'Intel64AdvOmpOperator', 'Intel64FsgCOperator', - 'Intel64FsgOmpOperator', 'Intel64CXXAdvCOperator', 'Intel64CXXAdvOmpOperator', - 'Intel64CXXFsgCOperator', 'Intel64CXXFsgOmpOperator'] + 'Intel64FsgOmpOperator', 'Intel64CXXAdvCOperator', 'Intel64AdvCXXOmpOperator', + 'Intel64FsgCXXOperator', 'Intel64FsgCXXOmpOperator'] Intel64AdvCOperator = Cpu64AdvCOperator @@ -13,6 +13,6 @@ Intel64FsgCOperator = Cpu64FsgCOperator Intel64FsgOmpOperator = Cpu64FsgOmpOperator Intel64CXXAdvCOperator = Cpu64AdvCXXOperator -Intel64CXXAdvOmpOperator = Cpu64CXXAdvOmpOperator -Intel64CXXFsgCOperator = Cpu64CXXFsgCOperator -Intel64CXXFsgOmpOperator = Cpu64CXXFsgOmpOperator +Intel64AdvCXXOmpOperator = Cpu64AdvCXXOmpOperator +Intel64FsgCXXOperator = Cpu64FsgCXXOperator +Intel64FsgCXXOmpOperator = Cpu64FsgCXXOmpOperator diff --git a/devito/core/power.py b/devito/core/power.py index 5868df0819..65cf4c3cf3 100644 --- a/devito/core/power.py +++ b/devito/core/power.py @@ -1,10 +1,10 @@ from devito.core.cpu import (Cpu64AdvCOperator, Cpu64AdvOmpOperator, - Cpu64AdvCXXOperator, Cpu64CXXAdvOmpOperator) + Cpu64AdvCXXOperator, Cpu64AdvCXXOmpOperator) __all__ = ['PowerAdvCOperator', 'PowerAdvOmpOperator', - 'PowerCXXAdvCOperator', 'PowerCXXAdvOmpOperator'] + 'PowerCXXAdvCOperator', 'PowerAdvCXXOmpOperator'] PowerAdvCOperator = Cpu64AdvCOperator PowerAdvOmpOperator = Cpu64AdvOmpOperator PowerCXXAdvCOperator = Cpu64AdvCXXOperator -PowerCXXAdvOmpOperator = Cpu64CXXAdvOmpOperator +PowerAdvCXXOmpOperator = Cpu64AdvCXXOmpOperator diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index 13044cd8b5..0453bdcad7 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -44,7 +44,6 @@ class CPrinter(BasePrinter, C99CodePrinter): _func_litterals = {np.float32: 'f', np.complex64: 'f'} _func_prefix = {np.float32: 'f', np.float64: 'f', np.complex64: 'c', np.complex128: 'c'} - _prec_litterals = {**BasePrinter._prec_litterals, np.float16: 'F16'} _default_includes = ['stdlib.h', 'math.h', 'sys/time.h'] # These cannot go through _print_xxx because they are classes not From 99b46bef2621102d31f4cdda64b4e1a232741060 Mon Sep 17 00:00:00 2001 From: mloubout Date: Tue, 28 Jan 2025 09:45:13 -0500 Subject: [PATCH 54/58] compiler: switch cxx backend to static_cast --- devito/arch/compiler.py | 4 ++-- devito/ir/cgen/printer.py | 4 ++-- devito/passes/iet/languages/CXX.py | 8 ++++++++ 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/devito/arch/compiler.py b/devito/arch/compiler.py index 0f74107af2..01cf9c3261 100644 --- a/devito/arch/compiler.py +++ b/devito/arch/compiler.py @@ -497,7 +497,7 @@ def __init_finalize__(self, **kwargs): language = kwargs.pop('language', configuration['language']) platform = kwargs.pop('platform', configuration['platform']) - if platform is NVIDIAX: + if platform is NvidiaDevice: self.cflags.remove(f'-std={self.std}') # Add flags for OpenMP offloading if language in ['C', 'openmp']: @@ -565,7 +565,7 @@ def __init_finalize__(self, **kwargs): if not configuration['safe-math']: self.cflags.append('-ffast-math') - if platform is NVIDIAX: + if platform is NvidiaDevice: self.cflags.remove(f'-std={self.std}') elif platform is AMDGPUX: self.cflags.remove(f'-std={self.std}') diff --git a/devito/ir/cgen/printer.py b/devito/ir/cgen/printer.py index 20fc3007c0..328a219de3 100644 --- a/devito/ir/cgen/printer.py +++ b/devito/ir/cgen/printer.py @@ -344,10 +344,10 @@ def _print_InlineIf(self, expr): PREC = precedence(expr) return self.parenthesize("(%s) ? %s : %s" % (cond, true_expr, false_expr), PREC) - def _print_UnaryOp(self, expr, op=None): + def _print_UnaryOp(self, expr, op=None, parenthesize=False): op = op or expr._op base = self._print(expr.base) - if not expr.base.is_Symbol: + if not expr.base.is_Symbol or parenthesize: base = f'({base})' return f'{op}{base}' diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index b031786bed..9089962b49 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -85,3 +85,11 @@ class CXXPrinter(BasePrinter, CXX11CodePrinter): def _print_ImaginaryUnit(self, expr): return f'1i{self.prec_literal(expr).lower()}' + + def _print_Cast(self, expr): + # The CXX recommended way to cast is to use static_cast + tstr = self._print(expr._C_ctype) + if 'void' in tstr: + return super()._print_Cast(expr) + cast = f'static_cast<{tstr}{self._print(expr.stars)}>' + return self._print_UnaryOp(expr, op=cast, parenthesize=True) From 36f2e16141b3ebb1621c9e2074dd7ed7bfcecf8f Mon Sep 17 00:00:00 2001 From: mloubout Date: Tue, 28 Jan 2025 10:40:21 -0500 Subject: [PATCH 55/58] compiler: add switch for static_cast vs reinterpret_cast --- devito/arch/compiler.py | 8 +++++--- devito/operator/operator.py | 3 ++- devito/passes/iet/languages/CXX.py | 3 ++- devito/passes/iet/languages/openacc.py | 4 ++-- devito/symbolics/extended_sympy.py | 9 +++++++-- 5 files changed, 18 insertions(+), 9 deletions(-) diff --git a/devito/arch/compiler.py b/devito/arch/compiler.py index 01cf9c3261..66b3c880d4 100644 --- a/devito/arch/compiler.py +++ b/devito/arch/compiler.py @@ -181,6 +181,8 @@ def __init__(self): fields = {'cc', 'ld'} default_cpp = False + _cxxstd = 'c++14' + _cstd = 'c99' def __init__(self, **kwargs): _name = kwargs.pop('name', self.__class__.__name__) @@ -256,7 +258,7 @@ def version(self): @property def std(self): - return 'c++14' if self._cpp else 'c99' + return self._cxxstd if self._cpp else self._cstd def get_version(self): result, stdout, stderr = call_capture_output((self.cc, "--version")) @@ -497,7 +499,7 @@ def __init_finalize__(self, **kwargs): language = kwargs.pop('language', configuration['language']) platform = kwargs.pop('platform', configuration['platform']) - if platform is NvidiaDevice: + if isinstance(platform, NvidiaDevice): self.cflags.remove(f'-std={self.std}') # Add flags for OpenMP offloading if language in ['C', 'openmp']: @@ -565,7 +567,7 @@ def __init_finalize__(self, **kwargs): if not configuration['safe-math']: self.cflags.append('-ffast-math') - if platform is NvidiaDevice: + if isinstance(platform, NvidiaDevice): self.cflags.remove(f'-std={self.std}') elif platform is AMDGPUX: self.cflags.remove(f'-std={self.std}') diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 2dd6eace02..0dc74d2c3e 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -1410,7 +1410,8 @@ def parse_kwargs(**kwargs): kwargs['compiler'] = configuration['compiler'].__new_with__() # Make sure compiler and language are compatible - if kwargs['compiler']._cpp and kwargs['language'] in ['C', 'openmp']: + if compiler is not None and kwargs['compiler']._cpp and \ + kwargs['language'] in ['C', 'openmp']: kwargs['language'] = 'CXX' if kwargs['language'] == 'C' else 'CXXopenmp' if 'CXX' in kwargs['language'] and not kwargs['compiler']._cpp: kwargs['compiler'] = kwargs['compiler'].__new_with__(cpp=True) diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index 9089962b49..5554308a0c 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -91,5 +91,6 @@ def _print_Cast(self, expr): tstr = self._print(expr._C_ctype) if 'void' in tstr: return super()._print_Cast(expr) - cast = f'static_cast<{tstr}{self._print(expr.stars)}>' + caster = 'reinterpret_cast' if expr.reinterpret else 'static_cast' + cast = f'{caster}<{tstr}{self._print(expr.stars)}>' return self._print_UnaryOp(expr, op=cast, parenthesize=True) diff --git a/devito/passes/iet/languages/openacc.py b/devito/passes/iet/languages/openacc.py index 25d8e6e478..5d80428eb6 100644 --- a/devito/passes/iet/languages/openacc.py +++ b/devito/passes/iet/languages/openacc.py @@ -236,11 +236,11 @@ def place_devptr(self, iet, **kwargs): dpf = List(body=[ self.lang.mapper['map-serial-present'](hp, tdp), - Block(body=DummyExpr(tdp, cast_mapper(tdp.dtype)(hp))) + Block(body=DummyExpr(tdp, cast_mapper(tdp.dtype)(hp, reinterpret=True))) ]) ffp = FieldFromPointer(f._C_field_dmap, f._C_symbol) - ctdp = cast_mapper((hp.dtype, '*'))(tdp) + ctdp = cast_mapper((hp.dtype, '*'))(tdp, reinterpret=True) cast = DummyExpr(ffp, ctdp) ret = Return(ctdp) diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index cb1d835c08..8281902e78 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -384,9 +384,9 @@ class Cast(UnaryOp): """ __rargs__ = ('base', ) - __rkwargs__ = ('dtype', 'stars') + __rkwargs__ = ('dtype', 'stars', 'reinterpret') - def __new__(cls, base, dtype=None, stars=None, **kwargs): + def __new__(cls, base, dtype=None, stars=None, reinterpret=False, **kwargs): try: if issubclass(dtype, np.generic) and sympify(base).is_Number: base = sympify(dtype(base)) @@ -397,6 +397,7 @@ def __new__(cls, base, dtype=None, stars=None, **kwargs): obj = super().__new__(cls, base) obj._stars = stars or '' obj._dtype = dtype + obj._reinterpret = reinterpret return obj def _hashable_content(self): @@ -412,6 +413,10 @@ def stars(self): def dtype(self): return self._dtype + @property + def reinterpret(self): + return self._reinterpret + @property def _C_ctype(self): ctype = ctypes_vector_mapper.get(self.dtype, self.dtype) From c5ece9c40eda1454568a9fd67c2697299b24320e Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 30 Jan 2025 09:09:00 -0500 Subject: [PATCH 56/58] compiler: handle plain text header --- devito/ir/iet/visitors.py | 10 +++++++++- devito/passes/iet/dtypes.py | 2 +- devito/passes/iet/languages/C.py | 2 +- devito/passes/iet/languages/CXX.py | 2 +- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 2f1987aca9..228a3e7f86 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -705,7 +705,15 @@ def visit_Operator(self, o, mode='all'): efuncs.extend([self._visit(i), blankline]) # Definitions - headers = [c.Define(*i) for i in o._headers] + [blankline] + headers = [] + for h in o._headers: + try: + headers.append(c.Define(*h)) + except TypeError: + # Plain string + headers.append(c.Line(h)) + headers = headers + [blankline] + # headers = [c.Define(*i) for i in o._headers] + [blankline] # Header files includes = self._operator_includes(o) + [blankline] diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index 95b944c2b1..619d81fd8a 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -30,7 +30,7 @@ def _complex_includes(iet: Callable, lang: type[LangBB], compiler: Compiler, return iet, {} metadata = {} - lib = as_tuple(lang['header-complex']) + lib = as_tuple(lang['includes-complex']) if lang.get('complex-namespace') is not None: metadata['namespaces'] = lang['complex-namespace'] diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index 0453bdcad7..3bcade46e2 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -25,7 +25,7 @@ class CBB(LangBB): 'alloc-global-symbol': lambda i, j, k: Call('memcpy', (i, j, k)), # Complex and float16 - 'header-complex': 'complex.h', + 'includes-complex': 'complex.h', } diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index 5554308a0c..52e4ba695b 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -61,7 +61,7 @@ class CXXBB(LangBB): 'alloc-global-symbol': lambda i, j, k: Call('memcpy', (i, j, k)), # Complex and float16 - 'header-complex': 'complex', + 'includes-complex': 'complex', 'complex-namespace': [UsingNamespace('std::complex_literals')], 'def-complex': std_arith, } From 79212c706d6476af1489ad5b99b4b5d54bf33620 Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 30 Jan 2025 18:07:32 -0500 Subject: [PATCH 57/58] compiler: convert all in visitors to f-string --- devito/arch/compiler.py | 8 +- devito/data/allocators.py | 7 +- devito/ir/iet/visitors.py | 128 +++++++++++++++++--------------- devito/mpi/routines.py | 6 +- devito/tools/dtypes_lowering.py | 19 +++-- tests/test_dtypes.py | 2 +- 6 files changed, 91 insertions(+), 79 deletions(-) diff --git a/devito/arch/compiler.py b/devito/arch/compiler.py index 66b3c880d4..377237ac0c 100644 --- a/devito/arch/compiler.py +++ b/devito/arch/compiler.py @@ -185,10 +185,10 @@ def __init__(self): _cstd = 'c99' def __init__(self, **kwargs): - _name = kwargs.pop('name', self.__class__.__name__) - if isinstance(_name, Compiler): - _name = _name.name - self._name = _name + name = kwargs.pop('name', self.__class__.__name__) + if isinstance(name, Compiler): + name = name.name + self._name = name super().__init__(**kwargs) diff --git a/devito/data/allocators.py b/devito/data/allocators.py index 1e00be3f20..4ccd7cddfc 100644 --- a/devito/data/allocators.py +++ b/devito/data/allocators.py @@ -1,6 +1,4 @@ import abc -from functools import reduce -from operator import mul import ctypes from ctypes.util import find_library import mmap @@ -11,7 +9,7 @@ from devito.logger import logger from devito.parameters import configuration -from devito.tools import is_integer, dtype_alloc_ctype +from devito.tools import is_integer, infer_datasize __all__ = ['ALLOC_ALIGNED', 'ALLOC_NUMA_LOCAL', 'ALLOC_NUMA_ANY', 'ALLOC_KNL_MCDRAM', 'ALLOC_KNL_DRAM', 'ALLOC_GUARD', @@ -92,8 +90,7 @@ def initialize(cls): return def alloc(self, shape, dtype, padding=0): - ctype, c_scale = dtype_alloc_ctype(dtype) - datasize = int(reduce(mul, shape) * c_scale) + ctype, datasize = infer_datasize(dtype, shape) # Add padding, if any try: diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 228a3e7f86..845d56f3c5 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -80,23 +80,24 @@ def indent(self): return ' ' * self._depth def visit_Node(self, o): - return self.indent + '<%s>' % o.__class__.__name__ + return self.indent + f'<{o.__class__.__name__}>' def visit_Generable(self, o): - body = ' %s' % str(o) if self.verbose else '' - return self.indent + '' % (o.__class__.__name__, body) + body = f" {str(o) if self.verbose else ''}" + return self.indent + f'' def visit_Callable(self, o): self._depth += 1 body = self._visit(o.children) self._depth -= 1 - return self.indent + '\n%s' % (o.name, body) + return self.indent + f'\n{body}' def visit_CallableBody(self, o): self._depth += 1 body = [self._visit(o.init), self._visit(o.unpacks), self._visit(o.body)] self._depth -= 1 - return self.indent + "%s\n%s" % (o.__repr__(), '\n'.join([i for i in body if i])) + cbody = '\n'.join([i for i in body if i]) + return self.indent + f"{o.__repr__()}\n{cbody}" def visit_list(self, o): return ('\n').join([self._visit(i) for i in o]) @@ -111,43 +112,49 @@ def visit_List(self, o): else: body = [self._visit(o.body)] self._depth -= 1 - return self.indent + "%s\n%s" % (o.__repr__(), '\n'.join(body)) + cbody = '\n'.join(body) + return self.indent + f"{o.__repr__()}\n{cbody}" def visit_TimedList(self, o): self._depth += 1 body = [self._visit(o.body)] self._depth -= 1 - return self.indent + "%s\n%s" % (o.__repr__(), '\n'.join(body)) + cbody = '\n'.join(body) + return self.indent + f"{o.__repr__()}\n{cbody}" def visit_Iteration(self, o): self._depth += 1 body = self._visit(o.children) self._depth -= 1 if self.verbose: - detail = '::%s::%s' % (o.index, o.limits) + detail = f'::{o.index}::{o.limits}' props = [str(i) for i in o.properties] - props = '[%s] ' % ','.join(props) if props else '' + if props: + cprops = ','.join(props) + props = f'[{cprops}] ' + else: + props = '' else: detail, props = '', '' - return self.indent + "<%sIteration %s%s>\n%s" % (props, o.dim.name, detail, body) + return self.indent + f"<{props}Iteration {o.dim.name}{detail}>\n{body}" def visit_While(self, o): self._depth += 1 body = self._visit(o.children) self._depth -= 1 - return self.indent + "\n%s" % (o.condition, body) + return self.indent + f"\n{body}" def visit_Expression(self, o): if self.verbose: - body = "%s = %s" % (o.expr.lhs, o.expr.rhs) - return self.indent + "" % body + body = f"{o.expr.lhs} = {o.expr.rhs}" + return self.indent + f"" else: return self.indent + str(o) def visit_AugmentedExpression(self, o): if self.verbose: - body = "%s %s= %s" % (o.expr.lhs, o.op, o.expr.rhs) - return self.indent + "<%s %s>" % (o.__class__.__name__, body) + body = f"{o.expr.lhs} {o.op}= {o.expr.rhs}" + return self.indent + f"<{o.__class__.__name__} {body}>" else: return self.indent + str(o) @@ -155,7 +162,7 @@ def visit_HaloSpot(self, o): self._depth += 1 body = self._visit(o.children) self._depth -= 1 - return self.indent + "%s\n%s" % (o.__repr__(), body) + return self.indent + f"{o.__repr__()}\n{body}" def visit_Conditional(self, o): self._depth += 1 @@ -163,10 +170,9 @@ def visit_Conditional(self, o): self._depth -= 1 if o.else_body: else_body = self._visit(o.else_body) - return self.indent + "\n%s\n\n%s" % (o.condition, - then_body, else_body) + return self.indent + f"\n{then_body}\n\n{else_body}" else: - return self.indent + "\n%s" % (o.condition, then_body) + return self.indent + f"\n{then_body}" class CGen(Visitor): @@ -249,20 +255,20 @@ def _gen_value(self, obj, mode=1, masked=()): if (obj._mem_stack or obj._mem_constant) and mode == 1: strtype = self.ccode(obj._C_typedata) - strshape = ''.join('[%s]' % self.ccode(i) for i in obj.symbolic_shape) + strshape = ''.join(f'[{self.ccode(i)}]' for i in obj.symbolic_shape) else: strtype = self.ccode(obj._C_ctype) strshape = '' if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1: if not obj._mem_stack: - strtype = '%s%s' % (strtype, self._restrict_keyword) + strtype = f'{strtype}{self._restrict_keyword}' strtype = ' '.join(qualifiers + [strtype]) if obj.is_LocalObject and obj._C_modifier is not None and mode == 2: strtype += obj._C_modifier strname = obj._C_name - strobj = '%s%s' % (strname, strshape) + strobj = f'{strname}{strshape}' if obj.is_LocalObject and obj.cargs and mode == 1: arguments = [self.ccode(i) for i in obj.cargs] @@ -386,16 +392,16 @@ def visit_PointerCast(self, o): if f.is_PointerArray: # lvalue - lvalue = c.Value(cstr, '**%s' % f.name) + lvalue = c.Value(cstr, f'**{f.name}') # rvalue if isinstance(o.obj, ArrayObject): - v = '%s->%s' % (o.obj.name, f._C_name) + v = f'{o.obj.name}->{f._C_name}' elif isinstance(o.obj, IndexedData): v = f._C_name else: assert False - rvalue = '(%s**) %s' % (cstr, v) + rvalue = f'({cstr}**) {v}' else: # lvalue @@ -404,12 +410,12 @@ def visit_PointerCast(self, o): else: v = f.name if o.flat is None: - shape = ''.join("[%s]" % self.ccode(i) for i in o.castshape) - rshape = '(*)%s' % shape + shape = ''.join(f"[{self.ccode(i)}]" for i in o.castshape) + rshape = f'(*){shape}' lvalue = c.Value(cstr, f'(*{self._restrict_keyword} {v}){shape}') else: rshape = '*' - lvalue = c.Value(cstr, '*%s' % v) + lvalue = c.Value(cstr, f'*{v}') if o.alignment and f._data_alignment: lvalue = c.AlignedAttribute(f._data_alignment, lvalue) @@ -422,14 +428,14 @@ def visit_PointerCast(self, o): else: assert False - rvalue = '(%s %s) %s->%s' % (cstr, rshape, f._C_name, v) + rvalue = f'({cstr} {rshape}) {f._C_name}->{v}' else: if isinstance(o.obj, Pointer): v = o.obj.name else: v = f._C_name - rvalue = '(%s %s) %s' % (cstr, rshape, v) + rvalue = f'({cstr} {rshape}) {v}' return c.Initializer(lvalue, rvalue) @@ -439,17 +445,19 @@ def visit_Dereference(self, o): i = a1.indexed cstr = self.ccode(i._C_typedata) if o.flat is None: - shape = ''.join("[%s]" % self.ccode(i) for i in a0.symbolic_shape[1:]) - rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name, - a1.dim.name) + shape = ''.join(f"[{self.ccode(i)}]" for i in a0.symbolic_shape[1:]) + rvalue = f'({cstr} (*){shape}) {a1.name}[{a1.dim.name}]' lvalue = c.Value(cstr, f'(*{self._restrict_keyword} {a0.name}){shape}') else: - rvalue = '(%s *) %s[%s]' % (cstr, a1.name, a1.dim.name) + rvalue = f'({cstr} *) {a1.name}[{a1.dim.name}]' lvalue = c.Value(cstr, f'*{self._restrict_keyword} {a0.name}') if a0._data_alignment: lvalue = c.AlignedAttribute(a0._data_alignment, lvalue) else: - rvalue = '*%s' % a1.name if a1.is_Symbol else '%s->%s' % (a1.name, a0._C_name) + if a1.is_Symbol: + rvalue = f'*{a1.name}' + else: + rvalue = f'{a1.name}->{a0._C_name}' lvalue = self._gen_value(a0, 0) return c.Initializer(lvalue, rvalue) @@ -494,7 +502,7 @@ def visit_Expression(self, o): def visit_AugmentedExpression(self, o): c_lhs = self.ccode(o.expr.lhs, dtype=o.dtype) c_rhs = self.ccode(o.expr.rhs, dtype=o.dtype) - code = c.Statement("%s %s= %s" % (c_lhs, o.op, c_rhs)) + code = c.Statement(f"{c_lhs} {o.op}= {c_rhs}") if o.pragmas: code = c.Module(self._visit(o.pragmas) + (code,)) return code @@ -538,23 +546,23 @@ def visit_Iteration(self, o): # For backward direction flip loop bounds if o.direction == Backward: - loop_init = 'int %s = %s' % (o.index, self.ccode(_max)) - loop_cond = '%s >= %s' % (o.index, self.ccode(_min)) - loop_inc = '%s -= %s' % (o.index, o.limits[2]) + loop_init = f'int {o.index} = {self.ccode(_max)}' + loop_cond = f'{o.index} >= {self.ccode(_min)}' + loop_inc = f'{o.index} -= {o.limits[2]}' else: - loop_init = 'int %s = %s' % (o.index, self.ccode(_min)) - loop_cond = '%s <= %s' % (o.index, self.ccode(_max)) - loop_inc = '%s += %s' % (o.index, o.limits[2]) + loop_init = f'int {o.index} = {self.ccode(_min)}' + loop_cond = f'{o.index} <= {self.ccode(_max)}' + loop_inc = f'{o.index} += {o.limits[2]}' # Append unbounded indices, if any if o.uindices: - uinit = ['%s = %s' % (i.name, self.ccode(i.symbolic_min)) for i in o.uindices] + uinit = [f'{i.name} = {self.ccode(i.symbolic_min)}' for i in o.uindices] loop_init = c.Line(', '.join([loop_init] + uinit)) ustep = [] for i in o.uindices: op = '=' if i.is_Modulo else '+=' - ustep.append('%s %s %s' % (i.name, op, self.ccode(i.symbolic_incr))) + ustep.append(f'{i.name} {op} {self.ccode(i.symbolic_incr)}') loop_inc = c.Line(', '.join([loop_inc] + ustep)) # Create For header+body @@ -577,7 +585,7 @@ def visit_While(self, o): return c.While(condition, c.Block(body)) else: # Hack: cgen doesn't support body-less while-loops, i.e. `while(...);` - return c.Statement('while(%s)' % condition) + return c.Statement(f'while({condition})') def visit_Callable(self, o): body = flatten(self._visit(i) for i in o.children) @@ -597,7 +605,7 @@ def visit_MultiTraversable(self, o): return c.Collection(body) def visit_UsingNamespace(self, o): - return c.Statement('using namespace %s' % str(o.namespace)) + return c.Statement(f'using namespace {str(o.namespace)}') def visit_Lambda(self, o): body = [] @@ -615,9 +623,11 @@ def visit_Lambda(self, o): extra.append(' '.join(str(i) for i in o.special)) if o.attributes: extra.append(' ') - extra.append(' '.join('[[%s]]' % i for i in o.attributes)) - top = c.Line('[%s](%s)%s' % - (', '.join(captures), ', '.join(decls), ''.join(extra))) + extra.append(' '.join(f'[[{i}]]' for i in o.attributes)) + ccapt = ', '.join(captures) + cdecls = ', '.join(decls) + cextra = ''.join(extra) + top = c.Line(f'[{ccapt}]({cdecls}){cextra}') return LambdaCollection([top, c.Block(body)]) def visit_HaloSpot(self, o): @@ -626,7 +636,8 @@ def visit_HaloSpot(self, o): def visit_KernelLaunch(self, o): if o.templates: - templates = '<%s>' % ','.join([str(i) for i in o.templates]) + ctemplates = ','.join([str(i) for i in o.templates]) + templates = f'<{ctemplates}>' else: templates = '' @@ -640,8 +651,7 @@ def visit_KernelLaunch(self, o): arguments = self._args_call(o.arguments) arguments = ','.join(arguments) - return c.Statement('%s%s<<<%s>>>(%s)' - % (o.name, templates, launch_config, arguments)) + return c.Statement(f'{o.name}{templates}<<<{launch_config}>>>({arguments})') # Operator-handle machinery @@ -742,7 +752,7 @@ class CInterface(CGen): def _operator_includes(self, o): includes = super()._operator_includes(o) - includes.append(c.Include("%s.h" % o.name, system=False)) + includes.append(c.Include(f"{o.name}.h", system=False)) return includes @@ -754,7 +764,7 @@ def visit_Operator(self, o): typedecls = self._operator_typedecls(o, mode='public') guarded_typedecls = [] for i in typedecls: - guard = "DEVITO_%s" % i.tpname.upper() + guard = f"DEVITO_{i.tpname.upper()}" iflines = [c.Define(guard, ""), blankline, i, blankline] guarded_typedecl = c.IfNDef(guard, iflines, []) guarded_typedecls.extend([guarded_typedecl, blankline]) @@ -1383,13 +1393,15 @@ def __init__(self, name, arguments, is_expr=False, is_indirect=False, def generate(self): if self.templates: - tip = "%s<%s>" % (self.name, ", ".join(str(i) for i in self.templates)) + ctemplates = ", ".join(str(i) for i in self.templates) + tip = f"{self.name}<{ctemplates}>" else: tip = self.name if not self.is_indirect: - tip = "%s(" % tip + tip = f"{tip}(" else: - tip = "%s%s" % (tip, ',' if self.arguments else '') + cargs = ',' if self.arguments else '' + tip = f"{tip}{cargs}" processed = [] for i in self.arguments: if isinstance(i, (MultilineCall, LambdaCollection)): @@ -1411,7 +1423,7 @@ def generate(self): if not self.is_expr: tip += ";" if self.cast: - tip = '(%s)%s' % (self.cast, tip) + tip = f'({self.cast}){tip}' yield tip diff --git a/devito/mpi/routines.py b/devito/mpi/routines.py index 67158a621c..39a4554296 100644 --- a/devito/mpi/routines.py +++ b/devito/mpi/routines.py @@ -16,7 +16,7 @@ from devito.mpi import MPI from devito.symbolics import (Byref, CondNe, FieldFromPointer, FieldFromComposite, IndexedPointer, Macro, cast_mapper, subs_op_args) -from devito.tools import (as_mapper, dtype_to_mpitype, dtype_len, dtype_alloc_ctype, +from devito.tools import (as_mapper, dtype_to_mpitype, dtype_len, infer_datasize, flatten, generator, is_integer, split) from devito.types import (Array, Bag, Dimension, Eq, Symbol, LocalObject, CompositeObject, CustomDimension) @@ -1204,8 +1204,8 @@ def _arg_defaults(self, allocator, alias, args=None): entry.sizes = (c_int*len(shape))(*shape) # Allocate the send/recv buffers - ctype, c_scale = dtype_alloc_ctype(f.dtype) - size = int(reduce(mul, shape) * c_scale) * dtype_len(self.target.dtype) + ctype, datasize = infer_datasize(f.dtype, shape) + size = datasize * dtype_len(self.target.dtype) entry.bufg, bufg_memfree_args = allocator._alloc_C_libcall(size, ctype) entry.bufs, bufs_memfree_args = allocator._alloc_C_libcall(size, ctype) diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index e1bfe481a0..2e3bca96c9 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -3,6 +3,8 @@ """ import ctypes +from functools import reduce +from operator import mul import numpy as np from cgen import dtype_to_ctype as cgen_dtype_to_ctype @@ -11,7 +13,7 @@ __all__ = ['int2', 'int3', 'int4', 'float2', 'float3', 'float4', 'double2', # noqa 'double3', 'double4', 'dtypes_vector_mapper', 'dtype_to_mpidtype', - 'dtype_to_cstr', 'dtype_to_ctype', 'dtype_alloc_ctype', 'dtype_to_mpitype', + 'dtype_to_cstr', 'dtype_to_ctype', 'infer_datasize', 'dtype_to_mpitype', 'dtype_len', 'ctypes_to_cstr', 'c_restrict_void_p', 'ctypes_vector_mapper', 'is_external_ctype', 'infer_dtype', 'CustomDtype'] @@ -152,30 +154,31 @@ def dtype_to_ctype(dtype): return np.ctypeslib.as_ctypes_type(dtype) -def dtype_alloc_ctype(dtype): +def infer_datasize(dtype, shape): """ Translate numpy.dtype to (ctype, int): type and scale for correct C allocation size. """ + datasize = int(reduce(mul, shape)) if isinstance(dtype, CustomDtype): - return dtype, 1 + return dtype, datasize try: - return ctypes_vector_mapper[dtype], 1 + return ctypes_vector_mapper[dtype], datasize except KeyError: pass if issubclass(dtype, ctypes._SimpleCData): - return dtype, 1 + return dtype, datasize if dtype == np.float16: # Allocate half float as unsigned short - return ctypes.c_uint16, 1 + return ctypes.c_uint16, datasize if np.issubdtype(dtype, np.complexfloating): # For complex float, allocate twice the size of real/imaginary part - return np.ctypeslib.as_ctypes_type(dtype(0).real.__class__), 2 + return np.ctypeslib.as_ctypes_type(dtype(0).real.__class__), 2 * datasize - return np.ctypeslib.as_ctypes_type(dtype), 1 + return np.ctypeslib.as_ctypes_type(dtype), datasize def dtype_to_mpitype(dtype): diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 5d31fd2a21..fad1840383 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -131,7 +131,7 @@ def test_complex_headers(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> op = Operator(eq, **kwargs) # Check that the complex header is included <=> complex dtypes are present - header: str = _get_language(**kwargs).get('header-complex') + header: str = _get_language(**kwargs).get('includes-complex') if np.issubdtype(dtype, np.complexfloating): assert header in op._includes else: From 467a3231926aacb4112eb81b3a8fe6147c99760d Mon Sep 17 00:00:00 2001 From: mloubout Date: Fri, 31 Jan 2025 12:08:07 -0500 Subject: [PATCH 58/58] compiler: convert printer to f-string --- devito/ir/cgen/printer.py | 82 +++++++++++++++++------------- devito/ir/iet/nodes.py | 15 +++++- devito/ir/iet/visitors.py | 3 ++ devito/operator/operator.py | 2 +- devito/passes/iet/errors.py | 4 +- devito/symbolics/extended_sympy.py | 9 ++-- docker/Dockerfile.cpu | 56 ++++++++++++-------- requirements-testing.txt | 2 +- 8 files changed, 111 insertions(+), 62 deletions(-) diff --git a/devito/ir/cgen/printer.py b/devito/ir/cgen/printer.py index 328a219de3..bd9b5547ef 100644 --- a/devito/ir/cgen/printer.py +++ b/devito/ir/cgen/printer.py @@ -98,9 +98,16 @@ def func_prefix(self, expr, abs=False): def parenthesize(self, item, level, strict=False): if isinstance(item, BooleanFunction): - return "(%s)" % self._print(item) + return f"({self._print(item)})" return super().parenthesize(item, level, strict=strict) + def _print_PyCPointerType(self, expr): + ctype = f'{self._print_type(expr._type_)}' + if ctype.endswith('*'): + return f'{ctype}*' + else: + return f'{ctype} *' + def _print_type(self, expr): try: expr = dtype_to_ctype(expr) @@ -120,7 +127,7 @@ def _print_Function(self, expr): return super()._print_Function(expr) def _print_CondEq(self, expr): - return "%s == %s" % (self._print(expr.lhs), self._print(expr.rhs)) + return f"{self._print(expr.lhs)} == {self._print(expr.rhs)}" def _print_Indexed(self, expr): """ @@ -131,7 +138,7 @@ def _print_Indexed(self, expr): U[t,x,y,z] -> U[t][x][y][z] """ inds = ''.join(['[' + self._print(x) + ']' for x in expr.indices]) - return '%s%s' % (self._print(expr.base.label), inds) + return f'{self._print(expr.base.label)}{inds}' def _print_FIndexed(self, expr): """ @@ -146,7 +153,7 @@ def _print_FIndexed(self, expr): label = expr.accessor.label except AttributeError: label = expr.base.label - return '%s(%s)' % (self._print(label), inds) + return f'{self._print(label)}({inds})' def _print_Rational(self, expr): """Print a Rational as a C-like float/float division.""" @@ -155,10 +162,8 @@ def _print_Rational(self, expr): # to be 32-bit floats. # http://en.cppreference.com/w/cpp/language/floating_literal p, q = int(expr.p), int(expr.q) - if self.dtype == np.float64: - return '%d.0/%d.0' % (p, q) - else: - return '%d.0F/%d.0F' % (p, q) + prec = self.prec_literal(expr) + return f'{p}.0{prec}/{q}.0{prec}' def _print_math_func(self, expr, nest=False, known=None): cls = type(expr) @@ -208,16 +213,22 @@ def _print_SafeInv(self, expr): def _print_Mod(self, expr): """Print a Mod as a C-like %-based operation.""" - args = ['(%s)' % self._print(a) for a in expr.args] + args = [f'({self._print(a)})' for a in expr.args] return '%'.join(args) def _print_Mul(self, expr): - term = super()._print_Mul(expr) - # avoid (-1)*... - term = term.replace("(-1)*", "-") - # Avoid (-1) / ... - term = term.replace("(-1)/", f"-{self._prec(expr)(1)}/") - return term + args = [a for a in expr.args if a != -1] + neg = (len(expr.args) - len(args)) % 2 + + if len(args) > 1: + term = super()._print_Mul(expr.func(*args, evaluate=False)) + else: + term = self.parenthesize(args[0], precedence(expr)) + + if neg: + return f'-{term}' + else: + return term def _print_fmath_func(self, name, expr): args = ",".join([self._print(i) for i in expr.args]) @@ -230,7 +241,7 @@ def _print_Min(self, expr): expr.func(*expr.args[1:]), evaluate=False)) elif has_integer_args(*expr.args) and len(expr.args) == 2: - return "MIN(%s)" % self._print(expr.args)[1:-1] + return f"MIN({self._print(expr.args)[1:-1]})" else: return self._print_fmath_func('min', expr) @@ -240,7 +251,7 @@ def _print_Max(self, expr): expr.func(*expr.args[1:]), evaluate=False)) elif has_integer_args(*expr.args) and len(expr.args) == 2: - return "MAX(%s)" % self._print(expr.args)[1:-1] + return f"MAX({self._print(expr.args)[1:-1]})" else: return self._print_fmath_func('max', expr) @@ -251,7 +262,7 @@ def _print_Abs(self, expr): # AOMPCC errors with abs, always use fabs if isinstance(self.compiler, AOMPCompiler) and \ not np.issubdtype(self._prec(expr), np.integer): - return "fabs(%s)" % self._print(arg) + return f"fabs({self._print(arg)})" return self._print_fmath_func('abs', expr) def _print_Add(self, expr, order=None): @@ -265,7 +276,7 @@ def _print_Add(self, expr, order=None): for term in terms: t = self._print(term) if precedence(term) < PREC: - l.extend(["+", "(%s)" % t]) + l.extend(["+", f"({t})"]) elif t.startswith('-'): l.extend(["-", t[1:]]) else: @@ -305,44 +316,44 @@ def _print_Float(self, expr): return f'{rv}{self.prec_literal(expr)}' def _print_Differentiable(self, expr): - return "(%s)" % self._print(expr._expr) + return f"({self._print(expr._expr)})" _print_EvalDerivative = _print_Add def _print_CallFromPointer(self, expr): indices = [self._print(i) for i in expr.params] - return "%s->%s(%s)" % (expr.pointer, expr.call, ', '.join(indices)) + return f"{expr.pointer}->{expr.call}({', '.join(indices)})" def _print_CallFromComposite(self, expr): indices = [self._print(i) for i in expr.params] - return "%s.%s(%s)" % (expr.pointer, expr.call, ', '.join(indices)) + return f"{expr.pointer}.{expr.call}({', '.join(indices)})" def _print_FieldFromPointer(self, expr): - return "%s->%s" % (expr.pointer, expr.field) + return f"{expr.pointer}->{expr.field}" def _print_FieldFromComposite(self, expr): - return "%s.%s" % (expr.pointer, expr.field) + return f"{expr.pointer}.{expr.field}" def _print_ListInitializer(self, expr): - return "{%s}" % ', '.join([self._print(i) for i in expr.params]) + return f"{{{', '.join(self._print(i) for i in expr.params)}}}" def _print_IndexedPointer(self, expr): - return "%s%s" % (expr.base, ''.join('[%s]' % self._print(i) for i in expr.index)) + return f"{expr.base}{''.join(f'[{self._print(i)}]' for i in expr.index)}" def _print_IntDiv(self, expr): lhs = self._print(expr.lhs) if not expr.lhs.is_Atom: - lhs = '(%s)' % (lhs) + lhs = f"({lhs})" rhs = self._print(expr.rhs) PREC = precedence(expr) - return self.parenthesize("%s / %s" % (lhs, rhs), PREC) + return self.parenthesize(f"{lhs} / {rhs}", PREC) def _print_InlineIf(self, expr): cond = self._print(expr.cond) true_expr = self._print(expr.true_expr) false_expr = self._print(expr.false_expr) PREC = precedence(expr) - return self.parenthesize("(%s) ? %s : %s" % (cond, true_expr, false_expr), PREC) + return self.parenthesize(f"({cond}) ? {true_expr} : {false_expr}", PREC) def _print_UnaryOp(self, expr, op=None, parenthesize=False): op = op or expr._op @@ -356,20 +367,23 @@ def _print_Cast(self, expr): return self._print_UnaryOp(expr, op=cast) def _print_ComponentAccess(self, expr): - return "%s.%s" % (self._print(expr.base), expr.sindex) + return f"{self._print(expr.base)}.{expr.sindex}" def _print_DefFunction(self, expr): arguments = [self._print(i) for i in expr.arguments] if expr.template: - template = '<%s>' % ','.join([str(i) for i in expr.template]) + ctemplate = ','.join([str(i) for i in expr.template]) + template = f'<{ctemplate}>' else: template = '' - return "%s%s(%s)" % (expr.name, template, ','.join(arguments)) + args = ','.join(arguments) + return f"{expr.name}{template}({args})" def _print_SizeOf(self, expr): return f'sizeof({self._print(expr.intype)}{self._print(expr.stars)})' - _print_MathFunction = _print_DefFunction + def _print_MathFunction(self, expr): + return f"{self._ns}{self._print_DefFunction(expr)}" def _print_Fallback(self, expr): return expr.__str__() @@ -385,7 +399,7 @@ def _print_Fallback(self, expr): # Lifted from SymPy so that we go through our own `_print_math_func` for k in ('exp log sin cos tan ceiling floor').split(): - setattr(BasePrinter, '_print_%s' % k, BasePrinter._print_math_func) + setattr(BasePrinter, f'_print_{k}', BasePrinter._print_math_func) # Always parenthesize IntDiv and InlineIf within expressions diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index 2e5620c644..ebf22a5aba 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -30,7 +30,7 @@ 'Increment', 'Return', 'While', 'ListMajor', 'ParallelIteration', 'ParallelBlock', 'Dereference', 'Lambda', 'SyncSpot', 'Pragma', 'DummyExpr', 'BlankLine', 'ParallelTree', 'BusyWait', 'UsingNamespace', - 'CallableBody', 'Transfer'] + 'Using', 'CallableBody', 'Transfer'] # First-class IET nodes @@ -1212,6 +1212,19 @@ def periodic(self): return self._periodic +class Using(Node): + + """ + A C++ using directive. + """ + + def __init__(self, name): + self.name = name + + def __repr__(self): + return "" % self.name + + class UsingNamespace(Node): """ diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 845d56f3c5..d0a9e98465 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -604,6 +604,9 @@ def visit_MultiTraversable(self, o): body.extend(as_tuple(v)) return c.Collection(body) + def visit_Using(self, o): + return c.Statement(f'using {str(o.name)}') + def visit_UsingNamespace(self, o): return c.Statement(f'using namespace {str(o.namespace)}') diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 0dc74d2c3e..0907d4c3be 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -1125,7 +1125,7 @@ def __setstate__(self, state): self._lib.name = soname self._allocator = default_allocator( - '%s.%s.%s' % (self._compiler.__class__.name, self._language, self._platform) + '%s.%s.%s' % (type(self._compiler).__name__, self._language, self._platform) ) diff --git a/devito/passes/iet/errors.py b/devito/passes/iet/errors.py index 13f1101a3f..27331efef5 100644 --- a/devito/passes/iet/errors.py +++ b/devito/passes/iet/errors.py @@ -6,7 +6,7 @@ List, Break, Return, FindNodes, FindSymbols, Transformer, make_callable) from devito.passes.iet.engine import iet_pass -from devito.symbolics import CondEq, DefFunction +from devito.symbolics import CondEq, MathFunction from devito.tools import dtype_to_ctype from devito.types import Eq, Inc, LocalObject, Symbol @@ -58,7 +58,7 @@ def _check_stability(iet, wmovs=(), rcompile=None, sregistry=None): irs, byproduct = rcompile(eqns) name = sregistry.make_name(prefix='is_finite') - retval = Return(DefFunction('isfinite', accumulator)) + retval = Return(MathFunction('isfinite', accumulator)) body = irs.iet.body.body + (retval,) efunc = make_callable(name, body, retval='int') diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 8281902e78..953b52fa7d 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -776,9 +776,12 @@ def __new__(cls, intype, stars=None, **kwargs): stars = stars or '' if not isinstance(intype, (str, ReservedWord)): ctype = dtype_to_ctype(intype) - if ctype in ctypes_vector_mapper.values(): - idx = list(ctypes_vector_mapper.values()).index(ctype) - intype = list(ctypes_vector_mapper.keys())[idx] + for k, v in ctypes_vector_mapper.items(): + if ctype is v: + intype = k + break + else: + intype = ctypes_to_cstr(ctype) newobj = super().__new__(cls, 'sizeof', arguments=f'{intype}{stars}', **kwargs) newobj.stars = stars diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu index 3b2ae714c7..dca2ceb98a 100644 --- a/docker/Dockerfile.cpu +++ b/docker/Dockerfile.cpu @@ -4,45 +4,53 @@ # architectures using GCC compilers and OpenMPI. ############################################################## -ARG OMPI_BRANCH="v4.1.4" +ARG ubuntu="22.04" # Base image -FROM ubuntu:22.04 as base +FROM ubuntu:${ubuntu} AS base + +ARG python="python3" +ARG gcc="" + +ENV DEBIAN_FRONTEND=noninteractive -ENV DEBIAN_FRONTEND noninteractive +# Add repo for other python versions +RUN apt-get update && \ + apt-get install -y software-properties-common && \ + add-apt-repository ppa:deadsnakes/ppa # Install python RUN apt-get update && \ - apt-get install -y dh-autoreconf python3-venv python3-dev python3-pip + apt-get install -y dh-autoreconf ${python}-venv ${python}-dev python3-pip + +# Set python3 to use the specified Python version by default +RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/${python} 1 \ + && update-alternatives --config python3 # Install for basic base not containing it RUN apt-get install -y vim wget git flex libnuma-dev tmux \ numactl hwloc curl \ autoconf libtool build-essential procps software-properties-common + +# Install compilers +RUN if [ -n "$gcc" ]; then \ + apt-get install -y gcc-${gcc} g++-${gcc} && \ + # Update alternatives + update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-${gcc} 60 && \ + update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-${gcc} 60;\ + fi; + # Install tmpi RUN curl https://raw.githubusercontent.com/Azrael3000/tmpi/master/tmpi -o /usr/local/bin/tmpi # Install OpenGL library, necessary for the installation of GemPy -RUN apt-get install -y libgl1-mesa-glx +RUN apt-get install -y libgl1-mesa-glx | echo "Skipping libgl1-mesa-glx installation, outdated" +RUN apt-get install -y libgl1 libglx-mesa0 | echo "Skipping libgl1 libglx-mesa0 installation, not recent enough" RUN apt-get clean && apt-get autoclean && apt-get autoremove -y && \ rm -rf /var/lib/apt/lists/* -EXPOSE 8888 -CMD ["/bin/bash"] - -############################################################## -# GCC standard image -############################################################## -FROM base as gcc - -# Install gcc 13 for better hardware and software support -RUN add-apt-repository ppa:ubuntu-toolchain-r/test -y && apt update && \ - apt install gcc-13 g++-13 -y && \ - update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-13 100 && \ - update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-13 100 - ARG OMPI_BRANCH="v4.1.4" # Install OpenMPI RUN mkdir -p /deps && mkdir -p /opt/openmpi && cd /deps && \ @@ -57,7 +65,15 @@ RUN mkdir -p /deps && mkdir -p /opt/openmpi && cd /deps && \ # Set OpenMPI path ENV PATH=${PATH}:/opt/openmpi/bin -ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/opt/openmpi/lib +ENV LD_LIBRARY_PATH=/opt/openmpi/lib + +EXPOSE 8888 +CMD ["/bin/bash"] + +############################################################## +# GCC standard image +############################################################## +FROM base AS gcc # Env vars defaults ENV DEVITO_ARCH="gcc" diff --git a/requirements-testing.txt b/requirements-testing.txt index 0f88276721..e8ff2b0866 100644 --- a/requirements-testing.txt +++ b/requirements-testing.txt @@ -5,4 +5,4 @@ codecov flake8>=2.1.0 nbval scipy -pooch; python_version >= "3.8" +pooch \ No newline at end of file