-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[next]: Support for Array Api namespace as allocator #1771
base: main
Are you sure you want to change the base?
Conversation
src/gt4py/next/allocators.py
Outdated
allocator: FieldBufferAllocationUtil = actual_allocator, | ||
device: core_defs.Device = device, | ||
) -> core_defs.NDArrayObject: | ||
# TODO check how to get from FieldBufferAllocationUtil to FieldBufferAllocatorProtocol |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- think about long names -> maybe rename to
NDArray...
__copy__
,__deepcopy__
as_ndarray(allocator: ConcreteAllocator, copy: Optional[bool])
- can
TensorBuffer
be removed?
|
self.ndarray, | ||
domain=self.domain, | ||
dtype=self.dtype, | ||
copy=True, # aligned_index??? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually this is a missing piece in the current implementation. We should probably provide this as a default in all functions.
def get_array_allocation_namespace( | ||
allocator: Optional[FieldBufferAllocationUtil | core_defs.ArrayApiNamespace], | ||
device: Optional[core_defs.Device] = None, | ||
) -> GTArrayAllocationNamespace: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe put aligned_index here and add it as a default to the construction functions.
Same for device.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added some basic comments, but it's not a detailed review. I agree with should use this strategy short-term to be able to use jax
and similar immutable data array api libraries, but I'm not sure yet if the current implementation strategy is good enough or it should be refactored. I'd propose to discuss it further offline.
def empty(self, shape: Sequence[int], *, dtype: Any = None, device: Any = None) -> Any: ... | ||
def zeros(self, shape: Sequence[int], *, dtype: Any = None, device: Any = None) -> Any: ... | ||
def ones(self, shape: Sequence[int], *, dtype: Any = None, device: Any = None) -> Any: ... | ||
def full( | ||
self, shape: Sequence[int], fill_value: Scalar, *, dtype: Any = None, device: Any = None | ||
) -> Any: ... | ||
def asarray(self, obj: Any, *, dtype: Any = None, copy: Any = None) -> Any: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess using NDArrayObject
as output type doesn't work with mypy, right?
# TODO(havogt): add relevant methods and attributes or wait for the standard to provide it, see e.g. https://github.com/data-apis/array-api/issues/697 | ||
|
||
|
||
def is_array_api_namespace(obj: Any) -> TypeGuard[ArrayApiNamespace]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should start using TypeIs
(https://devdocs.io/python~3.13/library/typing#typing.TypeIs) instead of TypeGuard
whenever possible.
def is_array_api_namespace(obj: Any) -> TypeGuard[ArrayApiNamespace]: | |
def is_array_api_namespace(obj: Any) -> TypeIs[ArrayApiNamespace]: |
if device is None and allocator is None: | ||
raise ValueError("No 'device' or 'allocator' specified.") | ||
actual_allocator = get_allocator(allocator) | ||
if actual_allocator is None: | ||
assert device is not None # for mypy | ||
actual_allocator = device_allocators[device.device_type] | ||
elif device is None: | ||
device = core_defs.Device(actual_allocator.__gt_device_type__, 0) | ||
elif device.device_type != actual_allocator.__gt_device_type__: | ||
raise ValueError(f"Device '{device}' and allocator '{actual_allocator}' are incompatible.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this part is gone, the docstring needs to be updated, since most of the parameters are now compulsory.
@@ -41,6 +43,63 @@ | |||
) | |||
|
|||
|
|||
class GTArrayAllocationNamespace(Protocol): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really like this name but haven't found a better one yet
device: Optional[core_defs.Device], aligned_index: Optional[Sequence[common.NamedIndex]] | ||
) -> None: | ||
if aligned_index is not None: | ||
raise NotImplementedError("Aligned index is not support for Array API namespaces.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise NotImplementedError("Aligned index is not support for Array API namespaces.") | |
raise NotImplementedError("Aligned index is not supported for Array API namespaces.") |
def test_copy(copy, nd_array_implementation): | ||
testee = _make_field_or_scalar([[0, 1], [2, 3]], nd_array_implementation) | ||
result = copy(testee) | ||
assert np.array_equal(testee.ndarray, result.ndarray) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would also test that testee
and result
(and their ndarrays) are not the same object
return request.param | ||
|
||
|
||
def test_empty(allocator_device_refnamespace): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a suggestion: we could avoid the intermediate fixture here and just parametrize the test explicitly:
def test_empty(allocator_device_refnamespace): | |
@pytest.mark.parametrize("allocator, device, xp", [*allocator_device_refnamespace_params()]) | |
def test_empty(allocator, device, xp): |
|
||
assert np.array_equal(a.ndarray, ref) | ||
|
||
def test_copy(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this test the same than the one in the embedded test_nd_array_field
tests?
Allow Array API namespaces as allocators in
gtx.constructors
. This allows e.g. to construct jax fields in non-hacky way.Additional:
__copy__
,__deepcopy__
to NDArrayField (with same memory layout as source Field)