From 70fbc809de38b7a2250d12986a41b5b045130e45 Mon Sep 17 00:00:00 2001 From: roblabla Date: Mon, 10 Oct 2022 18:40:28 +0200 Subject: [PATCH] Impl a lifetime-relaxed broadcast for ArrayView ArrayView::broadcast has a lifetime that depends on &self instead of its internal buffer. This prevents writing some types of functions in an allocation-free way. For instance, take the numpy `meshgrid` function: It could be implemented like so: ```rust fn meshgrid_2d<'a, 'b>(coords_x: ArrayView1<'a, X>, coords_y: ArrayView1<'b, X>) -> (ArrayView2<'a, X>, ArrayView2<'b, X>) { let x_len = coords_x.shape()[0]; let y_len = coords_y.shape()[0]; let coords_x_s = coords_x.into_shape((1, y_len)).unwrap(); let coords_x_b = coords_x_s.broadcast((x_len, y_len)).unwrap(); let coords_y_s = coords_y.into_shape((x_len, 1)).unwrap(); let coords_y_b = coords_y_s.broadcast((x_len, y_len)).unwrap(); (coords_x_b, coords_y_b) } ``` Unfortunately, this doesn't work, because `coords_x_b` is bound to the lifetime of `coord_x_s`, instead of being bound to 'a. This commit introduces a new function, broadcast_ref, that does just that. --- src/impl_views/methods.rs | 84 +++++++++++++++++++++++++++++++++++++++ src/impl_views/mod.rs | 1 + 2 files changed, 85 insertions(+) create mode 100644 src/impl_views/methods.rs diff --git a/src/impl_views/methods.rs b/src/impl_views/methods.rs new file mode 100644 index 000000000..d3b00aa32 --- /dev/null +++ b/src/impl_views/methods.rs @@ -0,0 +1,84 @@ +// Copyright 2014-2016 bluss and ndarray developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::imp_prelude::*; +use crate::dimension::IntoDimension; +use crate::dimension::size_of_shape_checked; + +impl<'a, A, D> ArrayView<'a, A, D> +where + D: Dimension, +{ + /// Broadcasts an `ArrayView`. See [`ArrayBase::broadcast`]. + /// + /// This is a specialized version of [`ArrayBase::broadcast`] that transfers + /// the view's lifetime to the output. + pub fn broadcast_ref(&self, dim: E) -> Option> + where + E: IntoDimension, + { + /// Return new stride when trying to grow `from` into shape `to` + /// + /// Broadcasting works by returning a "fake stride" where elements + /// to repeat are in axes with 0 stride, so that several indexes point + /// to the same element. + /// + /// **Note:** Cannot be used for mutable iterators, since repeating + /// elements would create aliasing pointers. + fn upcast(to: &D, from: &E, stride: &E) -> Option { + // Make sure the product of non-zero axis lengths does not exceed + // `isize::MAX`. This is the only safety check we need to perform + // because all the other constraints of `ArrayBase` are guaranteed + // to be met since we're starting from a valid `ArrayBase`. + let _ = size_of_shape_checked(to).ok()?; + + let mut new_stride = to.clone(); + // begin at the back (the least significant dimension) + // size of the axis has to either agree or `from` has to be 1 + if to.ndim() < from.ndim() { + return None; + } + + { + let mut new_stride_iter = new_stride.slice_mut().iter_mut().rev(); + for ((er, es), dr) in from + .slice() + .iter() + .rev() + .zip(stride.slice().iter().rev()) + .zip(new_stride_iter.by_ref()) + { + /* update strides */ + if *dr == *er { + /* keep stride */ + *dr = *es; + } else if *er == 1 { + /* dead dimension, zero stride */ + *dr = 0 + } else { + return None; + } + } + + /* set remaining strides to zero */ + for dr in new_stride_iter { + *dr = 0; + } + } + Some(new_stride) + } + let dim = dim.into_dimension(); + + // Note: zero strides are safe precisely because we return an read-only view + let broadcast_strides = match upcast(&dim, &self.dim, &self.strides) { + Some(st) => st, + None => return None, + }; + unsafe { Some(ArrayView::new(self.ptr, dim, broadcast_strides)) } + } +} diff --git a/src/impl_views/mod.rs b/src/impl_views/mod.rs index 487cc3cb2..fda58242a 100644 --- a/src/impl_views/mod.rs +++ b/src/impl_views/mod.rs @@ -1,6 +1,7 @@ mod constructors; mod conversions; mod indexing; +mod methods; mod splitting; pub use constructors::*;