diff --git a/devito/core/cpu.py b/devito/core/cpu.py index 27ba1ebb51d..82b8c09dca6 100644 --- a/devito/core/cpu.py +++ b/devito/core/cpu.py @@ -63,9 +63,12 @@ def _normalize_kwargs(cls, **kwargs): # Distributed parallelism o['dist-drop-unwritten'] = oo.pop('dist-drop-unwritten', cls.DIST_DROP_UNWRITTEN) - # Misc + # Code generation options for derivatives o['expand'] = oo.pop('expand', cls.EXPAND) o['deriv-schedule'] = oo.pop('deriv-schedule', cls.DERIV_SCHEDULE) + o['deriv-unroll'] = oo.pop('deriv-unroll', False) + + # Misc o['opt-comms'] = oo.pop('opt-comms', True) o['linearize'] = oo.pop('linearize', False) o['mapify-reduce'] = oo.pop('mapify-reduce', cls.MAPIFY_REDUCE) diff --git a/devito/core/gpu.py b/devito/core/gpu.py index 8a61442feab..50168a8f0a1 100644 --- a/devito/core/gpu.py +++ b/devito/core/gpu.py @@ -78,9 +78,12 @@ def _normalize_kwargs(cls, **kwargs): # Distributed parallelism o['dist-drop-unwritten'] = oo.pop('dist-drop-unwritten', cls.DIST_DROP_UNWRITTEN) - # Misc + # Code generation options for derivatives o['expand'] = oo.pop('expand', cls.EXPAND) o['deriv-schedule'] = oo.pop('deriv-schedule', cls.DERIV_SCHEDULE) + o['deriv-unroll'] = oo.pop('deriv-unroll', False) + + # Misc o['opt-comms'] = oo.pop('opt-comms', True) o['linearize'] = oo.pop('linearize', False) o['mapify-reduce'] = oo.pop('mapify-reduce', cls.MAPIFY_REDUCE) diff --git a/devito/core/operator.py b/devito/core/operator.py index 0b2c2c62c7f..020c8dc618e 100644 --- a/devito/core/operator.py +++ b/devito/core/operator.py @@ -1,7 +1,7 @@ from collections.abc import Iterable from devito.core.autotuning import autotune -from devito.exceptions import InvalidOperator +from devito.exceptions import InvalidArgument, InvalidOperator from devito.logger import warning from devito.mpi.routines import mpi_registry from devito.parameters import configuration @@ -150,6 +150,11 @@ def _check_kwargs(cls, **kwargs): if oo['mpi'] and oo['mpi'] not in cls.MPI_MODES: raise InvalidOperator("Unsupported MPI mode `%s`" % oo['mpi']) + if oo['deriv-schedule'] not in ('basic', 'smart'): + raise InvalidArgument("Illegal `deriv-schedule` value") + if oo['deriv-unroll'] not in (False, 'inner', 'full'): + raise InvalidArgument("Illegal `deriv-unroll` value") + def _autotune(self, args, setup): if setup in [False, 'off']: return args diff --git a/devito/passes/clusters/derivatives.py b/devito/passes/clusters/derivatives.py index a4fdc43ee1d..bcecd6b4a92 100644 --- a/devito/passes/clusters/derivatives.py +++ b/devito/passes/clusters/derivatives.py @@ -30,7 +30,11 @@ def lower_index_derivatives(clusters, mode=None, **kwargs): def _lower_index_derivatives(clusters, sregistry=None, options=None, **kwargs): - callback = deriv_schedule_registry[options['deriv-schedule']] + try: + callback0 = deriv_schedule_registry[options['deriv-schedule']] + callback1 = deriv_unroll_registry[options['deriv-unroll']] + except KeyError: + raise ValueError("Unknown derivative lowering mode") weights = {} processed = [] @@ -51,8 +55,8 @@ def dump(exprs, c): else: reusable = set() - expr, v = _core(e, c, c.ispace, weights, reusable, mapper, callback, - sregistry) + expr, v = _core(e, c, c.ispace, weights, reusable, mapper, + callback0, callback1, sregistry) if v: dump(exprs, c) @@ -74,7 +78,8 @@ def dump(exprs, c): return processed, weights, mapper -def _core(expr, c, ispace, weights, reusables, mapper, callback, sregistry): +def _core(expr, c, ispace, weights, reusables, mapper, callback0, callback1, + sregistry): """ Recursively carry out the core of `lower_index_derivatives`. """ @@ -85,8 +90,8 @@ def _core(expr, c, ispace, weights, reusables, mapper, callback, sregistry): args = [] processed = [] for a in expr.args: - e, clusters = _core(a, c, ispace, weights, reusables, mapper, callback, - sregistry) + e, clusters = _core(a, c, ispace, weights, reusables, mapper, + callback0, callback1, sregistry) args.append(e) processed.extend(clusters) @@ -95,12 +100,7 @@ def _core(expr, c, ispace, weights, reusables, mapper, callback, sregistry): return expr, processed # Lower the IndexDerivative - init, ideriv = callback(expr) - - # TODO: When we'll support unrolling, probably all we have to check at this - # point is whether `ideriv` is actually an IndexDerivative. If it's not, then - # we can just return `init` as is, which is expected to contain the result of - # the unrolled IndexDerivative computation + init, ideriv = callback0(expr) # Create the concrete Weights array, or reuse an already existing one # if possible @@ -124,17 +124,21 @@ def _core(expr, c, ispace, weights, reusables, mapper, callback, sregistry): directions = {d: Backward if d.backward else Forward for d in dims} ispace0 = IterationSpace(intervals, directions=directions) - extra = (ispace.itdims + dims,) - ispace1 = IterationSpace.union(ispace, ispace0, relations=extra) - # Minimize the amount of integer arithmetic to calculate the various index # access functions by enforcing start at 0, e.g. `r0[x + i0 + 2] -> r0[x + i0]` base = ideriv.base for d in dims: - ispace1 = ispace1.translate(d, -d._min) + ispace0 = ispace0.translate(d, -d._min) base = base.subs(d, d + d._min) ideriv = ideriv._subs(ideriv.base, base) + # Should the IndexDerivative be unrolled? + init, expr, ispace0 = callback1(init, ideriv, ispace0) + + # The full IterationSpace + extra = (ispace.itdims + ispace0.itdims,) + ispace1 = IterationSpace.union(ispace, ispace0, relations=extra) + # The Symbol that will hold the result of the IndexDerivative computation # NOTE: created before recurring so that we ultimately get a sound ordering try: @@ -144,17 +148,20 @@ def _core(expr, c, ispace, weights, reusables, mapper, callback, sregistry): name = sregistry.make_name(prefix='r') s = Symbol(name=name, dtype=c.dtype) - expr0 = Eq(s, init) - processed = [c.rebuild(exprs=expr0, ispace=ispace)] + # Go inside `expr` and recursively lower any nested IndexDerivatives + expr, processed = _core(expr, c, ispace1, weights, reusables, mapper, + callback0, callback1, sregistry) - # Go inside `ideriv` - expr, clusters = _core(ideriv.expr, c, ispace1, weights, reusables, mapper, - callback, sregistry) - processed.extend(clusters) + # Finally inject the lowered IndexDerivative + if init is not None: + expr0 = Eq(s, init) + processed.insert(0, c.rebuild(exprs=expr0, ispace=ispace)) - # Finally append the lowered IndexDerivative - expr1 = Inc(s, expr) - processed.append(c.rebuild(exprs=expr1, ispace=ispace1)) + expr1 = Inc(s, expr) + processed.append(c.rebuild(exprs=expr1, ispace=ispace1)) + else: + expr1 = Eq(s, expr) + processed.append(c.rebuild(exprs=expr1, ispace=ispace1)) # Track the lowered IndexDerivative for subsequent optimization by the caller mapper.setdefault(expr, []).append(s) @@ -171,6 +178,11 @@ def _lower_index_derivative_base(ideriv): } +deriv_unroll_registry = { + False: lambda init, ideriv, ispace: (init, ideriv.expr, ispace) +} + + class CDE(Queue): """