From 002825ab7b2e9188ebf967b47d4d7670c64a306d Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 24 Oct 2024 15:21:02 +0200 Subject: [PATCH 01/17] Add half precision support --- src/gt4py/_core/definitions.py | 3 +- src/gt4py/next/ffront/fbuiltins.py | 23 ++++++++- src/gt4py/next/iterator/builtins.py | 50 ++++++++++++++++++- src/gt4py/next/iterator/embedded.py | 15 +++++- src/gt4py/next/iterator/ir.py | 13 ++++- src/gt4py/next/otf/binding/cpp_interface.py | 16 +++++- .../compilation/build_systems/cmake_lists.py | 3 +- .../codegens/gtfn/codegen.py | 8 +++ .../codegens/gtfn/itir_to_gtfn_ir.py | 7 +++ .../runners/dace_common/utility.py | 19 +++---- .../runners/dace_iterator/itir_to_tasklet.py | 14 ++++++ src/gt4py/next/type_system/type_info.py | 11 +++- .../next/type_system/type_specifications.py | 17 +++++-- .../next/type_system/type_translation.py | 10 +--- tests/next_tests/definitions.py | 2 + tests/next_tests/integration_tests/cases.py | 1 + .../ffront_tests/test_execution.py | 13 +++++ 17 files changed, 191 insertions(+), 34 deletions(-) diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 9d07b2eb79..8d2acca918 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -62,6 +62,7 @@ uint32 = np.uint32 uint64 = np.uint64 +float16 = np.float16 float32 = np.float32 float64 = np.float64 @@ -94,7 +95,7 @@ INTEGRAL_TYPES: Final[Tuple[type, ...]] = (*INT_TYPES, *UINT_TYPES) -FloatingScalar: TypeAlias = Union[float32, float64, float] +FloatingScalar: TypeAlias = Union[float16, float32, float64, float] FloatingT = TypeVar("FloatingT", bound=FloatingScalar) FLOAT_TYPES: Final[Tuple[type, ...]] = cast( Tuple[type, ...], diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 3b711212a3..8e5b66a626 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -14,7 +14,19 @@ from typing import Any, Callable, Final, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast import numpy as np -from numpy import float32, float64, int32, int64 +from numpy import ( + float16, + float32, + float64, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +) import gt4py.next as gtx from gt4py._core import definitions as core_defs @@ -30,12 +42,19 @@ TYPE_BUILTINS = [ common.Field, common.Dimension, + int8, + uint8, + int16, + uint16, int32, + uint32, int64, + uint64, + float16, float32, float64, *PYTHON_TYPE_BUILTINS, -] +] # TODO(tehrengruber): validate matches itir type builtins? TYPE_BUILTIN_NAMES = [t.__name__ for t in TYPE_BUILTINS] # Be aware: Type aliases are not fully supported in the frontend yet, e.g. `IndexType(1)` will not diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 264ac2685c..a90e6f0e08 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -332,21 +332,56 @@ def int(*args): # noqa: A001 [builtin-variable-shadowing] raise BackendNotSelectedError() +@builtin_dispatch +def int8(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def uint8(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def int16(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def uint16(*args): + raise BackendNotSelectedError() + + @builtin_dispatch def int32(*args): raise BackendNotSelectedError() +@builtin_dispatch +def uint32(*args): + raise BackendNotSelectedError() + + @builtin_dispatch def int64(*args): raise BackendNotSelectedError() +@builtin_dispatch +def uint64(*args): + raise BackendNotSelectedError() + + @builtin_dispatch def float(*args): # noqa: A001 [builtin-variable-shadowing] raise BackendNotSelectedError() +@builtin_dispatch +def float16(*args): + raise BackendNotSelectedError() + + @builtin_dispatch def float32(*args): raise BackendNotSelectedError() @@ -387,7 +422,20 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] } UNARY_MATH_FP_PREDICATE_BUILTINS = {"isfinite", "isinf", "isnan"} BINARY_MATH_NUMBER_BUILTINS = {"minimum", "maximum", "fmod", "power"} -TYPEBUILTINS = {"int32", "int64", "float32", "float64", "bool"} +TYPEBUILTINS = { + "int8", + "uint8", + "int16", + "uint16", + "int32", + "uint32", + "int64", + "uint64", + "float16", + "float32", + "float64", + "bool", +} # TODO(tehrengruber): This list already exists in ir.py; unify. MATH_BUILTINS = ( UNARY_MATH_NUMBER_BUILTINS | UNARY_MATH_FP_BUILTINS diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index afe0cec402..498031b180 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -85,7 +85,20 @@ TupleAxis: TypeAlias = type[None] Axis: TypeAlias = Union[FieldAxis, TupleAxis] Scalar: TypeAlias = ( - SupportsInt | SupportsFloat | np.int32 | np.int64 | np.float32 | np.float64 | np.bool_ + SupportsInt + | SupportsFloat + | np.int8 + | np.uint8 + | np.int16 + | np.uint16 + | np.int32 + | np.uint32 + | np.int64 + | np.uint64 + | np.float16 + | np.float32 + | np.float64 + | np.bool_ ) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 42da4c83a6..7c0ca8751e 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -161,8 +161,17 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib #: builtin / dtype used to construct integer indices, like domain bounds INTEGER_INDEX_BUILTIN = "int32" -INTEGER_BUILTINS = {"int32", "int64"} -FLOATING_POINT_BUILTINS = {"float32", "float64"} +INTEGER_BUILTINS = { + "int8", + "uint8", + "int16", + "uint16", + "int32", + "uint32", + "int64", + "uint64", +} # Todo: should we distinguish int and uint? +FLOATING_POINT_BUILTINS = {"float16", "float32", "float64"} TYPEBUILTINS = {*INTEGER_BUILTINS, *FLOATING_POINT_BUILTINS, "bool"} BUILTINS = { diff --git a/src/gt4py/next/otf/binding/cpp_interface.py b/src/gt4py/next/otf/binding/cpp_interface.py index d112a9c256..fb21dfc93c 100644 --- a/src/gt4py/next/otf/binding/cpp_interface.py +++ b/src/gt4py/next/otf/binding/cpp_interface.py @@ -19,13 +19,27 @@ def render_scalar_type(scalar_type: ts.ScalarType) -> str: - match scalar_type.kind: + match scalar_type.kind: # TODO: merge with dict in itir_tp_gtfn case ts.ScalarKind.BOOL: return "bool" + case ts.ScalarKind.INT8: + return "std::int8_t" + case ts.ScalarKind.UINT8: + return "std::uint8_t" + case ts.ScalarKind.INT16: + return "std::int16_t" + case ts.ScalarKind.UINT16: + return "std::uint16_t" case ts.ScalarKind.INT32: return "std::int32_t" + case ts.ScalarKind.UINT32: + return "std::uint32_t" case ts.ScalarKind.INT64: return "std::int64_t" + case ts.ScalarKind.UINT64: + return "std::uint64_t" + case ts.ScalarKind.FLOAT16: + return "std::float16_t" case ts.ScalarKind.FLOAT32: return "float" case ts.ScalarKind.FLOAT64: diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py index 0533adac81..9701b6eb61 100644 --- a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py +++ b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py @@ -69,7 +69,8 @@ class CMakeListsGenerator(eve.codegen.TemplatedGenerator): # Targets add_library({{project_name}} MODULE) - target_compile_features({{project_name}} PRIVATE cxx_std_17) + #target_compile_features({{project_name}} PRIVATE cxx_std_17) + target_compile_features({{project_name}} PRIVATE cxx_std_23) set_target_properties({{project_name}} PROPERTIES PREFIX "" SUFFIX ".{{bin_output_suffix}}") target_sources({{project_name}} diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index 92dbcedeaa..5f5eb1a923 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -50,10 +50,17 @@ class GTFNCodegen(codegen.TemplatedGenerator): "maximum": "std::max", "fmod": "std::fmod", "power": "std::pow", + "float16": "std::float16_t", "float32": "float", "float64": "double", + "int8": "std::int8_t", + "uint8": "std::uint8_t", + "int16": "std::int16_t", + "uint16": "std::uint16_t", "int32": "std::int32_t", + "uint32": "std::uint32_t", "int64": "std::int64_t", + "uint64": "std::uint64_t", "bool": "bool", "plus": "std::plus{}", "minus": "std::minus{}", @@ -256,6 +263,7 @@ def visit_Program(self, node: gtfn_ir.Program, **kwargs: Any) -> Union[str, Coll """ #include #include + #include #include #include #include diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 3bd96d14d7..79c75a4220 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -52,10 +52,17 @@ def pytype_to_cpptype(t: ts.ScalarType | str) -> Optional[str]: t = t.kind.name.lower() try: return { + "float16": "std::float16_t", "float32": "float", "float64": "double", + "int8": "std::int8_t", + "uint8": "std::uint8_t", + "int16": "std::int16_t", + "uint16": "std::uint16_t", "int32": "std::int32_t", + "uint32": "std::uint32_t", "int64": "std::int64_t", + "uint64": "std::uint64_t", "bool": "bool", "axis_literal": None, # TODO: domain? }[t] diff --git a/src/gt4py/next/program_processors/runners/dace_common/utility.py b/src/gt4py/next/program_processors/runners/dace_common/utility.py index d678fdab7f..fd5c4823ef 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_common/utility.py @@ -24,17 +24,14 @@ def as_dace_type(type_: ts.ScalarType) -> dace.typeclass: """Converts GT4Py scalar type to corresponding DaCe type.""" - if type_.kind == ts.ScalarKind.BOOL: - return dace.bool_ - elif type_.kind == ts.ScalarKind.INT32: - return dace.int32 - elif type_.kind == ts.ScalarKind.INT64: - return dace.int64 - elif type_.kind == ts.ScalarKind.FLOAT32: - return dace.float32 - elif type_.kind == ts.ScalarKind.FLOAT64: - return dace.float64 - raise ValueError(f"Scalar type '{type_}' not supported.") + + match type_.kind: + case ts.ScalarKind.BOOL: + return dace.bool_ + case ts.ScalarKind(): + return getattr(dace, type_.kind.name.lower()) + case _: + raise ValueError(f"Scalar type '{type_}' not supported.") def as_itir_type(dtype: dace.typeclass) -> ts.ScalarType: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 991053b4a5..1d408f2287 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -40,11 +40,18 @@ _TYPE_MAPPING = { "float": dace.float64, + "float16": dace.float16, "float32": dace.float32, "float64": dace.float64, "int": dace.int32 if np.dtype(int).itemsize == 4 else dace.int64, + "int8": dace.int8, + "uint8": dace.uint8, + "int16": dace.int16, + "uint16": dace.uint16, "int32": dace.int32, + "uint32": dace.uint32, "int64": dace.int64, + "uint64": dace.uint64, "bool": dace.bool_, } @@ -109,11 +116,18 @@ def get_reduce_identity_value(op_name_: str, type_: Any): "fmod": "fmod({}, {})", "power": "math.pow({}, {})", "float": "dace.float64({})", + "float16": "dace.float16({})", "float32": "dace.float32({})", "float64": "dace.float64({})", "int": "dace.int32({})" if np.dtype(int).itemsize == 4 else "dace.int64({})", + "int8": "dace.int8({})", + "uint8": "dace.uint8({})", + "int16": "dace.int16({})", + "uint16": "dace.uint16({})", "int32": "dace.int32({})", + "uint32": "dace.uint32({})", "int64": "dace.int64({})", + "uint64": "dace.uint64({})", "bool": "dace.bool_({})", "plus": "({} + {})", "minus": "({} - {})", diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 5bda9a6f2e..9953321bf8 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -234,7 +234,11 @@ def is_floating_point(symbol_type: ts.TypeSpec) -> bool: >>> is_floating_point(ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))) True """ - return extract_dtype(symbol_type).kind in [ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64] + return extract_dtype(symbol_type).kind in [ + ts.ScalarKind.FLOAT16, + ts.ScalarKind.FLOAT32, + ts.ScalarKind.FLOAT64, + ] def is_integer(symbol_type: ts.TypeSpec) -> bool: @@ -251,7 +255,12 @@ def is_integer(symbol_type: ts.TypeSpec) -> bool: False """ return isinstance(symbol_type, ts.ScalarType) and symbol_type.kind in { + ts.ScalarKind.INT8, + ts.ScalarKind.UINT8, + ts.ScalarKind.INT16, + ts.ScalarKind.UINT16, ts.ScalarKind.INT32, + ts.ScalarKind.UINT32, ts.ScalarKind.INT64, } diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 0827d99cdc..14beefeb7b 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -72,11 +72,18 @@ def __str__(self) -> str: class ScalarKind(IntEnum): BOOL = 1 - INT32 = 32 - INT64 = 64 - FLOAT32 = 1032 - FLOAT64 = 1064 - STRING = 3001 + INT8 = 2 + UINT8 = 3 + INT16 = 4 + UINT16 = 5 + INT32 = 6 + UINT32 = 7 + INT64 = 8 + UINT64 = 9 + FLOAT16 = 10 + FLOAT32 = 11 + FLOAT64 = 12 + STRING = 13 @dataclass(frozen=True) diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 62a6781316..89744ad059 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -42,16 +42,10 @@ def get_scalar_kind(dtype: npt.DTypeLike) -> ts.ScalarKind: match dt: case np.bool_: return ts.ScalarKind.BOOL - case np.int32: - return ts.ScalarKind.INT32 - case np.int64: - return ts.ScalarKind.INT64 - case np.float32: - return ts.ScalarKind.FLOAT32 - case np.float64: - return ts.ScalarKind.FLOAT64 case np.str_: return ts.ScalarKind.STRING + case np.dtype(): + return getattr(ts.ScalarKind, dt.name.upper()) case _: raise ValueError(f"Impossible to map '{dtype}' value to a 'ScalarKind'.") else: diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 123384a098..015af8abb1 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -117,6 +117,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_MESH_WITH_SKIP_VALUES = "uses_mesh_with_skip_values" USES_SCALAR_IN_DOMAIN_AND_FO = "uses_scalar_in_domain_and_fo" CHECKS_SPECIFIC_ERROR = "checks_specific_error" +USES_HALF_PRECISION = "uses_half_precision" # Skip messages (available format keys: 'marker', 'backend') UNSUPPORTED_MESSAGE = "'{marker}' tests not supported by '{backend}' backend" @@ -145,6 +146,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE), (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_HALF_PRECISION, XFAIL, UNSUPPORTED_MESSAGE), ] GTIR_DACE_SKIP_TEST_LIST = [ (ALL, SKIP, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index d85cd5b3df..0009b9887a 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -61,6 +61,7 @@ IField: TypeAlias = gtx.Field[[IDim], np.int32] # type: ignore [valid-type] IFloatField: TypeAlias = gtx.Field[[IDim], np.float64] # type: ignore [valid-type] +IHalfField: TypeAlias = gtx.Field[[IDim], np.float16] # type: ignore [valid-type] IBoolField: TypeAlias = gtx.Field[[IDim], bool] # type: ignore [valid-type] KField: TypeAlias = gtx.Field[[KDim], np.int32] # type: ignore [valid-type] IJField: TypeAlias = gtx.Field[[IDim, JDim], np.int32] # type: ignore [valid-type] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 36d6debf9d..8bb4ee4eee 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -1201,3 +1201,16 @@ def consume_constants(input: cases.IFloatField) -> cases.IFloatField: cases.verify_with_default_data( cartesian_case, consume_constants, ref=lambda input: constants.PI * constants.E * input ) + + +@pytest.mark.uses_half_precision +def test_half_precision(cartesian_case): + dtype = np.float16 + + @gtx.field_operator + def multiply_by_two(input: cases.IHalfField, input2: cases.IFloatField) -> cases.IHalfField: + return dtype(2) * input * astype(input2, dtype) + + cases.verify_with_default_data( + cartesian_case, multiply_by_two, ref=lambda input, input2: dtype(2) * input * input2 + ) From b4e7fe5fe689c763d0d269fb4cd2a31cfe98d963 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 25 Oct 2024 10:37:58 +0200 Subject: [PATCH 02/17] Extend support for further datatypes --- src/gt4py/_core/definitions.py | 3 +++ src/gt4py/next/program_processors/codegens/gtfn/codegen.py | 3 +++ src/gt4py/next/type_system/type_info.py | 7 +++++++ src/gt4py/next/type_system/type_translation.py | 4 ++++ .../feature_tests/ffront_tests/test_execution.py | 6 +++--- 5 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 8d2acca918..e05c925ef0 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -315,6 +315,9 @@ class Int64DType(SignedIntDType[int64]): class FloatingDType(DType[FloatingT]): pass +@dataclasses.dataclass(frozen=True) # TODO +class Float16DType(FloatingDType[float16]): + scalar_type: Final[Type[float16]] = dataclasses.field(default=float16, init=False) @dataclasses.dataclass(frozen=True) class Float32DType(FloatingDType[float32]): diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index 5f5eb1a923..f8330cb64c 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -105,9 +105,12 @@ def visit_Literal(self, node: gtfn_ir.Literal, **kwargs: Any) -> str: return self.asfloat(node.value) + "f" case "double": return self.asfloat(node.value) + case "std::float16_t": + return "std::float16_t("+self.asfloat(node.value)+")" # TODO: this is not a proper solution case "bool": return node.value.lower() case _: + # TODO: we should probably shouldn't just allow anything here. revisit. return node.value IntegralConstant = as_fmt("{value}_c") diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 9953321bf8..c7b2bb4d7d 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -328,10 +328,17 @@ def is_arithmetic(symbol_type: ts.TypeSpec) -> bool: def arithmetic_bounds(arithmetic_type: ts.ScalarType) -> tuple[np.number, np.number]: assert is_arithmetic(arithmetic_type) return { # type: ignore[return-value] # why resolved to `tuple[object, object]`? + ts.ScalarKind.FLOAT16: (np.finfo(np.float16).min, np.finfo(np.float16).max), # todo: cleanup? ts.ScalarKind.FLOAT32: (np.finfo(np.float32).min, np.finfo(np.float32).max), ts.ScalarKind.FLOAT64: (np.finfo(np.float64).min, np.finfo(np.float64).max), + ts.ScalarKind.INT8: (np.iinfo(np.int8).min, np.iinfo(np.int8).max), + ts.ScalarKind.UINT8: (np.iinfo(np.uint8).min, np.iinfo(np.uint8).max), + ts.ScalarKind.INT16: (np.iinfo(np.int16).min, np.iinfo(np.int16).max), + ts.ScalarKind.UINT16: (np.iinfo(np.uint16).min, np.iinfo(np.uint16).max), ts.ScalarKind.INT32: (np.iinfo(np.int32).min, np.iinfo(np.int32).max), + ts.ScalarKind.UINT32: (np.iinfo(np.uint32).min, np.iinfo(np.uint32).max), ts.ScalarKind.INT64: (np.iinfo(np.int64).min, np.iinfo(np.int64).max), + ts.ScalarKind.UINT64: (np.iinfo(np.uint64).min, np.iinfo(np.uint64).max), }[arithmetic_type.kind] diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 89744ad059..d8824dd93c 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -239,6 +239,8 @@ def as_dtype(type_: ts.ScalarType) -> core_defs.DType: return core_defs.Int32DType() elif type_.kind == ts.ScalarKind.INT64: return core_defs.Int64DType() + elif type_.kind == ts.ScalarKind.FLOAT16: # TODO + return core_defs.Float16DType() elif type_.kind == ts.ScalarKind.FLOAT32: return core_defs.Float32DType() elif type_.kind == ts.ScalarKind.FLOAT64: @@ -259,6 +261,8 @@ def from_dtype(dtype: core_defs.DType) -> ts.ScalarType: return ts.ScalarType(kind=ts.ScalarKind.INT32) elif dtype == core_defs.Int64DType(): return ts.ScalarType(kind=ts.ScalarKind.INT64) + elif dtype == core_defs.Float16DType(): #TODO + return ts.ScalarType(kind=ts.ScalarKind.FLOAT16) elif dtype == core_defs.Float32DType(): return ts.ScalarType(kind=ts.ScalarKind.FLOAT32) elif dtype == core_defs.Float64DType(): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 8bb4ee4eee..bde4813b11 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -1208,9 +1208,9 @@ def test_half_precision(cartesian_case): dtype = np.float16 @gtx.field_operator - def multiply_by_two(input: cases.IHalfField, input2: cases.IFloatField) -> cases.IHalfField: - return dtype(2) * input * astype(input2, dtype) + def multiply_by_two(input: cases.IHalfField, input2: cases.IFloatField, scalar: np.float16) -> cases.IHalfField: + return dtype(2) * input * astype(input2, dtype) * scalar ** 2.0 cases.verify_with_default_data( - cartesian_case, multiply_by_two, ref=lambda input, input2: dtype(2) * input * input2 + cartesian_case, multiply_by_two, ref=lambda input, input2, scalar: dtype(2) * input * input2 * scalar ** 2.0 ) From 60df4e232958281075c15b081fcd4510fc9e5cfb Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 5 Nov 2024 17:07:47 +0100 Subject: [PATCH 03/17] bfloat16 changes --- pyproject.toml | 1 + src/gt4py/_core/definitions.py | 18 +++++++++++++-- src/gt4py/next/ffront/fbuiltins.py | 7 ++++++ src/gt4py/next/ffront/foast_to_itir.py | 6 +++++ src/gt4py/next/iterator/builtins.py | 13 ++++++++++- src/gt4py/next/iterator/embedded.py | 8 +++++++ src/gt4py/next/iterator/ir.py | 7 +++++- src/gt4py/next/otf/binding/cpp_interface.py | 2 ++ .../codegens/gtfn/codegen.py | 8 +++++++ .../codegens/gtfn/itir_to_gtfn_ir.py | 1 + .../runners/dace_iterator/itir_to_tasklet.py | 2 +- src/gt4py/next/type_system/type_info.py | 12 ++++++++-- .../next/type_system/type_specifications.py | 7 +++--- .../next/type_system/type_translation.py | 10 ++++++++ tests/next_tests/integration_tests/cases.py | 13 ++++++++++- .../ffront_tests/test_execution.py | 23 +++++++++++++++---- 16 files changed, 123 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3c3efab625..9e84ba8dcb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,7 @@ jax-cuda12 = ['jax[cuda12_pip]>=0.4.18; python_version>="3.10"'] performance = ['scipy>=1.9.2'] rocm-43 = ['cupy-rocm-4-3'] testing = ['hypothesis>=6.0.0', 'pytest>=7.0'] +half-precision = ['ml_dtypes'] [project.scripts] gtpyc = 'gt4py.cartesian.cli:gtpyc' diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index e05c925ef0..97a2e9c287 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -14,6 +14,10 @@ import functools import math import numbers +try: + from ml_dtypes import bfloat16 +except ModuleNotFoundError: + bfloat16 = None import numpy as np import numpy.typing as npt @@ -63,6 +67,8 @@ uint64 = np.uint64 float16 = np.float16 +if bfloat16: + bfloat16 = bfloat16 float32 = np.float32 float64 = np.float64 @@ -94,8 +100,10 @@ IntegralT = TypeVar("IntegralT", bound=IntegralScalar) INTEGRAL_TYPES: Final[Tuple[type, ...]] = (*INT_TYPES, *UINT_TYPES) - -FloatingScalar: TypeAlias = Union[float16, float32, float64, float] +if bfloat16: + FloatingScalar: TypeAlias = Union[float16, bfloat16, float32, float64, float] +else: + FloatingScalar: TypeAlias = Union[float16, float32, float64, float] FloatingT = TypeVar("FloatingT", bound=FloatingScalar) FLOAT_TYPES: Final[Tuple[type, ...]] = cast( Tuple[type, ...], @@ -319,6 +327,12 @@ class FloatingDType(DType[FloatingT]): class Float16DType(FloatingDType[float16]): scalar_type: Final[Type[float16]] = dataclasses.field(default=float16, init=False) +if bfloat16: + @dataclasses.dataclass(frozen=True) # TODO + class BFloat16DType(FloatingDType[bfloat16]): + scalar_type: Final[Type[bfloat16]] = dataclasses.field(default=bfloat16, init=False) + + @dataclasses.dataclass(frozen=True) class Float32DType(FloatingDType[float32]): scalar_type: Final[Type[float32]] = dataclasses.field(default=float32, init=False) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 8e5b66a626..68cfab88cd 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -27,6 +27,10 @@ uint32, uint64, ) +try: + from ml_dtypes import bfloat16 +except ModuleNotFoundError: + bfloat16 = None import gt4py.next as gtx from gt4py._core import definitions as core_defs @@ -55,6 +59,9 @@ float64, *PYTHON_TYPE_BUILTINS, ] # TODO(tehrengruber): validate matches itir type builtins? +if bfloat16: + TYPE_BUILTINS.append(bfloat16) + TYPE_BUILTIN_NAMES = [t.__name__ for t in TYPE_BUILTINS] # Be aware: Type aliases are not fully supported in the frontend yet, e.g. `IndexType(1)` will not diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 7936eda1cf..dc49f3babe 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -31,6 +31,10 @@ from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.otf import toolchain, workflow from gt4py.next.type_system import type_info, type_specifications as ts +try: + from ml_dtypes import bfloat16 +except ModuleNotFoundError: + bfloat16 = None def foast_to_itir(inp: FOP) -> itir.Expr: @@ -461,6 +465,8 @@ def _visit_type_constr(self, node: foast.Call, **kwargs: Any) -> itir.Expr: raise FieldOperatorLoweringError( f"Type cast only supports literal arguments, {node.type} not supported." ) + # if bfloat16 and node_kind == 'bfloat16': + # val = float(val) val = target_type(val) return im.promote_to_const_iterator(im.literal(str(val), node_kind)) diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index a90e6f0e08..c5b62b667e 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -7,7 +7,10 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py.next.iterator.dispatcher import Dispatcher - +try: + from ml_dtypes import bfloat16 +except ModuleNotFoundError: + bfloat16 = None builtin_dispatch = Dispatcher() @@ -377,6 +380,11 @@ def float(*args): # noqa: A001 [builtin-variable-shadowing] raise BackendNotSelectedError() +if bfloat16: + @builtin_dispatch + def bfloat16(*args): + raise BackendNotSelectedError() + @builtin_dispatch def float16(*args): raise BackendNotSelectedError() @@ -436,6 +444,9 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] "float64", "bool", } # TODO(tehrengruber): This list already exists in ir.py; unify. +if bfloat16: + TYPEBUILTINS.add("bfloat16") + MATH_BUILTINS = ( UNARY_MATH_NUMBER_BUILTINS | UNARY_MATH_FP_BUILTINS diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 498031b180..4b36a968f2 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -20,6 +20,10 @@ import numpy as np import numpy.typing as npt +try: + from ml_dtypes import bfloat16 +except ModuleNotFoundError: + bfloat16 = None from gt4py import eve from gt4py._core import definitions as core_defs @@ -100,6 +104,8 @@ | np.float64 | np.bool_ ) +if bfloat16: + Scalar = Scalar | bfloat16 class SparseTag(Tag): ... @@ -552,6 +558,8 @@ def promote_scalars(val: CompositeOfScalarOrField): impl: Callable if math_builtin_name == "gamma": continue # treated explicitly + elif math_builtin_name == "bfloat16" and bfloat16: + impl = bfloat16 # treated explicitly elif math_builtin_name in python_builtins: # TODO: Should potentially use numpy fixed size types to be consistent # with compiled backends. Currently using Python types to preserve diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 7c0ca8751e..0793a88fe9 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -15,7 +15,10 @@ from gt4py.eve.utils import noninstantiable from gt4py.next import common from gt4py.next.type_system import type_specifications as ts - +try: + from ml_dtypes import bfloat16 +except ModuleNotFoundError: + bfloat16 = None DimensionKind = common.DimensionKind @@ -172,6 +175,8 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib "uint64", } # Todo: should we distinguish int and uint? FLOATING_POINT_BUILTINS = {"float16", "float32", "float64"} +if bfloat16: + FLOATING_POINT_BUILTINS.add("bfloat16") TYPEBUILTINS = {*INTEGER_BUILTINS, *FLOATING_POINT_BUILTINS, "bool"} BUILTINS = { diff --git a/src/gt4py/next/otf/binding/cpp_interface.py b/src/gt4py/next/otf/binding/cpp_interface.py index fb21dfc93c..2fbd49839e 100644 --- a/src/gt4py/next/otf/binding/cpp_interface.py +++ b/src/gt4py/next/otf/binding/cpp_interface.py @@ -40,6 +40,8 @@ def render_scalar_type(scalar_type: ts.ScalarType) -> str: return "std::uint64_t" case ts.ScalarKind.FLOAT16: return "std::float16_t" + case ts.ScalarKind.BFLOAT16: + return "std::bfloat16_t" case ts.ScalarKind.FLOAT32: return "float" case ts.ScalarKind.FLOAT64: diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index f8330cb64c..43e0617f68 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -7,6 +7,10 @@ # SPDX-License-Identifier: BSD-3-Clause from typing import Any, Collection, Final, Union +try: + import ml_dtypes +except ModuleNotFoundError: + bfloat16 = None from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako @@ -78,6 +82,8 @@ class GTFNCodegen(codegen.TemplatedGenerator): "mod": "std::modulus{}", "not_": "std::logical_not{}", } + if ml_dtypes: + _builtins_mapping["bfloat16"] = "std::bfloat16_t" Sym = as_fmt("{id}") @@ -107,6 +113,8 @@ def visit_Literal(self, node: gtfn_ir.Literal, **kwargs: Any) -> str: return self.asfloat(node.value) case "std::float16_t": return "std::float16_t("+self.asfloat(node.value)+")" # TODO: this is not a proper solution + case "std::bfloat16_t": + return "std::bfloat16_t("+self.asfloat(node.value)+")" # TODO: this is not a proper solution case "bool": return node.value.lower() case _: diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 79c75a4220..70d291c394 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -53,6 +53,7 @@ def pytype_to_cpptype(t: ts.ScalarType | str) -> Optional[str]: try: return { "float16": "std::float16_t", + "bfloat16": "std::bfloat16_t", "float32": "float", "float64": "double", "int8": "std::int8_t", diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 1d408f2287..bc731ce7b8 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -40,7 +40,7 @@ _TYPE_MAPPING = { "float": dace.float64, - "float16": dace.float16, + "float16": dace.float16, # TODO: bfloat16? "float32": dace.float32, "float64": dace.float64, "int": dace.int32 if np.dtype(int).itemsize == 4 else dace.int64, diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index c7b2bb4d7d..f36e939fbd 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -23,6 +23,10 @@ ) import numpy as np +try: + from ml_dtypes import bfloat16, finfo +except ModuleNotFoundError: + bfloat16 = None from gt4py.eve.utils import XIterable, xiter from gt4py.next import common @@ -236,6 +240,7 @@ def is_floating_point(symbol_type: ts.TypeSpec) -> bool: """ return extract_dtype(symbol_type).kind in [ ts.ScalarKind.FLOAT16, + ts.ScalarKind.BFLOAT16, ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64, ] @@ -327,7 +332,7 @@ def is_arithmetic(symbol_type: ts.TypeSpec) -> bool: def arithmetic_bounds(arithmetic_type: ts.ScalarType) -> tuple[np.number, np.number]: assert is_arithmetic(arithmetic_type) - return { # type: ignore[return-value] # why resolved to `tuple[object, object]`? + bounds = { # type: ignore[return-value] # why resolved to `tuple[object, object]`? ts.ScalarKind.FLOAT16: (np.finfo(np.float16).min, np.finfo(np.float16).max), # todo: cleanup? ts.ScalarKind.FLOAT32: (np.finfo(np.float32).min, np.finfo(np.float32).max), ts.ScalarKind.FLOAT64: (np.finfo(np.float64).min, np.finfo(np.float64).max), @@ -339,7 +344,10 @@ def arithmetic_bounds(arithmetic_type: ts.ScalarType) -> tuple[np.number, np.num ts.ScalarKind.UINT32: (np.iinfo(np.uint32).min, np.iinfo(np.uint32).max), ts.ScalarKind.INT64: (np.iinfo(np.int64).min, np.iinfo(np.int64).max), ts.ScalarKind.UINT64: (np.iinfo(np.uint64).min, np.iinfo(np.uint64).max), - }[arithmetic_type.kind] + } + if bfloat16: + bounds[ts.ScalarKind.BFLOAT16] = (finfo(bfloat16).min, finfo(bfloat16).max) + return bounds[arithmetic_type.kind] def is_type_or_tuple_of_type(type_: ts.TypeSpec, expected_type: type | tuple) -> bool: diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 14beefeb7b..b1f5bfceea 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -81,9 +81,10 @@ class ScalarKind(IntEnum): INT64 = 8 UINT64 = 9 FLOAT16 = 10 - FLOAT32 = 11 - FLOAT64 = 12 - STRING = 13 + BFLOAT16 = 11 + FLOAT32 = 12 + FLOAT64 = 13 + STRING = 14 @dataclass(frozen=True) diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index d8824dd93c..ce51c6633c 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -18,6 +18,10 @@ import numpy as np import numpy.typing as npt +try: + from ml_dtypes import bfloat16 +except ModuleNotFoundError: + bfloat16 = None from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping @@ -46,6 +50,8 @@ def get_scalar_kind(dtype: npt.DTypeLike) -> ts.ScalarKind: return ts.ScalarKind.STRING case np.dtype(): return getattr(ts.ScalarKind, dt.name.upper()) + case _ if bfloat16 and dt == bfloat16: + return ts.ScalarKind.BFLOAT16 case _: raise ValueError(f"Impossible to map '{dtype}' value to a 'ScalarKind'.") else: @@ -241,6 +247,8 @@ def as_dtype(type_: ts.ScalarType) -> core_defs.DType: return core_defs.Int64DType() elif type_.kind == ts.ScalarKind.FLOAT16: # TODO return core_defs.Float16DType() + elif type_.kind == ts.ScalarKind.BFLOAT16: # TODO + return core_defs.BFloat16DType() elif type_.kind == ts.ScalarKind.FLOAT32: return core_defs.Float32DType() elif type_.kind == ts.ScalarKind.FLOAT64: @@ -263,6 +271,8 @@ def from_dtype(dtype: core_defs.DType) -> ts.ScalarType: return ts.ScalarType(kind=ts.ScalarKind.INT64) elif dtype == core_defs.Float16DType(): #TODO return ts.ScalarType(kind=ts.ScalarKind.FLOAT16) + elif dtype == core_defs.BFloat16DType(): #TODO + return ts.ScalarType(kind=ts.ScalarKind.BFLOAT16) elif dtype == core_defs.Float32DType(): return ts.ScalarType(kind=ts.ScalarKind.FLOAT32) elif dtype == core_defs.Float64DType(): diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 0009b9887a..a6fd6c4d12 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -17,6 +17,10 @@ import numpy as np import pytest +try: + from ml_dtypes import bfloat16 +except ModuleNotFoundError: + bfloat16 = None import gt4py.next as gtx from gt4py._core import definitions as core_defs @@ -62,6 +66,8 @@ IField: TypeAlias = gtx.Field[[IDim], np.int32] # type: ignore [valid-type] IFloatField: TypeAlias = gtx.Field[[IDim], np.float64] # type: ignore [valid-type] IHalfField: TypeAlias = gtx.Field[[IDim], np.float16] # type: ignore [valid-type] +if bfloat16: + IBFloatField: TypeAlias = gtx.Field[[IDim], bfloat16] # type: ignore [valid-type] IBoolField: TypeAlias = gtx.Field[[IDim], bool] # type: ignore [valid-type] KField: TypeAlias = gtx.Field[[KDim], np.int32] # type: ignore [valid-type] IJField: TypeAlias = gtx.Field[[IDim, JDim], np.int32] # type: ignore [valid-type] @@ -198,7 +204,7 @@ class UniqueInitializer(DataInitializer): def scalar_value(self) -> ScalarValue: start = self.start self.start += 1 - return np.int64(start) + return start def field( self, @@ -423,6 +429,11 @@ def verify( assert out_comp is not None out_comp_ndarray = field_utils.asnumpy(out_comp) ref_ndarray = field_utils.asnumpy(ref) + + if bfloat16 and out_comp_ndarray.dtype == bfloat16: + out_comp_ndarray = out_comp_ndarray.astype(np.float32) + ref_ndarray = ref_ndarray.astype(np.float32) + assert comparison(ref_ndarray, out_comp_ndarray), ( f"Verification failed:\n" f"\tcomparison={comparison.__name__}(ref, out)\n" diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index ec6a4fd4eb..043df446f1 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -10,6 +10,10 @@ import numpy as np import pytest +try: + from ml_dtypes import bfloat16 +except ModuleNotFoundError: + bfloat16 = None import gt4py.next as gtx from gt4py.next import ( @@ -1205,13 +1209,24 @@ def consume_constants(input: cases.IFloatField) -> cases.IFloatField: @pytest.mark.uses_half_precision -def test_half_precision(cartesian_case): +def test_flaot16(cartesian_case): dtype = np.float16 @gtx.field_operator def multiply_by_two(input: cases.IHalfField, input2: cases.IFloatField, scalar: np.float16) -> cases.IHalfField: - return dtype(2) * input * astype(input2, dtype) * scalar ** 2.0 - + return dtype(2) * input * astype(input2, dtype) * scalar ** dtype(1.0) #TODO fails with 0.5 cases.verify_with_default_data( - cartesian_case, multiply_by_two, ref=lambda input, input2, scalar: dtype(2) * input * input2 * scalar ** 2.0 + cartesian_case, multiply_by_two, ref=lambda input, input2, scalar: dtype(2) * input * input2 * scalar ** dtype(1.0) ) + + +@pytest.mark.uses_half_precision +def test_bfloat16(cartesian_case): + dtype = bfloat16 + + @gtx.field_operator + def multiply_by_two_(input: cases.IBFloatField) -> cases.IBFloatField: + return input #* astype(input2, dtype) * scalar ** dtype(1.0) #TODO + cases.verify_with_default_data( + cartesian_case, multiply_by_two_, ref=lambda input: input #* input2 * scalar ** dtype(1.0) #TODO + ) \ No newline at end of file From 2a5ea800cb5c5b2fbd31e181f348d20c9b0fb2d1 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 7 Nov 2024 15:11:27 +0100 Subject: [PATCH 04/17] BFloat16 working --- src/gt4py/next/ffront/foast_to_itir.py | 5 +++-- src/gt4py/next/program_processors/runners/gtfn.py | 8 ++++++++ tests/next_tests/integration_tests/cases.py | 2 +- .../feature_tests/ffront_tests/test_execution.py | 8 ++++---- 4 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index dc49f3babe..c163762dd7 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -465,8 +465,9 @@ def _visit_type_constr(self, node: foast.Call, **kwargs: Any) -> itir.Expr: raise FieldOperatorLoweringError( f"Type cast only supports literal arguments, {node.type} not supported." ) - # if bfloat16 and node_kind == 'bfloat16': - # val = float(val) + # TODO: why? + if bfloat16 and node_kind == 'bfloat16': + val = float(val) val = target_type(val) return im.promote_to_const_iterator(im.literal(str(val), node_kind)) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 2275576081..1dcbba85c2 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -32,6 +32,14 @@ def convert_arg(arg: Any) -> Any: if isinstance(arg, common.Field): arr = arg.ndarray origin = getattr(arg, "__gt_origin__", tuple([0] * len(arg.domain))) + + # TODO: bloody hack just to get a dlpack compatible array of bfloat16s + from ml_dtypes import bfloat16 + import jax + import jax.numpy as jnp + if arr.dtype == bfloat16: + arr = jnp.asarray(arr, dtype=jnp.bfloat16) + return arr, origin else: return arg diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index a6fd6c4d12..61438b6192 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -198,7 +198,7 @@ class UniqueInitializer(DataInitializer): data containers. """ - start: int = 0 + start: int = 1 # PR comment: do not start from zero as this has the same value as zero-initialized memory @property def scalar_value(self) -> ScalarValue: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 043df446f1..11f78aed4b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -1209,7 +1209,7 @@ def consume_constants(input: cases.IFloatField) -> cases.IFloatField: @pytest.mark.uses_half_precision -def test_flaot16(cartesian_case): +def test_float16(cartesian_case): dtype = np.float16 @gtx.field_operator @@ -1225,8 +1225,8 @@ def test_bfloat16(cartesian_case): dtype = bfloat16 @gtx.field_operator - def multiply_by_two_(input: cases.IBFloatField) -> cases.IBFloatField: - return input #* astype(input2, dtype) * scalar ** dtype(1.0) #TODO + def multiply_by_two(input: cases.IBFloatField, input2: cases.IBFloatField, scalar: bfloat16) -> cases.IBFloatField: + return dtype(2) * input * astype(input2, dtype) * scalar ** dtype(1.0) #TODO fails with 0.5 cases.verify_with_default_data( - cartesian_case, multiply_by_two_, ref=lambda input: input #* input2 * scalar ** dtype(1.0) #TODO + cartesian_case, multiply_by_two, ref=lambda input, input2, scalar: dtype(2) * input * input2 * scalar ** dtype(1.0) ) \ No newline at end of file From b95921c9ecdef18e1a127eff0bcd9eace3477f2e Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 8 Nov 2024 14:09:32 +0100 Subject: [PATCH 05/17] Cleanup ml_dtypes import --- pyproject.toml | 2 +- src/gt4py/_core/definitions.py | 18 +++++++++--------- src/gt4py/next/ffront/fbuiltins.py | 7 ++++--- src/gt4py/next/ffront/foast_to_itir.py | 6 +++--- src/gt4py/next/iterator/builtins.py | 8 ++++---- src/gt4py/next/iterator/embedded.py | 12 ++++++------ src/gt4py/next/iterator/ir.py | 6 +++--- .../codegens/gtfn/codegen.py | 2 +- .../next/program_processors/runners/gtfn.py | 13 ++++++++----- src/gt4py/next/type_system/type_info.py | 8 ++++---- src/gt4py/next/type_system/type_translation.py | 6 +++--- tests/next_tests/integration_tests/cases.py | 10 +++++----- 12 files changed, 51 insertions(+), 47 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9e84ba8dcb..3cd2f8cc3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,7 +84,7 @@ jax-cuda12 = ['jax[cuda12_pip]>=0.4.18; python_version>="3.10"'] performance = ['scipy>=1.9.2'] rocm-43 = ['cupy-rocm-4-3'] testing = ['hypothesis>=6.0.0', 'pytest>=7.0'] -half-precision = ['ml_dtypes'] +half-precision = ['ml_dtypes', 'jax'] [project.scripts] gtpyc = 'gt4py.cartesian.cli:gtpyc' diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 97a2e9c287..a68b850499 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -15,9 +15,9 @@ import math import numbers try: - from ml_dtypes import bfloat16 + import ml_dtypes except ModuleNotFoundError: - bfloat16 = None + ml_dtypes = None import numpy as np import numpy.typing as npt @@ -67,8 +67,8 @@ uint64 = np.uint64 float16 = np.float16 -if bfloat16: - bfloat16 = bfloat16 +if ml_dtypes: + bfloat16 = ml_dtypes.bfloat16 float32 = np.float32 float64 = np.float64 @@ -100,8 +100,8 @@ IntegralT = TypeVar("IntegralT", bound=IntegralScalar) INTEGRAL_TYPES: Final[Tuple[type, ...]] = (*INT_TYPES, *UINT_TYPES) -if bfloat16: - FloatingScalar: TypeAlias = Union[float16, bfloat16, float32, float64, float] +if ml_dtypes: + FloatingScalar: TypeAlias = Union[float16, ml_dtypes.bfloat16, float32, float64, float] else: FloatingScalar: TypeAlias = Union[float16, float32, float64, float] FloatingT = TypeVar("FloatingT", bound=FloatingScalar) @@ -327,10 +327,10 @@ class FloatingDType(DType[FloatingT]): class Float16DType(FloatingDType[float16]): scalar_type: Final[Type[float16]] = dataclasses.field(default=float16, init=False) -if bfloat16: +if ml_dtypes: @dataclasses.dataclass(frozen=True) # TODO - class BFloat16DType(FloatingDType[bfloat16]): - scalar_type: Final[Type[bfloat16]] = dataclasses.field(default=bfloat16, init=False) + class BFloat16DType(FloatingDType[ml_dtypes.bfloat16]): + scalar_type: Final[Type[ml_dtypes.bfloat16]] = dataclasses.field(default=ml_dtypes.bfloat16, init=False) @dataclasses.dataclass(frozen=True) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 68cfab88cd..84b5804ad3 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -28,9 +28,9 @@ uint64, ) try: - from ml_dtypes import bfloat16 + import ml_dtypes except ModuleNotFoundError: - bfloat16 = None + ml_dtypes = None import gt4py.next as gtx from gt4py._core import definitions as core_defs @@ -59,7 +59,8 @@ float64, *PYTHON_TYPE_BUILTINS, ] # TODO(tehrengruber): validate matches itir type builtins? -if bfloat16: +if ml_dtypes: + from ml_dtypes import bfloat16 TYPE_BUILTINS.append(bfloat16) TYPE_BUILTIN_NAMES = [t.__name__ for t in TYPE_BUILTINS] diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index c163762dd7..5751fb2ea2 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -32,9 +32,9 @@ from gt4py.next.otf import toolchain, workflow from gt4py.next.type_system import type_info, type_specifications as ts try: - from ml_dtypes import bfloat16 + import ml_dtypes except ModuleNotFoundError: - bfloat16 = None + ml_dtypes = None def foast_to_itir(inp: FOP) -> itir.Expr: @@ -466,7 +466,7 @@ def _visit_type_constr(self, node: foast.Call, **kwargs: Any) -> itir.Expr: f"Type cast only supports literal arguments, {node.type} not supported." ) # TODO: why? - if bfloat16 and node_kind == 'bfloat16': + if ml_dtypes and node_kind == 'bfloat16': val = float(val) val = target_type(val) diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index c5b62b667e..941f99fd6c 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -8,9 +8,9 @@ from gt4py.next.iterator.dispatcher import Dispatcher try: - from ml_dtypes import bfloat16 + import ml_dtypes except ModuleNotFoundError: - bfloat16 = None + ml_dtypes = None builtin_dispatch = Dispatcher() @@ -380,7 +380,7 @@ def float(*args): # noqa: A001 [builtin-variable-shadowing] raise BackendNotSelectedError() -if bfloat16: +if ml_dtypes: @builtin_dispatch def bfloat16(*args): raise BackendNotSelectedError() @@ -444,7 +444,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] "float64", "bool", } # TODO(tehrengruber): This list already exists in ir.py; unify. -if bfloat16: +if ml_dtypes: TYPEBUILTINS.add("bfloat16") MATH_BUILTINS = ( diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 4b36a968f2..28311ffbea 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -21,9 +21,9 @@ import numpy as np import numpy.typing as npt try: - from ml_dtypes import bfloat16 + import ml_dtypes except ModuleNotFoundError: - bfloat16 = None + ml_dtypes = None from gt4py import eve from gt4py._core import definitions as core_defs @@ -104,8 +104,8 @@ | np.float64 | np.bool_ ) -if bfloat16: - Scalar = Scalar | bfloat16 +if ml_dtypes: + Scalar = Scalar | ml_dtypes.bfloat16 class SparseTag(Tag): ... @@ -558,8 +558,8 @@ def promote_scalars(val: CompositeOfScalarOrField): impl: Callable if math_builtin_name == "gamma": continue # treated explicitly - elif math_builtin_name == "bfloat16" and bfloat16: - impl = bfloat16 # treated explicitly + elif math_builtin_name == "bfloat16" and ml_dtypes: + impl = ml_dtypes.bfloat16 # treated explicitly elif math_builtin_name in python_builtins: # TODO: Should potentially use numpy fixed size types to be consistent # with compiled backends. Currently using Python types to preserve diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 0793a88fe9..c110057a18 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -16,9 +16,9 @@ from gt4py.next import common from gt4py.next.type_system import type_specifications as ts try: - from ml_dtypes import bfloat16 + import ml_dtypes except ModuleNotFoundError: - bfloat16 = None + ml_dtypes = None DimensionKind = common.DimensionKind @@ -175,7 +175,7 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib "uint64", } # Todo: should we distinguish int and uint? FLOATING_POINT_BUILTINS = {"float16", "float32", "float64"} -if bfloat16: +if ml_dtypes.bfloat16: FLOATING_POINT_BUILTINS.add("bfloat16") TYPEBUILTINS = {*INTEGER_BUILTINS, *FLOATING_POINT_BUILTINS, "bool"} diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index 43e0617f68..4601970966 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -10,7 +10,7 @@ try: import ml_dtypes except ModuleNotFoundError: - bfloat16 = None + ml_dtypes = None from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 1dcbba85c2..213fa5a9c5 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -34,11 +34,14 @@ def convert_arg(arg: Any) -> Any: origin = getattr(arg, "__gt_origin__", tuple([0] * len(arg.domain))) # TODO: bloody hack just to get a dlpack compatible array of bfloat16s - from ml_dtypes import bfloat16 - import jax - import jax.numpy as jnp - if arr.dtype == bfloat16: - arr = jnp.asarray(arr, dtype=jnp.bfloat16) + try: + from ml_dtypes import bfloat16 + import jax.numpy as jnp + if arr.dtype == bfloat16: + arr = jnp.asarray(arr, dtype=jnp.bfloat16) + except: + raise ValueError("ml_dtypes and jax must to be installed.") + return arr, origin else: diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index f36e939fbd..52e7a169e4 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -24,9 +24,9 @@ import numpy as np try: - from ml_dtypes import bfloat16, finfo + import ml_dtypes except ModuleNotFoundError: - bfloat16 = None + ml_dtypes = None from gt4py.eve.utils import XIterable, xiter from gt4py.next import common @@ -345,8 +345,8 @@ def arithmetic_bounds(arithmetic_type: ts.ScalarType) -> tuple[np.number, np.num ts.ScalarKind.INT64: (np.iinfo(np.int64).min, np.iinfo(np.int64).max), ts.ScalarKind.UINT64: (np.iinfo(np.uint64).min, np.iinfo(np.uint64).max), } - if bfloat16: - bounds[ts.ScalarKind.BFLOAT16] = (finfo(bfloat16).min, finfo(bfloat16).max) + if ml_dtypes: + bounds[ts.ScalarKind.BFLOAT16] = (ml_dtypes.finfo(ml_dtypes.bfloat16).min, ml_dtypes.finfo(ml_dtypes.bfloat16).max) return bounds[arithmetic_type.kind] diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index ce51c6633c..d20c2c86f9 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -19,9 +19,9 @@ import numpy as np import numpy.typing as npt try: - from ml_dtypes import bfloat16 + import ml_dtypes except ModuleNotFoundError: - bfloat16 = None + ml_dtypes = None from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping @@ -50,7 +50,7 @@ def get_scalar_kind(dtype: npt.DTypeLike) -> ts.ScalarKind: return ts.ScalarKind.STRING case np.dtype(): return getattr(ts.ScalarKind, dt.name.upper()) - case _ if bfloat16 and dt == bfloat16: + case _ if ml_dtypes and dt == ml_dtypes.bfloat16: return ts.ScalarKind.BFLOAT16 case _: raise ValueError(f"Impossible to map '{dtype}' value to a 'ScalarKind'.") diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 61438b6192..1c40f88a57 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -18,9 +18,9 @@ import numpy as np import pytest try: - from ml_dtypes import bfloat16 + import ml_dtypes except ModuleNotFoundError: - bfloat16 = None + ml_dtypes = None import gt4py.next as gtx from gt4py._core import definitions as core_defs @@ -66,8 +66,8 @@ IField: TypeAlias = gtx.Field[[IDim], np.int32] # type: ignore [valid-type] IFloatField: TypeAlias = gtx.Field[[IDim], np.float64] # type: ignore [valid-type] IHalfField: TypeAlias = gtx.Field[[IDim], np.float16] # type: ignore [valid-type] -if bfloat16: - IBFloatField: TypeAlias = gtx.Field[[IDim], bfloat16] # type: ignore [valid-type] +if ml_dtypes: + IBFloatField: TypeAlias = gtx.Field[[IDim], ml_dtypes.bfloat16] # type: ignore [valid-type] IBoolField: TypeAlias = gtx.Field[[IDim], bool] # type: ignore [valid-type] KField: TypeAlias = gtx.Field[[KDim], np.int32] # type: ignore [valid-type] IJField: TypeAlias = gtx.Field[[IDim, JDim], np.int32] # type: ignore [valid-type] @@ -430,7 +430,7 @@ def verify( out_comp_ndarray = field_utils.asnumpy(out_comp) ref_ndarray = field_utils.asnumpy(ref) - if bfloat16 and out_comp_ndarray.dtype == bfloat16: + if ml_dtypes and out_comp_ndarray.dtype == ml_dtypes.bfloat16: out_comp_ndarray = out_comp_ndarray.astype(np.float32) ref_ndarray = ref_ndarray.astype(np.float32) From c654ca211a6e2df08d8c9fd1d1ce211cbde049cd Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Mon, 30 Dec 2024 11:38:52 +0100 Subject: [PATCH 06/17] Undo float16 and bfloat16 changes --- pyproject.toml | 1 - src/gt4py/_core/definitions.py | 24 ++-------------- src/gt4py/next/ffront/fbuiltins.py | 9 ------ src/gt4py/next/iterator/builtins.py | 17 ----------- src/gt4py/next/iterator/embedded.py | 9 ------ src/gt4py/next/iterator/ir.py | 8 +----- src/gt4py/next/otf/binding/cpp_interface.py | 4 --- .../compilation/build_systems/cmake_lists.py | 3 +- .../codegens/gtfn/codegen.py | 12 -------- .../codegens/gtfn/itir_to_gtfn_ir.py | 2 -- .../next/program_processors/runners/gtfn.py | 11 -------- src/gt4py/next/type_system/type_info.py | 9 ------ .../next/type_system/type_specifications.py | 2 -- .../next/type_system/type_translation.py | 14 ---------- tests/next_tests/definitions.py | 1 - tests/next_tests/integration_tests/cases.py | 11 -------- .../ffront_tests/test_execution.py | 28 ------------------- 17 files changed, 5 insertions(+), 160 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 101ef15176..d086363ec4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,7 +86,6 @@ jax-cuda12 = ['jax[cuda12_pip]>=0.4.18; python_version>="3.10"'] performance = ['scipy>=1.9.2'] rocm-43 = ['cupy-rocm-4-3'] testing = ['hypothesis>=6.0.0', 'pytest>=7.0'] -half-precision = ['ml_dtypes', 'jax'] [project.scripts] gtpyc = 'gt4py.cartesian.cli:gtpyc' diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index b6e3ff4689..5c7635272a 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -14,10 +14,6 @@ import functools import math import numbers -try: - import ml_dtypes -except ModuleNotFoundError: - ml_dtypes = None import numpy as np import numpy.typing as npt @@ -66,9 +62,6 @@ uint32 = np.uint32 uint64 = np.uint64 -float16 = np.float16 -if ml_dtypes: - bfloat16 = ml_dtypes.bfloat16 float32 = np.float32 float64 = np.float64 @@ -100,10 +93,8 @@ IntegralT = TypeVar("IntegralT", bound=IntegralScalar) INTEGRAL_TYPES: Final[Tuple[type, ...]] = (*INT_TYPES, *UINT_TYPES) -if ml_dtypes: - FloatingScalar: TypeAlias = Union[float16, ml_dtypes.bfloat16, float32, float64, float] -else: - FloatingScalar: TypeAlias = Union[float16, float32, float64, float] + +FloatingScalar: TypeAlias = Union[float32, float64, float] FloatingT = TypeVar("FloatingT", bound=FloatingScalar) FLOAT_TYPES: Final[Tuple[type, ...]] = cast( Tuple[type, ...], @@ -153,7 +144,7 @@ def is_valid_tensor_shape(value: Sequence[IntegralScalar]) -> TypeGuard[TensorSh # -- Data type descriptors -- -class DTypeKind(eve.StrEnum): +class DTypeKind(eve.StrEnum):\ """ Kind of a specific data type. @@ -323,15 +314,6 @@ class Int64DType(SignedIntDType[int64]): class FloatingDType(DType[FloatingT]): pass -@dataclasses.dataclass(frozen=True) # TODO -class Float16DType(FloatingDType[float16]): - scalar_type: Final[Type[float16]] = dataclasses.field(default=float16, init=False) - -if ml_dtypes: - @dataclasses.dataclass(frozen=True) # TODO - class BFloat16DType(FloatingDType[ml_dtypes.bfloat16]): - scalar_type: Final[Type[ml_dtypes.bfloat16]] = dataclasses.field(default=ml_dtypes.bfloat16, init=False) - @dataclasses.dataclass(frozen=True) class Float32DType(FloatingDType[float32]): diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index f2a967794b..5f7944ec83 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -15,7 +15,6 @@ import numpy as np from numpy import ( - float16, float32, float64, int8, @@ -27,10 +26,6 @@ uint32, uint64, ) -try: - import ml_dtypes -except ModuleNotFoundError: - ml_dtypes = None from gt4py._core import definitions as core_defs from gt4py.next import common @@ -53,14 +48,10 @@ uint32, int64, uint64, - float16, float32, float64, *PYTHON_TYPE_BUILTINS, ] # TODO(tehrengruber): validate matches itir type builtins? -if ml_dtypes: - from ml_dtypes import bfloat16 - TYPE_BUILTINS.append(bfloat16) TYPE_BUILTIN_NAMES = [t.__name__ for t in TYPE_BUILTINS] diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 3e9ca8b542..263549c49c 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -7,10 +7,6 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py.next.iterator.dispatcher import Dispatcher -try: - import ml_dtypes -except ModuleNotFoundError: - ml_dtypes = None builtin_dispatch = Dispatcher() @@ -385,16 +381,6 @@ def float(*args): # noqa: A001 [builtin-variable-shadowing] raise BackendNotSelectedError() -if ml_dtypes: - @builtin_dispatch - def bfloat16(*args): - raise BackendNotSelectedError() - -@builtin_dispatch -def float16(*args): - raise BackendNotSelectedError() - - @builtin_dispatch def float32(*args): raise BackendNotSelectedError() @@ -444,13 +430,10 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] "uint32", "int64", "uint64", - "float16", "float32", "float64", "bool", } # TODO(tehrengruber): This list already exists in ir.py; unify. -if ml_dtypes: - TYPEBUILTINS.add("bfloat16") MATH_BUILTINS = ( UNARY_MATH_NUMBER_BUILTINS diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index f2954379ac..f6c0783f74 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -20,10 +20,6 @@ import numpy as np import numpy.typing as npt -try: - import ml_dtypes -except ModuleNotFoundError: - ml_dtypes = None from gt4py import eve from gt4py._core import definitions as core_defs @@ -100,13 +96,10 @@ | np.uint32 | np.int64 | np.uint64 - | np.float16 | np.float32 | np.float64 | np.bool_ ) -if ml_dtypes: - Scalar = Scalar | ml_dtypes.bfloat16 class SparseTag(Tag): ... @@ -611,8 +604,6 @@ def promote_scalars(val: CompositeOfScalarOrField): impl: Callable if math_builtin_name == "gamma": continue # treated explicitly - elif math_builtin_name == "bfloat16" and ml_dtypes: - impl = ml_dtypes.bfloat16 # treated explicitly elif math_builtin_name in python_builtins: # TODO: Should potentially use numpy fixed size types to be consistent # with compiled backends. Currently using Python types to preserve diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 9d9f23cbd0..643090aaa7 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -15,10 +15,6 @@ from gt4py.eve.utils import noninstantiable from gt4py.next import common from gt4py.next.type_system import type_specifications as ts -try: - import ml_dtypes -except ModuleNotFoundError: - ml_dtypes = None DimensionKind = common.DimensionKind @@ -158,9 +154,7 @@ class FunctionDefinition(Node, SymbolTableTrait): "int64", "uint64", } # Todo: should we distinguish int and uint? -FLOATING_POINT_BUILTINS = {"float16", "float32", "float64"} -if ml_dtypes.bfloat16: - FLOATING_POINT_BUILTINS.add("bfloat16") +FLOATING_POINT_BUILTINS = {"float32", "float64"} TYPEBUILTINS = {*INTEGER_BUILTINS, *FLOATING_POINT_BUILTINS, "bool"} BUILTINS = { diff --git a/src/gt4py/next/otf/binding/cpp_interface.py b/src/gt4py/next/otf/binding/cpp_interface.py index 2fbd49839e..14ab42fa83 100644 --- a/src/gt4py/next/otf/binding/cpp_interface.py +++ b/src/gt4py/next/otf/binding/cpp_interface.py @@ -38,10 +38,6 @@ def render_scalar_type(scalar_type: ts.ScalarType) -> str: return "std::int64_t" case ts.ScalarKind.UINT64: return "std::uint64_t" - case ts.ScalarKind.FLOAT16: - return "std::float16_t" - case ts.ScalarKind.BFLOAT16: - return "std::bfloat16_t" case ts.ScalarKind.FLOAT32: return "float" case ts.ScalarKind.FLOAT64: diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py index 9701b6eb61..0533adac81 100644 --- a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py +++ b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py @@ -69,8 +69,7 @@ class CMakeListsGenerator(eve.codegen.TemplatedGenerator): # Targets add_library({{project_name}} MODULE) - #target_compile_features({{project_name}} PRIVATE cxx_std_17) - target_compile_features({{project_name}} PRIVATE cxx_std_23) + target_compile_features({{project_name}} PRIVATE cxx_std_17) set_target_properties({{project_name}} PROPERTIES PREFIX "" SUFFIX ".{{bin_output_suffix}}") target_sources({{project_name}} diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index 8d9244ab89..6b0ead5239 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -7,10 +7,6 @@ # SPDX-License-Identifier: BSD-3-Clause from typing import Any, Collection, Final, Union -try: - import ml_dtypes -except ModuleNotFoundError: - ml_dtypes = None from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako @@ -54,7 +50,6 @@ class GTFNCodegen(codegen.TemplatedGenerator): "maximum": "std::max", "fmod": "std::fmod", "power": "std::pow", - "float16": "std::float16_t", "float32": "float", "float64": "double", "int8": "std::int8_t", @@ -82,8 +77,6 @@ class GTFNCodegen(codegen.TemplatedGenerator): "mod": "std::modulus{}", "not_": "std::logical_not{}", } - if ml_dtypes: - _builtins_mapping["bfloat16"] = "std::bfloat16_t" Sym = as_fmt("{id}") @@ -111,10 +104,6 @@ def visit_Literal(self, node: gtfn_ir.Literal, **kwargs: Any) -> str: return self.asfloat(node.value) + "f" case "double": return self.asfloat(node.value) - case "std::float16_t": - return "std::float16_t("+self.asfloat(node.value)+")" # TODO: this is not a proper solution - case "std::bfloat16_t": - return "std::bfloat16_t("+self.asfloat(node.value)+")" # TODO: this is not a proper solution case "bool": return node.value.lower() case _: @@ -274,7 +263,6 @@ def visit_Program(self, node: gtfn_ir.Program, **kwargs: Any) -> Union[str, Coll """ #include #include - #include #include #include #include diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 94eb7a6cbd..ba90c08261 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -52,8 +52,6 @@ def pytype_to_cpptype(t: ts.ScalarType | str) -> Optional[str]: t = t.kind.name.lower() try: return { - "float16": "std::float16_t", - "bfloat16": "std::bfloat16_t", "float32": "float", "float64": "double", "int8": "std::int8_t", diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 641b531cae..c0a9be9168 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -36,17 +36,6 @@ def convert_arg(arg: Any) -> Any: if isinstance(arg, common.Field): arr = arg.ndarray origin = getattr(arg, "__gt_origin__", tuple([0] * len(arg.domain))) - - # TODO: bloody hack just to get a dlpack compatible array of bfloat16s - try: - from ml_dtypes import bfloat16 - import jax.numpy as jnp - if arr.dtype == bfloat16: - arr = jnp.asarray(arr, dtype=jnp.bfloat16) - except: - raise ValueError("ml_dtypes and jax must to be installed.") - - return arr, origin else: return arg diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index d028d7d8b0..9a02e925d2 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -23,10 +23,6 @@ ) import numpy as np -try: - import ml_dtypes -except ModuleNotFoundError: - ml_dtypes = None from gt4py.eve.utils import XIterable, xiter from gt4py.next import common @@ -239,8 +235,6 @@ def is_floating_point(symbol_type: ts.TypeSpec) -> bool: True """ return extract_dtype(symbol_type).kind in [ - ts.ScalarKind.FLOAT16, - ts.ScalarKind.BFLOAT16, ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64, ] @@ -333,7 +327,6 @@ def is_arithmetic(symbol_type: ts.TypeSpec) -> bool: def arithmetic_bounds(arithmetic_type: ts.ScalarType) -> tuple[np.number, np.number]: assert is_arithmetic(arithmetic_type) bounds = { # type: ignore[return-value] # why resolved to `tuple[object, object]`? - ts.ScalarKind.FLOAT16: (np.finfo(np.float16).min, np.finfo(np.float16).max), # todo: cleanup? ts.ScalarKind.FLOAT32: (np.finfo(np.float32).min, np.finfo(np.float32).max), ts.ScalarKind.FLOAT64: (np.finfo(np.float64).min, np.finfo(np.float64).max), ts.ScalarKind.INT8: (np.iinfo(np.int8).min, np.iinfo(np.int8).max), @@ -345,8 +338,6 @@ def arithmetic_bounds(arithmetic_type: ts.ScalarType) -> tuple[np.number, np.num ts.ScalarKind.INT64: (np.iinfo(np.int64).min, np.iinfo(np.int64).max), ts.ScalarKind.UINT64: (np.iinfo(np.uint64).min, np.iinfo(np.uint64).max), } - if ml_dtypes: - bounds[ts.ScalarKind.BFLOAT16] = (ml_dtypes.finfo(ml_dtypes.bfloat16).min, ml_dtypes.finfo(ml_dtypes.bfloat16).max) return bounds[arithmetic_type.kind] diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index b6e80cd9ee..0878083a1e 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -81,8 +81,6 @@ class ScalarKind(IntEnum): UINT32 = 7 INT64 = 8 UINT64 = 9 - FLOAT16 = 10 - BFLOAT16 = 11 FLOAT32 = 12 FLOAT64 = 13 STRING = 14 diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index d20c2c86f9..89744ad059 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -18,10 +18,6 @@ import numpy as np import numpy.typing as npt -try: - import ml_dtypes -except ModuleNotFoundError: - ml_dtypes = None from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping @@ -50,8 +46,6 @@ def get_scalar_kind(dtype: npt.DTypeLike) -> ts.ScalarKind: return ts.ScalarKind.STRING case np.dtype(): return getattr(ts.ScalarKind, dt.name.upper()) - case _ if ml_dtypes and dt == ml_dtypes.bfloat16: - return ts.ScalarKind.BFLOAT16 case _: raise ValueError(f"Impossible to map '{dtype}' value to a 'ScalarKind'.") else: @@ -245,10 +239,6 @@ def as_dtype(type_: ts.ScalarType) -> core_defs.DType: return core_defs.Int32DType() elif type_.kind == ts.ScalarKind.INT64: return core_defs.Int64DType() - elif type_.kind == ts.ScalarKind.FLOAT16: # TODO - return core_defs.Float16DType() - elif type_.kind == ts.ScalarKind.BFLOAT16: # TODO - return core_defs.BFloat16DType() elif type_.kind == ts.ScalarKind.FLOAT32: return core_defs.Float32DType() elif type_.kind == ts.ScalarKind.FLOAT64: @@ -269,10 +259,6 @@ def from_dtype(dtype: core_defs.DType) -> ts.ScalarType: return ts.ScalarType(kind=ts.ScalarKind.INT32) elif dtype == core_defs.Int64DType(): return ts.ScalarType(kind=ts.ScalarKind.INT64) - elif dtype == core_defs.Float16DType(): #TODO - return ts.ScalarType(kind=ts.ScalarKind.FLOAT16) - elif dtype == core_defs.BFloat16DType(): #TODO - return ts.ScalarType(kind=ts.ScalarKind.BFLOAT16) elif dtype == core_defs.Float32DType(): return ts.ScalarType(kind=ts.ScalarKind.FLOAT32) elif dtype == core_defs.Float64DType(): diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 5b06e94109..efe85af92f 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -113,7 +113,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_MESH_WITH_SKIP_VALUES = "uses_mesh_with_skip_values" USES_SCALAR_IN_DOMAIN_AND_FO = "uses_scalar_in_domain_and_fo" CHECKS_SPECIFIC_ERROR = "checks_specific_error" -USES_HALF_PRECISION = "uses_half_precision" USES_INDEX_BUILTIN = "uses_index_builtin" # Skip messages (available format keys: 'marker', 'backend') diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 54b496b68f..7ba1ee45e2 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -17,10 +17,6 @@ import numpy as np import pytest -try: - import ml_dtypes -except ModuleNotFoundError: - ml_dtypes = None import gt4py.next as gtx from gt4py._core import definitions as core_defs @@ -65,9 +61,6 @@ IField: TypeAlias = gtx.Field[[IDim], np.int32] # type: ignore [valid-type] IFloatField: TypeAlias = gtx.Field[[IDim], np.float64] # type: ignore [valid-type] -IHalfField: TypeAlias = gtx.Field[[IDim], np.float16] # type: ignore [valid-type] -if ml_dtypes: - IBFloatField: TypeAlias = gtx.Field[[IDim], ml_dtypes.bfloat16] # type: ignore [valid-type] IBoolField: TypeAlias = gtx.Field[[IDim], bool] # type: ignore [valid-type] KField: TypeAlias = gtx.Field[[KDim], np.int32] # type: ignore [valid-type] IJField: TypeAlias = gtx.Field[[IDim, JDim], np.int32] # type: ignore [valid-type] @@ -431,10 +424,6 @@ def verify( out_comp_ndarray = field_utils.asnumpy(out_comp) ref_ndarray = field_utils.asnumpy(ref) - if ml_dtypes and out_comp_ndarray.dtype == ml_dtypes.bfloat16: - out_comp_ndarray = out_comp_ndarray.astype(np.float32) - ref_ndarray = ref_ndarray.astype(np.float32) - assert comparison(ref_ndarray, out_comp_ndarray), ( f"Verification failed:\n" f"\tcomparison={comparison.__name__}(ref, out)\n" diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 9c59837f3e..524c556b76 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -9,10 +9,6 @@ from functools import reduce import numpy as np import pytest -try: - from ml_dtypes import bfloat16 -except ModuleNotFoundError: - bfloat16 = None import gt4py.next as gtx from gt4py.next import ( @@ -1233,28 +1229,4 @@ def consume_constants(input: cases.IFloatField) -> cases.IFloatField: cases.verify_with_default_data( cartesian_case, consume_constants, ref=lambda input: constants.PI * constants.E * input - ) - - -@pytest.mark.uses_half_precision -def test_float16(cartesian_case): - dtype = np.float16 - - @gtx.field_operator - def multiply_by_two(input: cases.IHalfField, input2: cases.IFloatField, scalar: np.float16) -> cases.IHalfField: - return dtype(2) * input * astype(input2, dtype) * scalar ** dtype(1.0) #TODO fails with 0.5 - cases.verify_with_default_data( - cartesian_case, multiply_by_two, ref=lambda input, input2, scalar: dtype(2) * input * input2 * scalar ** dtype(1.0) - ) - - -@pytest.mark.uses_half_precision -def test_bfloat16(cartesian_case): - dtype = bfloat16 - - @gtx.field_operator - def multiply_by_two(input: cases.IBFloatField, input2: cases.IBFloatField, scalar: bfloat16) -> cases.IBFloatField: - return dtype(2) * input * astype(input2, dtype) * scalar ** dtype(1.0) #TODO fails with 0.5 - cases.verify_with_default_data( - cartesian_case, multiply_by_two, ref=lambda input, input2, scalar: dtype(2) * input * input2 * scalar ** dtype(1.0) ) \ No newline at end of file From 84341b7e0e28e4e9806f388c6844d96805880a91 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Mon, 30 Dec 2024 11:44:46 +0100 Subject: [PATCH 07/17] Minor and pre-commit --- src/gt4py/_core/definitions.py | 2 +- src/gt4py/next/ffront/fbuiltins.py | 13 +------------ src/gt4py/next/iterator/builtins.py | 1 + src/gt4py/next/iterator/ir.py | 1 + src/gt4py/next/type_system/type_info.py | 5 ++--- tests/next_tests/definitions.py | 1 - tests/next_tests/integration_tests/cases.py | 1 - .../feature_tests/ffront_tests/test_execution.py | 3 +-- 8 files changed, 7 insertions(+), 20 deletions(-) diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 5c7635272a..8f62788b8f 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -144,7 +144,7 @@ def is_valid_tensor_shape(value: Sequence[IntegralScalar]) -> TypeGuard[TensorSh # -- Data type descriptors -- -class DTypeKind(eve.StrEnum):\ +class DTypeKind(eve.StrEnum): """ Kind of a specific data type. diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 5f7944ec83..ba98c80f6b 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -14,18 +14,7 @@ from typing import Any, Callable, Final, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast import numpy as np -from numpy import ( - float32, - float64, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, -) +from numpy import float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64 from gt4py._core import definitions as core_defs from gt4py.next import common diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 263549c49c..ce776b8398 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -8,6 +8,7 @@ from gt4py.next.iterator.dispatcher import Dispatcher + builtin_dispatch = Dispatcher() diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 643090aaa7..50330cf84e 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -16,6 +16,7 @@ from gt4py.next import common from gt4py.next.type_system import type_specifications as ts + DimensionKind = common.DimensionKind diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 9a02e925d2..39ebd366a5 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -326,7 +326,7 @@ def is_arithmetic(symbol_type: ts.TypeSpec) -> bool: def arithmetic_bounds(arithmetic_type: ts.ScalarType) -> tuple[np.number, np.number]: assert is_arithmetic(arithmetic_type) - bounds = { # type: ignore[return-value] # why resolved to `tuple[object, object]`? + return { # type: ignore[return-value] # why resolved to `tuple[object, object]`? ts.ScalarKind.FLOAT32: (np.finfo(np.float32).min, np.finfo(np.float32).max), ts.ScalarKind.FLOAT64: (np.finfo(np.float64).min, np.finfo(np.float64).max), ts.ScalarKind.INT8: (np.iinfo(np.int8).min, np.iinfo(np.int8).max), @@ -337,8 +337,7 @@ def arithmetic_bounds(arithmetic_type: ts.ScalarType) -> tuple[np.number, np.num ts.ScalarKind.UINT32: (np.iinfo(np.uint32).min, np.iinfo(np.uint32).max), ts.ScalarKind.INT64: (np.iinfo(np.int64).min, np.iinfo(np.int64).max), ts.ScalarKind.UINT64: (np.iinfo(np.uint64).min, np.iinfo(np.uint64).max), - } - return bounds[arithmetic_type.kind] + }[arithmetic_type.kind] def is_type_or_tuple_of_type(type_: ts.TypeSpec, expected_type: type | tuple) -> bool: diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index efe85af92f..bed6e89a52 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -113,7 +113,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_MESH_WITH_SKIP_VALUES = "uses_mesh_with_skip_values" USES_SCALAR_IN_DOMAIN_AND_FO = "uses_scalar_in_domain_and_fo" CHECKS_SPECIFIC_ERROR = "checks_specific_error" -USES_INDEX_BUILTIN = "uses_index_builtin" # Skip messages (available format keys: 'marker', 'backend') UNSUPPORTED_MESSAGE = "'{marker}' tests not supported by '{backend}' backend" diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 7ba1ee45e2..98103fe901 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -423,7 +423,6 @@ def verify( assert out_comp is not None out_comp_ndarray = field_utils.asnumpy(out_comp) ref_ndarray = field_utils.asnumpy(ref) - assert comparison(ref_ndarray, out_comp_ndarray), ( f"Verification failed:\n" f"\tcomparison={comparison.__name__}(ref, out)\n" diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 524c556b76..9de4449ac2 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -9,7 +9,6 @@ from functools import reduce import numpy as np import pytest - import gt4py.next as gtx from gt4py.next import ( astype, @@ -1229,4 +1228,4 @@ def consume_constants(input: cases.IFloatField) -> cases.IFloatField: cases.verify_with_default_data( cartesian_case, consume_constants, ref=lambda input: constants.PI * constants.E * input - ) \ No newline at end of file + ) From a4919f1939443acf4d161af1ffeb5a43906e01ef Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Mon, 30 Dec 2024 14:39:28 +0100 Subject: [PATCH 08/17] Fixed ScalarKind enum and Doctests and moved builtins from ir.py to builtins.py --- .../ffront/foast_passes/type_deduction.py | 6 +- src/gt4py/next/ffront/past_to_itir.py | 6 +- src/gt4py/next/ffront/type_info.py | 2 +- src/gt4py/next/iterator/builtins.py | 97 +++++++++++-------- src/gt4py/next/iterator/ir.py | 91 +---------------- .../next/iterator/ir_utils/domain_utils.py | 6 +- src/gt4py/next/iterator/ir_utils/ir_makers.py | 20 ++-- .../iterator/transforms/constant_folding.py | 4 +- .../next/iterator/transforms/infer_domain.py | 6 +- .../next/iterator/transforms/prune_casts.py | 4 +- .../next/iterator/transforms/trace_shifts.py | 6 +- .../next/iterator/type_system/inference.py | 12 +-- .../iterator/type_system/type_synthesizer.py | 12 +-- .../codegens/gtfn/gtfn_ir.py | 6 +- .../runners/dace_fieldview/gtir_dataflow.py | 4 +- .../dace_fieldview/gtir_python_codegen.py | 4 +- src/gt4py/next/type_system/type_info.py | 2 +- .../next/type_system/type_specifications.py | 6 +- .../iterator_tests/test_program.py | 5 +- tests/next_tests/toy_connectivity.py | 10 +- .../ffront_tests/test_past_to_gtir.py | 6 +- .../iterator_tests/test_pretty_parser.py | 4 +- .../iterator_tests/test_pretty_printer.py | 4 +- .../transforms_tests/test_global_tmps.py | 4 +- .../gtfn_tests/test_gtfn_module.py | 6 +- 25 files changed, 129 insertions(+), 204 deletions(-) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index d334487ae1..ec8f58da68 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -106,15 +106,15 @@ def promote_to_mask_type( >>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) >>> dtype = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) >>> promote_to_mask_type(ts.FieldType(dims=[I, J], dtype=bool_type), ts.ScalarType(kind=dtype)) - FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=ScalarType(kind=, shape=None), shape=None)) + FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=ScalarType(kind=, shape=None), shape=None)) >>> promote_to_mask_type( ... ts.FieldType(dims=[I, J], dtype=bool_type), ts.FieldType(dims=[I], dtype=dtype) ... ) - FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=, shape=None)) + FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=, shape=None)) >>> promote_to_mask_type( ... ts.FieldType(dims=[I], dtype=bool_type), ts.FieldType(dims=[I, J], dtype=dtype) ... ) - FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=, shape=None)) + FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=, shape=None)) """ if isinstance(input_type, ts.ScalarType) or not all( item in input_type.dims for item in mask_type.dims diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 4ec12bb76b..bcff7c4c6e 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -24,7 +24,7 @@ type_specifications as ts_ffront, ) from gt4py.next.ffront.stages import AOT_PRG -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.otf import stages, workflow from gt4py.next.type_system import type_info, type_specifications as ts @@ -222,7 +222,7 @@ def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]: itir.Sym( id=_size_arg_from_field(param.id, dim_idx), type=ts.ScalarType( - kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) + kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) ), ) ) @@ -359,7 +359,7 @@ def _construct_itir_domain_arg( else: lower = self._visit_slice_bound( slices[dim_i].lower if slices else None, - im.literal("0", itir.INTEGER_INDEX_BUILTIN), + im.literal("0", builtins.INTEGER_INDEX_BUILTIN), dim_size, ) upper = self._visit_slice_bound( diff --git a/src/gt4py/next/ffront/type_info.py b/src/gt4py/next/ffront/type_info.py index 8160a2c42d..10843823cb 100644 --- a/src/gt4py/next/ffront/type_info.py +++ b/src/gt4py/next/ffront/type_info.py @@ -171,7 +171,7 @@ def _scan_param_promotion(param: ts.TypeSpec, arg: ts.TypeSpec) -> ts.FieldType ... ts.ScalarType(kind=ts.ScalarKind.INT64), ... ts.FieldType(dims=[common.Dimension("I")], dtype=ts.ScalarKind.FLOAT64), ... ) - FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None)) + FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None)) """ def _as_field(dtype: ts.TypeSpec, path: tuple[int, ...]) -> ts.FieldType: diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index ce776b8398..9dac8ac61c 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -398,6 +398,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] UNARY_MATH_NUMBER_BUILTINS = {"abs"} +UNARY_LOGICAL_BUILTINS = {"not_"} UNARY_MATH_FP_BUILTINS = { "sin", "cos", @@ -421,8 +422,24 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] "trunc", } UNARY_MATH_FP_PREDICATE_BUILTINS = {"isfinite", "isinf", "isnan"} -BINARY_MATH_NUMBER_BUILTINS = {"minimum", "maximum", "fmod", "power"} -TYPEBUILTINS = { +BINARY_MATH_NUMBER_BUILTINS = {"minimum", "maximum", "fmod"} +BINARY_MATH_NUMBER_ADDITIONAL_BUILTINS = { + "plus", + "minus", + "multiplies", + "divides", + "mod", + "floordiv", # TODO see https://github.com/GridTools/gt4py/issues/1136 + *BINARY_MATH_NUMBER_BUILTINS, +} +BINARY_MATH_COMPARISON_BUILTINS = {"eq", "less", "greater", "greater_equal", "less_equal", "not_eq"} +BINARY_LOGICAL_BUILTINS = {"and_", "or_", "xor_"} + + + +#: builtin / dtype used to construct integer indices, like domain bounds +INTEGER_INDEX_BUILTIN = "int32" +INTEGER_BUILTINS = { "int8", "uint8", "int16", @@ -431,55 +448,51 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] "uint32", "int64", "uint64", - "float32", - "float64", - "bool", -} # TODO(tehrengruber): This list already exists in ir.py; unify. - -MATH_BUILTINS = ( - UNARY_MATH_NUMBER_BUILTINS - | UNARY_MATH_FP_BUILTINS - | UNARY_MATH_FP_PREDICATE_BUILTINS - | BINARY_MATH_NUMBER_BUILTINS - | TYPEBUILTINS -) +} # Todo: should we distinguish int and uint? +FLOATING_POINT_BUILTINS = {"float32", "float64"} +TYPEBUILTINS = {*INTEGER_BUILTINS, *FLOATING_POINT_BUILTINS, "bool"} + +MATH_BUILTINS = { + * UNARY_MATH_NUMBER_BUILTINS, + * UNARY_MATH_FP_BUILTINS, + * UNARY_MATH_FP_PREDICATE_BUILTINS, + * BINARY_MATH_NUMBER_BUILTINS, + "power", + * TYPEBUILTINS, +} + +ARITHMETIC_BUILTINS = { + *UNARY_MATH_NUMBER_BUILTINS, + *UNARY_LOGICAL_BUILTINS, + *UNARY_MATH_FP_BUILTINS, + *UNARY_MATH_FP_PREDICATE_BUILTINS, + *BINARY_MATH_NUMBER_ADDITIONAL_BUILTINS, + "power", + *BINARY_MATH_COMPARISON_BUILTINS, + *BINARY_LOGICAL_BUILTINS, +} + BUILTINS = { - "deref", + "as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution) "can_deref", + "cartesian_domain", + "cast_", + "deref", + "if_", + "index", # `index(dim)` creates a dim-field that has the current index at each point "shift", - "neighbors", "list_get", + "lift", "make_const_list", + "make_tuple", "map_", - "lift", + "named_range", + "neighbors", "reduce", - "plus", - "minus", - "multiplies", - "divides", - "floordiv", - "mod", - "make_tuple", - "tuple_get", - "if_", - "cast_", - "greater", - "less", - "less_equal", - "greater_equal", - "eq", - "not_eq", - "not_", - "and_", - "or_", - "xor_", "scan", - "cartesian_domain", + "tuple_get", "unstructured_domain", - "named_range", - "as_fieldop", - "index", - *MATH_BUILTINS, + *ARITHMETIC_BUILTINS, } __all__ = [*BUILTINS] diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 50330cf84e..ab68907f3b 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -15,6 +15,7 @@ from gt4py.eve.utils import noninstantiable from gt4py.next import common from gt4py.next.type_system import type_specifications as ts +from gt4py.next.iterator.builtins import BUILTINS DimensionKind = common.DimensionKind @@ -93,96 +94,6 @@ class FunctionDefinition(Node, SymbolTableTrait): expr: Expr -UNARY_MATH_NUMBER_BUILTINS = {"abs"} -UNARY_LOGICAL_BUILTINS = {"not_"} -UNARY_MATH_FP_BUILTINS = { - "sin", - "cos", - "tan", - "arcsin", - "arccos", - "arctan", - "sinh", - "cosh", - "tanh", - "arcsinh", - "arccosh", - "arctanh", - "sqrt", - "exp", - "log", - "gamma", - "cbrt", - "floor", - "ceil", - "trunc", -} -UNARY_MATH_FP_PREDICATE_BUILTINS = {"isfinite", "isinf", "isnan"} -BINARY_MATH_NUMBER_BUILTINS = { - "minimum", - "maximum", - "fmod", - "plus", - "minus", - "multiplies", - "divides", - "mod", - "floordiv", # TODO see https://github.com/GridTools/gt4py/issues/1136 -} -BINARY_MATH_COMPARISON_BUILTINS = {"eq", "less", "greater", "greater_equal", "less_equal", "not_eq"} -BINARY_LOGICAL_BUILTINS = {"and_", "or_", "xor_"} - -ARITHMETIC_BUILTINS = { - *UNARY_MATH_NUMBER_BUILTINS, - *UNARY_LOGICAL_BUILTINS, - *UNARY_MATH_FP_BUILTINS, - *UNARY_MATH_FP_PREDICATE_BUILTINS, - *BINARY_MATH_NUMBER_BUILTINS, - "power", - *BINARY_MATH_COMPARISON_BUILTINS, - *BINARY_LOGICAL_BUILTINS, -} - -#: builtin / dtype used to construct integer indices, like domain bounds -INTEGER_INDEX_BUILTIN = "int32" -INTEGER_BUILTINS = { - "int8", - "uint8", - "int16", - "uint16", - "int32", - "uint32", - "int64", - "uint64", -} # Todo: should we distinguish int and uint? -FLOATING_POINT_BUILTINS = {"float32", "float64"} -TYPEBUILTINS = {*INTEGER_BUILTINS, *FLOATING_POINT_BUILTINS, "bool"} - -BUILTINS = { - "tuple_get", - "cast_", - "cartesian_domain", - "unstructured_domain", - "make_tuple", - "shift", - "neighbors", - "named_range", - "list_get", - "map_", - "make_const_list", - "lift", - "reduce", - "deref", - "can_deref", - "scan", - "if_", - "index", # `index(dim)` creates a dim-field that has the current index at each point - "as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution) - *ARITHMETIC_BUILTINS, - *TYPEBUILTINS, -} - - class Stmt(Node): ... diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 4a023f7535..c84e2c0228 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -13,7 +13,7 @@ from typing import Any, Literal, Mapping, Optional from gt4py.next import common -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms import trace_shifts from gt4py.next.iterator.transforms.constant_folding import ConstantFolding @@ -127,7 +127,7 @@ def translate( else: # note: ugly but cheap re-computation, but should disappear horizontal_sizes = { - k: im.literal(str(v), itir.INTEGER_INDEX_BUILTIN) + k: im.literal(str(v), builtins.INTEGER_INDEX_BUILTIN) for k, v in _max_domain_sizes_by_location_type(offset_provider).items() } @@ -137,7 +137,7 @@ def translate( assert new_dim not in new_ranges or old_dim == new_dim new_range = SymbolicRange( - im.literal("0", itir.INTEGER_INDEX_BUILTIN), + im.literal("0", builtins.INTEGER_INDEX_BUILTIN), horizontal_sizes[new_dim.value], ) new_ranges = dict( diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 0839e95b5b..18d72f394b 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -11,7 +11,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import common -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.type_system import type_specifications as ts, type_translation @@ -29,7 +29,7 @@ def sym(sym_or_name: Union[str, itir.Sym], type_: str | ts.TypeSpec | None = Non >>> a = sym("a", "float32") >>> a.id, a.type - (SymbolName('a'), ScalarType(kind=, shape=None)) + (SymbolName('a'), ScalarType(kind=, shape=None)) """ if isinstance(sym_or_name, itir.Sym): assert not type_ @@ -53,7 +53,7 @@ def ref( >>> a = ref("a", "float32") >>> a.id, a.type - (SymbolRef('a'), ScalarType(kind=, shape=None)) + (SymbolRef('a'), ScalarType(kind=, shape=None)) """ if isinstance(ref_or_name, itir.SymRef): assert not type_ @@ -71,7 +71,7 @@ def ensure_expr(literal_or_expr: Union[str, core_defs.Scalar, itir.Expr]) -> iti SymRef(id=SymbolRef('a')) >>> ensure_expr(3) - Literal(value='3', type=ScalarType(kind=, shape=None)) + Literal(value='3', type=ScalarType(kind=, shape=None)) >>> ensure_expr(itir.OffsetLiteral(value="i")) OffsetLiteral(value='i') @@ -134,7 +134,7 @@ class call: Examples -------- >>> call("plus")(1, 1) - FunCall(fun=SymRef(id=SymbolRef('plus')), args=[Literal(value='1', type=ScalarType(kind=, shape=None)), Literal(value='1', type=ScalarType(kind=, shape=None))]) + FunCall(fun=SymRef(id=SymbolRef('plus')), args=[Literal(value='1', type=ScalarType(kind=, shape=None)), Literal(value='1', type=ScalarType(kind=, shape=None))]) """ def __init__(self, expr): @@ -238,7 +238,7 @@ def make_tuple(*args): def tuple_get(index: str | int, tuple_expr): """Create a tuple_get FunCall, shorthand for ``call("tuple_get")(index, tuple_expr)``.""" - return call("tuple_get")(literal(str(index), itir.INTEGER_INDEX_BUILTIN), tuple_expr) + return call("tuple_get")(literal(str(index), builtins.INTEGER_INDEX_BUILTIN), tuple_expr) def if_(cond, true_val, false_val): @@ -316,11 +316,11 @@ def literal_from_value(val: core_defs.Scalar) -> itir.Literal: Make a literal node from a value. >>> literal_from_value(1.0) - Literal(value='1.0', type=ScalarType(kind=, shape=None)) + Literal(value='1.0', type=ScalarType(kind=, shape=None)) >>> literal_from_value(1) - Literal(value='1', type=ScalarType(kind=, shape=None)) + Literal(value='1', type=ScalarType(kind=, shape=None)) >>> literal_from_value(2147483648) - Literal(value='2147483648', type=ScalarType(kind=, shape=None)) + Literal(value='2147483648', type=ScalarType(kind=, shape=None)) >>> literal_from_value(True) Literal(value='True', type=ScalarType(kind=, shape=None)) """ @@ -335,7 +335,7 @@ def literal_from_value(val: core_defs.Scalar) -> itir.Literal: assert isinstance(type_spec, ts.ScalarType) typename = type_spec.kind.name.lower() - assert typename in itir.TYPEBUILTINS + assert typename in builtins.TYPEBUILTINS return literal(str(val), typename) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 2084ab2518..18f890e08f 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.next.iterator import embedded, ir +from gt4py.next.iterator import embedded, ir, builtins from gt4py.next.iterator.ir_utils import ir_makers as im @@ -44,7 +44,7 @@ def visit_FunCall(self, node: ir.FunCall): and all(isinstance(arg, ir.Literal) for arg in new_node.args) ): # `1 + 1` -> `2` try: - if new_node.fun.id in ir.ARITHMETIC_BUILTINS: + if new_node.fun.id in builtins.ARITHMETIC_BUILTINS: fun = getattr(embedded, str(new_node.fun.id)) arg_values = [ getattr(embedded, str(arg.type))(arg.value) # type: ignore[attr-defined] # arg type already established in if condition diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index f26d3f9ec2..3a52569829 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -15,7 +15,7 @@ from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Callable, Optional, TypeAlias, Unpack from gt4py.next import common -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import ( common_pattern_matcher as cpm, domain_utils, @@ -383,8 +383,8 @@ def _infer_expr( elif cpm.is_call_to(expr, "if_"): return _infer_if(expr, domain, **kwargs) elif ( - cpm.is_call_to(expr, itir.ARITHMETIC_BUILTINS) - or cpm.is_call_to(expr, itir.TYPEBUILTINS) + cpm.is_call_to(expr, builtins.ARITHMETIC_BUILTINS) + or cpm.is_call_to(expr, builtins.TYPEBUILTINS) or cpm.is_call_to(expr, ("cast_", "index", "unstructured_domain", "cartesian_domain")) ): return expr, {} diff --git a/src/gt4py/next/iterator/transforms/prune_casts.py b/src/gt4py/next/iterator/transforms/prune_casts.py index c825f68a5f..0216a0c0fb 100644 --- a/src/gt4py/next/iterator/transforms/prune_casts.py +++ b/src/gt4py/next/iterator/transforms/prune_casts.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py import eve -from gt4py.next.iterator import ir +from gt4py.next.iterator import ir, builtins from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.type_system import type_specifications as ts @@ -31,7 +31,7 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.Node: assert ( value.type and isinstance(type_constructor, ir.SymRef) - and (type_constructor.id in ir.TYPEBUILTINS) + and (type_constructor.id in builtins.TYPEBUILTINS) ) dtype = ts.ScalarType(kind=getattr(ts.ScalarKind, type_constructor.id.upper())) diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 68346b6622..c8003cb9ba 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -13,7 +13,7 @@ from gt4py import eve from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.next.iterator import ir +from gt4py.next.iterator import ir, builtins from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift @@ -278,9 +278,9 @@ def visit_Literal(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any: def visit_SymRef(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any: if node.id in ctx: return ctx[node.id] - elif node.id in ir.TYPEBUILTINS: + elif node.id in builtins.TYPEBUILTINS: return Sentinel.TYPE - elif node.id in (ir.ARITHMETIC_BUILTINS | {"list_get", "make_const_list", "cast_"}): + elif node.id in (builtins.ARITHMETIC_BUILTINS | {"list_get", "make_const_list", "cast_"}): return _combine raise ValueError(f"Undefined symbol {node.id}") diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 1b980783fa..b2a85ef1f3 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -17,7 +17,7 @@ from gt4py.eve import concepts from gt4py.eve.extended_typing import Any, Callable, Optional, TypeVar, Union from gt4py.next import common -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_call_to from gt4py.next.iterator.type_system import type_specifications as it_ts, type_synthesizer from gt4py.next.type_system import type_info, type_specifications as ts @@ -147,7 +147,7 @@ class ObservableTypeSynthesizer(type_synthesizer.TypeSynthesizer): >>> float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) >>> int_type = ts.ScalarType(kind=ts.ScalarKind.INT64) >>> power(float_type, int_type) - ScalarType(kind=, shape=None) + ScalarType(kind=, shape=None) Now, consider a simple lambda function that squares its argument using the power builtin. A type synthesizer for this function is simple to formulate, but merely gives us the return @@ -159,7 +159,7 @@ class ObservableTypeSynthesizer(type_synthesizer.TypeSynthesizer): ... type_synthesizer=lambda base: power(base, int_type) ... ) >>> square_func_type_synthesizer(float_type, offset_provider_type={}) - ScalarType(kind=, shape=None) + ScalarType(kind=, shape=None) Note that without a corresponding call the function itself can not be fully typed and as such the type inference algorithm has to defer typing until then. This task is handled transparently @@ -173,7 +173,7 @@ class ObservableTypeSynthesizer(type_synthesizer.TypeSynthesizer): ... store_inferred_type_in_node=True, ... ) >>> o_type_synthesizer(float_type, offset_provider_type={}) - ScalarType(kind=, shape=None) + ScalarType(kind=, shape=None) >>> square_func.type == ts.FunctionType( ... pos_only_args=[float_type], pos_or_kw_args={}, kw_only_args={}, returns=float_type ... ) @@ -529,7 +529,7 @@ def visit_AxisLiteral(self, node: itir.AxisLiteral, **kwargs) -> ts.DimensionTyp def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs) -> it_ts.OffsetLiteralType: if _is_representable_as_int(node.value): return it_ts.OffsetLiteralType( - value=ts.ScalarType(kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())) + value=ts.ScalarType(kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())) ) else: assert isinstance(node.value, str) and node.value in self.dimensions @@ -578,7 +578,7 @@ def visit_FunCall( self.visit(value, ctx=ctx) # ensure types in value are also inferred assert ( isinstance(type_constructor, itir.SymRef) - and type_constructor.id in itir.TYPEBUILTINS + and type_constructor.id in builtins.TYPEBUILTINS ) return ts.ScalarType(kind=getattr(ts.ScalarKind, type_constructor.id.upper())) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 5be9ed7438..ff2489113c 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -14,7 +14,7 @@ from gt4py.eve.extended_typing import Callable, Iterable, Optional, Union from gt4py.next import common -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins from gt4py.next.iterator.type_system import type_specifications as it_ts from gt4py.next.type_system import type_info, type_specifications as ts from gt4py.next.utils import tree_map @@ -81,7 +81,7 @@ def _register_builtin_type_synthesizer( @_register_builtin_type_synthesizer( - fun_names=itir.UNARY_MATH_NUMBER_BUILTINS | itir.UNARY_MATH_FP_BUILTINS + fun_names=builtins.UNARY_MATH_NUMBER_BUILTINS | builtins.UNARY_MATH_FP_BUILTINS ) def _(val: ts.ScalarType) -> ts.ScalarType: return val @@ -92,21 +92,21 @@ def power(base: ts.ScalarType, exponent: ts.ScalarType) -> ts.ScalarType: return base -@_register_builtin_type_synthesizer(fun_names=itir.BINARY_MATH_NUMBER_BUILTINS) +@_register_builtin_type_synthesizer(fun_names=builtins.BINARY_MATH_NUMBER_ADDITIONAL_BUILTINS) def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType: assert lhs == rhs return lhs @_register_builtin_type_synthesizer( - fun_names=itir.UNARY_MATH_FP_PREDICATE_BUILTINS | itir.UNARY_LOGICAL_BUILTINS + fun_names=builtins.UNARY_MATH_FP_PREDICATE_BUILTINS | builtins.UNARY_LOGICAL_BUILTINS ) def _(arg: ts.ScalarType) -> ts.ScalarType: return ts.ScalarType(kind=ts.ScalarKind.BOOL) @_register_builtin_type_synthesizer( - fun_names=itir.BINARY_MATH_COMPARISON_BUILTINS | itir.BINARY_LOGICAL_BUILTINS + fun_names=builtins.BINARY_MATH_COMPARISON_BUILTINS | builtins.BINARY_LOGICAL_BUILTINS ) def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType | ts.TupleType: return ts.ScalarType(kind=ts.ScalarKind.BOOL) @@ -193,7 +193,7 @@ def make_tuple(*args: ts.DataType) -> ts.TupleType: def index(arg: ts.DimensionType) -> ts.FieldType: return ts.FieldType( dims=[arg.dim], - dtype=ts.ScalarType(kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())), + dtype=ts.ScalarType(kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())), ) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index 85a100a88d..621aac6033 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -13,7 +13,7 @@ from gt4py.eve import Coerced, SymbolName, datamodels from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait from gt4py.next import common -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.program_processors.codegens.gtfn.gtfn_im_ir import ImperativeFunctionDefinition from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_common import Expr, Node, Sym, SymRef @@ -230,8 +230,8 @@ class TemporaryAllocation(Node): "reduce", "index", ] -ARITHMETIC_BUILTINS = itir.ARITHMETIC_BUILTINS -TYPEBUILTINS = itir.TYPEBUILTINS +ARITHMETIC_BUILTINS = builtins.ARITHMETIC_BUILTINS +TYPEBUILTINS = builtins.TYPEBUILTINS BUILTINS = {*GTFN_BUILTINS, *ARITHMETIC_BUILTINS, *TYPEBUILTINS} diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index a3653fb519..7b1bf1edf9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -29,7 +29,7 @@ from gt4py import eve from gt4py.next import common as gtx_common -from gt4py.next.iterator import ir as gtir +from gt4py.next.iterator import builtins, ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.program_processors.runners.dace_common import utility as dace_utils @@ -1146,7 +1146,7 @@ def _make_unstructured_shift( def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: # convert builtin-index type to dace type IndexDType: Final = dace_utils.as_dace_type( - ts.ScalarType(kind=getattr(ts.ScalarKind, gtir.INTEGER_INDEX_BUILTIN.upper())) + ts.ScalarType(kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())) ) assert isinstance(node.fun, gtir.FunCall) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py index 4bdb602f5f..883f2f946b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py @@ -14,7 +14,7 @@ from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt -from gt4py.next.iterator import ir as gtir +from gt4py.next.iterator import builtins, ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm @@ -75,7 +75,7 @@ def builtin_cast(*args: Any) -> str: val, target_type = args - assert target_type in gtir.TYPEBUILTINS + assert target_type in builtins.TYPEBUILTINS return MATH_BUILTINS_MAPPING[target_type].format(val) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 39ebd366a5..28d8d3c187 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -173,7 +173,7 @@ def apply_to_primitive_constituents( ... with_path_arg=True, ... tuple_constructor=lambda *elements: dict(elements), ... ) - {(0,): ScalarType(kind=, shape=None), (1,): ScalarType(kind=, shape=None)} + {(0,): ScalarType(kind=, shape=None), (1,): ScalarType(kind=, shape=None)} """ if isinstance(symbol_types[0], ts.TupleType): assert all(isinstance(symbol_type, ts.TupleType) for symbol_type in symbol_types) diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 0878083a1e..2c48599f0a 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -81,9 +81,9 @@ class ScalarKind(IntEnum): UINT32 = 7 INT64 = 8 UINT64 = 9 - FLOAT32 = 12 - FLOAT64 = 13 - STRING = 14 + FLOAT32 = 10 + FLOAT64 = 11 + STRING = 12 @dataclass(frozen=True) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py index c79f8dbb6b..09dc04acb1 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py @@ -18,6 +18,7 @@ index, named_range, shift, + INTEGER_INDEX_BUILTIN, ) from gt4py.next.iterator.runtime import fendef, fundef, set_at @@ -68,7 +69,7 @@ def test_index_builtin(program_processor): program_processor, validate = program_processor isize = 10 - out = gtx.as_field([I], np.zeros(shape=(isize,)), dtype=getattr(np, itir.INTEGER_INDEX_BUILTIN)) + out = gtx.as_field([I], np.zeros(shape=(isize,)), dtype=getattr(np, INTEGER_INDEX_BUILTIN)) run_processor(index_program_simple, program_processor, out, isize, offset_provider={}) if validate: @@ -91,7 +92,7 @@ def test_index_builtin_shift(program_processor): program_processor, validate = program_processor isize = 10 - out = gtx.as_field([I], np.zeros(shape=(isize,)), dtype=getattr(np, itir.INTEGER_INDEX_BUILTIN)) + out = gtx.as_field([I], np.zeros(shape=(isize,)), dtype=getattr(np, INTEGER_INDEX_BUILTIN)) run_processor(index_program_shift, program_processor, out, isize, offset_provider={"Ioff": I}) if validate: diff --git a/tests/next_tests/toy_connectivity.py b/tests/next_tests/toy_connectivity.py index 50db24b880..154b666c5d 100644 --- a/tests/next_tests/toy_connectivity.py +++ b/tests/next_tests/toy_connectivity.py @@ -9,7 +9,7 @@ import numpy as np import gt4py.next as gtx -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir Vertex = gtx.Dimension("Vertex") @@ -46,7 +46,7 @@ [7, 17, 1, 16], [8, 15, 2, 17], ], - dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), + dtype=np.dtype(builtins.INTEGER_INDEX_BUILTIN), ) c2e_conn = gtx.as_connectivity(domain={Cell: 9, C2EDim: 4}, codomain=Edge, data=c2e_arr) @@ -63,7 +63,7 @@ [8, 1, 6, 4], [6, 2, 7, 5], ], - dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), + dtype=np.dtype(builtins.INTEGER_INDEX_BUILTIN), ) v2v_conn = gtx.as_connectivity(domain={Vertex: 9, V2VDim: 4}, codomain=Vertex, data=v2v_arr) @@ -89,7 +89,7 @@ [7, 1], [8, 2], ], - dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), + dtype=np.dtype(builtins.INTEGER_INDEX_BUILTIN), ) e2v_conn = gtx.as_connectivity(domain={Edge: 18, E2VDim: 2}, codomain=Vertex, data=e2v_arr) @@ -107,7 +107,7 @@ [7, 13, 6, 16], [8, 14, 7, 17], ], - dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), + dtype=np.dtype(builtins.INTEGER_INDEX_BUILTIN), ) v2e_conn = gtx.as_connectivity(domain={Vertex: 9, V2EDim: 4}, codomain=Edge, data=v2e_arr) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py index c813285bd0..379afa3be3 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py @@ -17,7 +17,7 @@ from gt4py.next import errors from gt4py.next.ffront.func_to_past import ProgramParser from gt4py.next.ffront.past_to_itir import ProgramLowering -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_specifications as ts @@ -108,14 +108,14 @@ def test_copy_restrict_lowering(copy_restrict_program_def, gtir_identity_fundef) itir.Literal, value="1", type=ts.ScalarType( - kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) + kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) ), ), P( itir.Literal, value="2", type=ts.ScalarType( - kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) + kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) ), ), ], diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py index af9084f407..f825c3823b 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from gt4py.next.iterator import ir +from gt4py.next.iterator import ir, builtins from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.pretty_parser import pparse from gt4py.next.type_system import type_specifications as ts @@ -111,7 +111,7 @@ def test_tuple_get(): testee = "x[42]" expected = ir.FunCall( fun=ir.SymRef(id="tuple_get"), - args=[im.literal("42", ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")], + args=[im.literal("42", builtins.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")], ) actual = pparse(testee) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py index 6b45f470b7..b0f7021bc0 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from gt4py.next.iterator import ir +from gt4py.next.iterator import ir, builtins from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.pretty_printer import PrettyPrinter, pformat from gt4py.next.type_system import type_specifications as ts @@ -200,7 +200,7 @@ def test_shift(): def test_tuple_get(): testee = ir.FunCall( fun=ir.SymRef(id="tuple_get"), - args=[im.literal("42", ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")], + args=[im.literal("42", builtins.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")], ) expected = "x[42]" actual = pformat(testee) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 9d51dc4f33..52d77e5fda 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -9,7 +9,7 @@ from typing import Optional from gt4py.next import common -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms import global_tmps, infer_domain from gt4py.next.iterator.type_system import inference as type_inference @@ -19,7 +19,7 @@ IDim = common.Dimension(value="IDim") JDim = common.Dimension(value="JDim") KDim = common.Dimension(value="KDim", kind=common.DimensionKind.VERTICAL) -index_type = ts.ScalarType(kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())) +index_type = ts.ScalarType(kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())) float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) i_field_type = ts.FieldType(dims=[IDim], dtype=float_type) index_field_type_factory = lambda dim: ts.FieldType(dims=[dim], dtype=index_type) diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index 0586d48703..1afd6e8113 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -13,7 +13,7 @@ import pytest import gt4py.next as gtx -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.otf import arguments, languages, stages from gt4py.next.program_processors.codegens.gtfn import gtfn_module @@ -41,8 +41,8 @@ def program_example(): fun=itir.SymRef(id="named_range"), args=[ itir.AxisLiteral(value="I"), - im.literal("0", itir.INTEGER_INDEX_BUILTIN), - im.literal("10", itir.INTEGER_INDEX_BUILTIN), + im.literal("0", builtins.INTEGER_INDEX_BUILTIN), + im.literal("10", builtins.INTEGER_INDEX_BUILTIN), ], ) ], From 20fba484287a8f461dc649eba972fe8a52201c8e Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Mon, 30 Dec 2024 14:41:43 +0100 Subject: [PATCH 09/17] Run pre-commit --- src/gt4py/next/iterator/builtins.py | 13 ++++++------- src/gt4py/next/iterator/ir.py | 2 +- .../next/iterator/transforms/constant_folding.py | 2 +- src/gt4py/next/iterator/transforms/prune_casts.py | 2 +- src/gt4py/next/iterator/transforms/trace_shifts.py | 2 +- src/gt4py/next/iterator/type_system/inference.py | 4 +++- .../program_processors/codegens/gtfn/gtfn_ir.py | 2 +- 7 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 9dac8ac61c..c20d4d07e6 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -436,7 +436,6 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] BINARY_LOGICAL_BUILTINS = {"and_", "or_", "xor_"} - #: builtin / dtype used to construct integer indices, like domain bounds INTEGER_INDEX_BUILTIN = "int32" INTEGER_BUILTINS = { @@ -453,12 +452,12 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] TYPEBUILTINS = {*INTEGER_BUILTINS, *FLOATING_POINT_BUILTINS, "bool"} MATH_BUILTINS = { - * UNARY_MATH_NUMBER_BUILTINS, - * UNARY_MATH_FP_BUILTINS, - * UNARY_MATH_FP_PREDICATE_BUILTINS, - * BINARY_MATH_NUMBER_BUILTINS, + *UNARY_MATH_NUMBER_BUILTINS, + *UNARY_MATH_FP_BUILTINS, + *UNARY_MATH_FP_PREDICATE_BUILTINS, + *BINARY_MATH_NUMBER_BUILTINS, "power", - * TYPEBUILTINS, + *TYPEBUILTINS, } ARITHMETIC_BUILTINS = { @@ -473,7 +472,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] } BUILTINS = { - "as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution) + "as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution) "can_deref", "cartesian_domain", "cast_", diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index ab68907f3b..ea5cf84d86 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -14,8 +14,8 @@ from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait from gt4py.eve.utils import noninstantiable from gt4py.next import common -from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.builtins import BUILTINS +from gt4py.next.type_system import type_specifications as ts DimensionKind = common.DimensionKind diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 18f890e08f..7215d0787a 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.next.iterator import embedded, ir, builtins +from gt4py.next.iterator import builtins, embedded, ir from gt4py.next.iterator.ir_utils import ir_makers as im diff --git a/src/gt4py/next/iterator/transforms/prune_casts.py b/src/gt4py/next/iterator/transforms/prune_casts.py index 0216a0c0fb..409b5ee556 100644 --- a/src/gt4py/next/iterator/transforms/prune_casts.py +++ b/src/gt4py/next/iterator/transforms/prune_casts.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py import eve -from gt4py.next.iterator import ir, builtins +from gt4py.next.iterator import builtins, ir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.type_system import type_specifications as ts diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index c8003cb9ba..38ab398941 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -13,7 +13,7 @@ from gt4py import eve from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.next.iterator import ir, builtins +from gt4py.next.iterator import builtins, ir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index b2a85ef1f3..b7e0729a4b 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -529,7 +529,9 @@ def visit_AxisLiteral(self, node: itir.AxisLiteral, **kwargs) -> ts.DimensionTyp def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs) -> it_ts.OffsetLiteralType: if _is_representable_as_int(node.value): return it_ts.OffsetLiteralType( - value=ts.ScalarType(kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())) + value=ts.ScalarType( + kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) + ) ) else: assert isinstance(node.value, str) and node.value in self.dimensions diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index 621aac6033..39ac04ecc0 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -13,7 +13,7 @@ from gt4py.eve import Coerced, SymbolName, datamodels from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait from gt4py.next import common -from gt4py.next.iterator import builtins, ir as itir +from gt4py.next.iterator import builtins from gt4py.next.program_processors.codegens.gtfn.gtfn_im_ir import ImperativeFunctionDefinition from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_common import Expr, Node, Sym, SymRef From 84b51ec9f5114326a405be2779b08f7346a46b36 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Mon, 30 Dec 2024 15:17:40 +0100 Subject: [PATCH 10/17] Fix tests --- src/gt4py/next/iterator/builtins.py | 1 + .../feature_tests/ffront_tests/test_execution.py | 6 ++++-- .../integration_tests/feature_tests/test_util_cases.py | 6 +++--- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index c20d4d07e6..02f6ebbe5c 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -492,6 +492,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] "tuple_get", "unstructured_domain", *ARITHMETIC_BUILTINS, + *TYPEBUILTINS, } __all__ = [*BUILTINS] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 9de4449ac2..55847cbcf0 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -318,7 +318,9 @@ def testee(a: int32, b: int32, c: cases.IField) -> cases.IField: # not inlined return tmp2 * tmp2 * c - cases.verify_with_default_data(cartesian_case, testee, ref=lambda a, b, c: a * b * a * b * c) + cases.verify_with_default_data( + cartesian_case, testee, ref=lambda a, b, c: a * b * a * b * a * b * a * b * c + ) @pytest.mark.uses_scalar_in_domain_and_fo @@ -1121,7 +1123,7 @@ def implicit_broadcast_scalar(inp: cases.EmptyField): inp = cases.allocate(cartesian_case, implicit_broadcast_scalar, "inp")() out = cases.allocate(cartesian_case, implicit_broadcast_scalar, "inp")() - cases.verify(cartesian_case, implicit_broadcast_scalar, inp, out=out, ref=np.array(0)) + cases.verify(cartesian_case, implicit_broadcast_scalar, inp, out=out, ref=np.array(1)) def test_implicit_broadcast_mixed_dim(cartesian_case): diff --git a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py index eaeb76b404..3e3df069bf 100644 --- a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py +++ b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py @@ -35,8 +35,8 @@ def mixed_args( def test_allocate_default_unique(cartesian_case): a = cases.allocate(cartesian_case, mixed_args, "a")() - assert np.min(a.asnumpy()) == 0 - assert np.max(a.asnumpy()) == np.prod(tuple(cartesian_case.default_sizes.values())) - 1 + assert np.min(a.asnumpy()) == 1 + assert np.max(a.asnumpy()) == np.prod(tuple(cartesian_case.default_sizes.values())) b = cases.allocate(cartesian_case, mixed_args, "b")() @@ -45,7 +45,7 @@ def test_allocate_default_unique(cartesian_case): c = cases.allocate(cartesian_case, mixed_args, "c")() assert np.min(c.asnumpy()) == b + 1 - assert np.max(c.asnumpy()) == np.prod(tuple(cartesian_case.default_sizes.values())) * 2 + assert np.max(c.asnumpy()) == np.prod(tuple(cartesian_case.default_sizes.values())) * 2 + 1 def test_allocate_return_default_zeros(cartesian_case): From 52c7e6ec7c4f2bd2059b2ebd91739aeea9d9dbca Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Mon, 30 Dec 2024 15:51:48 +0100 Subject: [PATCH 11/17] Use itir_to_gtfn.py:pytype_to_cpptype in cpp_interface.py:render_scalar_type to avoid duplication --- src/gt4py/next/otf/binding/cpp_interface.py | 36 ++++--------------- .../codegens/gtfn/itir_to_gtfn_ir.py | 1 + .../binding_tests/test_cpp_interface.py | 16 ++++----- 3 files changed, 16 insertions(+), 37 deletions(-) diff --git a/src/gt4py/next/otf/binding/cpp_interface.py b/src/gt4py/next/otf/binding/cpp_interface.py index 14ab42fa83..70afcfedcb 100644 --- a/src/gt4py/next/otf/binding/cpp_interface.py +++ b/src/gt4py/next/otf/binding/cpp_interface.py @@ -10,6 +10,7 @@ from gt4py.next.otf import languages from gt4py.next.otf.binding import interface +from gt4py.next.program_processors.codegens.gtfn import itir_to_gtfn_ir from gt4py.next.type_system import type_info as ti, type_specifications as ts @@ -19,35 +20,12 @@ def render_scalar_type(scalar_type: ts.ScalarType) -> str: - match scalar_type.kind: # TODO: merge with dict in itir_tp_gtfn - case ts.ScalarKind.BOOL: - return "bool" - case ts.ScalarKind.INT8: - return "std::int8_t" - case ts.ScalarKind.UINT8: - return "std::uint8_t" - case ts.ScalarKind.INT16: - return "std::int16_t" - case ts.ScalarKind.UINT16: - return "std::uint16_t" - case ts.ScalarKind.INT32: - return "std::int32_t" - case ts.ScalarKind.UINT32: - return "std::uint32_t" - case ts.ScalarKind.INT64: - return "std::int64_t" - case ts.ScalarKind.UINT64: - return "std::uint64_t" - case ts.ScalarKind.FLOAT32: - return "float" - case ts.ScalarKind.FLOAT64: - return "double" - case ts.ScalarKind.STRING: - return "std::string" - case _: - raise AssertionError( - f"Scalar kind '{scalar_type}' is not implemented when it should be." - ) + try: + return itir_to_gtfn_ir.pytype_to_cpptype(scalar_type) # type: ignore[return-value] # always returns a str for ts.ScalarType + except Exception: + raise AssertionError( + f"Scalar kind '{scalar_type}' is not implemented when it should be." + ) from None def render_function_declaration(function: interface.Function, body: str) -> str: diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index ba90c08261..ed1cf77904 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -63,6 +63,7 @@ def pytype_to_cpptype(t: ts.ScalarType | str) -> Optional[str]: "int64": "std::int64_t", "uint64": "std::uint64_t", "bool": "bool", + "string": "string", "axis_literal": None, # TODO: domain? }[t] except KeyError: diff --git a/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py b/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py index b1e051c82b..51b6bf512b 100644 --- a/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py +++ b/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py @@ -33,10 +33,10 @@ def test_render_function_declaration_scalar(function_scalar_example): expected = format_source( "cpp", """\ - decltype(auto) example(double a, std::int64_t b) { +decltype(auto) example(double a, std::int64_t b) { return; }\ - """, +""", style="LLVM", ) assert rendered == expected @@ -81,11 +81,11 @@ def test_render_function_declaration_buffer(function_buffer_example): expected = format_source( "cpp", """\ - template - decltype(auto) example(ArgT0 &&a_buf, ArgT1 &&b_buf) { +template + decltype(auto) example(ArgT0&& a_buf, ArgT1&& b_buf) { return; }\ - """, +""", style="LLVM", ) assert rendered == expected @@ -132,11 +132,11 @@ def test_render_function_declaration_tuple(function_tuple_example): expected = format_source( "cpp", """\ - template - decltype(auto) example(ArgT0 &&a_buf) { +template + decltype(auto) example(ArgT0&& a_buf) { return; }\ - """, +""", style="LLVM", ) assert rendered == expected From 5b886ce0ff5db00bba808e36c2714ea5c2863e89 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Mon, 30 Dec 2024 16:03:25 +0100 Subject: [PATCH 12/17] Renaming and update comments --- src/gt4py/next/ffront/fbuiltins.py | 2 +- src/gt4py/next/iterator/builtins.py | 6 +++--- src/gt4py/next/iterator/ir_utils/ir_makers.py | 2 +- src/gt4py/next/iterator/transforms/infer_domain.py | 2 +- src/gt4py/next/iterator/transforms/prune_casts.py | 2 +- src/gt4py/next/iterator/transforms/trace_shifts.py | 2 +- src/gt4py/next/iterator/type_system/inference.py | 2 +- src/gt4py/next/program_processors/codegens/gtfn/codegen.py | 2 +- src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py | 2 +- .../program_processors/codegens/gtfn/itir_to_gtfn_ir.py | 2 +- .../runners/dace_fieldview/gtir_python_codegen.py | 2 +- tests/next_tests/integration_tests/cases.py | 2 +- 12 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index ba98c80f6b..cef7fc101f 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -40,7 +40,7 @@ float32, float64, *PYTHON_TYPE_BUILTINS, -] # TODO(tehrengruber): validate matches itir type builtins? +] # TODO(tehrengruber): validate matches iterator.builtins.TYPE_BUILTINS? TYPE_BUILTIN_NAMES = [t.__name__ for t in TYPE_BUILTINS] diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 02f6ebbe5c..b20b64e7da 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -449,7 +449,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] "uint64", } # Todo: should we distinguish int and uint? FLOATING_POINT_BUILTINS = {"float32", "float64"} -TYPEBUILTINS = {*INTEGER_BUILTINS, *FLOATING_POINT_BUILTINS, "bool"} +TYPE_BUILTINS = {*INTEGER_BUILTINS, *FLOATING_POINT_BUILTINS, "bool"} MATH_BUILTINS = { *UNARY_MATH_NUMBER_BUILTINS, @@ -457,7 +457,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] *UNARY_MATH_FP_PREDICATE_BUILTINS, *BINARY_MATH_NUMBER_BUILTINS, "power", - *TYPEBUILTINS, + *TYPE_BUILTINS, } ARITHMETIC_BUILTINS = { @@ -492,7 +492,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] "tuple_get", "unstructured_domain", *ARITHMETIC_BUILTINS, - *TYPEBUILTINS, + *TYPE_BUILTINS, } __all__ = [*BUILTINS] diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 18d72f394b..c5cf2efa5a 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -335,7 +335,7 @@ def literal_from_value(val: core_defs.Scalar) -> itir.Literal: assert isinstance(type_spec, ts.ScalarType) typename = type_spec.kind.name.lower() - assert typename in builtins.TYPEBUILTINS + assert typename in builtins.TYPE_BUILTINS return literal(str(val), typename) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 3a52569829..f3c3185225 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -384,7 +384,7 @@ def _infer_expr( return _infer_if(expr, domain, **kwargs) elif ( cpm.is_call_to(expr, builtins.ARITHMETIC_BUILTINS) - or cpm.is_call_to(expr, builtins.TYPEBUILTINS) + or cpm.is_call_to(expr, builtins.TYPE_BUILTINS) or cpm.is_call_to(expr, ("cast_", "index", "unstructured_domain", "cartesian_domain")) ): return expr, {} diff --git a/src/gt4py/next/iterator/transforms/prune_casts.py b/src/gt4py/next/iterator/transforms/prune_casts.py index 409b5ee556..3276f47042 100644 --- a/src/gt4py/next/iterator/transforms/prune_casts.py +++ b/src/gt4py/next/iterator/transforms/prune_casts.py @@ -31,7 +31,7 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.Node: assert ( value.type and isinstance(type_constructor, ir.SymRef) - and (type_constructor.id in builtins.TYPEBUILTINS) + and (type_constructor.id in builtins.TYPE_BUILTINS) ) dtype = ts.ScalarType(kind=getattr(ts.ScalarKind, type_constructor.id.upper())) diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 38ab398941..4c44d660f6 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -278,7 +278,7 @@ def visit_Literal(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any: def visit_SymRef(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any: if node.id in ctx: return ctx[node.id] - elif node.id in builtins.TYPEBUILTINS: + elif node.id in builtins.TYPE_BUILTINS: return Sentinel.TYPE elif node.id in (builtins.ARITHMETIC_BUILTINS | {"list_get", "make_const_list", "cast_"}): return _combine diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index b7e0729a4b..0b69293f3a 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -580,7 +580,7 @@ def visit_FunCall( self.visit(value, ctx=ctx) # ensure types in value are also inferred assert ( isinstance(type_constructor, itir.SymRef) - and type_constructor.id in builtins.TYPEBUILTINS + and type_constructor.id in builtins.TYPE_BUILTINS ) return ts.ScalarType(kind=getattr(ts.ScalarKind, type_constructor.id.upper())) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index 6b0ead5239..eb71eabc3c 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -107,7 +107,7 @@ def visit_Literal(self, node: gtfn_ir.Literal, **kwargs: Any) -> str: case "bool": return node.value.lower() case _: - # TODO: we should probably shouldn't just allow anything here. revisit. + # TODO(tehrengruber): we should probably shouldn't just allow anything here. Revisit. return node.value IntegralConstant = as_fmt("{value}_c") diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index 39ac04ecc0..831694791a 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -231,7 +231,7 @@ class TemporaryAllocation(Node): "index", ] ARITHMETIC_BUILTINS = builtins.ARITHMETIC_BUILTINS -TYPEBUILTINS = builtins.TYPEBUILTINS +TYPEBUILTINS = builtins.TYPE_BUILTINS BUILTINS = {*GTFN_BUILTINS, *ARITHMETIC_BUILTINS, *TYPEBUILTINS} diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index ed1cf77904..d688138905 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -64,7 +64,7 @@ def pytype_to_cpptype(t: ts.ScalarType | str) -> Optional[str]: "uint64": "std::uint64_t", "bool": "bool", "string": "string", - "axis_literal": None, # TODO: domain? + "axis_literal": None, # TODO(tehrengruber): domain? }[t] except KeyError: raise TypeError(f"Unsupported type '{t}'.") from None diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py index 883f2f946b..c5ba0d9551 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py @@ -75,7 +75,7 @@ def builtin_cast(*args: Any) -> str: val, target_type = args - assert target_type in builtins.TYPEBUILTINS + assert target_type in builtins.TYPE_BUILTINS return MATH_BUILTINS_MAPPING[target_type].format(val) diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 98103fe901..5070dd71a5 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -192,7 +192,7 @@ class UniqueInitializer(DataInitializer): data containers. """ - start: int = 1 # PR comment: do not start from zero as this has the same value as zero-initialized memory + start: int = 1 @property def scalar_value(self) -> ScalarValue: From 1f8de33192a4abca4c0116cafba20bb8d90b74b1 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 14 Jan 2025 14:02:14 +0100 Subject: [PATCH 13/17] Fix type_deduction doctest --- src/gt4py/next/ffront/foast_passes/type_deduction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 1030d25e2f..26bcadaef1 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -79,7 +79,7 @@ def construct_tuple_type( ... ts.ScalarType(kind=ts.ScalarKind.FLOAT64), ... ] >>> print(construct_tuple_type(true_branch_types, false_branch_types, mask_type)) - [FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None)), FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None))] + [FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None)), FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None))] """ element_types_new = true_branch_types for i, element in enumerate(true_branch_types): From 23d406adb144800c1b7d397b18b54eb4289522cb Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 17 Jan 2025 03:02:24 +0100 Subject: [PATCH 14/17] Small cleanup --- src/gt4py/next/iterator/builtins.py | 23 +++----- src/gt4py/next/iterator/embedded.py | 55 ++++++++----------- .../iterator/type_system/type_synthesizer.py | 2 +- 3 files changed, 31 insertions(+), 49 deletions(-) diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index b20b64e7da..7096f7ab56 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -422,15 +422,16 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] "trunc", } UNARY_MATH_FP_PREDICATE_BUILTINS = {"isfinite", "isinf", "isnan"} -BINARY_MATH_NUMBER_BUILTINS = {"minimum", "maximum", "fmod"} -BINARY_MATH_NUMBER_ADDITIONAL_BUILTINS = { +BINARY_MATH_NUMBER_BUILTINS = { "plus", "minus", "multiplies", "divides", "mod", "floordiv", # TODO see https://github.com/GridTools/gt4py/issues/1136 - *BINARY_MATH_NUMBER_BUILTINS, + "minimum", + "maximum", + "fmod", } BINARY_MATH_COMPARISON_BUILTINS = {"eq", "less", "greater", "greater_equal", "less_equal", "not_eq"} BINARY_LOGICAL_BUILTINS = {"and_", "or_", "xor_"} @@ -438,7 +439,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] #: builtin / dtype used to construct integer indices, like domain bounds INTEGER_INDEX_BUILTIN = "int32" -INTEGER_BUILTINS = { +INTEGER_TYPE_BUILTINS = { "int8", "uint8", "int16", @@ -447,25 +448,15 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] "uint32", "int64", "uint64", -} # Todo: should we distinguish int and uint? -FLOATING_POINT_BUILTINS = {"float32", "float64"} -TYPE_BUILTINS = {*INTEGER_BUILTINS, *FLOATING_POINT_BUILTINS, "bool"} - -MATH_BUILTINS = { - *UNARY_MATH_NUMBER_BUILTINS, - *UNARY_MATH_FP_BUILTINS, - *UNARY_MATH_FP_PREDICATE_BUILTINS, - *BINARY_MATH_NUMBER_BUILTINS, - "power", - *TYPE_BUILTINS, } +FLOATING_POINT_TYPE_BUILTINS = {"float32", "float64"} +TYPE_BUILTINS = {*INTEGER_TYPE_BUILTINS, *FLOATING_POINT_TYPE_BUILTINS, "bool"} ARITHMETIC_BUILTINS = { *UNARY_MATH_NUMBER_BUILTINS, *UNARY_LOGICAL_BUILTINS, *UNARY_MATH_FP_BUILTINS, *UNARY_MATH_FP_PREDICATE_BUILTINS, - *BINARY_MATH_NUMBER_ADDITIONAL_BUILTINS, "power", *BINARY_MATH_COMPARISON_BUILTINS, *BINARY_LOGICAL_BUILTINS, diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 956edf8480..736677ca62 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -15,6 +15,7 @@ import dataclasses import itertools import math +import operator import sys import warnings @@ -509,36 +510,6 @@ def named_range(tag: Tag | common.Dimension, start: int, end: int) -> NamedRange return (tag, range(start, end)) -@builtins.minus.register(EMBEDDED) -def minus(first, second): - return first - second - - -@builtins.plus.register(EMBEDDED) -def plus(first, second): - return first + second - - -@builtins.multiplies.register(EMBEDDED) -def multiplies(first, second): - return first * second - - -@builtins.divides.register(EMBEDDED) -def divides(first, second): - return first / second - - -@builtins.floordiv.register(EMBEDDED) -def floordiv(first, second): - return first // second - - -@builtins.mod.register(EMBEDDED) -def mod(first, second): - return first % second - - @builtins.eq.register(EMBEDDED) def eq(first, second): return first == second @@ -597,8 +568,28 @@ def promote_scalars(val: CompositeOfScalarOrField): ) -for math_builtin_name in builtins.MATH_BUILTINS: - python_builtins = {"int": int, "float": float, "bool": bool, "str": str} +for math_builtin_name in builtins.ARITHMETIC_BUILTINS | builtins.TYPE_BUILTINS: + python_builtins: dict[str, Callable] = { + "int": int, + "float": float, + "bool": bool, + "str": str, + "plus": operator.add, + "minus": operator.sub, + "multiplies": operator.mul, + "divides": operator.truediv, + "mod": operator.mod, + "floordiv": operator.floordiv, + "eq": operator.eq, + "less": operator.lt, + "greater": operator.gt, + "greater_equal": operator.ge, + "less_equal": operator.le, + "not_eq": operator.ne, + "and_": operator.and_, + "or_": operator.or_, + "xor": operator.xor, + } decorator = getattr(builtins, math_builtin_name).register(EMBEDDED) impl: Callable if math_builtin_name == "gamma": diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index d8ef7d2f16..f5aeac7943 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -92,7 +92,7 @@ def power(base: ts.ScalarType, exponent: ts.ScalarType) -> ts.ScalarType: return base -@_register_builtin_type_synthesizer(fun_names=builtins.BINARY_MATH_NUMBER_ADDITIONAL_BUILTINS) +@_register_builtin_type_synthesizer(fun_names=builtins.BINARY_MATH_NUMBER_BUILTINS) def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType: if isinstance(lhs, ts.DeferredType): return rhs From 2d08ed3933b5dbfcc434a5cd0c3172042e8e4fa9 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 17 Jan 2025 03:25:15 +0100 Subject: [PATCH 15/17] Small cleanup --- src/gt4py/next/iterator/builtins.py | 1 + src/gt4py/next/iterator/embedded.py | 3 +- src/gt4py/next/otf/binding/cpp_interface.py | 14 ++------ src/gt4py/next/otf/binding/nanobind.py | 6 ++-- src/gt4py/next/otf/cpp_utils.py | 32 +++++++++++++++++++ .../codegens/gtfn/codegen.py | 7 ++-- .../codegens/gtfn/itir_to_gtfn_ir.py | 26 ++------------- 7 files changed, 47 insertions(+), 42 deletions(-) create mode 100644 src/gt4py/next/otf/cpp_utils.py diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 7096f7ab56..959f451e01 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -457,6 +457,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] *UNARY_LOGICAL_BUILTINS, *UNARY_MATH_FP_BUILTINS, *UNARY_MATH_FP_PREDICATE_BUILTINS, + *BINARY_MATH_NUMBER_BUILTINS, "power", *BINARY_MATH_COMPARISON_BUILTINS, *BINARY_LOGICAL_BUILTINS, diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 736677ca62..e3f87d220e 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -580,6 +580,7 @@ def promote_scalars(val: CompositeOfScalarOrField): "divides": operator.truediv, "mod": operator.mod, "floordiv": operator.floordiv, + "not_": operator.not_, "eq": operator.eq, "less": operator.lt, "greater": operator.gt, @@ -588,7 +589,7 @@ def promote_scalars(val: CompositeOfScalarOrField): "not_eq": operator.ne, "and_": operator.and_, "or_": operator.or_, - "xor": operator.xor, + "xor_": operator.xor, } decorator = getattr(builtins, math_builtin_name).register(EMBEDDED) impl: Callable diff --git a/src/gt4py/next/otf/binding/cpp_interface.py b/src/gt4py/next/otf/binding/cpp_interface.py index 70afcfedcb..17eee4d5c6 100644 --- a/src/gt4py/next/otf/binding/cpp_interface.py +++ b/src/gt4py/next/otf/binding/cpp_interface.py @@ -8,9 +8,8 @@ from typing import Final, Sequence -from gt4py.next.otf import languages +from gt4py.next.otf import cpp_utils, languages from gt4py.next.otf.binding import interface -from gt4py.next.program_processors.codegens.gtfn import itir_to_gtfn_ir from gt4py.next.type_system import type_info as ti, type_specifications as ts @@ -19,21 +18,12 @@ ) -def render_scalar_type(scalar_type: ts.ScalarType) -> str: - try: - return itir_to_gtfn_ir.pytype_to_cpptype(scalar_type) # type: ignore[return-value] # always returns a str for ts.ScalarType - except Exception: - raise AssertionError( - f"Scalar kind '{scalar_type}' is not implemented when it should be." - ) from None - - def render_function_declaration(function: interface.Function, body: str) -> str: template_params: list[str] = [] rendered_params: list[str] = [] for index, param in enumerate(function.parameters): if isinstance(param.type_, ts.ScalarType): - rendered_params.append(f"{render_scalar_type(param.type_)} {param.name}") + rendered_params.append(f"{cpp_utils.pytype_to_cpptype(param.type_)} {param.name}") elif ti.is_type_or_tuple_of_type(param.type_, (ts.FieldType, ts.ScalarType)): template_param = f"ArgT{index}" template_params.append(f"class {template_param}") diff --git a/src/gt4py/next/otf/binding/nanobind.py b/src/gt4py/next/otf/binding/nanobind.py index 3abf49788f..a2cf480d7f 100644 --- a/src/gt4py/next/otf/binding/nanobind.py +++ b/src/gt4py/next/otf/binding/nanobind.py @@ -14,7 +14,7 @@ import gt4py.eve as eve from gt4py.eve.codegen import JinjaTemplate as as_jinja, TemplatedGenerator -from gt4py.next.otf import languages, stages, workflow +from gt4py.next.otf import cpp_utils, languages, stages, workflow from gt4py.next.otf.binding import cpp_interface, interface from gt4py.next.type_system import type_specifications as ts @@ -88,13 +88,13 @@ def _type_string(type_: ts.TypeSpec) -> str: ndims = len(type_.dims) # cannot be ListType: the concept is represented as Field with local Dimension in this interface assert isinstance(type_.dtype, ts.ScalarType) - dtype = cpp_interface.render_scalar_type(type_.dtype) + dtype = cpp_utils.pytype_to_cpptype(type_.dtype) shape = f"nanobind::shape<{', '.join(['gridtools::nanobind::dynamic_size'] * ndims)}>" buffer_t = f"nanobind::ndarray<{dtype}, {shape}>" origin_t = f"std::tuple<{', '.join(['ptrdiff_t'] * ndims)}>" return f"std::pair<{buffer_t}, {origin_t}>" elif isinstance(type_, ts.ScalarType): - return cpp_interface.render_scalar_type(type_) + return cpp_utils.pytype_to_cpptype(type_) else: raise ValueError(f"Type '{type_}' is not supported in nanobind interfaces.") diff --git a/src/gt4py/next/otf/cpp_utils.py b/src/gt4py/next/otf/cpp_utils.py new file mode 100644 index 0000000000..8b2af40eb5 --- /dev/null +++ b/src/gt4py/next/otf/cpp_utils.py @@ -0,0 +1,32 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + + +from gt4py.next.type_system import type_specifications as ts + + +def pytype_to_cpptype(t: ts.ScalarType | str) -> str: + if isinstance(t, ts.ScalarType): + t = t.kind.name.lower() + try: + return { + "float32": "float", + "float64": "double", + "int8": "std::int8_t", + "uint8": "std::uint8_t", + "int16": "std::int16_t", + "uint16": "std::uint16_t", + "int32": "std::int32_t", + "uint32": "std::uint32_t", + "int64": "std::int64_t", + "uint64": "std::uint64_t", + "bool": "bool", + "string": "string", + }[t] + except KeyError: + raise TypeError(f"Unsupported type '{t}'.") from None diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index eb71eabc3c..c6bf28d8e0 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -11,8 +11,8 @@ from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako from gt4py.next import common +from gt4py.next.otf import cpp_utils from gt4py.next.program_processors.codegens.gtfn import gtfn_im_ir, gtfn_ir, gtfn_ir_common -from gt4py.next.program_processors.codegens.gtfn.itir_to_gtfn_ir import pytype_to_cpptype class GTFNCodegen(codegen.TemplatedGenerator): @@ -98,8 +98,11 @@ def asfloat(value: str) -> str: return value def visit_Literal(self, node: gtfn_ir.Literal, **kwargs: Any) -> str: + if node.type == "axis_literal": + return node.value + # TODO(tehrengruber): isn't this wrong and int32 should be casted to an actual int32? - match pytype_to_cpptype(node.type): + match cpp_utils.pytype_to_cpptype(node.type): case "float": return self.asfloat(node.value) + "f" case "double": diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 10662ba1ec..104e2eccc1 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -17,6 +17,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.type_system import inference as itir_type_inference +from gt4py.next.otf import cpp_utils from gt4py.next.program_processors.codegens.gtfn.gtfn_ir import ( Backend, BinaryExpr, @@ -47,29 +48,6 @@ from gt4py.next.type_system import type_info, type_specifications as ts -def pytype_to_cpptype(t: ts.ScalarType | str) -> Optional[str]: - if isinstance(t, ts.ScalarType): - t = t.kind.name.lower() - try: - return { - "float32": "float", - "float64": "double", - "int8": "std::int8_t", - "uint8": "std::uint8_t", - "int16": "std::int16_t", - "uint16": "std::uint16_t", - "int32": "std::int32_t", - "uint32": "std::uint32_t", - "int64": "std::int64_t", - "uint64": "std::uint64_t", - "bool": "bool", - "string": "string", - "axis_literal": None, # TODO(tehrengruber): domain? - }[t] - except KeyError: - raise TypeError(f"Unsupported type '{t}'.") from None - - _vertical_dimension = "gtfn::unstructured::dim::vertical" _horizontal_dimension = "gtfn::unstructured::dim::horizontal" @@ -714,7 +692,7 @@ def dtype_to_cpp(x: ts.DataType) -> str: assert all(isinstance(i, ts.ScalarType) for i in x.types) return "::gridtools::tuple<" + ", ".join(dtype_to_cpp(i) for i in x.types) + ">" # type: ignore[arg-type] # ensured by assert assert isinstance(x, ts.ScalarType) - res = pytype_to_cpptype(x) + res = cpp_utils.pytype_to_cpptype(x) assert isinstance(res, str) return res From cd337e5fcac8a32ae1773235bfe60932fcb86d8e Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 17 Jan 2025 13:50:08 +0100 Subject: [PATCH 16/17] Minor fix --- src/gt4py/next/ffront/past_to_itir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index c88a9ac62e..4bc1dfb2f8 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -218,7 +218,7 @@ def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]: if len(fields_dims) > 0: # otherwise `param` has no constituent which is of `FieldType` assert all(field_dims == fields_dims[0] for field_dims in fields_dims) index_type = ts.ScalarType( - kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) + kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) ) for dim_idx in range(len(fields_dims[0])): size_params.append( From d37fe9f017ba42ab6993ac5d8e1c454a775c7b7c Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Mon, 20 Jan 2025 19:25:40 +0100 Subject: [PATCH 17/17] Fix test and cleanup unused builtins --- src/gt4py/next/iterator/embedded.py | 54 +---------------------------- 1 file changed, 1 insertion(+), 53 deletions(-) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index e3f87d220e..970e88e8c5 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -402,27 +402,6 @@ def gamma(a): return res.item() -@builtins.and_.register(EMBEDDED) -def and_(a, b): - if isinstance(a, Column): - return np.logical_and(a, b) - return a and b - - -@builtins.or_.register(EMBEDDED) -def or_(a, b): - if isinstance(a, Column): - return np.logical_or(a, b) - return a or b - - -@builtins.xor_.register(EMBEDDED) -def xor_(a, b): - if isinstance(a, Column): - return np.logical_xor(a, b) - return a ^ b - - @builtins.tuple_get.register(EMBEDDED) def tuple_get(i, tup): if isinstance(tup, Column): @@ -510,36 +489,6 @@ def named_range(tag: Tag | common.Dimension, start: int, end: int) -> NamedRange return (tag, range(start, end)) -@builtins.eq.register(EMBEDDED) -def eq(first, second): - return first == second - - -@builtins.greater.register(EMBEDDED) -def greater(first, second): - return first > second - - -@builtins.less.register(EMBEDDED) -def less(first, second): - return first < second - - -@builtins.less_equal.register(EMBEDDED) -def less_equal(first, second): - return first <= second - - -@builtins.greater_equal.register(EMBEDDED) -def greater_equal(first, second): - return first >= second - - -@builtins.not_eq.register(EMBEDDED) -def not_eq(first, second): - return first != second - - CompositeOfScalarOrField: TypeAlias = Scalar | common.Field | tuple["CompositeOfScalarOrField", ...] @@ -580,7 +529,6 @@ def promote_scalars(val: CompositeOfScalarOrField): "divides": operator.truediv, "mod": operator.mod, "floordiv": operator.floordiv, - "not_": operator.not_, "eq": operator.eq, "less": operator.lt, "greater": operator.gt, @@ -593,7 +541,7 @@ def promote_scalars(val: CompositeOfScalarOrField): } decorator = getattr(builtins, math_builtin_name).register(EMBEDDED) impl: Callable - if math_builtin_name == "gamma": + if math_builtin_name in ["gamma", "not_"]: continue # treated explicitly elif math_builtin_name in python_builtins: # TODO: Should potentially use numpy fixed size types to be consistent