Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Support for Array Api namespace as allocator #1771

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

havogt
Copy link
Contributor

@havogt havogt commented Dec 4, 2024

Allow Array API namespaces as allocators in gtx.constructors. This allows e.g. to construct jax fields in non-hacky way.

Additional:

  • enable JAX embedded execution
  • add __copy__, __deepcopy__ to NDArrayField (with same memory layout as source Field)

src/gt4py/next/constructors.py Outdated Show resolved Hide resolved
allocator: FieldBufferAllocationUtil = actual_allocator,
device: core_defs.Device = device,
) -> core_defs.NDArrayObject:
# TODO check how to get from FieldBufferAllocationUtil to FieldBufferAllocatorProtocol
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • think about long names -> maybe rename to NDArray...
  • __copy__, __deepcopy__
  • as_ndarray(allocator: ConcreteAllocator, copy: Optional[bool])
  • can TensorBuffer be removed?

@havogt
Copy link
Contributor Author

havogt commented Dec 6, 2024

  • think about long names -> maybe rename to NDArray...
  • __copy__, __deepcopy__
  • as_ndarray(allocator: ConcreteAllocator, copy: Optional[bool])
  • can TensorBuffer be removed?

@havogt havogt changed the title feat[next]: allocator in field feat[next]: Support for array api namespace as allocator Dec 20, 2024
@havogt havogt changed the title feat[next]: Support for array api namespace as allocator feat[next]: Support for Array Api namespace as allocator Dec 20, 2024
self.ndarray,
domain=self.domain,
dtype=self.dtype,
copy=True, # aligned_index???
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this is a missing piece in the current implementation. We should probably provide this as a default in all functions.

def get_array_allocation_namespace(
allocator: Optional[FieldBufferAllocationUtil | core_defs.ArrayApiNamespace],
device: Optional[core_defs.Device] = None,
) -> GTArrayAllocationNamespace:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe put aligned_index here and add it as a default to the construction functions.
Same for device.

Copy link
Contributor

@egparedes egparedes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some basic comments, but it's not a detailed review. I agree with should use this strategy short-term to be able to use jax and similar immutable data array api libraries, but I'm not sure yet if the current implementation strategy is good enough or it should be refactored. I'd propose to discuss it further offline.

Comment on lines +517 to +523
def empty(self, shape: Sequence[int], *, dtype: Any = None, device: Any = None) -> Any: ...
def zeros(self, shape: Sequence[int], *, dtype: Any = None, device: Any = None) -> Any: ...
def ones(self, shape: Sequence[int], *, dtype: Any = None, device: Any = None) -> Any: ...
def full(
self, shape: Sequence[int], fill_value: Scalar, *, dtype: Any = None, device: Any = None
) -> Any: ...
def asarray(self, obj: Any, *, dtype: Any = None, copy: Any = None) -> Any: ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess using NDArrayObject as output type doesn't work with mypy, right?

# TODO(havogt): add relevant methods and attributes or wait for the standard to provide it, see e.g. https://github.com/data-apis/array-api/issues/697


def is_array_api_namespace(obj: Any) -> TypeGuard[ArrayApiNamespace]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should start using TypeIs (https://devdocs.io/python~3.13/library/typing#typing.TypeIs) instead of TypeGuard whenever possible.

Suggested change
def is_array_api_namespace(obj: Any) -> TypeGuard[ArrayApiNamespace]:
def is_array_api_namespace(obj: Any) -> TypeIs[ArrayApiNamespace]:

Comment on lines -326 to -335
if device is None and allocator is None:
raise ValueError("No 'device' or 'allocator' specified.")
actual_allocator = get_allocator(allocator)
if actual_allocator is None:
assert device is not None # for mypy
actual_allocator = device_allocators[device.device_type]
elif device is None:
device = core_defs.Device(actual_allocator.__gt_device_type__, 0)
elif device.device_type != actual_allocator.__gt_device_type__:
raise ValueError(f"Device '{device}' and allocator '{actual_allocator}' are incompatible.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this part is gone, the docstring needs to be updated, since most of the parameters are now compulsory.

@@ -41,6 +43,63 @@
)


class GTArrayAllocationNamespace(Protocol):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really like this name but haven't found a better one yet

device: Optional[core_defs.Device], aligned_index: Optional[Sequence[common.NamedIndex]]
) -> None:
if aligned_index is not None:
raise NotImplementedError("Aligned index is not support for Array API namespaces.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise NotImplementedError("Aligned index is not support for Array API namespaces.")
raise NotImplementedError("Aligned index is not supported for Array API namespaces.")

def test_copy(copy, nd_array_implementation):
testee = _make_field_or_scalar([[0, 1], [2, 3]], nd_array_implementation)
result = copy(testee)
assert np.array_equal(testee.ndarray, result.ndarray)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also test that testee and result (and their ndarrays) are not the same object

return request.param


def test_empty(allocator_device_refnamespace):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a suggestion: we could avoid the intermediate fixture here and just parametrize the test explicitly:

Suggested change
def test_empty(allocator_device_refnamespace):
@pytest.mark.parametrize("allocator, device, xp", [*allocator_device_refnamespace_params()])
def test_empty(allocator, device, xp):


assert np.array_equal(a.ndarray, ref)

def test_copy():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this test the same than the one in the embedded test_nd_array_field tests?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants