-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[next][dace]: support for field origin in lowering to SDFG #1818
base: main
Are you sure you want to change the base?
feat[next][dace]: support for field origin in lowering to SDFG #1818
Conversation
…_with_non_zero_domain_start
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks generally good, but there are a few things.
The most important thing is, that are sometimes a not defined origin is converted to all zero, sometimes it is kept as None
.
I made several comments about that using one form, i.e. implicitly converting to all zero, would be more consistent and could simplify the code.
Although it would no longer be possible to distinguish between the case of the case "explicitly set to zero" and "not defined", but I am not sure if this would be an issue.
src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py
Outdated
Show resolved
Hide resolved
outer_sdfg_state: dace.SDFGState, | ||
symbol_mapping: dict[str, dace.symbolic.SymbolicType], | ||
) -> FieldopData: | ||
"""Helper method to map a field data container from a nested SDFG to the parent SDFG.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the doc string is a bit incomplete.
How does something along the lines:
"Make the data descriptor, self
refers to available, which is located inside a NestedSDFG available in its parent SDFG.
Thus it will turn self
into a non transient and create a new data descriptor inside the parent."
Furthermore, something that comes to my mind, why was this function not needed before?
I mean in the end you replaced offset
with origin
, so the change is not that big?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This functionality existed before, but it was located in gtir_sdfg
inside construct_output_for_nested_sdfg()
.
I decided to remove the method make_copy()
, because it didn't mean much, and instead created this one.
outer_desc = self.dc_node.desc(sdfg) | ||
assert isinstance(outer_desc, dace.data.Array) | ||
if self.origin is None: | ||
outer_field_origin = [0] * ndims |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At other places you set such values to None
.
So why do you not do it here as well or why do you return None
at other places?
Okay "not defined" is conceptional different from "having value 0", but still there are a lot of if
that you could remove if you would follow this approach.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, origin
can be None
only inside FieldopData
, in case of ScalarType
data. For FieldType
data, it cannot be undefined.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is right, I did not consider this case.
src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py
Show resolved
Hide resolved
dace.symbolic.pystr_to_symbolic(gtx_dace_utils.field_size_symbol_name(name, i)) | ||
) | ||
else: | ||
# the size of global dimensions for a regular field is the symbolic |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are you doing with transients, i.e. where does their symbols came from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_make_array_shape_and_strides()
is only called for non-transient arrays. I will add this comment:
This method is only called for non-transient arrays, which require symbolic
memory layout. The memory layout of transient arrays, used for temporary
fields, is assigned by DaCe during lowering to SDFG.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should add "[...] used for temporary field, the default of DaCe, which is row major, is used and might be changed during optimization".
I think this makes it a bit clearer that the lowering does not really care about the strides beside being that they are correct (but not necessarily optimal).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct.
@@ -201,6 +199,24 @@ def _collect_symbols_in_domain_expressions( | |||
) | |||
|
|||
|
|||
def _get_field_doman_subset( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just my impression, but I am not sure if "subset" is the right term to use here.
For me a subset is more like a range, such as 3:6
.
This here is more like a concrete access, such as a[i]
, so I would suggest to name the function more something like _make_access_index_for_field()
or something.
If you keep this function, then there is a typo in "domain".
@@ -941,10 +901,47 @@ def visit_SymRef( | |||
return gtir_builtin_translators.translate_symbol_ref(node, sdfg, head_state, self) | |||
|
|||
|
|||
def _remove_field_origin_symbols(ir: gtir.Program, sdfg: dace.SDFG) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function confused my.
First it was not clear that this function is essentially undoing all the work of this PR.
Until I realized that this function is only used for build_sdfg_from_gtir()
when disable_field_origin_on_program_arguments
is set to True
.
I would add this comment to the doc string.
The second thing I do not understand is why do you only do it for transients.
I guess it is because they are never set, but I am not fully sure about that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment added. I do not understand the second comment. I actually collect the range start symbols of the program parameters, that is the non-transient arrays in the top-level SDFG.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The second part is also wrong.
I wanted to write was "why do you not do it for transients".
However, in another comment you wrote that only non transient data descriptor have this symbolic sizes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review.
src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py
Outdated
Show resolved
Hide resolved
outer_sdfg_state: dace.SDFGState, | ||
symbol_mapping: dict[str, dace.symbolic.SymbolicType], | ||
) -> FieldopData: | ||
"""Helper method to map a field data container from a nested SDFG to the parent SDFG.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This functionality existed before, but it was located in gtir_sdfg
inside construct_output_for_nested_sdfg()
.
I decided to remove the method make_copy()
, because it didn't mean much, and instead created this one.
outer_desc = self.dc_node.desc(sdfg) | ||
assert isinstance(outer_desc, dace.data.Array) | ||
if self.origin is None: | ||
outer_field_origin = [0] * ndims |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, origin
can be None
only inside FieldopData
, in case of ScalarType
data. For FieldType
data, it cannot be undefined.
src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py
Outdated
Show resolved
Hide resolved
dace.symbolic.pystr_to_symbolic(gtx_dace_utils.field_size_symbol_name(name, i)) | ||
) | ||
else: | ||
# the size of global dimensions for a regular field is the symbolic |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_make_array_shape_and_strides()
is only called for non-transient arrays. I will add this comment:
This method is only called for non-transient arrays, which require symbolic
memory layout. The memory layout of transient arrays, used for temporary
fields, is assigned by DaCe during lowering to SDFG.
@@ -279,7 +291,8 @@ def make_field( | |||
raise NotImplementedError( | |||
"Fields with more than one local dimension are not supported." | |||
) | |||
return gtir_builtin_translators.FieldopData(data_node, field_type, domain_offset) | |||
field_origin = gtx_dace_utils.get_symbolic_origin(data_node.data, field_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make_field
is actually called for SymRefs
to global symbols, that is non-transient data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some points that needs some further work, but nothing serious.
@@ -78,8 +75,8 @@ class FieldopData: | |||
Args: | |||
dc_node: DaCe access node to the data storage. | |||
gt_type: GT4Py type definition, which includes the field domain information. | |||
origin: List of start indices, in each dimension, when the dimension range |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add a check to ensure that origin
is None
when you construct a FiledopData
for scalar, i.e. adding a __post_init__
that does this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
origin: List of start indices, in each dimension, when the dimension range | ||
does not start from zero; assume zero, if origin is not set. | ||
origin: List of start indices, in each dimension, for `FieldType` data. | ||
Set to `None` only for `ScalarType` data. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Set to `None` only for `ScalarType` data. | |
Has to be `None` only for `ScalarType` data. For fields it is assumed to be all zero if not given. |
I am also thinking that you should enforce the that origin
is set correctly during construction.
Since this is a dataclass you have to implement __post_init__()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
lambda m: dace.sdfg.replace_properties_dict(outer_desc, m), | ||
) | ||
# Same applies to the symbols used as field origin (the domain range start) | ||
assert self.origin is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should do such a check at construction time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, done. However I will have to ignore a type-checking warning ([union-attr]
).
@@ -161,24 +168,18 @@ def get_symbol_mapping( | |||
""" | |||
if isinstance(self.gt_type, ts.ScalarType): | |||
return {} | |||
assert self.origin is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should enforce such constraints in the constructor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
In case of `ScalarType` data, the descriptor is constructed with `offset=None`. | ||
In case of `ScalarType` data, the descriptor is constructed with `origin=None`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case of `ScalarType` data, the descriptor is constructed with `origin=None`. | |
In case of `ScalarType` data, the `FieldopData` is constructed with `origin=None`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also add a test that data_node
is a transient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wrong, the transient property does not always hold. The lowering creates a FieldopData
also for access nodes to global arrays. I will modify the code comment.
A refactoring PR could move the declaration of this method to a type module, close to the FieldopData
type declaration.
@@ -835,12 +837,19 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp | |||
|
|||
outputs = {outval.dc_node.data for outval in gtx_utils.flatten_nested_tuple((result,))} | |||
|
|||
if nsdfg_symbols_mapping is None: | |||
# `None` means that all free symbols are mapped to the symbols available in parent SDFG |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# `None` means that all free symbols are mapped to the symbols available in parent SDFG | |
# `None` means that all free symbols are mapped to the symbols available in parent SDFG by the `add_nested_sdfg()` function. |
This case also means that we never need to do a remapping where the names inside and outside are different.
I would actually kill this if
because as far as I can see nsdfg_symbol_mapping
is either None
or {"__cond": ...}
.
if nsdfg_symbols_mapping is None: | ||
# `None` means that all free symbols are mapped to the symbols available in parent SDFG | ||
pass | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not really understand this case, is it really needed?
for psym, arg in lambda_args_mapping: | ||
nsdfg_symbols_mapping |= gtir_translators.get_arg_symbol_mapping(psym.id, arg, sdfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just my paranoia, for me there is a potential sever error.
Could it be that one argument needs the mapping {'x': 'y'}
while another argument needs {'x': 'z'}
, or is this not possible or is implementing the check not worth it?
@@ -321,28 +335,46 @@ def unique_tasklet_name(self, name: str) -> str: | |||
|
|||
def _make_array_shape_and_strides( | |||
self, name: str, dims: Sequence[gtx_common.Dimension] | |||
) -> tuple[list[dace.symbol], list[dace.symbol]]: | |||
) -> tuple[list[dace.symbolic.SymExpr], list[dace.symbolic.SymExpr]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
) -> tuple[list[dace.symbolic.SymExpr], list[dace.symbolic.SymExpr]]: | |
) -> tuple[list[dace.symbolic.SymbolicType], list[dace.symbolic.SymbolicType]]: |
outer_node = head_state.add_access(inner_data.dc_node.data) | ||
outer_data = inner_data.make_copy(outer_node) | ||
# This must be a symbol captured from the lambda parent scope. | ||
outer_node = head_state.add_access(inner_dataname) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I am not mistaken, then this implicitly assumes that the data container on the inside and the outside always have the same name.
If the above is true, then I am not sure if this correct all the time.
If the above statement is false, then what does it means.
Furthermore, what do you mean with the symbol?
How can a symbol be an output?
I am sure I miss something here.
This PR adds support for GT4Py field arguments with non-zero start index, for example:
inp = constructors.empty(common.domain({IDim: (1, 9)}), ...)
which was supported in baseline only for temporary fields, by means of a data structure called
field_offsets
. This data structure is removed for two reasons:We introduce the GT4Py concept of field origin and use it for both temporary fields and program arguments. The field origin corresponds to the start of the field domain range.
This PR also changes the symbolic definition of array shape. Before, the array shape was defined as
[data_size_0, data_size_1, ...]
, now the size corresponds to the range extentstop - start
as[(data_0_range_1 - data_0_range_0), (data_1_range_1 - data_1_range_0), ...]
.The translation stage of the dace workflow is extended with an option
disable_field_origin_on_program_arguments
to set the field range start symbols to constant value zero. This is needed for the dace orchestration, because the signature of a dace-orchestrated program does not provide the domain origin.