Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Add support for more datatypes #1786

Merged
merged 27 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
002825a
Add half precision support
SF-N Oct 24, 2024
573d88f
Merge branch 'main' into half_precision
SF-N Oct 24, 2024
b4e7fe5
Extend support for further datatypes
SF-N Oct 25, 2024
db5a29a
Merge branch 'main' into half_precision
SF-N Oct 28, 2024
60f424f
Merge branch 'main' into half_precision
SF-N Nov 1, 2024
60df4e2
bfloat16 changes
SF-N Nov 5, 2024
2a5ea80
BFloat16 working
tehrengruber Nov 7, 2024
b95921c
Cleanup ml_dtypes import
SF-N Nov 8, 2024
7b39f2a
Merge branch 'main' into half_precision
SF-N Nov 8, 2024
4de666e
Merge main
SF-N Dec 30, 2024
c654ca2
Undo float16 and bfloat16 changes
SF-N Dec 30, 2024
84341b7
Minor and pre-commit
SF-N Dec 30, 2024
a4919f1
Fixed ScalarKind enum and Doctests and moved builtins from ir.py to b…
SF-N Dec 30, 2024
20fba48
Run pre-commit
SF-N Dec 30, 2024
84b51ec
Fix tests
SF-N Dec 30, 2024
52c7e6e
Use itir_to_gtfn.py:pytype_to_cpptype in cpp_interface.py:render_scal…
SF-N Dec 30, 2024
5b886ce
Renaming and update comments
SF-N Dec 30, 2024
a82bd88
Merge branch 'main' into more_datatypes
SF-N Jan 14, 2025
1f8de33
Fix type_deduction doctest
SF-N Jan 14, 2025
1eb4277
Merge branch 'main' into more_datatypes
SF-N Jan 16, 2025
23d406a
Small cleanup
tehrengruber Jan 17, 2025
2d08ed3
Small cleanup
tehrengruber Jan 17, 2025
3d41cd8
Merge branch 'main' into more_datatypes
SF-N Jan 17, 2025
cd337e5
Minor fix
SF-N Jan 17, 2025
d37fe9f
Fix test and cleanup unused builtins
SF-N Jan 20, 2025
f988ea4
Merge branch 'main' into more_datatypes
SF-N Jan 20, 2025
67d0763
Merge branch 'main' into more_datatypes
tehrengruber Jan 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Any, Callable, Final, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast

import numpy as np
from numpy import float32, float64, int32, int64
from numpy import float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64

from gt4py._core import definitions as core_defs
from gt4py.next import common
Expand All @@ -29,12 +29,19 @@
TYPE_BUILTINS = [
common.Field,
common.Dimension,
int8,
uint8,
int16,
uint16,
int32,
uint32,
int64,
uint64,
float32,
float64,
*PYTHON_TYPE_BUILTINS,
]
] # TODO(tehrengruber): validate matches iterator.builtins.TYPE_BUILTINS?

TYPE_BUILTIN_NAMES = [t.__name__ for t in TYPE_BUILTINS]

# Be aware: Type aliases are not fully supported in the frontend yet, e.g. `IndexType(1)` will not
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,15 @@ def promote_to_mask_type(
>>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL)
>>> dtype = ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
>>> promote_to_mask_type(ts.FieldType(dims=[I, J], dtype=bool_type), ts.ScalarType(kind=dtype))
FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='J', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=ScalarType(kind=<ScalarKind.FLOAT64: 1064>, shape=None), shape=None))
FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='J', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=ScalarType(kind=<ScalarKind.FLOAT64: 11>, shape=None), shape=None))
>>> promote_to_mask_type(
... ts.FieldType(dims=[I, J], dtype=bool_type), ts.FieldType(dims=[I], dtype=dtype)
... )
FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='J', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.FLOAT64: 1064>, shape=None))
FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='J', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.FLOAT64: 11>, shape=None))
>>> promote_to_mask_type(
... ts.FieldType(dims=[I], dtype=bool_type), ts.FieldType(dims=[I, J], dtype=dtype)
... )
FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='J', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.FLOAT64: 1064>, shape=None))
FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='J', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.FLOAT64: 11>, shape=None))
"""
if isinstance(input_type, ts.ScalarType) or not all(
item in input_type.dims for item in mask_type.dims
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
type_specifications as ts_ffront,
)
from gt4py.next.ffront.stages import AOT_PRG
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator import builtins, ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.otf import stages, workflow
from gt4py.next.type_system import type_info, type_specifications as ts
Expand Down Expand Up @@ -222,7 +222,7 @@ def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]:
itir.Sym(
id=_size_arg_from_field(param.id, dim_idx),
type=ts.ScalarType(
kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())
kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())
),
)
)
Expand Down Expand Up @@ -359,7 +359,7 @@ def _construct_itir_domain_arg(
else:
lower = self._visit_slice_bound(
slices[dim_i].lower if slices else None,
im.literal("0", itir.INTEGER_INDEX_BUILTIN),
im.literal("0", builtins.INTEGER_INDEX_BUILTIN),
dim_size,
)
upper = self._visit_slice_bound(
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _scan_param_promotion(param: ts.TypeSpec, arg: ts.TypeSpec) -> ts.FieldType
... ts.ScalarType(kind=ts.ScalarKind.INT64),
... ts.FieldType(dims=[common.Dimension("I")], dtype=ts.ScalarKind.FLOAT64),
... )
FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.INT64: 64>, shape=None))
FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.INT64: 8>, shape=None))
"""

def _as_field(dtype: ts.TypeSpec, path: tuple[int, ...]) -> ts.FieldType:
Expand Down
130 changes: 93 additions & 37 deletions src/gt4py/next/iterator/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,16 +337,46 @@ def int(*args): # noqa: A001 [builtin-variable-shadowing]
raise BackendNotSelectedError()


@builtin_dispatch
def int8(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def uint8(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def int16(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def uint16(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def int32(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def uint32(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def int64(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def uint64(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def float(*args): # noqa: A001 [builtin-variable-shadowing]
raise BackendNotSelectedError()
Expand All @@ -368,6 +398,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing]


UNARY_MATH_NUMBER_BUILTINS = {"abs"}
UNARY_LOGICAL_BUILTINS = {"not_"}
UNARY_MATH_FP_BUILTINS = {
"sin",
"cos",
Expand All @@ -391,52 +422,77 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing]
"trunc",
}
UNARY_MATH_FP_PREDICATE_BUILTINS = {"isfinite", "isinf", "isnan"}
BINARY_MATH_NUMBER_BUILTINS = {"minimum", "maximum", "fmod", "power"}
TYPEBUILTINS = {"int32", "int64", "float32", "float64", "bool"}
MATH_BUILTINS = (
UNARY_MATH_NUMBER_BUILTINS
| UNARY_MATH_FP_BUILTINS
| UNARY_MATH_FP_PREDICATE_BUILTINS
| BINARY_MATH_NUMBER_BUILTINS
| TYPEBUILTINS
)
BINARY_MATH_NUMBER_BUILTINS = {"minimum", "maximum", "fmod"}
BINARY_MATH_NUMBER_ADDITIONAL_BUILTINS = {
"plus",
"minus",
"multiplies",
"divides",
"mod",
"floordiv", # TODO see https://github.com/GridTools/gt4py/issues/1136
*BINARY_MATH_NUMBER_BUILTINS,
}
BINARY_MATH_COMPARISON_BUILTINS = {"eq", "less", "greater", "greater_equal", "less_equal", "not_eq"}
BINARY_LOGICAL_BUILTINS = {"and_", "or_", "xor_"}


#: builtin / dtype used to construct integer indices, like domain bounds
INTEGER_INDEX_BUILTIN = "int32"
INTEGER_BUILTINS = {
"int8",
"uint8",
"int16",
"uint16",
"int32",
"uint32",
"int64",
"uint64",
} # Todo: should we distinguish int and uint?
SF-N marked this conversation as resolved.
Show resolved Hide resolved
FLOATING_POINT_BUILTINS = {"float32", "float64"}
TYPE_BUILTINS = {*INTEGER_BUILTINS, *FLOATING_POINT_BUILTINS, "bool"}

MATH_BUILTINS = {
*UNARY_MATH_NUMBER_BUILTINS,
*UNARY_MATH_FP_BUILTINS,
*UNARY_MATH_FP_PREDICATE_BUILTINS,
*BINARY_MATH_NUMBER_BUILTINS,
"power",
*TYPE_BUILTINS,
}

ARITHMETIC_BUILTINS = {
*UNARY_MATH_NUMBER_BUILTINS,
*UNARY_LOGICAL_BUILTINS,
*UNARY_MATH_FP_BUILTINS,
*UNARY_MATH_FP_PREDICATE_BUILTINS,
*BINARY_MATH_NUMBER_ADDITIONAL_BUILTINS,
"power",
*BINARY_MATH_COMPARISON_BUILTINS,
*BINARY_LOGICAL_BUILTINS,
}

BUILTINS = {
"deref",
"as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution)
"can_deref",
"cartesian_domain",
"cast_",
"deref",
"if_",
"index", # `index(dim)` creates a dim-field that has the current index at each point
"shift",
"neighbors",
"list_get",
"lift",
"make_const_list",
"make_tuple",
"map_",
"lift",
"named_range",
"neighbors",
"reduce",
"plus",
"minus",
"multiplies",
"divides",
"floordiv",
"mod",
"make_tuple",
"tuple_get",
"if_",
"cast_",
"greater",
"less",
"less_equal",
"greater_equal",
"eq",
"not_eq",
"not_",
"and_",
"or_",
"xor_",
"scan",
"cartesian_domain",
"tuple_get",
"unstructured_domain",
"named_range",
"as_fieldop",
"index",
*MATH_BUILTINS,
*ARITHMETIC_BUILTINS,
*TYPE_BUILTINS,
}

__all__ = [*BUILTINS]
14 changes: 13 additions & 1 deletion src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,19 @@
TupleAxis: TypeAlias = type[None]
Axis: TypeAlias = Union[FieldAxis, TupleAxis]
Scalar: TypeAlias = (
SupportsInt | SupportsFloat | np.int32 | np.int64 | np.float32 | np.float64 | np.bool_
SupportsInt
| SupportsFloat
| np.int8
| np.uint8
| np.int16
| np.uint16
| np.int32
| np.uint32
| np.int64
| np.uint64
| np.float32
| np.float64
| np.bool_
)


Expand Down
82 changes: 1 addition & 81 deletions src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait
from gt4py.eve.utils import noninstantiable
from gt4py.next import common
from gt4py.next.iterator.builtins import BUILTINS
from gt4py.next.type_system import type_specifications as ts


Expand Down Expand Up @@ -93,87 +94,6 @@ class FunctionDefinition(Node, SymbolTableTrait):
expr: Expr


UNARY_MATH_NUMBER_BUILTINS = {"abs"}
UNARY_LOGICAL_BUILTINS = {"not_"}
UNARY_MATH_FP_BUILTINS = {
"sin",
"cos",
"tan",
"arcsin",
"arccos",
"arctan",
"sinh",
"cosh",
"tanh",
"arcsinh",
"arccosh",
"arctanh",
"sqrt",
"exp",
"log",
"gamma",
"cbrt",
"floor",
"ceil",
"trunc",
}
UNARY_MATH_FP_PREDICATE_BUILTINS = {"isfinite", "isinf", "isnan"}
BINARY_MATH_NUMBER_BUILTINS = {
"minimum",
"maximum",
"fmod",
"plus",
"minus",
"multiplies",
"divides",
"mod",
"floordiv", # TODO see https://github.com/GridTools/gt4py/issues/1136
}
BINARY_MATH_COMPARISON_BUILTINS = {"eq", "less", "greater", "greater_equal", "less_equal", "not_eq"}
BINARY_LOGICAL_BUILTINS = {"and_", "or_", "xor_"}

ARITHMETIC_BUILTINS = {
*UNARY_MATH_NUMBER_BUILTINS,
*UNARY_LOGICAL_BUILTINS,
*UNARY_MATH_FP_BUILTINS,
*UNARY_MATH_FP_PREDICATE_BUILTINS,
*BINARY_MATH_NUMBER_BUILTINS,
"power",
*BINARY_MATH_COMPARISON_BUILTINS,
*BINARY_LOGICAL_BUILTINS,
}

#: builtin / dtype used to construct integer indices, like domain bounds
INTEGER_INDEX_BUILTIN = "int32"
INTEGER_BUILTINS = {"int32", "int64"}
FLOATING_POINT_BUILTINS = {"float32", "float64"}
TYPEBUILTINS = {*INTEGER_BUILTINS, *FLOATING_POINT_BUILTINS, "bool"}

BUILTINS = {
"tuple_get",
"cast_",
"cartesian_domain",
"unstructured_domain",
"make_tuple",
"shift",
"neighbors",
"named_range",
"list_get",
"map_",
"make_const_list",
"lift",
"reduce",
"deref",
"can_deref",
"scan",
"if_",
"index", # `index(dim)` creates a dim-field that has the current index at each point
"as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution)
*ARITHMETIC_BUILTINS,
*TYPEBUILTINS,
}


class Stmt(Node): ...


Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/iterator/ir_utils/domain_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import Any, Literal, Mapping, Optional

from gt4py.next import common
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator import builtins, ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.transforms import trace_shifts
from gt4py.next.iterator.transforms.constant_folding import ConstantFolding
Expand Down Expand Up @@ -127,7 +127,7 @@ def translate(
else:
# note: ugly but cheap re-computation, but should disappear
horizontal_sizes = {
k: im.literal(str(v), itir.INTEGER_INDEX_BUILTIN)
k: im.literal(str(v), builtins.INTEGER_INDEX_BUILTIN)
for k, v in _max_domain_sizes_by_location_type(offset_provider).items()
}

Expand All @@ -137,7 +137,7 @@ def translate(
assert new_dim not in new_ranges or old_dim == new_dim

new_range = SymbolicRange(
im.literal("0", itir.INTEGER_INDEX_BUILTIN),
im.literal("0", builtins.INTEGER_INDEX_BUILTIN),
horizontal_sizes[new_dim.value],
)
new_ranges = dict(
Expand Down
Loading
Loading