Skip to content

Commit

Permalink
feat[next]: Add support for more datatypes (#1786)
Browse files Browse the repository at this point in the history
This builds on [PR#1708](#1708)
without the `float16` and `bfloat16` changes.

Add support for `int8, uin8, int16, uint16, uint32` and `uint64`.

Move builtin definitions from `src/gt4py/next/iterator/ir.py` to
`src/gt4py/next/iterator/builtins.py`.

Use ascending integer values in `ScalarKind`-Enum and modify tests
respectively.

Set `start: int = 1` in `tests/next_tests/integration_tests/cases.py` to
not start initialization from zero as this has the same value as
zero-initialized memory and modify tests respectively.

---------

Co-authored-by: Till Ehrengruber <[email protected]>
  • Loading branch information
SF-N and tehrengruber authored Jan 21, 2025
1 parent ae603cb commit 022a73c
Show file tree
Hide file tree
Showing 38 changed files with 304 additions and 357 deletions.
11 changes: 9 additions & 2 deletions src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Any, Callable, Final, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast

import numpy as np
from numpy import float32, float64, int32, int64
from numpy import float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64

from gt4py._core import definitions as core_defs
from gt4py.next import common
Expand All @@ -29,12 +29,19 @@
TYPE_BUILTINS = [
common.Field,
common.Dimension,
int8,
uint8,
int16,
uint16,
int32,
uint32,
int64,
uint64,
float32,
float64,
*PYTHON_TYPE_BUILTINS,
]
] # TODO(tehrengruber): validate matches iterator.builtins.TYPE_BUILTINS?

TYPE_BUILTIN_NAMES = [t.__name__ for t in TYPE_BUILTINS]

# Be aware: Type aliases are not fully supported in the frontend yet, e.g. `IndexType(1)` will not
Expand Down
8 changes: 4 additions & 4 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def construct_tuple_type(
... ts.ScalarType(kind=ts.ScalarKind.FLOAT64),
... ]
>>> print(construct_tuple_type(true_branch_types, false_branch_types, mask_type))
[FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.FLOAT64: 1064>, shape=None)), FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.FLOAT64: 1064>, shape=None))]
[FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.FLOAT64: 11>, shape=None)), FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.FLOAT64: 11>, shape=None))]
"""
element_types_new = true_branch_types
for i, element in enumerate(true_branch_types):
Expand Down Expand Up @@ -111,15 +111,15 @@ def promote_to_mask_type(
>>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL)
>>> dtype = ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
>>> promote_to_mask_type(ts.FieldType(dims=[I, J], dtype=bool_type), dtype)
FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='J', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.FLOAT64: 1064>, shape=None))
FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='J', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.FLOAT64: 11>, shape=None))
>>> promote_to_mask_type(
... ts.FieldType(dims=[I, J], dtype=bool_type), ts.FieldType(dims=[I], dtype=dtype)
... )
FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='J', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.FLOAT64: 1064>, shape=None))
FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='J', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.FLOAT64: 11>, shape=None))
>>> promote_to_mask_type(
... ts.FieldType(dims=[I], dtype=bool_type), ts.FieldType(dims=[I, J], dtype=dtype)
... )
FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='J', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.FLOAT64: 1064>, shape=None))
FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='J', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.FLOAT64: 11>, shape=None))
"""
if isinstance(input_type, ts.ScalarType) or not all(
item in input_type.dims for item in mask_type.dims
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
type_specifications as ts_ffront,
)
from gt4py.next.ffront.stages import AOT_PRG
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator import builtins, ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.otf import stages, workflow
from gt4py.next.type_system import type_info, type_specifications as ts
Expand Down Expand Up @@ -218,7 +218,7 @@ def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]:
if len(fields_dims) > 0: # otherwise `param` has no constituent which is of `FieldType`
assert all(field_dims == fields_dims[0] for field_dims in fields_dims)
index_type = ts.ScalarType(
kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())
kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())
)
for dim_idx in range(len(fields_dims[0])):
size_params.append(
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _scan_param_promotion(param: ts.TypeSpec, arg: ts.TypeSpec) -> ts.FieldType
... dims=[common.Dimension("I")], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
... ),
... )
FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.INT64: 64>, shape=None))
FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.INT64: 8>, shape=None))
"""

def _as_field(dtype: ts.TypeSpec, path: tuple[int, ...]) -> ts.FieldType:
Expand Down
122 changes: 85 additions & 37 deletions src/gt4py/next/iterator/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,16 +337,46 @@ def int(*args): # noqa: A001 [builtin-variable-shadowing]
raise BackendNotSelectedError()


@builtin_dispatch
def int8(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def uint8(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def int16(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def uint16(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def int32(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def uint32(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def int64(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def uint64(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def float(*args): # noqa: A001 [builtin-variable-shadowing]
raise BackendNotSelectedError()
Expand All @@ -368,6 +398,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing]


UNARY_MATH_NUMBER_BUILTINS = {"abs"}
UNARY_LOGICAL_BUILTINS = {"not_"}
UNARY_MATH_FP_BUILTINS = {
"sin",
"cos",
Expand All @@ -391,52 +422,69 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing]
"trunc",
}
UNARY_MATH_FP_PREDICATE_BUILTINS = {"isfinite", "isinf", "isnan"}
BINARY_MATH_NUMBER_BUILTINS = {"minimum", "maximum", "fmod", "power"}
TYPEBUILTINS = {"int32", "int64", "float32", "float64", "bool"}
MATH_BUILTINS = (
UNARY_MATH_NUMBER_BUILTINS
| UNARY_MATH_FP_BUILTINS
| UNARY_MATH_FP_PREDICATE_BUILTINS
| BINARY_MATH_NUMBER_BUILTINS
| TYPEBUILTINS
)
BINARY_MATH_NUMBER_BUILTINS = {
"plus",
"minus",
"multiplies",
"divides",
"mod",
"floordiv", # TODO see https://github.com/GridTools/gt4py/issues/1136
"minimum",
"maximum",
"fmod",
}
BINARY_MATH_COMPARISON_BUILTINS = {"eq", "less", "greater", "greater_equal", "less_equal", "not_eq"}
BINARY_LOGICAL_BUILTINS = {"and_", "or_", "xor_"}


#: builtin / dtype used to construct integer indices, like domain bounds
INTEGER_INDEX_BUILTIN = "int32"
INTEGER_TYPE_BUILTINS = {
"int8",
"uint8",
"int16",
"uint16",
"int32",
"uint32",
"int64",
"uint64",
}
FLOATING_POINT_TYPE_BUILTINS = {"float32", "float64"}
TYPE_BUILTINS = {*INTEGER_TYPE_BUILTINS, *FLOATING_POINT_TYPE_BUILTINS, "bool"}

ARITHMETIC_BUILTINS = {
*UNARY_MATH_NUMBER_BUILTINS,
*UNARY_LOGICAL_BUILTINS,
*UNARY_MATH_FP_BUILTINS,
*UNARY_MATH_FP_PREDICATE_BUILTINS,
*BINARY_MATH_NUMBER_BUILTINS,
"power",
*BINARY_MATH_COMPARISON_BUILTINS,
*BINARY_LOGICAL_BUILTINS,
}

BUILTINS = {
"deref",
"as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution)
"can_deref",
"cartesian_domain",
"cast_",
"deref",
"if_",
"index", # `index(dim)` creates a dim-field that has the current index at each point
"shift",
"neighbors",
"list_get",
"lift",
"make_const_list",
"make_tuple",
"map_",
"lift",
"named_range",
"neighbors",
"reduce",
"plus",
"minus",
"multiplies",
"divides",
"floordiv",
"mod",
"make_tuple",
"tuple_get",
"if_",
"cast_",
"greater",
"less",
"less_equal",
"greater_equal",
"eq",
"not_eq",
"not_",
"and_",
"or_",
"xor_",
"scan",
"cartesian_domain",
"tuple_get",
"unstructured_domain",
"named_range",
"as_fieldop",
"index",
*MATH_BUILTINS,
*ARITHMETIC_BUILTINS,
*TYPE_BUILTINS,
}

__all__ = [*BUILTINS]
Loading

0 comments on commit 022a73c

Please sign in to comment.