diff --git a/src/ragged/_spec_set_functions.py b/src/ragged/_spec_set_functions.py index 0ab8af6..259ccd9 100644 --- a/src/ragged/_spec_set_functions.py +++ b/src/ragged/_spec_set_functions.py @@ -64,10 +64,10 @@ def unique_all(x: array, /) -> tuple[array, array, array, array]: x_flat = ak.ravel(x._impl) # pylint: disable=W0212 if isinstance(x_flat.layout, ak.contents.EmptyArray): # pylint: disable=E1101 return unique_all_result( - values=ragged.array([]), - indices=ragged.array([]), - inverse_indices=ragged.array([]), - counts=ragged.array([]), + values=ragged.array(np.empty(0, x.dtype)), + indices=ragged.array(np.empty(0, np.int64)), + inverse_indices=ragged.array(np.empty(0, np.int64)), + counts=ragged.array(np.empty(0, np.int64)), ) values, indices, inverse_indices, counts = np.unique( x_flat.layout.data, # pylint: disable=E1101 @@ -123,7 +123,8 @@ def unique_counts(x: array, /) -> tuple[array, array]: x_flat = ak.ravel(x._impl) # pylint: disable=W0212 if isinstance(x_flat.layout, ak.contents.EmptyArray): # pylint: disable=E1101 return unique_counts_result( - values=ragged.array([]), counts=ragged.array([]) + values=ragged.array(np.empty(0, x.dtype)), + counts=ragged.array(np.empty(0, np.int64)), ) values, counts = np.unique( x_flat.layout.data, # pylint: disable=E1101 @@ -174,7 +175,8 @@ def unique_inverse(x: array, /) -> tuple[array, array]: x_flat = ak.ravel(x._impl) # pylint: disable=W0212 if isinstance(x_flat.layout, ak.contents.EmptyArray): # pylint: disable=E1101 return unique_inverse_result( - values=ragged.array([]), inverse_indices=ragged.array([]) + values=ragged.array(np.empty(0, x.dtype)), + inverse_indices=ragged.array(np.empty(0, np.int64)), ) values, inverse_indices = np.unique( x_flat.layout.data, # pylint: disable=E1101 @@ -213,7 +215,7 @@ def unique_values(x: array, /) -> array: else: x_flat = ak.ravel(x._impl) # pylint: disable=W0212 if isinstance(x_flat.layout, ak.contents.EmptyArray): # pylint: disable=E1101 - return ragged.array([]) + return ragged.array(np.empty(0, x.dtype)) return ragged.array(np.unique(x_flat.layout.data, equal_nan=False)) # pylint: disable=E1101 else: err = f"Expected ragged type but got {type(x)}" # type: ignore[unreachable]