From 52ca23424c28506b01d9b47ad6a42f95d553cedc Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sat, 15 Feb 2020 13:04:30 -0500 Subject: [PATCH 1/3] Add benches for op with scalar and strided array --- benches/bench1.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/benches/bench1.rs b/benches/bench1.rs index 35a1d6e7e..291a25e97 100644 --- a/benches/bench1.rs +++ b/benches/bench1.rs @@ -431,6 +431,22 @@ fn scalar_add_2(bench: &mut test::Bencher) { bench.iter(|| n + &a); } +#[bench] +fn scalar_add_strided_1(bench: &mut test::Bencher) { + let a = + Array::from_shape_fn((64, 64 * 2), |(i, j)| (i * 64 + j) as f32).slice_move(s![.., ..;2]); + let n = 1.; + bench.iter(|| &a + n); +} + +#[bench] +fn scalar_add_strided_2(bench: &mut test::Bencher) { + let a = + Array::from_shape_fn((64, 64 * 2), |(i, j)| (i * 64 + j) as f32).slice_move(s![.., ..;2]); + let n = 1.; + bench.iter(|| n + &a); +} + #[bench] fn scalar_sub_1(bench: &mut test::Bencher) { let a = Array::::zeros((64, 64)); From 65e13d9fde4b8aaf59216f3c7245404829501216 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sat, 15 Feb 2020 13:21:23 -0500 Subject: [PATCH 2/3] Allow `&arr (op) scalar` output to be any elem type This change has two benefits: * The new implementation applies to more combinations of types. For example, it now applies to `&Array2` and `Complex`. * The new implementation avoids cloning the elements twice, and it avoids iterating over the elements twice. (The old implementation called `.to_owned()` followed by the arithmetic operation, while the new implementation clones the elements and performs the arithmetic operation in the same iteration.) On my machine, this change improves the performance for both contiguous and discontiguous arrays. (`scalar_add_1/2` go from ~530 ns/iter to ~380 ns/iter, and `scalar_add_strided_1/2` go from ~1540 ns/iter to ~1420 ns/iter.) --- src/impl_ops.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/impl_ops.rs b/src/impl_ops.rs index 4804356e8..5e57d610e 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -152,15 +152,15 @@ impl $trt for ArrayBase #[doc=$doc] /// between the reference `self` and the scalar `x`, /// and return the result as a new `Array`. -impl<'a, A, S, D, B> $trt for &'a ArrayBase - where A: Clone + $trt, +impl<'a, A, S, D, B, C> $trt for &'a ArrayBase + where A: Clone + $trt, S: Data, D: Dimension, B: ScalarOperand, { - type Output = Array; - fn $mth(self, x: B) -> Array { - self.to_owned().$mth(x) + type Output = Array; + fn $mth(self, x: B) -> Self::Output { + self.map(move |elt| elt.clone() $operator x.clone()) } } ); From b2a7d0b8852455b069b45a1414da6b672ecdafac Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sat, 15 Feb 2020 15:15:06 -0500 Subject: [PATCH 3/3] Generalize lhs scalar ops to more combos of types This doesn't have a noticeable impact on the results of the `scalar_add_2` and `scalar_add_strided_2` benchmarks. --- src/impl_ops.rs | 134 ++++++++++++++++++++++-------------------------- 1 file changed, 60 insertions(+), 74 deletions(-) diff --git a/src/impl_ops.rs b/src/impl_ops.rs index 5e57d610e..8999999f1 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -166,56 +166,42 @@ impl<'a, A, S, D, B, C> $trt for &'a ArrayBase ); ); -// Pick the expression $a for commutative and $b for ordered binop -macro_rules! if_commutative { - (Commute { $a:expr } or { $b:expr }) => { - $a - }; - (Ordered { $a:expr } or { $b:expr }) => { - $b - }; -} - macro_rules! impl_scalar_lhs_op { - // $commutative flag. Reuse the self + scalar impl if we can. - // We can do this safely since these are the primitive numeric types - ($scalar:ty, $commutative:ident, $operator:tt, $trt:ident, $mth:ident, $doc:expr) => ( -// these have no doc -- they are not visible in rustdoc -// Perform elementwise -// between the scalar `self` and array `rhs`, -// and return the result (based on `self`). -impl $trt> for $scalar - where S: DataOwned + DataMut, - D: Dimension, + ($scalar:ty, $operator:tt, $trt:ident, $mth:ident, $doc:expr) => ( +/// Perform elementwise +#[doc=$doc] +/// between the scalar `self` and array `rhs`, +/// and return the result (based on `self`). +impl $trt> for $scalar +where + $scalar: Clone + $trt, + A: Clone, + S: DataOwned + DataMut, + D: Dimension, { type Output = ArrayBase; - fn $mth(self, rhs: ArrayBase) -> ArrayBase { - if_commutative!($commutative { - rhs.$mth(self) - } or {{ - let mut rhs = rhs; - rhs.unordered_foreach_mut(move |elt| { - *elt = self $operator *elt; - }); - rhs - }}) + fn $mth(self, mut rhs: ArrayBase) -> ArrayBase { + rhs.unordered_foreach_mut(move |elt| { + *elt = self.clone() $operator elt.clone(); + }); + rhs } } -// Perform elementwise -// between the scalar `self` and array `rhs`, -// and return the result as a new `Array`. -impl<'a, S, D> $trt<&'a ArrayBase> for $scalar - where S: Data, - D: Dimension, +/// Perform elementwise +#[doc=$doc] +/// between the scalar `self` and array `rhs`, +/// and return the result as a new `Array`. +impl<'a, A, S, D, B> $trt<&'a ArrayBase> for $scalar +where + $scalar: Clone + $trt, + A: Clone, + S: Data, + D: Dimension, { - type Output = Array<$scalar, D>; - fn $mth(self, rhs: &ArrayBase) -> Array<$scalar, D> { - if_commutative!($commutative { - rhs.$mth(self) - } or { - self.$mth(rhs.to_owned()) - }) + type Output = Array; + fn $mth(self, rhs: &ArrayBase) -> Array { + rhs.map(move |elt| self.clone() $operator elt.clone()) } } ); @@ -241,16 +227,16 @@ mod arithmetic_ops { macro_rules! all_scalar_ops { ($int_scalar:ty) => ( - impl_scalar_lhs_op!($int_scalar, Commute, +, Add, add, "addition"); - impl_scalar_lhs_op!($int_scalar, Ordered, -, Sub, sub, "subtraction"); - impl_scalar_lhs_op!($int_scalar, Commute, *, Mul, mul, "multiplication"); - impl_scalar_lhs_op!($int_scalar, Ordered, /, Div, div, "division"); - impl_scalar_lhs_op!($int_scalar, Ordered, %, Rem, rem, "remainder"); - impl_scalar_lhs_op!($int_scalar, Commute, &, BitAnd, bitand, "bit and"); - impl_scalar_lhs_op!($int_scalar, Commute, |, BitOr, bitor, "bit or"); - impl_scalar_lhs_op!($int_scalar, Commute, ^, BitXor, bitxor, "bit xor"); - impl_scalar_lhs_op!($int_scalar, Ordered, <<, Shl, shl, "left shift"); - impl_scalar_lhs_op!($int_scalar, Ordered, >>, Shr, shr, "right shift"); + impl_scalar_lhs_op!($int_scalar, +, Add, add, "addition"); + impl_scalar_lhs_op!($int_scalar, -, Sub, sub, "subtraction"); + impl_scalar_lhs_op!($int_scalar, *, Mul, mul, "multiplication"); + impl_scalar_lhs_op!($int_scalar, /, Div, div, "division"); + impl_scalar_lhs_op!($int_scalar, %, Rem, rem, "remainder"); + impl_scalar_lhs_op!($int_scalar, &, BitAnd, bitand, "bit and"); + impl_scalar_lhs_op!($int_scalar, |, BitOr, bitor, "bit or"); + impl_scalar_lhs_op!($int_scalar, ^, BitXor, bitxor, "bit xor"); + impl_scalar_lhs_op!($int_scalar, <<, Shl, shl, "left shift"); + impl_scalar_lhs_op!($int_scalar, >>, Shr, shr, "right shift"); ); } all_scalar_ops!(i8); @@ -264,31 +250,31 @@ mod arithmetic_ops { all_scalar_ops!(i128); all_scalar_ops!(u128); - impl_scalar_lhs_op!(bool, Commute, &, BitAnd, bitand, "bit and"); - impl_scalar_lhs_op!(bool, Commute, |, BitOr, bitor, "bit or"); - impl_scalar_lhs_op!(bool, Commute, ^, BitXor, bitxor, "bit xor"); + impl_scalar_lhs_op!(bool, &, BitAnd, bitand, "bit and"); + impl_scalar_lhs_op!(bool, |, BitOr, bitor, "bit or"); + impl_scalar_lhs_op!(bool, ^, BitXor, bitxor, "bit xor"); - impl_scalar_lhs_op!(f32, Commute, +, Add, add, "addition"); - impl_scalar_lhs_op!(f32, Ordered, -, Sub, sub, "subtraction"); - impl_scalar_lhs_op!(f32, Commute, *, Mul, mul, "multiplication"); - impl_scalar_lhs_op!(f32, Ordered, /, Div, div, "division"); - impl_scalar_lhs_op!(f32, Ordered, %, Rem, rem, "remainder"); + impl_scalar_lhs_op!(f32, +, Add, add, "addition"); + impl_scalar_lhs_op!(f32, -, Sub, sub, "subtraction"); + impl_scalar_lhs_op!(f32, *, Mul, mul, "multiplication"); + impl_scalar_lhs_op!(f32, /, Div, div, "division"); + impl_scalar_lhs_op!(f32, %, Rem, rem, "remainder"); - impl_scalar_lhs_op!(f64, Commute, +, Add, add, "addition"); - impl_scalar_lhs_op!(f64, Ordered, -, Sub, sub, "subtraction"); - impl_scalar_lhs_op!(f64, Commute, *, Mul, mul, "multiplication"); - impl_scalar_lhs_op!(f64, Ordered, /, Div, div, "division"); - impl_scalar_lhs_op!(f64, Ordered, %, Rem, rem, "remainder"); + impl_scalar_lhs_op!(f64, +, Add, add, "addition"); + impl_scalar_lhs_op!(f64, -, Sub, sub, "subtraction"); + impl_scalar_lhs_op!(f64, *, Mul, mul, "multiplication"); + impl_scalar_lhs_op!(f64, /, Div, div, "division"); + impl_scalar_lhs_op!(f64, %, Rem, rem, "remainder"); - impl_scalar_lhs_op!(Complex, Commute, +, Add, add, "addition"); - impl_scalar_lhs_op!(Complex, Ordered, -, Sub, sub, "subtraction"); - impl_scalar_lhs_op!(Complex, Commute, *, Mul, mul, "multiplication"); - impl_scalar_lhs_op!(Complex, Ordered, /, Div, div, "division"); + impl_scalar_lhs_op!(Complex, +, Add, add, "addition"); + impl_scalar_lhs_op!(Complex, -, Sub, sub, "subtraction"); + impl_scalar_lhs_op!(Complex, *, Mul, mul, "multiplication"); + impl_scalar_lhs_op!(Complex, /, Div, div, "division"); - impl_scalar_lhs_op!(Complex, Commute, +, Add, add, "addition"); - impl_scalar_lhs_op!(Complex, Ordered, -, Sub, sub, "subtraction"); - impl_scalar_lhs_op!(Complex, Commute, *, Mul, mul, "multiplication"); - impl_scalar_lhs_op!(Complex, Ordered, /, Div, div, "division"); + impl_scalar_lhs_op!(Complex, +, Add, add, "addition"); + impl_scalar_lhs_op!(Complex, -, Sub, sub, "subtraction"); + impl_scalar_lhs_op!(Complex, *, Mul, mul, "multiplication"); + impl_scalar_lhs_op!(Complex, /, Div, div, "division"); impl Neg for ArrayBase where