Skip to content

Commit

Permalink
compiler: move dtype pass to top level operator iet pass
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jan 16, 2025
1 parent 3458305 commit 9768015
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 23 deletions.
4 changes: 4 additions & 0 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
generate_macros, minimize_symbols, unevaluate,
error_mapper, is_on_device)
from devito.passes.iet.dtypes import lower_dtypes
from devito.symbolics import estimate_cost, subs_op_args
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_mapper, as_tuple,
flatten, filter_sorted, frozendict, is_integer,
Expand Down Expand Up @@ -489,6 +490,9 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs):
# Extract the necessary macros from the symbolic objects
generate_macros(graph, **kwargs)

# Add type specific metadata
lower_dtypes(graph, lang=cls._Target.lang, **kwargs)

# Target-independent optimizations
minimize_symbols(graph)

Expand Down
8 changes: 0 additions & 8 deletions devito/passes/iet/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
FindNodes, FindSymbols, MapExprStmts, Transformer,
make_callable)
from devito.passes import is_gpu_create
from devito.passes.iet.dtypes import lower_dtypes
from devito.passes.iet.engine import iet_pass
from devito.passes.iet.langbase import LangBB
from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer,
Expand Down Expand Up @@ -465,18 +464,12 @@ def place_casts(self, iet, **kwargs):

return iet, {}

@iet_pass
def lower_dtypes(self, iet):
iet, metadata = lower_dtypes(iet, self.lang, self.compiler, self.sregistry)
return iet, metadata

def process(self, graph):
"""
Apply the `place_definitions` and `place_casts` passes.
"""
self.place_definitions(graph, globs=set())
self.place_casts(graph)
self.lower_dtypes(graph)


class DeviceAwareDataManager(DataManager):
Expand Down Expand Up @@ -618,7 +611,6 @@ def process(self, graph):
self.place_devptr(graph)
self.place_bundling(graph, writes_input=graph.writes_input)
self.place_casts(graph)
self.lower_dtypes(graph)


def make_zero_init(obj, rcompile, sregistry):
Expand Down
14 changes: 7 additions & 7 deletions devito/passes/iet/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,25 @@

from devito.arch.compiler import Compiler
from devito.ir import Callable, FindSymbols, SymbolRegistry
from devito.passes.iet.engine import iet_pass
from devito.passes.iet.langbase import LangBB

__all__ = ['lower_dtypes']


def lower_dtypes(iet: Callable, lang: type[LangBB], compiler: Compiler,
sregistry: SymbolRegistry) -> tuple[Callable, dict]:
def lower_dtypes(graph: Callable, lang: type[LangBB] = None, compiler: Compiler = None,
sregistry: SymbolRegistry = None, **kwargs) -> tuple[Callable, dict]:
"""
Lowers float16 scalar types to pointers since we can't directly pass their
value. Also includes headers for complex arithmetic if needed.
"""
# Complex numbers
iet, metadata = _complex_includes(iet, lang, compiler)

return iet, metadata
_complex_includes(graph, lang=lang, compiler=compiler)


def _complex_includes(iet: Callable, lang: type[LangBB],
compiler: Compiler) -> tuple[Callable, dict]:
@iet_pass
def _complex_includes(iet: Callable, lang: type[LangBB] = None,
compiler: Compiler = None) -> tuple[Callable, dict]:
"""
Includes complex arithmetic headers for the given language, if needed.
"""
Expand Down
8 changes: 0 additions & 8 deletions devito/passes/iet/langbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,6 @@ def initialize(self, iet, options=None):
"""
return iet, {}

@iet_pass
def make_langtypes(self, iet):
"""
An `iet_pass` which transforms an IET such that the target language
types are introduced.
"""
return iet, {}

@property
def Region(self):
return self.lang.Region
Expand Down

0 comments on commit 9768015

Please sign in to comment.