diff --git a/Cargo.toml b/Cargo.toml index a648b09bc..c75066c5d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,8 @@ matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm serde = { version = "1.0", optional = true, default-features = false, features = ["alloc"] } rawpointer = { version = "0.2" } +dlpark = { version = "0.3.0", optional = true } + [dev-dependencies] defmac = "0.2" quickcheck = { version = "1.0", default-features = false } @@ -73,6 +75,8 @@ rayon = ["rayon_", "std"] matrixmultiply-threading = ["matrixmultiply/threading"] +dlpack = ["dep:dlpark"] + [profile.bench] debug = true [profile.dev.package.numeric-tests] diff --git a/src/data_traits.rs b/src/data_traits.rs index acf4b0b7a..7095db73d 100644 --- a/src/data_traits.rs +++ b/src/data_traits.rs @@ -17,9 +17,12 @@ use alloc::sync::Arc; use alloc::vec::Vec; use crate::{ - ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr, + ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr }; +#[cfg(feature = "dlpack")] +use crate::ManagedRepr; + /// Array representation trait. /// /// For an array that meets the invariants of the `ArrayBase` type. This trait @@ -346,6 +349,24 @@ unsafe impl RawData for OwnedRepr { private_impl! {} } +#[cfg(feature = "dlpack")] +unsafe impl RawData for ManagedRepr { + type Elem = A; + + fn _data_slice(&self) -> Option<&[A]> { + Some(self.as_slice()) + } + + fn _is_pointer_inbounds(&self, self_ptr: *const Self::Elem) -> bool { + let slc = self.as_slice(); + let ptr = slc.as_ptr() as *mut A; + let end = unsafe { ptr.add(slc.len()) }; + self_ptr >= ptr && self_ptr <= end + } + + private_impl! {} +} + unsafe impl RawDataMut for OwnedRepr { #[inline] fn try_ensure_unique(_: &mut ArrayBase) @@ -382,6 +403,28 @@ unsafe impl Data for OwnedRepr { } } +#[cfg(feature = "dlpack")] +unsafe impl Data for ManagedRepr { + #[inline] + fn into_owned(self_: ArrayBase) -> Array + where + A: Clone, + D: Dimension, + { + self_.to_owned() + } + + #[inline] + fn try_into_owned_nocopy( + self_: ArrayBase, + ) -> Result, ArrayBase> + where + D: Dimension, + { + Err(self_) + } +} + unsafe impl DataMut for OwnedRepr {} unsafe impl RawDataClone for OwnedRepr diff --git a/src/dlpack.rs b/src/dlpack.rs new file mode 100644 index 000000000..d80919e11 --- /dev/null +++ b/src/dlpack.rs @@ -0,0 +1,94 @@ +use core::ptr::NonNull; +use std::marker::PhantomData; + +use dlpark::prelude::*; + +use crate::{ArrayBase, Dimension, IntoDimension, IxDyn, ManagedArray, RawData}; + +impl ToTensor for ArrayBase +where + A: InferDtype, + S: RawData, + D: Dimension, +{ + fn data_ptr(&self) -> *mut std::ffi::c_void { + self.as_ptr() as *mut std::ffi::c_void + } + + fn byte_offset(&self) -> u64 { + 0 + } + + fn device(&self) -> Device { + Device::CPU + } + + fn dtype(&self) -> DataType { + A::infer_dtype() + } + + fn shape(&self) -> CowIntArray { + dlpark::prelude::CowIntArray::from_owned( + self.shape().into_iter().map(|&x| x as i64).collect(), + ) + } + + fn strides(&self) -> Option { + Some(dlpark::prelude::CowIntArray::from_owned( + self.strides().into_iter().map(|&x| x as i64).collect(), + )) + } +} + +pub struct ManagedRepr { + managed_tensor: ManagedTensor, + _ty: PhantomData, +} + +impl ManagedRepr { + pub fn new(managed_tensor: ManagedTensor) -> Self { + Self { + managed_tensor, + _ty: PhantomData, + } + } + + pub fn as_slice(&self) -> &[A] { + self.managed_tensor.as_slice() + } + + pub fn as_ptr(&self) -> *const A { + self.managed_tensor.data_ptr() as *const A + } +} + +unsafe impl Sync for ManagedRepr where A: Sync {} +unsafe impl Send for ManagedRepr where A: Send {} + +impl FromDLPack for ManagedArray { + fn from_dlpack(dlpack: NonNull) -> Self { + let managed_tensor = ManagedTensor::new(dlpack); + let shape: Vec = managed_tensor + .shape() + .into_iter() + .map(|x| *x as _) + .collect(); + + let strides: Vec = match (managed_tensor.strides(), managed_tensor.is_contiguous()) { + (Some(s), _) => s.into_iter().map(|&x| x as _).collect(), + (None, true) => managed_tensor + .calculate_contiguous_strides() + .into_iter() + .map(|x| x as _) + .collect(), + (None, false) => panic!("dlpack: invalid strides"), + }; + let ptr = managed_tensor.data_ptr() as *mut A; + + let managed_repr = ManagedRepr::::new(managed_tensor); + unsafe { + ArrayBase::from_data_ptr(managed_repr, NonNull::new_unchecked(ptr)) + .with_strides_dim(strides.into_dimension(), shape.into_dimension()) + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 07e5ed680..2927f80fc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -207,6 +207,9 @@ mod zip; mod dimension; +#[cfg(feature = "dlpack")] +mod dlpack; + pub use crate::zip::{FoldWhile, IntoNdProducer, NdProducer, Zip}; pub use crate::layout::Layout; @@ -1346,6 +1349,12 @@ pub type Array = ArrayBase, D>; /// instead of either a view or a uniquely owned copy. pub type CowArray<'a, A, D> = ArrayBase, D>; + +/// An array from managed memory +#[cfg(feature = "dlpack")] +pub type ManagedArray = ArrayBase, D>; + + /// A read-only array view. /// /// An array view represents an array or a part of it, created from @@ -1420,6 +1429,10 @@ pub type RawArrayViewMut = ArrayBase, D>; pub use data_repr::OwnedRepr; +#[cfg(feature = "dlpack")] +pub use dlpack::ManagedRepr; + + /// ArcArray's representation. /// /// *Don’t use this type directly—use the type alias diff --git a/tests/dlpack.rs b/tests/dlpack.rs new file mode 100644 index 000000000..c0ba6e307 --- /dev/null +++ b/tests/dlpack.rs @@ -0,0 +1,17 @@ +#![cfg(feature = "dlpack")] + +use dlpark::prelude::*; +use ndarray::ManagedArray; + +#[test] +fn test_dlpack() { + let arr = ndarray::arr1(&[1i32, 2, 3]); + let ptr = arr.as_ptr(); + let dlpack = arr.into_dlpack(); + let arr2 = ManagedArray::::from_dlpack(dlpack); + let ptr2 = arr2.as_ptr(); + assert_eq!(ptr, ptr2); + let arr3 = arr2.to_owned(); + let ptr3 = arr3.as_ptr(); + assert_ne!(ptr2, ptr3); +}