Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next][dace]: support for field origin in lowering to SDFG #1818

Open
wants to merge 286 commits into
base: main
Choose a base branch
from

Conversation

edopao
Copy link
Contributor

@edopao edopao commented Jan 22, 2025

This PR adds support for GT4Py field arguments with non-zero start index, for example:

inp = constructors.empty(common.domain({IDim: (1, 9)}), ...)

which was supported in baseline only for temporary fields, by means of a data structure called field_offsets. This data structure is removed for two reasons:

  1. the name "offset" is a left-over from previous design based on dace array offset
  2. offset has a different meaning in GT4Py

We introduce the GT4Py concept of field origin and use it for both temporary fields and program arguments. The field origin corresponds to the start of the field domain range.

This PR also changes the symbolic definition of array shape. Before, the array shape was defined as [data_size_0, data_size_1, ...], now the size corresponds to the range extent stop - start as [(data_0_range_1 - data_0_range_0), (data_1_range_1 - data_1_range_0), ...].

The translation stage of the dace workflow is extended with an option disable_field_origin_on_program_arguments to set the field range start symbols to constant value zero. This is needed for the dace orchestration, because the signature of a dace-orchestrated program does not provide the domain origin.

edopao and others added 30 commits December 9, 2024 12:09
@edopao edopao changed the title feat[next][dace]: support for filed origin in lowering to SDFG feat[next][dace]: support for field origin in lowering to SDFG Jan 23, 2025
@edopao edopao marked this pull request as ready for review January 23, 2025 23:13
Copy link
Contributor

@philip-paul-mueller philip-paul-mueller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks generally good, but there are a few things.

The most important thing is, that are sometimes a not defined origin is converted to all zero, sometimes it is kept as None.
I made several comments about that using one form, i.e. implicitly converting to all zero, would be more consistent and could simplify the code.
Although it would no longer be possible to distinguish between the case of the case "explicitly set to zero" and "not defined", but I am not sure if this would be an issue.

outer_sdfg_state: dace.SDFGState,
symbol_mapping: dict[str, dace.symbolic.SymbolicType],
) -> FieldopData:
"""Helper method to map a field data container from a nested SDFG to the parent SDFG."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the doc string is a bit incomplete.
How does something along the lines:
"Make the data descriptor, self refers to available, which is located inside a NestedSDFG available in its parent SDFG.
Thus it will turn self into a non transient and create a new data descriptor inside the parent."

Furthermore, something that comes to my mind, why was this function not needed before?
I mean in the end you replaced offset with origin, so the change is not that big?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This functionality existed before, but it was located in gtir_sdfg inside construct_output_for_nested_sdfg().
I decided to remove the method make_copy(), because it didn't mean much, and instead created this one.

outer_desc = self.dc_node.desc(sdfg)
assert isinstance(outer_desc, dace.data.Array)
if self.origin is None:
outer_field_origin = [0] * ndims
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At other places you set such values to None.
So why do you not do it here as well or why do you return None at other places?
Okay "not defined" is conceptional different from "having value 0", but still there are a lot of if that you could remove if you would follow this approach.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, origin can be None only inside FieldopData, in case of ScalarType data. For FieldType data, it cannot be undefined.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is right, I did not consider this case.

dace.symbolic.pystr_to_symbolic(gtx_dace_utils.field_size_symbol_name(name, i))
)
else:
# the size of global dimensions for a regular field is the symbolic
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are you doing with transients, i.e. where does their symbols came from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_make_array_shape_and_strides() is only called for non-transient arrays. I will add this comment:

        This method is only called for non-transient arrays, which require symbolic
        memory layout. The memory layout of transient arrays, used for temporary
        fields, is assigned by DaCe during lowering to SDFG.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should add "[...] used for temporary field, the default of DaCe, which is row major, is used and might be changed during optimization".
I think this makes it a bit clearer that the lowering does not really care about the strides beside being that they are correct (but not necessarily optimal).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct.

@@ -201,6 +199,24 @@ def _collect_symbols_in_domain_expressions(
)


def _get_field_doman_subset(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just my impression, but I am not sure if "subset" is the right term to use here.
For me a subset is more like a range, such as 3:6.
This here is more like a concrete access, such as a[i], so I would suggest to name the function more something like _make_access_index_for_field() or something.

If you keep this function, then there is a typo in "domain".

@@ -941,10 +901,47 @@ def visit_SymRef(
return gtir_builtin_translators.translate_symbol_ref(node, sdfg, head_state, self)


def _remove_field_origin_symbols(ir: gtir.Program, sdfg: dace.SDFG) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function confused my.
First it was not clear that this function is essentially undoing all the work of this PR.
Until I realized that this function is only used for build_sdfg_from_gtir() when disable_field_origin_on_program_arguments is set to True.
I would add this comment to the doc string.

The second thing I do not understand is why do you only do it for transients.
I guess it is because they are never set, but I am not fully sure about that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment added. I do not understand the second comment. I actually collect the range start symbols of the program parameters, that is the non-transient arrays in the top-level SDFG.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second part is also wrong.
I wanted to write was "why do you not do it for transients".
However, in another comment you wrote that only non transient data descriptor have this symbolic sizes.

Copy link
Contributor Author

@edopao edopao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review.

outer_sdfg_state: dace.SDFGState,
symbol_mapping: dict[str, dace.symbolic.SymbolicType],
) -> FieldopData:
"""Helper method to map a field data container from a nested SDFG to the parent SDFG."""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This functionality existed before, but it was located in gtir_sdfg inside construct_output_for_nested_sdfg().
I decided to remove the method make_copy(), because it didn't mean much, and instead created this one.

outer_desc = self.dc_node.desc(sdfg)
assert isinstance(outer_desc, dace.data.Array)
if self.origin is None:
outer_field_origin = [0] * ndims
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, origin can be None only inside FieldopData, in case of ScalarType data. For FieldType data, it cannot be undefined.

dace.symbolic.pystr_to_symbolic(gtx_dace_utils.field_size_symbol_name(name, i))
)
else:
# the size of global dimensions for a regular field is the symbolic
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_make_array_shape_and_strides() is only called for non-transient arrays. I will add this comment:

        This method is only called for non-transient arrays, which require symbolic
        memory layout. The memory layout of transient arrays, used for temporary
        fields, is assigned by DaCe during lowering to SDFG.

@@ -279,7 +291,8 @@ def make_field(
raise NotImplementedError(
"Fields with more than one local dimension are not supported."
)
return gtir_builtin_translators.FieldopData(data_node, field_type, domain_offset)
field_origin = gtx_dace_utils.get_symbolic_origin(data_node.data, field_type)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make_field is actually called for SymRefs to global symbols, that is non-transient data.

Copy link
Contributor

@philip-paul-mueller philip-paul-mueller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some points that needs some further work, but nothing serious.

@@ -78,8 +75,8 @@ class FieldopData:
Args:
dc_node: DaCe access node to the data storage.
gt_type: GT4Py type definition, which includes the field domain information.
origin: List of start indices, in each dimension, when the dimension range
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a check to ensure that origin is None when you construct a FiledopData for scalar, i.e. adding a __post_init__ that does this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

origin: List of start indices, in each dimension, when the dimension range
does not start from zero; assume zero, if origin is not set.
origin: List of start indices, in each dimension, for `FieldType` data.
Set to `None` only for `ScalarType` data.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Set to `None` only for `ScalarType` data.
Has to be `None` only for `ScalarType` data. For fields it is assumed to be all zero if not given.

I am also thinking that you should enforce the that origin is set correctly during construction.
Since this is a dataclass you have to implement __post_init__().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

lambda m: dace.sdfg.replace_properties_dict(outer_desc, m),
)
# Same applies to the symbols used as field origin (the domain range start)
assert self.origin is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should do such a check at construction time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, done. However I will have to ignore a type-checking warning ([union-attr]).

@@ -161,24 +168,18 @@ def get_symbol_mapping(
"""
if isinstance(self.gt_type, ts.ScalarType):
return {}
assert self.origin is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should enforce such constraints in the constructor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


In case of `ScalarType` data, the descriptor is constructed with `offset=None`.
In case of `ScalarType` data, the descriptor is constructed with `origin=None`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
In case of `ScalarType` data, the descriptor is constructed with `origin=None`.
In case of `ScalarType` data, the `FieldopData` is constructed with `origin=None`.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add a test that data_node is a transient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wrong, the transient property does not always hold. The lowering creates a FieldopData also for access nodes to global arrays. I will modify the code comment.
A refactoring PR could move the declaration of this method to a type module, close to the FieldopData type declaration.

@@ -835,12 +837,19 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp

outputs = {outval.dc_node.data for outval in gtx_utils.flatten_nested_tuple((result,))}

if nsdfg_symbols_mapping is None:
# `None` means that all free symbols are mapped to the symbols available in parent SDFG
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# `None` means that all free symbols are mapped to the symbols available in parent SDFG
# `None` means that all free symbols are mapped to the symbols available in parent SDFG by the `add_nested_sdfg()` function.

This case also means that we never need to do a remapping where the names inside and outside are different.

I would actually kill this if because as far as I can see nsdfg_symbol_mapping is either None or {"__cond": ...}.

if nsdfg_symbols_mapping is None:
# `None` means that all free symbols are mapped to the symbols available in parent SDFG
pass
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not really understand this case, is it really needed?

Comment on lines +636 to +637
for psym, arg in lambda_args_mapping:
nsdfg_symbols_mapping |= gtir_translators.get_arg_symbol_mapping(psym.id, arg, sdfg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just my paranoia, for me there is a potential sever error.
Could it be that one argument needs the mapping {'x': 'y'} while another argument needs {'x': 'z'}, or is this not possible or is implementing the check not worth it?

@@ -321,28 +335,46 @@ def unique_tasklet_name(self, name: str) -> str:

def _make_array_shape_and_strides(
self, name: str, dims: Sequence[gtx_common.Dimension]
) -> tuple[list[dace.symbol], list[dace.symbol]]:
) -> tuple[list[dace.symbolic.SymExpr], list[dace.symbolic.SymExpr]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) -> tuple[list[dace.symbolic.SymExpr], list[dace.symbolic.SymExpr]]:
) -> tuple[list[dace.symbolic.SymbolicType], list[dace.symbolic.SymbolicType]]:

outer_node = head_state.add_access(inner_data.dc_node.data)
outer_data = inner_data.make_copy(outer_node)
# This must be a symbol captured from the lambda parent scope.
outer_node = head_state.add_access(inner_dataname)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I am not mistaken, then this implicitly assumes that the data container on the inside and the outside always have the same name.
If the above is true, then I am not sure if this correct all the time.
If the above statement is false, then what does it means.

Furthermore, what do you mean with the symbol?
How can a symbol be an output?
I am sure I miss something here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants