diff --git a/Cargo.toml b/Cargo.toml
index a648b09bc..c75066c5d 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -46,6 +46,8 @@ matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm
serde = { version = "1.0", optional = true, default-features = false, features = ["alloc"] }
rawpointer = { version = "0.2" }
+dlpark = { version = "0.3.0", optional = true }
+
[dev-dependencies]
defmac = "0.2"
quickcheck = { version = "1.0", default-features = false }
@@ -73,6 +75,8 @@ rayon = ["rayon_", "std"]
matrixmultiply-threading = ["matrixmultiply/threading"]
+dlpack = ["dep:dlpark"]
+
[profile.bench]
debug = true
[profile.dev.package.numeric-tests]
diff --git a/src/data_traits.rs b/src/data_traits.rs
index acf4b0b7a..7095db73d 100644
--- a/src/data_traits.rs
+++ b/src/data_traits.rs
@@ -17,9 +17,12 @@ use alloc::sync::Arc;
use alloc::vec::Vec;
use crate::{
- ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr,
+ ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr
};
+#[cfg(feature = "dlpack")]
+use crate::ManagedRepr;
+
/// Array representation trait.
///
/// For an array that meets the invariants of the `ArrayBase` type. This trait
@@ -346,6 +349,24 @@ unsafe impl RawData for OwnedRepr {
private_impl! {}
}
+#[cfg(feature = "dlpack")]
+unsafe impl RawData for ManagedRepr {
+ type Elem = A;
+
+ fn _data_slice(&self) -> Option<&[A]> {
+ Some(self.as_slice())
+ }
+
+ fn _is_pointer_inbounds(&self, self_ptr: *const Self::Elem) -> bool {
+ let slc = self.as_slice();
+ let ptr = slc.as_ptr() as *mut A;
+ let end = unsafe { ptr.add(slc.len()) };
+ self_ptr >= ptr && self_ptr <= end
+ }
+
+ private_impl! {}
+}
+
unsafe impl RawDataMut for OwnedRepr {
#[inline]
fn try_ensure_unique(_: &mut ArrayBase)
@@ -382,6 +403,28 @@ unsafe impl Data for OwnedRepr {
}
}
+#[cfg(feature = "dlpack")]
+unsafe impl Data for ManagedRepr {
+ #[inline]
+ fn into_owned(self_: ArrayBase) -> Array
+ where
+ A: Clone,
+ D: Dimension,
+ {
+ self_.to_owned()
+ }
+
+ #[inline]
+ fn try_into_owned_nocopy(
+ self_: ArrayBase,
+ ) -> Result, ArrayBase>
+ where
+ D: Dimension,
+ {
+ Err(self_)
+ }
+}
+
unsafe impl DataMut for OwnedRepr {}
unsafe impl RawDataClone for OwnedRepr
diff --git a/src/dlpack.rs b/src/dlpack.rs
new file mode 100644
index 000000000..d80919e11
--- /dev/null
+++ b/src/dlpack.rs
@@ -0,0 +1,94 @@
+use core::ptr::NonNull;
+use std::marker::PhantomData;
+
+use dlpark::prelude::*;
+
+use crate::{ArrayBase, Dimension, IntoDimension, IxDyn, ManagedArray, RawData};
+
+impl ToTensor for ArrayBase
+where
+ A: InferDtype,
+ S: RawData,
+ D: Dimension,
+{
+ fn data_ptr(&self) -> *mut std::ffi::c_void {
+ self.as_ptr() as *mut std::ffi::c_void
+ }
+
+ fn byte_offset(&self) -> u64 {
+ 0
+ }
+
+ fn device(&self) -> Device {
+ Device::CPU
+ }
+
+ fn dtype(&self) -> DataType {
+ A::infer_dtype()
+ }
+
+ fn shape(&self) -> CowIntArray {
+ dlpark::prelude::CowIntArray::from_owned(
+ self.shape().into_iter().map(|&x| x as i64).collect(),
+ )
+ }
+
+ fn strides(&self) -> Option {
+ Some(dlpark::prelude::CowIntArray::from_owned(
+ self.strides().into_iter().map(|&x| x as i64).collect(),
+ ))
+ }
+}
+
+pub struct ManagedRepr {
+ managed_tensor: ManagedTensor,
+ _ty: PhantomData,
+}
+
+impl ManagedRepr {
+ pub fn new(managed_tensor: ManagedTensor) -> Self {
+ Self {
+ managed_tensor,
+ _ty: PhantomData,
+ }
+ }
+
+ pub fn as_slice(&self) -> &[A] {
+ self.managed_tensor.as_slice()
+ }
+
+ pub fn as_ptr(&self) -> *const A {
+ self.managed_tensor.data_ptr() as *const A
+ }
+}
+
+unsafe impl Sync for ManagedRepr where A: Sync {}
+unsafe impl Send for ManagedRepr where A: Send {}
+
+impl FromDLPack for ManagedArray {
+ fn from_dlpack(dlpack: NonNull) -> Self {
+ let managed_tensor = ManagedTensor::new(dlpack);
+ let shape: Vec = managed_tensor
+ .shape()
+ .into_iter()
+ .map(|x| *x as _)
+ .collect();
+
+ let strides: Vec = match (managed_tensor.strides(), managed_tensor.is_contiguous()) {
+ (Some(s), _) => s.into_iter().map(|&x| x as _).collect(),
+ (None, true) => managed_tensor
+ .calculate_contiguous_strides()
+ .into_iter()
+ .map(|x| x as _)
+ .collect(),
+ (None, false) => panic!("dlpack: invalid strides"),
+ };
+ let ptr = managed_tensor.data_ptr() as *mut A;
+
+ let managed_repr = ManagedRepr::::new(managed_tensor);
+ unsafe {
+ ArrayBase::from_data_ptr(managed_repr, NonNull::new_unchecked(ptr))
+ .with_strides_dim(strides.into_dimension(), shape.into_dimension())
+ }
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
index 07e5ed680..2927f80fc 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -207,6 +207,9 @@ mod zip;
mod dimension;
+#[cfg(feature = "dlpack")]
+mod dlpack;
+
pub use crate::zip::{FoldWhile, IntoNdProducer, NdProducer, Zip};
pub use crate::layout::Layout;
@@ -1346,6 +1349,12 @@ pub type Array = ArrayBase, D>;
/// instead of either a view or a uniquely owned copy.
pub type CowArray<'a, A, D> = ArrayBase, D>;
+
+/// An array from managed memory
+#[cfg(feature = "dlpack")]
+pub type ManagedArray = ArrayBase, D>;
+
+
/// A read-only array view.
///
/// An array view represents an array or a part of it, created from
@@ -1420,6 +1429,10 @@ pub type RawArrayViewMut = ArrayBase, D>;
pub use data_repr::OwnedRepr;
+#[cfg(feature = "dlpack")]
+pub use dlpack::ManagedRepr;
+
+
/// ArcArray's representation.
///
/// *Don’t use this type directly—use the type alias
diff --git a/tests/dlpack.rs b/tests/dlpack.rs
new file mode 100644
index 000000000..c0ba6e307
--- /dev/null
+++ b/tests/dlpack.rs
@@ -0,0 +1,17 @@
+#![cfg(feature = "dlpack")]
+
+use dlpark::prelude::*;
+use ndarray::ManagedArray;
+
+#[test]
+fn test_dlpack() {
+ let arr = ndarray::arr1(&[1i32, 2, 3]);
+ let ptr = arr.as_ptr();
+ let dlpack = arr.into_dlpack();
+ let arr2 = ManagedArray::::from_dlpack(dlpack);
+ let ptr2 = arr2.as_ptr();
+ assert_eq!(ptr, ptr2);
+ let arr3 = arr2.to_owned();
+ let ptr3 = arr3.as_ptr();
+ assert_ne!(ptr2, ptr3);
+}