Skip to content

Commit

Permalink
Make iterators covariant in element type
Browse files Browse the repository at this point in the history
The internal Baseiter type underlies most of the ndarray iterators, and
it used `*mut A` for element type A. Update it to use `NonNull<A>` which
behaves identically except it's guaranteed to be non-null and is
covariant w.r.t the parameter A.

Add compile test from the issue.

Fixes #1290
  • Loading branch information
bluss committed Aug 6, 2024
1 parent 84fe611 commit 75b5f93
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/impl_owned_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,7 @@ where D: Dimension

// iter is a raw pointer iterator traversing the array in memory order now with the
// sorted axes.
let mut iter = Baseiter::new(self_.ptr.as_ptr(), self_.dim, self_.strides);
let mut iter = Baseiter::new(self_.ptr, self_.dim, self_.strides);
let mut dropped_elements = 0;

let mut last_ptr = data_ptr;
Expand Down
8 changes: 4 additions & 4 deletions src/impl_views/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ where D: Dimension
#[inline]
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
{
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
}
}

Expand All @@ -209,7 +209,7 @@ where D: Dimension
#[inline]
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
{
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
}
}

Expand All @@ -220,7 +220,7 @@ where D: Dimension
#[inline]
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
{
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
}

#[inline]
Expand Down Expand Up @@ -262,7 +262,7 @@ where D: Dimension
#[inline]
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
{
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
}

#[inline]
Expand Down
5 changes: 2 additions & 3 deletions src/iterators/into_iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,15 @@ impl<A, D> IntoIter<A, D>
where D: Dimension
{
/// Create a new by-value iterator that consumes `array`
pub(crate) fn new(mut array: Array<A, D>) -> Self
pub(crate) fn new(array: Array<A, D>) -> Self
{
unsafe {
let array_head_ptr = array.ptr;
let ptr = array.as_mut_ptr();
let mut array_data = array.data;
let data_len = array_data.release_all_elements();
debug_assert!(data_len >= array.dim.size());
let has_unreachable_elements = array.dim.size() != data_len;
let inner = Baseiter::new(ptr, array.dim, array.strides);
let inner = Baseiter::new(array_head_ptr, array.dim, array.strides);

IntoIter {
array_data,
Expand Down
16 changes: 9 additions & 7 deletions src/iterators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use alloc::vec::Vec;
use std::iter::FromIterator;
use std::marker::PhantomData;
use std::ptr;
use std::ptr::NonNull;

use crate::Ix1;

Expand All @@ -38,7 +39,7 @@ use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut};
#[derive(Debug)]
pub struct Baseiter<A, D>
{
ptr: *mut A,
ptr: NonNull<A>,
dim: D,
strides: D,
index: Option<D>,
Expand All @@ -50,7 +51,7 @@ impl<A, D: Dimension> Baseiter<A, D>
/// to be correct to avoid performing an unsafe pointer offset while
/// iterating.
#[inline]
pub unsafe fn new(ptr: *mut A, len: D, stride: D) -> Baseiter<A, D>
pub unsafe fn new(ptr: NonNull<A>, len: D, stride: D) -> Baseiter<A, D>
{
Baseiter {
ptr,
Expand All @@ -74,7 +75,7 @@ impl<A, D: Dimension> Iterator for Baseiter<A, D>
};
let offset = D::stride_offset(&index, &self.strides);
self.index = self.dim.next_for(index);
unsafe { Some(self.ptr.offset(offset)) }
unsafe { Some(self.ptr.offset(offset).as_ptr()) }
}

fn size_hint(&self) -> (usize, Option<usize>)
Expand All @@ -99,7 +100,7 @@ impl<A, D: Dimension> Iterator for Baseiter<A, D>
let mut i = 0;
let i_end = len - elem_index;
while i < i_end {
accum = g(accum, row_ptr.offset(i as isize * stride));
accum = g(accum, row_ptr.offset(i as isize * stride).as_ptr());
i += 1;
}
}
Expand Down Expand Up @@ -145,7 +146,7 @@ impl<A> DoubleEndedIterator for Baseiter<A, Ix1>
self.index = None;
}

unsafe { Some(self.ptr.offset(offset)) }
unsafe { Some(self.ptr.offset(offset).as_ptr()) }
}

fn nth_back(&mut self, n: usize) -> Option<*mut A>
Expand All @@ -158,7 +159,7 @@ impl<A> DoubleEndedIterator for Baseiter<A, Ix1>
if index == self.dim {
self.index = None;
}
unsafe { Some(self.ptr.offset(offset)) }
unsafe { Some(self.ptr.offset(offset).as_ptr()) }
} else {
self.index = None;
None
Expand All @@ -178,7 +179,8 @@ impl<A> DoubleEndedIterator for Baseiter<A, Ix1>
accum = g(
accum,
self.ptr
.offset(Ix1::stride_offset(&self.dim, &self.strides)),
.offset(Ix1::stride_offset(&self.dim, &self.strides))
.as_ptr(),
);
}
}
Expand Down
34 changes: 31 additions & 3 deletions tests/iterators.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
#![allow(
clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names
)]
#![allow(clippy::deref_addrof, clippy::unreadable_literal)]

use ndarray::prelude::*;
use ndarray::{arr3, indices, s, Slice, Zip};
Expand Down Expand Up @@ -1055,3 +1053,33 @@ impl Drop for DropCount<'_>
self.drops.set(self.drops.get() + 1);
}
}

#[test]
fn test_impl_iter_compiles()
{
// Requires that the iterators are covariant in the element type

// base case: std
fn slice_iter_non_empty_indices<'s, 'a>(array: &'a Vec<&'s str>) -> impl Iterator<Item = usize> + 'a
{
array
.iter()
.enumerate()
.filter(|(_index, elem)| !elem.is_empty())
.map(|(index, _elem)| index)
}

let _ = slice_iter_non_empty_indices;

// ndarray case
fn array_iter_non_empty_indices<'s, 'a>(array: &'a Array<&'s str, Ix1>) -> impl Iterator<Item = usize> + 'a
{
array
.iter()
.enumerate()
.filter(|(_index, elem)| !elem.is_empty())
.map(|(index, _elem)| index)
}

let _ = array_iter_non_empty_indices;
}

0 comments on commit 75b5f93

Please sign in to comment.