Skip to content

Commit

Permalink
compiler: Support derivatives unrolling
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Feb 12, 2024
1 parent f8711a2 commit 377d189
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 28 deletions.
5 changes: 4 additions & 1 deletion devito/core/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,12 @@ def _normalize_kwargs(cls, **kwargs):
# Distributed parallelism
o['dist-drop-unwritten'] = oo.pop('dist-drop-unwritten', cls.DIST_DROP_UNWRITTEN)

# Misc
# Code generation options for derivatives
o['expand'] = oo.pop('expand', cls.EXPAND)
o['deriv-schedule'] = oo.pop('deriv-schedule', cls.DERIV_SCHEDULE)
o['deriv-unroll'] = oo.pop('deriv-unroll', False)

# Misc
o['opt-comms'] = oo.pop('opt-comms', True)
o['linearize'] = oo.pop('linearize', False)
o['mapify-reduce'] = oo.pop('mapify-reduce', cls.MAPIFY_REDUCE)
Expand Down
5 changes: 4 additions & 1 deletion devito/core/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,12 @@ def _normalize_kwargs(cls, **kwargs):
# Distributed parallelism
o['dist-drop-unwritten'] = oo.pop('dist-drop-unwritten', cls.DIST_DROP_UNWRITTEN)

# Misc
# Code generation options for derivatives
o['expand'] = oo.pop('expand', cls.EXPAND)
o['deriv-schedule'] = oo.pop('deriv-schedule', cls.DERIV_SCHEDULE)
o['deriv-unroll'] = oo.pop('deriv-unroll', False)

# Misc
o['opt-comms'] = oo.pop('opt-comms', True)
o['linearize'] = oo.pop('linearize', False)
o['mapify-reduce'] = oo.pop('mapify-reduce', cls.MAPIFY_REDUCE)
Expand Down
7 changes: 6 additions & 1 deletion devito/core/operator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Iterable

from devito.core.autotuning import autotune
from devito.exceptions import InvalidOperator
from devito.exceptions import InvalidArgument, InvalidOperator
from devito.logger import warning
from devito.mpi.routines import mpi_registry
from devito.parameters import configuration
Expand Down Expand Up @@ -150,6 +150,11 @@ def _check_kwargs(cls, **kwargs):
if oo['mpi'] and oo['mpi'] not in cls.MPI_MODES:
raise InvalidOperator("Unsupported MPI mode `%s`" % oo['mpi'])

if oo['deriv-schedule'] not in ('basic', 'smart'):
raise InvalidArgument("Illegal `deriv-schedule` value")
if oo['deriv-unroll'] not in (False, 'inner', 'full'):
raise InvalidArgument("Illegal `deriv-unroll` value")

def _autotune(self, args, setup):
if setup in [False, 'off']:
return args
Expand Down
62 changes: 37 additions & 25 deletions devito/passes/clusters/derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ def lower_index_derivatives(clusters, mode=None, **kwargs):


def _lower_index_derivatives(clusters, sregistry=None, options=None, **kwargs):
callback = deriv_schedule_registry[options['deriv-schedule']]
try:
callback0 = deriv_schedule_registry[options['deriv-schedule']]
callback1 = deriv_unroll_registry[options['deriv-unroll']]
except KeyError:
raise ValueError("Unknown derivative lowering mode")

weights = {}
processed = []
Expand All @@ -51,8 +55,8 @@ def dump(exprs, c):
else:
reusable = set()

expr, v = _core(e, c, c.ispace, weights, reusable, mapper, callback,
sregistry)
expr, v = _core(e, c, c.ispace, weights, reusable, mapper,
callback0, callback1, sregistry)

if v:
dump(exprs, c)
Expand All @@ -74,7 +78,8 @@ def dump(exprs, c):
return processed, weights, mapper


def _core(expr, c, ispace, weights, reusables, mapper, callback, sregistry):
def _core(expr, c, ispace, weights, reusables, mapper, callback0, callback1,
sregistry):
"""
Recursively carry out the core of `lower_index_derivatives`.
"""
Expand All @@ -85,8 +90,8 @@ def _core(expr, c, ispace, weights, reusables, mapper, callback, sregistry):
args = []
processed = []
for a in expr.args:
e, clusters = _core(a, c, ispace, weights, reusables, mapper, callback,
sregistry)
e, clusters = _core(a, c, ispace, weights, reusables, mapper,
callback0, callback1, sregistry)
args.append(e)
processed.extend(clusters)

Expand All @@ -95,12 +100,7 @@ def _core(expr, c, ispace, weights, reusables, mapper, callback, sregistry):
return expr, processed

# Lower the IndexDerivative
init, ideriv = callback(expr)

# TODO: When we'll support unrolling, probably all we have to check at this
# point is whether `ideriv` is actually an IndexDerivative. If it's not, then
# we can just return `init` as is, which is expected to contain the result of
# the unrolled IndexDerivative computation
init, ideriv = callback0(expr)

# Create the concrete Weights array, or reuse an already existing one
# if possible
Expand All @@ -124,17 +124,21 @@ def _core(expr, c, ispace, weights, reusables, mapper, callback, sregistry):
directions = {d: Backward if d.backward else Forward for d in dims}
ispace0 = IterationSpace(intervals, directions=directions)

extra = (ispace.itdims + dims,)
ispace1 = IterationSpace.union(ispace, ispace0, relations=extra)

# Minimize the amount of integer arithmetic to calculate the various index
# access functions by enforcing start at 0, e.g. `r0[x + i0 + 2] -> r0[x + i0]`
base = ideriv.base
for d in dims:
ispace1 = ispace1.translate(d, -d._min)
ispace0 = ispace0.translate(d, -d._min)
base = base.subs(d, d + d._min)
ideriv = ideriv._subs(ideriv.base, base)

# Should the IndexDerivative be unrolled?
init, expr, ispace0 = callback1(init, ideriv, ispace0)

# The full IterationSpace
extra = (ispace.itdims + ispace0.itdims,)
ispace1 = IterationSpace.union(ispace, ispace0, relations=extra)

# The Symbol that will hold the result of the IndexDerivative computation
# NOTE: created before recurring so that we ultimately get a sound ordering
try:
Expand All @@ -144,17 +148,20 @@ def _core(expr, c, ispace, weights, reusables, mapper, callback, sregistry):
name = sregistry.make_name(prefix='r')
s = Symbol(name=name, dtype=c.dtype)

expr0 = Eq(s, init)
processed = [c.rebuild(exprs=expr0, ispace=ispace)]
# Go inside `expr` and recursively lower any nested IndexDerivatives
expr, processed = _core(expr, c, ispace1, weights, reusables, mapper,
callback0, callback1, sregistry)

# Go inside `ideriv`
expr, clusters = _core(ideriv.expr, c, ispace1, weights, reusables, mapper,
callback, sregistry)
processed.extend(clusters)
# Finally inject the lowered IndexDerivative
if init is not None:
expr0 = Eq(s, init)
processed.insert(0, c.rebuild(exprs=expr0, ispace=ispace))

# Finally append the lowered IndexDerivative
expr1 = Inc(s, expr)
processed.append(c.rebuild(exprs=expr1, ispace=ispace1))
expr1 = Inc(s, expr)
processed.append(c.rebuild(exprs=expr1, ispace=ispace1))
else:
expr1 = Eq(s, expr)
processed.append(c.rebuild(exprs=expr1, ispace=ispace1))

# Track the lowered IndexDerivative for subsequent optimization by the caller
mapper.setdefault(expr, []).append(s)
Expand All @@ -171,6 +178,11 @@ def _lower_index_derivative_base(ideriv):
}


deriv_unroll_registry = {
False: lambda init, ideriv, ispace: (init, ideriv.expr, ispace)
}


class CDE(Queue):

"""
Expand Down

0 comments on commit 377d189

Please sign in to comment.