Skip to content

Commit

Permalink
refactor[next][dace]: normalize SDFG field type with local dimension (#…
Browse files Browse the repository at this point in the history
…1808)

Fields with a local dimension can be passed as program arguments. The
corresponding `FieldType` parameter type in GTIR contains the local
dimension in the list of field domain dimensions, while the data type of
the field elements is `ScalarType`. For example:

`ts.FieldType(dims=[Vertex, V2EDim], dtype=FLOAT_TYPE)`

where:
```
V2EDim = gtx.Dimension("V2E", kind=gtx.DimensionKind.LOCAL)
FLOAT_TYPE = ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
```

Except for the program arguments, the internal representation in the
SDFG lowering should only contain global dimensions in the field domain,
and use `ListType` for the element type in case of a list of values:

`ts.FieldType(dims=[Vertex], dtype=ts.ListType(element_type=FLOAT_TYPE,
offset_type=V2EDim))`

With this PR, the normalized form is used across the SDFG lowering. The
`make_field` helper method is modified to convert the type definition of
field arguments to the normalized form.
  • Loading branch information
edopao authored Jan 21, 2025
1 parent 7e566fc commit ae603cb
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,31 +111,13 @@ def get_local_view(
(dim, dace.symbolic.SymExpr(0) if self.offset is None else self.offset[i])
for i, dim in enumerate(self.gt_type.dims)
]
local_dims = [
dim for dim in self.gt_type.dims if dim.kind == gtx_common.DimensionKind.LOCAL
]
if len(local_dims) == 0:
return gtir_dataflow.IteratorExpr(
self.dc_node, self.gt_type.dtype, field_domain, it_indices
)

elif len(local_dims) == 1:
field_dtype = ts.ListType(
element_type=self.gt_type.dtype, offset_type=local_dims[0]
)
field_domain = [
(dim, offset)
for dim, offset in field_domain
if dim.kind != gtx_common.DimensionKind.LOCAL
]
return gtir_dataflow.IteratorExpr(
self.dc_node, field_dtype, field_domain, it_indices
)

else:
raise ValueError(
f"Unexpected data field {self.dc_node.data} with more than one local dimension."
)
# The property below is ensured by calling `make_field()` to construct `FieldopData`.
# The `make_field` constructor ensures that any local dimension, if present, is converted
# to `ListType` element type, while the field domain consists of all global dimensions.
assert all(dim != gtx_common.DimensionKind.LOCAL for dim in self.gt_type.dims)
return gtir_dataflow.IteratorExpr(
self.dc_node, self.gt_type.dtype, field_domain, it_indices
)

raise NotImplementedError(f"Node type {type(self.gt_type)} not supported.")

Expand Down Expand Up @@ -305,29 +287,24 @@ def _create_field_operator_impl(
raise TypeError(
f"Type mismatch, expected {output_type.dtype} got {output_edge.result.gt_dtype}."
)
field_dtype = output_edge.result.gt_dtype
field_dims, field_shape, field_offset = (domain_dims, domain_shape, domain_offset)
assert isinstance(dataflow_output_desc, dace.data.Scalar)
field_shape = domain_shape
field_subset = domain_subset
else:
assert isinstance(output_type.dtype, ts.ListType)
assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType)
field_dtype = output_edge.result.gt_dtype.element_type
if field_dtype != output_type.dtype.element_type:
if output_edge.result.gt_dtype.element_type != output_type.dtype.element_type:
raise TypeError(
f"Type mismatch, expected {output_type.dtype.element_type} got {field_dtype}."
f"Type mismatch, expected {output_type.dtype.element_type} got {output_edge.result.gt_dtype.element_type}."
)
assert isinstance(dataflow_output_desc, dace.data.Array)
assert len(dataflow_output_desc.shape) == 1
# extend the array with the local dimensions added by the field operator (e.g. `neighbors`)
assert output_edge.result.gt_dtype.offset_type is not None
field_dims = [*domain_dims, output_edge.result.gt_dtype.offset_type]
field_shape = [*domain_shape, dataflow_output_desc.shape[0]]
field_offset = [*domain_offset, dataflow_output_desc.offset[0]]
field_subset = domain_subset + dace_subsets.Range.from_array(dataflow_output_desc)

# allocate local temporary storage
assert dataflow_output_desc.dtype == dace_utils.as_dace_type(field_dtype)
field_name, _ = sdfg_builder.add_temp_array(sdfg, field_shape, dataflow_output_desc.dtype)
field_node = state.add_access(field_name)

Expand All @@ -336,8 +313,8 @@ def _create_field_operator_impl(

return FieldopData(
field_node,
ts.FieldType(field_dims, field_dtype),
offset=(field_offset if set(field_offset) != {0} else None),
ts.FieldType(domain_dims, output_edge.result.gt_dtype),
offset=(domain_offset if set(domain_offset) != {0} else None),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _create_scan_field_operator_impl(
f"Type mismatch, expected {output_type.dtype} got {output_edge.result.gt_dtype}."
)
field_dtype = output_edge.result.gt_dtype
field_dims, field_shape, field_offset = (domain_dims, domain_shape, domain_offset)
field_shape = domain_shape
# the scan field operator computes a column of scalar values
assert len(dataflow_output_desc.shape) == 1
else:
Expand All @@ -147,15 +147,12 @@ def _create_scan_field_operator_impl(
assert len(dataflow_output_desc.shape) == 2
# the lines below extend the array with the local dimension added by the field operator
assert output_edge.result.gt_dtype.offset_type is not None
field_dims = [*domain_dims, output_edge.result.gt_dtype.offset_type]
field_shape = [*domain_shape, dataflow_output_desc.shape[1]]
field_offset = [*domain_offset, dataflow_output_desc.offset[1]]
field_subset = field_subset + dace_subsets.Range.from_string(
f"0:{dataflow_output_desc.shape[1]}"
)

# allocate local temporary storage
assert dataflow_output_desc.dtype == dace_utils.as_dace_type(field_dtype)
field_name, field_desc = sdfg_builder.add_temp_array(
sdfg, field_shape, dataflow_output_desc.dtype
)
Expand All @@ -173,8 +170,8 @@ def _create_scan_field_operator_impl(

return gtir_translators.FieldopData(
field_node,
ts.FieldType(field_dims, field_dtype),
offset=(field_offset if set(field_offset) != {0} else None),
ts.FieldType(domain_dims, output_edge.result.gt_dtype),
offset=(domain_offset if set(domain_offset) != {0} else None),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,52 @@ def get_offset_provider_type(self, offset: str) -> gtx_common.OffsetProviderType
def make_field(
self, data_node: dace.nodes.AccessNode, data_type: ts.FieldType | ts.ScalarType
) -> gtir_builtin_translators.FieldopData:
if isinstance(data_type, ts.FieldType):
domain_offset = self.field_offsets.get(data_node.data, None)
"""
Helper method to build the field data type associated with an access node in the SDFG.
In case of `ScalarType` data, the descriptor is constructed with `offset=None`.
In case of `FieldType` data, the field origin is added to the data descriptor.
Besides, if the `FieldType` contains a local dimension, the descriptor is converted
to a canonical form where the field domain consists of all global dimensions
(the grid axes) and the field data type is `ListType`, with `offset_type` equal
to the field local dimension.
Args:
data_node: The access node to the SDFG data storage.
data_type: The GT4Py data descriptor, which can either come from a field parameter
of an expression node, or from an intermediate field in a previous expression.
Returns:
The descriptor associated with the SDFG data storage, filled with field origin.
"""
if isinstance(data_type, ts.ScalarType):
return gtir_builtin_translators.FieldopData(data_node, data_type, offset=None)
domain_offset = self.field_offsets.get(data_node.data, None)
local_dims = [dim for dim in data_type.dims if dim.kind == gtx_common.DimensionKind.LOCAL]
if len(local_dims) == 0:
# do nothing: the field domain consists of all global dimensions
field_type = data_type
elif len(local_dims) == 1:
local_dim = local_dims[0]
local_dim_index = data_type.dims.index(local_dim)
# the local dimension is converted into `ListType` data element
if not isinstance(data_type.dtype, ts.ScalarType):
raise ValueError(f"Invalid field type {data_type}.")
if local_dim_index != (len(data_type.dims) - 1):
raise ValueError(
f"Invalid field domain: expected the local dimension to be at the end, found at position {local_dim_index}."
)
if local_dim.value not in self.offset_provider_type:
raise ValueError(
f"The provided local dimension {local_dim} does not match any offset provider type."
)
local_type = ts.ListType(element_type=data_type.dtype, offset_type=local_dim)
field_type = ts.FieldType(dims=data_type.dims[:local_dim_index], dtype=local_type)
else:
domain_offset = None
return gtir_builtin_translators.FieldopData(data_node, data_type, domain_offset)
raise NotImplementedError(
"Fields with more than one local dimension are not supported."
)
return gtir_builtin_translators.FieldopData(data_node, field_type, domain_offset)

def get_symbol_type(self, symbol_name: str) -> ts.DataType:
return self.global_symbols[symbol_name]
Expand Down

0 comments on commit ae603cb

Please sign in to comment.