From 6b6b89e6b1a89ad0d0548ad1d35e35d55b6b03a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Drago=C8=99=20Rotaru?= Date: Thu, 28 Apr 2022 19:33:41 +0300 Subject: [PATCH] Better slicing (#1059) * added better slice draft * added slicing at the logical level * first slice example works * from sliced to slice again * rename better_slice to strided_slice * small cleanup of edsl/base.py * first stab at pythonic/numpy-style slicing of Expressions * replaced old slice with strided slice * added slice for host fixed point at the logical layer * added boolean kernels, but still need to do the implementation on the bit array level * added rep kernels * added slice support for uint64, added few missing kernels for loading constants * more tests for slicing * add ShapeType handling in Expression.__getitem__ * added missing kernel * linting * fixed flake8 issues Co-authored-by: jvmncs Co-authored-by: jvmncs --- moose/src/boolean/mod.rs | 56 ++++- moose/src/fixedpoint/ops.rs | 59 ++++++ moose/src/floatingpoint/ops.rs | 25 ++- moose/src/host/bitarray.rs | 6 + moose/src/host/ops.rs | 112 ++++++++-- moose/src/integer/ops.rs | 60 +++++- moose/src/kernels/indexing.rs | 40 +++- moose/src/kernels/io.rs | 8 + moose/src/logical/ops.rs | 96 +++++++++ moose/src/replicated/ops.rs | 44 +++- .../examples/linear-regression/linreg_test.py | 2 +- pymoose/pymoose/computation/operations.py | 5 + pymoose/pymoose/computation/utils.py | 8 + pymoose/pymoose/edsl/__init__.py | 7 +- pymoose/pymoose/edsl/base.py | 200 ++++++++++++------ pymoose/pymoose/edsl/tracer.py | 18 ++ .../pymoose/predictors/linear_predictor.py | 4 +- .../rust_integration_tests/slicing_test.py | 187 ++++++++++++++++ pymoose/src/computation.rs | 44 ++++ 19 files changed, 892 insertions(+), 89 deletions(-) create mode 100644 pymoose/rust_integration_tests/slicing_test.py diff --git a/moose/src/boolean/mod.rs b/moose/src/boolean/mod.rs index 412923494..a6e2d1f7a 100644 --- a/moose/src/boolean/mod.rs +++ b/moose/src/boolean/mod.rs @@ -4,7 +4,7 @@ use crate::computation::*; use crate::error::{Error, Result}; use crate::execution::Session; use crate::floatingpoint::FloatTensor; -use crate::host::HostPlacement; +use crate::host::{HostPlacement, SliceInfo}; use crate::integer::AbstractUint64Tensor; use crate::kernels::*; use crate::replicated::ReplicatedPlacement; @@ -281,6 +281,60 @@ impl IndexAxisOp { } } +impl SliceOp { + pub(crate) fn bool_rep_kernel( + sess: &S, + plc: &ReplicatedPlacement, + info: SliceInfo, + x: BoolTensor, + ) -> Result> + where + ReplicatedPlacement: PlacementSlice, + ReplicatedPlacement: PlacementShare, + { + let x = match x { + BoolTensor::Host(v) => plc.share(sess, &v), + BoolTensor::Replicated(v) => v, + }; + let z = plc.slice(sess, info, &x); + Ok(BoolTensor::Replicated(z)) + } + + pub(crate) fn bool_host_kernel( + sess: &S, + plc: &HostPlacement, + info: SliceInfo, + x: BoolTensor, + ) -> Result> + where + HostPlacement: PlacementSlice, + HostPlacement: PlacementReveal, + { + let x = match x { + BoolTensor::Replicated(v) => plc.reveal(sess, &v), + BoolTensor::Host(v) => v, + }; + let z = plc.slice(sess, info, &x); + Ok(BoolTensor::Host(z)) + } +} + +impl LoadOp { + pub(crate) fn bool_kernel( + sess: &S, + plc: &HostPlacement, + key: m!(HostString), + query: m!(HostString), + ) -> Result> + where + HostString: KnownType, + HostPlacement: PlacementLoad, + { + let z = plc.load(sess, &key, &query); + Ok(BoolTensor::Host(z)) + } +} + impl SaveOp { pub(crate) fn bool_kernel( sess: &S, diff --git a/moose/src/fixedpoint/ops.rs b/moose/src/fixedpoint/ops.rs index dba7beca8..8f1c0bd73 100644 --- a/moose/src/fixedpoint/ops.rs +++ b/moose/src/fixedpoint/ops.rs @@ -1161,6 +1161,65 @@ impl IndexAxisOp { } } +impl SliceOp { + pub(crate) fn fixed_host_kernel( + sess: &S, + plc: &HostPlacement, + info: SliceInfo, + x: FixedTensor, + ) -> Result> + where + HostPlacement: PlacementReveal, + HostPlacement: PlacementDemirror, + HostPlacement: PlacementSlice, + { + let x = match x { + FixedTensor::Replicated(v) => plc.reveal(sess, &v), + FixedTensor::Mirrored3(v) => plc.demirror(sess, &v), + FixedTensor::Host(v) => v, + }; + let z = plc.slice(sess, info, &x); + Ok(FixedTensor::Host(z)) + } + + pub(crate) fn fixed_rep_kernel( + sess: &S, + plc: &ReplicatedPlacement, + info: SliceInfo, + x: FixedTensor, + ) -> Result> + where + ReplicatedPlacement: PlacementShare, + ReplicatedPlacement: PlacementShare, + ReplicatedPlacement: PlacementSlice, + { + let x = match x { + FixedTensor::Host(v) => plc.share(sess, &v), + FixedTensor::Mirrored3(v) => plc.share(sess, &v), + FixedTensor::Replicated(v) => v, + }; + let z = plc.slice(sess, info, &x); + Ok(FixedTensor::Replicated(z)) + } + + pub(crate) fn repfixed_kernel( + sess: &S, + plc: &ReplicatedPlacement, + info: SliceInfo, + x: RepFixedTensor, + ) -> Result> + where + ReplicatedPlacement: PlacementSlice, + { + let y = plc.slice(sess, info, &x.tensor); + Ok(RepFixedTensor { + tensor: y, + fractional_precision: x.fractional_precision, + integral_precision: x.integral_precision, + }) + } +} + impl ShapeOp { pub(crate) fn host_fixed_kernel( sess: &S, diff --git a/moose/src/floatingpoint/ops.rs b/moose/src/floatingpoint/ops.rs index a0f8658f0..66258cece 100644 --- a/moose/src/floatingpoint/ops.rs +++ b/moose/src/floatingpoint/ops.rs @@ -4,7 +4,7 @@ use crate::computation::*; use crate::error::Error; use crate::error::Result; use crate::execution::Session; -use crate::host::HostPlacement; +use crate::host::{HostPlacement, SliceInfo}; use crate::kernels::*; use crate::mirrored::{Mir3Tensor, Mirrored3Placement}; use crate::types::*; @@ -707,3 +707,26 @@ impl MuxOp { Ok(FloatTensor::Host(z)) } } + +impl SliceOp { + pub(crate) fn float_host_kernel( + sess: &S, + plc: &HostPlacement, + slice: SliceInfo, + x: FloatTensor, + ) -> Result> + where + HostPlacement: PlacementSlice, + { + let x = match x { + FloatTensor::Host(v) => v, + FloatTensor::Mirrored3(_v) => { + return Err(Error::UnimplementedOperator( + "SliceOp @ Mirrored3Placement".to_string(), + )) + } + }; + let z = plc.slice(sess, slice, &x); + Ok(FloatTensor::Host(z)) + } +} diff --git a/moose/src/host/bitarray.rs b/moose/src/host/bitarray.rs index 988c12ee2..c52238a68 100644 --- a/moose/src/host/bitarray.rs +++ b/moose/src/host/bitarray.rs @@ -1,4 +1,5 @@ use crate::host::RawShape; +use crate::host::SliceInfo; use anyhow::anyhow; use bitvec::prelude::*; use ndarray::{prelude::*, RemoveAxis}; @@ -125,6 +126,11 @@ impl BitArrayRepr { dim: Arc::new(IxDyn(&[len])), } } + + pub(crate) fn slice(&self, _info: SliceInfo) -> anyhow::Result { + // TODO(Dragos) Implement this in future + Err(anyhow::anyhow!("slicing not implemented for BitArray yet")) + } } impl std::ops::BitXor for &BitArrayRepr { diff --git a/moose/src/host/ops.rs b/moose/src/host/ops.rs index f44fa99eb..3d1477dc3 100644 --- a/moose/src/host/ops.rs +++ b/moose/src/host/ops.rs @@ -473,31 +473,71 @@ impl AtLeast2DOp { } impl SliceOp { - pub(crate) fn host_kernel( - _sess: &S, + pub(crate) fn host_fixed_kernel( + sess: &S, + plc: &HostPlacement, + info: SliceInfo, + x: HostFixedTensor, + ) -> Result> + where + HostPlacement: PlacementSlice, + { + let tensor = plc.slice(sess, info, &x.tensor); + Ok(HostFixedTensor:: { + tensor, + fractional_precision: x.fractional_precision, + integral_precision: x.integral_precision, + }) + } + + pub(crate) fn host_bit_kernel( + sess: &S, plc: &HostPlacement, - slice_info: SliceInfo, + info: SliceInfo, + x: HostBitTensor, + ) -> Result + where + HostPlacement: PlacementPlace, + { + let x = plc.place(sess, x); + x.slice(info) + } + + pub(crate) fn host_generic_kernel( + sess: &S, + plc: &HostPlacement, + info: SliceInfo, + x: HostTensor, + ) -> Result> + where + HostPlacement: PlacementPlace>, + { + let x = plc.place(sess, x); + x.slice(info) + } + + pub(crate) fn host_ring_kernel( + sess: &S, + plc: &HostPlacement, + info: SliceInfo, x: HostRingTensor, ) -> Result> where T: Clone, + HostPlacement: PlacementPlace>, { - let slice_info = - ndarray::SliceInfo::, IxDyn, IxDyn>::from(slice_info); - let sliced = x.0.slice(slice_info).to_owned(); - Ok(HostRingTensor(sliced.to_shared(), plc.clone())) + let x = plc.place(sess, x); + x.slice(info) } pub(crate) fn shape_kernel( _sess: &S, plc: &HostPlacement, - slice_info: SliceInfo, + info: SliceInfo, x: HostShape, ) -> Result { - let slice = x.0.slice( - slice_info.0[0].start as usize, - slice_info.0[0].end.unwrap() as usize, - ); + let slice = + x.0.slice(info.0[0].start as usize, info.0[0].end.unwrap() as usize); Ok(HostShape(slice, plc.clone())) } } @@ -559,6 +599,21 @@ impl HostTensor { } } +impl HostTensor { + fn slice(&self, info: SliceInfo) -> Result> { + if info.0.len() != self.0.ndim() { + return Err(Error::InvalidArgument(format!( + "The input dimension of `info` must match the array to be sliced. Used slice info dim {}, tensor had dim {}", + info.0.len(), + self.0.ndim() + ))); + } + let info = ndarray::SliceInfo::, IxDyn, IxDyn>::from(info); + let result = self.0.slice(info); + Ok(HostTensor(result.to_owned().into_shared(), self.1.clone())) + } +} + impl HostRingTensor { fn index_axis(self, axis: usize, index: usize) -> Result> { if axis >= self.0.ndim() { @@ -581,6 +636,24 @@ impl HostRingTensor { } } +impl HostRingTensor { + fn slice(&self, info: SliceInfo) -> Result> { + if info.0.len() != self.0.ndim() { + return Err(Error::InvalidArgument(format!( + "The input dimension of `info` must match the array to be sliced. Used slice info dim {}, tensor had dim {}", + info.0.len(), + self.0.ndim() + ))); + } + let info = ndarray::SliceInfo::, IxDyn, IxDyn>::from(info); + let result = self.0.slice(info); + Ok(HostRingTensor( + result.to_owned().into_shared(), + self.1.clone(), + )) + } +} + impl HostBitTensor { fn index_axis(self, axis: usize, index: usize) -> Result { if axis >= self.0.ndim() { @@ -600,6 +673,21 @@ impl HostBitTensor { let result = self.0.index_axis(axis, index); Ok(HostBitTensor(result, self.1)) } + + fn slice(&self, info: SliceInfo) -> Result { + if info.0.len() != self.0.ndim() { + return Err(Error::InvalidArgument(format!( + "The input dimension of `info` must match the array to be sliced. Used slice info dim {}, tensor had dim {}", + info.0.len(), + self.0.ndim() + ))); + } + let result = self + .0 + .slice(info) + .map_err(|e| Error::KernelError(e.to_string()))?; + Ok(HostBitTensor(result, self.1.clone())) + } } impl IndexAxisOp { diff --git a/moose/src/integer/ops.rs b/moose/src/integer/ops.rs index 270e883be..4ab6eed92 100644 --- a/moose/src/integer/ops.rs +++ b/moose/src/integer/ops.rs @@ -2,7 +2,8 @@ use super::*; use crate::error::{Error, Result}; use crate::execution::Session; use crate::floatingpoint::FloatTensor; -use crate::host::HostPlacement; +use crate::host::{HostPlacement, SliceInfo}; +use crate::replicated::ReplicatedPlacement; use crate::types::HostString; impl ConstantOp { @@ -19,6 +20,22 @@ impl ConstantOp { } } +impl LoadOp { + pub(crate) fn u64_kernel( + sess: &S, + plc: &HostPlacement, + key: m!(HostString), + query: m!(HostString), + ) -> Result> + where + HostString: KnownType, + HostPlacement: PlacementLoad, + { + let z = plc.load(sess, &key, &query); + Ok(AbstractUint64Tensor::Host(z)) + } +} + impl SaveOp { pub fn u64_kernel( sess: &S, @@ -97,3 +114,44 @@ impl CastOp { Ok(AbstractUint64Tensor::Host(plc.cast(sess, &x))) } } + +impl SliceOp { + pub(crate) fn u64_rep_kernel( + sess: &S, + plc: &ReplicatedPlacement, + info: SliceInfo, + x: AbstractUint64Tensor, + ) -> Result> + where + ReplicatedPlacement: PlacementSlice, + { + let x = match x { + AbstractUint64Tensor::Host(_v) => { + return Err(Error::UnimplementedOperator( + "Cannot share a HostUint64Tensor to a replicated placement".to_string(), + )); + } + AbstractUint64Tensor::Replicated(v) => v, + }; + let z = plc.slice(sess, info, &x); + Ok(AbstractUint64Tensor::Replicated(z)) + } + + pub(crate) fn u64_host_kernel( + sess: &S, + plc: &HostPlacement, + info: SliceInfo, + x: AbstractUint64Tensor, + ) -> Result> + where + HostPlacement: PlacementSlice, + HostPlacement: PlacementReveal, + { + let x = match x { + AbstractUint64Tensor::Replicated(v) => plc.reveal(sess, &v), + AbstractUint64Tensor::Host(v) => v, + }; + let z = plc.slice(sess, info, &x); + Ok(AbstractUint64Tensor::Host(z)) + } +} diff --git a/moose/src/kernels/indexing.rs b/moose/src/kernels/indexing.rs index 550910a41..6d4474356 100644 --- a/moose/src/kernels/indexing.rs +++ b/moose/src/kernels/indexing.rs @@ -9,7 +9,6 @@ modelled_kernel! { PlacementIndexAxis::index_axis, IndexAxisOp{axis: usize, index: usize}, [ (HostPlacement, (BooleanTensor) -> BooleanTensor => [concrete] Self::bool_host_kernel), - (HostPlacement, (Tensor) -> Tensor => [concrete] Self::logical_host_kernel), (HostPlacement, (Float32Tensor) -> Float32Tensor => [concrete] Self::float_host_kernel), (HostPlacement, (Float64Tensor) -> Float64Tensor => [concrete] Self::float_host_kernel), (HostPlacement, (Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_host_kernel), @@ -21,6 +20,7 @@ modelled_kernel! { (HostPlacement, (HostFloat64Tensor) -> HostFloat64Tensor => [runtime] Self::host_float_kernel), (HostPlacement, (HostRing64Tensor) -> HostRing64Tensor => [runtime] Self::host_ring_kernel), (HostPlacement, (HostRing128Tensor) -> HostRing128Tensor => [runtime] Self::host_ring_kernel), + (HostPlacement, (Tensor) -> Tensor => [concrete] Self::logical_host_kernel), (ReplicatedPlacement, (BooleanTensor) -> BooleanTensor => [concrete] Self::bool_rep_kernel), (ReplicatedPlacement, (Tensor) -> Tensor => [concrete] Self::logical_rep_kernel), (ReplicatedPlacement, (Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_rep_kernel), @@ -57,11 +57,39 @@ pub trait PlacementSlice { modelled_kernel! { PlacementSlice::slice, SliceOp{slice: SliceInfo}, [ - (HostPlacement, (Shape) -> Shape => [concrete] Self::logical_host_shape), + // runtime kernels (HostPlacement, (HostShape) -> HostShape => [runtime] Self::shape_kernel), - (HostPlacement, (HostRing64Tensor) -> HostRing64Tensor => [runtime] Self::host_kernel), - (HostPlacement, (HostRing128Tensor) -> HostRing128Tensor => [runtime] Self::host_kernel), - (ReplicatedPlacement, (Shape) -> Shape => [concrete] Self::logical_rep_shape), - (ReplicatedPlacement, (ReplicatedShape) -> ReplicatedShape => [concrete] Self::rep_kernel), + (HostPlacement, (HostBitTensor) -> HostBitTensor => [runtime] Self::host_bit_kernel), + (HostPlacement, (HostFloat32Tensor) -> HostFloat32Tensor => [runtime] Self::host_generic_kernel), + (HostPlacement, (HostFloat64Tensor) -> HostFloat64Tensor => [runtime] Self::host_generic_kernel), + (HostPlacement, (HostRing64Tensor) -> HostRing64Tensor => [runtime] Self::host_ring_kernel), + (HostPlacement, (HostRing128Tensor) -> HostRing128Tensor => [runtime] Self::host_ring_kernel), + (HostPlacement, (HostUint64Tensor) -> HostUint64Tensor => [runtime] Self::host_generic_kernel), + // host lowering kernels + (HostPlacement, (BooleanTensor) -> BooleanTensor => [concrete] Self::bool_host_kernel), + (HostPlacement, (Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_host_kernel), + (HostPlacement, (Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_host_kernel), + (HostPlacement, (Float32Tensor) -> Float32Tensor => [concrete] Self::float_host_kernel), + (HostPlacement, (Float64Tensor) -> Float64Tensor => [concrete] Self::float_host_kernel), + (HostPlacement, (HostFixed64Tensor) -> HostFixed64Tensor => [concrete] Self::host_fixed_kernel), + (HostPlacement, (HostFixed128Tensor) -> HostFixed128Tensor => [concrete] Self::host_fixed_kernel), + (HostPlacement, (Shape) -> Shape => [concrete] Self::logical_host_shape), + (HostPlacement, (Tensor) -> Tensor => [concrete] Self::logical_host_kernel), + (HostPlacement, (Uint64Tensor) -> Uint64Tensor => [concrete] Self::u64_host_kernel), + // replicated kernels + (ReplicatedPlacement, (ReplicatedBitTensor) -> ReplicatedBitTensor => [concrete] Self::rep_ring_kernel), + (ReplicatedPlacement, (ReplicatedRing64Tensor) -> ReplicatedRing64Tensor => [concrete] Self::rep_ring_kernel), + (ReplicatedPlacement, (ReplicatedRing128Tensor) -> ReplicatedRing128Tensor => [concrete] Self::rep_ring_kernel), + (ReplicatedPlacement, (ReplicatedShape) -> ReplicatedShape => [concrete] Self::rep_shape_kernel), + // replicated lowering kernels + (ReplicatedPlacement, (BooleanTensor) -> BooleanTensor => [concrete] Self::bool_rep_kernel), + (ReplicatedPlacement, (Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_rep_kernel), + (ReplicatedPlacement, (Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_rep_kernel), + (ReplicatedPlacement, (ReplicatedFixed64Tensor) -> ReplicatedFixed64Tensor => [concrete] Self::repfixed_kernel), + (ReplicatedPlacement, (ReplicatedFixed128Tensor) -> ReplicatedFixed128Tensor => [concrete] Self::repfixed_kernel), + (ReplicatedPlacement, (ReplicatedUint64Tensor) -> ReplicatedUint64Tensor => [concrete] Self::rep_uint_kernel), + (ReplicatedPlacement, (Shape) -> Shape => [concrete] Self::logical_rep_shape), + (ReplicatedPlacement, (Tensor) -> Tensor => [concrete] Self::logical_rep_kernel), + (ReplicatedPlacement, (Uint64Tensor) -> Uint64Tensor => [concrete] Self::u64_rep_kernel), ] } diff --git a/moose/src/kernels/io.rs b/moose/src/kernels/io.rs index 57d14ed59..afd841331 100644 --- a/moose/src/kernels/io.rs +++ b/moose/src/kernels/io.rs @@ -144,8 +144,10 @@ modelled_kernel! { (HostPlacement, (HostString, HostString) -> HostUint64Tensor => [runtime] Self::session_specific_kernel), (HostPlacement, (HostString, HostString) -> HostFixed64Tensor => [runtime] Self::session_specific_kernel), (HostPlacement, (HostString, HostString) -> HostFixed128Tensor => [runtime] Self::session_specific_kernel), + (HostPlacement, (HostString, HostString) -> BooleanTensor => [hybrid] Self::bool_kernel), (HostPlacement, (HostString, HostString) -> Float32Tensor => [hybrid] Self::float_kernel), (HostPlacement, (HostString, HostString) -> Float64Tensor => [hybrid] Self::float_kernel), + (HostPlacement, (HostString, HostString) -> Uint64Tensor => [hybrid] Self::u64_kernel), (HostPlacement, (HostString, HostString) -> Tensor => [hybrid] custom |op| { use crate::logical::{AbstractTensor, TensorDType}; match op.sig.ret() { @@ -155,6 +157,12 @@ modelled_kernel! { Ty::Tensor(TensorDType::Float64) => Ok(Box::new(move |sess, plc, key, query| { Self::logical_kernel::<_, Float64Tensor>(sess, plc, key, query).map(AbstractTensor::Float64) })), + Ty::Tensor(TensorDType::Bool) => Ok(Box::new(move |sess, plc, key, query| { + Self::logical_kernel::<_, BooleanTensor>(sess, plc, key, query).map(AbstractTensor::Bool) + })), + Ty::Tensor(TensorDType::Uint64) => Ok(Box::new(move |sess, plc, key, query| { + Self::logical_kernel::<_, Uint64Tensor>(sess, plc, key, query).map(AbstractTensor::Uint64) + })), other => { return Err(Error::UnimplementedOperator( format!("Cannot load tensor of type {:?}", other))) diff --git a/moose/src/logical/ops.rs b/moose/src/logical/ops.rs index 702c80956..b560f6567 100644 --- a/moose/src/logical/ops.rs +++ b/moose/src/logical/ops.rs @@ -2223,6 +2223,57 @@ impl SliceOp { } } + pub(crate) fn logical_host_kernel< + S: Session, + Fixed64T, + Fixed128T, + Float32T, + Float64T, + BoolT, + Uint64T, + >( + sess: &S, + plc: &HostPlacement, + slice: SliceInfo, + x: AbstractTensor, + ) -> Result> + where + HostPlacement: PlacementSlice, + HostPlacement: PlacementSlice, + HostPlacement: PlacementSlice, + HostPlacement: PlacementSlice, + HostPlacement: PlacementSlice, + HostPlacement: PlacementSlice, + { + use AbstractTensor::*; + match x { + Float32(x) => { + let result = plc.slice(sess, slice, &x); + Ok(Float32(result)) + } + Float64(x) => { + let result = plc.slice(sess, slice, &x); + Ok(Float64(result)) + } + Fixed64(x) => { + let result = plc.slice(sess, slice, &x); + Ok(Fixed64(result)) + } + Fixed128(x) => { + let result = plc.slice(sess, slice, &x); + Ok(Fixed128(result)) + } + Bool(x) => { + let result = plc.slice(sess, slice, &x); + Ok(Bool(result)) + } + Uint64(x) => { + let result = plc.slice(sess, slice, &x); + Ok(Uint64(result)) + } + } + } + pub(crate) fn logical_rep_shape( sess: &S, plc: &ReplicatedPlacement, @@ -2242,6 +2293,51 @@ impl SliceOp { } } } + + pub(crate) fn logical_rep_kernel< + S: Session, + Fixed64T, + Fixed128T, + Float32T, + Float64T, + BoolT, + Uint64T, + >( + sess: &S, + plc: &ReplicatedPlacement, + info: SliceInfo, + x: AbstractTensor, + ) -> Result> + where + ReplicatedPlacement: PlacementSlice, + ReplicatedPlacement: PlacementSlice, + ReplicatedPlacement: PlacementSlice, + ReplicatedPlacement: PlacementSlice, + { + use AbstractTensor::*; + match x { + Fixed64(x) => { + let result = plc.slice(sess, info, &x); + Ok(Fixed64(result)) + } + Fixed128(x) => { + let result = plc.slice(sess, info, &x); + Ok(Fixed128(result)) + } + Bool(x) => { + let result = plc.slice(sess, info, &x); + Ok(Bool(result)) + } + Uint64(x) => { + let result = plc.slice(sess, info, &x); + Ok(Uint64(result)) + } + Float32(_) | Float64(_) => Err(Error::UnimplementedOperator(format!( + "Missing rep slice for {:?}", + &x.ty_desc(), + ))), + } + } } impl ConstantOp { diff --git a/moose/src/replicated/ops.rs b/moose/src/replicated/ops.rs index 3a0692330..5b64685b7 100644 --- a/moose/src/replicated/ops.rs +++ b/moose/src/replicated/ops.rs @@ -371,7 +371,7 @@ impl DiagOp { } impl SliceOp { - pub(crate) fn rep_kernel( + pub(crate) fn rep_shape_kernel( sess: &S, plc: &ReplicatedPlacement, slice: SliceInfo, @@ -394,6 +394,48 @@ impl SliceOp { shapes: [new_shape0, new_shape1, new_shape2], }) } + + pub(crate) fn rep_ring_kernel( + sess: &S, + plc: &ReplicatedPlacement, + info: SliceInfo, + x: RepTensor, + ) -> Result> + where + HostPlacement: PlacementSlice, + { + let (player0, player1, player2) = plc.host_placements(); + let RepTensor { + shares: [[x00, x10], [x11, x21], [x22, x02]], + } = &x; + + let z00 = player0.slice(sess, info.clone(), x00); + let z10 = player0.slice(sess, info.clone(), x10); + + let z11 = player1.slice(sess, info.clone(), x11); + let z21 = player1.slice(sess, info.clone(), x21); + + let z22 = player2.slice(sess, info.clone(), x22); + let z02 = player2.slice(sess, info, x02); + + Ok(RepTensor { + shares: [[z00, z10], [z11, z21], [z22, z02]], + }) + } + + pub(crate) fn rep_uint_kernel( + sess: &S, + rep: &ReplicatedPlacement, + info: SliceInfo, + x: RepUintTensor, + ) -> Result> + where + ReplicatedPlacement: PlacementSlice, + { + Ok(RepUintTensor { + tensor: rep.slice(sess, info, &x.tensor), + }) + } } impl ShlDimOp { diff --git a/pymoose/examples/linear-regression/linreg_test.py b/pymoose/examples/linear-regression/linreg_test.py index 2fb8398b9..a85ca11ab 100644 --- a/pymoose/examples/linear-regression/linreg_test.py +++ b/pymoose/examples/linear-regression/linreg_test.py @@ -77,7 +77,7 @@ def my_comp( # the past. For now, we've decided to implement squeeze and unsqueeze # ops instead. # But we have a feeling this issue will continue to come up! - bias_shape = edsl.slice(edsl.shape(X), begin=0, end=1) + bias_shape = edsl.shape(X)[0:1] bias = edsl.ones(bias_shape, dtype=edsl.float64) reshaped_bias = edsl.expand_dims(bias, 1) X_b = edsl.concatenate([reshaped_bias, X], axis=1) diff --git a/pymoose/pymoose/computation/operations.py b/pymoose/pymoose/computation/operations.py index 901aebfed..2440a9b0b 100644 --- a/pymoose/pymoose/computation/operations.py +++ b/pymoose/pymoose/computation/operations.py @@ -224,6 +224,11 @@ class SliceOperation(Operation): end: int +@dataclass +class StridedSliceOperation(Operation): + slices: Optional[Tuple[slice]] + + @dataclass class BitwiseOrOperation(Operation): pass diff --git a/pymoose/pymoose/computation/utils.py b/pymoose/pymoose/computation/utils.py index e35a7b637..ab99214d3 100644 --- a/pymoose/pymoose/computation/utils.py +++ b/pymoose/pymoose/computation/utils.py @@ -50,6 +50,7 @@ ops.SaveOperation, ops.ShapeOperation, ops.SliceOperation, + ops.StridedSliceOperation, ops.SqueezeOperation, ops.SqrtOperation, ops.SubOperation, @@ -123,6 +124,13 @@ def _encode(val): "items": val.flatten().tolist(), "shape": list(val.shape), } + elif isinstance(val, slice): + return { + "__type__": "PySlice", + "start": val.start, + "step": val.step, + "stop": val.stop, + } raise NotImplementedError(f"{type(val)}") diff --git a/pymoose/pymoose/edsl/__init__.py b/pymoose/pymoose/edsl/__init__.py index 5e268e1ea..5b29b48b4 100644 --- a/pymoose/pymoose/edsl/__init__.py +++ b/pymoose/pymoose/edsl/__init__.py @@ -49,11 +49,12 @@ from pymoose.edsl.base import save from pymoose.edsl.base import shape from pymoose.edsl.base import sigmoid -from pymoose.edsl.base import slice +from pymoose.edsl.base import sliced from pymoose.edsl.base import softmax from pymoose.edsl.base import sqrt from pymoose.edsl.base import square from pymoose.edsl.base import squeeze +from pymoose.edsl.base import strided_slice from pymoose.edsl.base import sub from pymoose.edsl.base import sum from pymoose.edsl.base import transpose @@ -85,6 +86,7 @@ float64, FloatType, host_placement, + greater, identity, index_axis, int32, @@ -107,11 +109,12 @@ reshape, ring64, save, - slice, shape, + sliced, softmax, square, squeeze, + strided_slice, sigmoid, sub, sum, diff --git a/pymoose/pymoose/edsl/base.py b/pymoose/pymoose/edsl/base.py index 2c8cbb8bf..fc7a7ab78 100644 --- a/pymoose/pymoose/edsl/base.py +++ b/pymoose/pymoose/edsl/base.py @@ -10,6 +10,11 @@ from pymoose.computation import types as ty from pymoose.computation import values +try: # post python 3.10 + from types import EllipsisType +except ImportError: + EllipsisType = type(...) + CURRENT_PLACEMENT: List = [] _NUMPY_DTYPES_MAP = { np.uint32: dtypes.uint32, @@ -102,6 +107,51 @@ class Expression: def __hash__(self): return id(self) + def __getitem__(self, slice_spec): + # TODO explicitly construe placement from + # global placement context and/or self.placement? + assert isinstance(self.vtype, (ty.TensorType, ty.ShapeType, ty.AesTensorType)) + assert isinstance(slice_spec, (slice, EllipsisType, list, tuple)) + if isinstance(self.vtype, (ty.TensorType, ty.AesTensorType)): + + # turn single entry to a list of entries + if isinstance(slice_spec, (slice, EllipsisType)): + slice_spec = (slice_spec,) + + assert isinstance(slice_spec, (list, tuple)) + slice_rewrite = [] + for cur_slice in slice_spec: + assert isinstance(cur_slice, (slice, EllipsisType)) + if isinstance(cur_slice, EllipsisType): + slice_rewrite.append(slice(None, None, None)) + elif isinstance(cur_slice, slice): + slice_rewrite.append(cur_slice) + else: + raise ValueError( + "Indexing with other types different than Ellipsis and slice " + "is not yet supported." + ) + return strided_slice(self, slices=slice_rewrite) + elif isinstance(self.vtype, ty.ShapeType): + if isinstance(slice_spec, (tuple, list)): + if len(slice_spec) > 2: + raise ValueError( + "Indexing ShapeType requires a simple slice, including only " + "`start` & `stop` slice values." + ) + begin, end = slice_spec + assert isinstance(begin, int) and isinstance(end, int) + elif isinstance(slice_spec, slice): + if slice_spec.step is not None: + raise ValueError( + "Indexing ShapeType requires a simple slice, including only " + "`start` & `stop` slice values." + ) + begin, end = slice_spec.start, slice_spec.stop + return sliced(self, begin, end) + else: + raise IndexError(f"Expression of vtype {self.vtype} is not slice-able.") + @dataclass class AddNExpression(Expression): @@ -339,6 +389,14 @@ def __hash__(self): return id(self) +@dataclass +class StridedSliceExpression(Expression): + slices: Optional[Tuple[slice]] + + def __hash__(self): + return id(self) + + @dataclass class LessExpression(Expression): def __hash__(self): @@ -364,7 +422,7 @@ def __hash__(self): def add_n(arrays, placement=None): - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) if not isinstance(arrays, (tuple, list)): raise ValueError( "Inputs to `add_n` must be array-like, found argument " @@ -390,12 +448,12 @@ def add_n(arrays, placement=None): def identity(x, placement=None): - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) return IdentityExpression(placement=placement, inputs=[x], vtype=x.vtype) def concatenate(arrays, axis=0, placement=None): - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) if not isinstance(arrays, (tuple, list)): raise ValueError( "Inputs to `concatenate` must be array-like, found argument " @@ -424,7 +482,7 @@ def concatenate(arrays, axis=0, placement=None): def maximum(arrays, placement=None): - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) if not isinstance(arrays, (tuple, list)): raise ValueError( "Inputs to `concatenate` must be array-like, found argument " @@ -451,7 +509,7 @@ def maximum(arrays, placement=None): def decrypt(key, ciphertext, placement=None): - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) # key expr typecheck if not isinstance(key.vtype, ty.AesKeyType): @@ -477,7 +535,7 @@ def decrypt(key, ciphertext, placement=None): def constant(value, dtype=None, vtype=None, placement=None): - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) vtype = _maybe_lift_dtype_to_tensor_vtype(dtype, vtype) if isinstance(value, np.ndarray): @@ -516,7 +574,7 @@ def constant(value, dtype=None, vtype=None, placement=None): def add(lhs, rhs, placement=None): assert isinstance(lhs, Expression) assert isinstance(rhs, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) vtype = _assimilate_arg_vtypes(lhs.vtype, rhs.vtype, "add") return BinaryOpExpression( op_name="add", placement=placement, inputs=[lhs, rhs], vtype=vtype @@ -526,7 +584,7 @@ def add(lhs, rhs, placement=None): def sub(lhs, rhs, placement=None): assert isinstance(lhs, Expression) assert isinstance(rhs, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) vtype = _assimilate_arg_vtypes(lhs.vtype, rhs.vtype, "sub") return BinaryOpExpression( op_name="sub", placement=placement, inputs=[lhs, rhs], vtype=vtype @@ -536,7 +594,7 @@ def sub(lhs, rhs, placement=None): def mul(lhs, rhs, placement=None): assert isinstance(lhs, Expression) assert isinstance(rhs, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) vtype = _assimilate_arg_vtypes(lhs.vtype, rhs.vtype, "mul") return BinaryOpExpression( op_name="mul", placement=placement, inputs=[lhs, rhs], vtype=vtype @@ -546,7 +604,7 @@ def mul(lhs, rhs, placement=None): def dot(lhs, rhs, placement=None): assert isinstance(lhs, Expression) assert isinstance(rhs, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) vtype = _assimilate_arg_vtypes(lhs.vtype, rhs.vtype, "dot") return BinaryOpExpression( op_name="dot", placement=placement, inputs=[lhs, rhs], vtype=vtype @@ -556,7 +614,7 @@ def dot(lhs, rhs, placement=None): def div(lhs, rhs, placement=None): assert isinstance(lhs, Expression) assert isinstance(rhs, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) vtype = _assimilate_arg_vtypes(lhs.vtype, rhs.vtype, "div") return BinaryOpExpression( op_name="div", placement=placement, inputs=[lhs, rhs], vtype=vtype @@ -566,7 +624,7 @@ def div(lhs, rhs, placement=None): def less(lhs, rhs, placement=None): assert isinstance(lhs, Expression) assert isinstance(rhs, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) return BinaryOpExpression( op_name="less", placement=placement, @@ -590,7 +648,7 @@ def greater(lhs, rhs, placement=None): def logical_or(lhs, rhs, placement=None): assert isinstance(lhs, Expression) assert isinstance(rhs, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) vtype = _assimilate_arg_vtypes(lhs.vtype, rhs.vtype, "or") return BinaryOpExpression( op_name="or", placement=placement, inputs=[lhs, rhs], vtype=vtype @@ -599,7 +657,7 @@ def logical_or(lhs, rhs, placement=None): def inverse(x, placement=None): assert isinstance(x, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) vtype = x.vtype if not isinstance(vtype, ty.TensorType): raise ValueError( @@ -624,7 +682,7 @@ def expand_dims(x, axis, placement=None): ) elif isinstance(axis, int): axis = [axis] - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) return ExpandDimsExpression( placement=placement, inputs=[x], axis=axis, vtype=x.vtype ) @@ -632,45 +690,45 @@ def expand_dims(x, axis, placement=None): def squeeze(x, axis=None, placement=None): assert isinstance(x, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) return SqueezeExpression(placement=placement, inputs=[x], axis=axis, vtype=x.vtype) def ones(shape, dtype, placement=None): assert isinstance(shape, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) vtype = ty.TensorType(dtype) return OnesExpression(placement=placement, inputs=[shape], vtype=vtype) def zeros(shape, dtype, placement=None): assert isinstance(shape, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) vtype = ty.TensorType(dtype) return ZerosExpression(placement=placement, inputs=[shape], vtype=vtype) def square(x, placement=None): assert isinstance(x, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) return mul(x, x, placement=placement) def sum(x, axis=None, placement=None): assert isinstance(x, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) return SumExpression(placement=placement, inputs=[x], axis=axis, vtype=x.vtype) def mean(x, axis=None, placement=None): assert isinstance(x, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) return MeanExpression(placement=placement, inputs=[x], axis=axis, vtype=x.vtype) def exp(x, placement=None): assert isinstance(x, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) return ExpExpression(placement=placement, inputs=[x], vtype=x.vtype) @@ -682,7 +740,7 @@ def sqrt(x, placement=None): def sigmoid(x, placement=None): assert isinstance(x, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) return SigmoidExpression(placement=placement, inputs=[x], vtype=x.vtype) @@ -694,7 +752,7 @@ def relu(x, placement=None): def softmax(x, axis, upmost_index, placement=None): assert isinstance(x, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) return SoftmaxExpression( placement=placement, inputs=[x], @@ -706,7 +764,7 @@ def softmax(x, axis, upmost_index, placement=None): def argmax(x, axis, upmost_index, placement=None): assert isinstance(x, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) return ArgmaxExpression( placement=placement, inputs=[x], @@ -718,7 +776,7 @@ def argmax(x, axis, upmost_index, placement=None): def log(x, placement=None): assert isinstance(x, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) return LogExpression( placement=placement, inputs=[x], @@ -728,13 +786,13 @@ def log(x, placement=None): def log2(x, placement=None): assert isinstance(x, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) return Log2Expression(placement=placement, inputs=[x], vtype=x.vtype) def shape(x, placement=None): assert isinstance(x, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) return ShapeExpression(placement=placement, inputs=[x], vtype=ty.ShapeType()) @@ -751,25 +809,39 @@ def index_axis(x, axis, index, placement=None): f"{index} of type {type(index)}" ) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) return IndexAxisExpression( placement=placement, inputs=[x], axis=axis, index=index, vtype=x.vtype ) -def slice(x, begin, end, placement=None): +def sliced(x, begin, end, placement=None): assert isinstance(x, Expression) assert isinstance(begin, int) assert isinstance(end, int) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) return SliceExpression( placement=placement, inputs=[x], begin=begin, end=end, vtype=x.vtype ) +def strided_slice(x, slices, placement=None): + assert isinstance(x, Expression) + assert isinstance(slices, (tuple, list)) + placement = _materialize_placement_arg(placement) + for s in slices: + if not isinstance(s, slice): + raise ValueError( + "`slices` argument must a list/tuple of slices, found " f"{type(s)}" + ) + return StridedSliceExpression( + placement=placement, inputs=[x], slices=slices, vtype=x.vtype + ) + + def transpose(x, axes=None, placement=None): assert isinstance(x, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) return TransposeExpression( placement=placement, inputs=[x], axes=axes, vtype=x.vtype ) @@ -777,7 +849,7 @@ def transpose(x, axes=None, placement=None): def atleast_2d(x, to_column_vector=False, placement=None): assert isinstance(x, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) return AtLeast2DExpression( placement=placement, inputs=[x], @@ -793,13 +865,13 @@ def reshape(x, shape, placement=None): values.ShapeConstant(value=shape), vtype=ty.ShapeType(), placement=placement ) assert isinstance(shape, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) return ReshapeExpression(placement=placement, inputs=[x, shape], vtype=x.vtype) def abs(x, placement=None): assert isinstance(x, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) return AbsExpression(placement=placement, inputs=[x], vtype=x.vtype) @@ -813,7 +885,7 @@ def mux(selector, x, y, placement=None): assert isinstance(y, Expression) assert isinstance(y.vtype, ty.TensorType), y.vtype assert y.vtype.dtype.is_fixedpoint, y.vtype.dtype - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) assert isinstance(placement, ReplicatedPlacementExpression) vtype = _assimilate_arg_vtypes(x.vtype, y.vtype, "mux") return MuxExpression(placement=placement, inputs=[selector, x, y], vtype=vtype) @@ -821,7 +893,7 @@ def mux(selector, x, y, placement=None): def cast(x, dtype, placement=None): assert isinstance(x, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) if not isinstance(x.vtype, ty.TensorType): raise ValueError( @@ -860,7 +932,7 @@ def cast(x, dtype, placement=None): def load(key, query="", dtype=None, vtype=None, placement=None): - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) vtype = _maybe_lift_dtype_to_tensor_vtype(dtype, vtype) if isinstance(key, str): key = constant(key, placement=placement, vtype=ty.StringType()) @@ -893,7 +965,7 @@ def load(key, query="", dtype=None, vtype=None, placement=None): def save(key, value, placement=None): assert isinstance(value, Expression) - placement = placement or get_current_placement() + placement = _materialize_placement_arg(placement) if isinstance(key, str): key = constant(key, placement=placement, vtype=ty.StringType()) elif isinstance(key, Argument) and key.vtype not in [ty.StringType(), None]: @@ -918,6 +990,28 @@ def __init__(self, func): self.func = func +def _assimilate_arg_dtypes(lhs_vtype, rhs_vtype, fn_name): + lhs_dtype = lhs_vtype.dtype + rhs_dtype = rhs_vtype.dtype + if lhs_dtype != rhs_dtype: + raise ValueError( + f"Function `{fn_name}` expected arguments of similar dtype: " + f"found mismatched dtypes `{lhs_dtype}` and `{rhs_dtype}`." + ) + return lhs_vtype + + +def _assimilate_arg_vtypes(lhs_vtype, rhs_vtype, fn_name): + if isinstance(lhs_vtype, ty.TensorType) and isinstance(rhs_vtype, ty.TensorType): + return _assimilate_arg_dtypes(lhs_vtype, rhs_vtype, fn_name) + if lhs_vtype != rhs_vtype: + raise ValueError( + f"Function `{fn_name}` expected arguments of similar type: " + f"found mismatched types `{lhs_vtype}` and `{rhs_vtype}`." + ) + return lhs_vtype + + def _check_tensor_type_arg_consistency(dtype, vtype): if isinstance(vtype, ty.TensorType) and vtype.dtype != dtype: raise ValueError( @@ -926,6 +1020,12 @@ def _check_tensor_type_arg_consistency(dtype, vtype): ) +def _materialize_placement_arg(plc): + plc = plc or get_current_placement() + assert isinstance(plc, PlacementExpression) + return plc + + def _maybe_lift_dtype_to_tensor_vtype(dtype, vtype): if dtype is None and vtype is None: return @@ -956,25 +1056,3 @@ def _interpret_numeric_value(value, vtype, fallback_vtype): "Cannot interpret numeric constant as non-numeric type {vtype}." ) return value, vtype - - -def _assimilate_arg_vtypes(lhs_vtype, rhs_vtype, fn_name): - if isinstance(lhs_vtype, ty.TensorType) and isinstance(rhs_vtype, ty.TensorType): - return _assimilate_arg_dtypes(lhs_vtype, rhs_vtype, fn_name) - if lhs_vtype != rhs_vtype: - raise ValueError( - f"Function `{fn_name}` expected arguments of similar type: " - f"found mismatched types `{lhs_vtype}` and `{rhs_vtype}`." - ) - return lhs_vtype - - -def _assimilate_arg_dtypes(lhs_vtype, rhs_vtype, fn_name): - lhs_dtype = lhs_vtype.dtype - rhs_dtype = rhs_vtype.dtype - if lhs_dtype != rhs_dtype: - raise ValueError( - f"Function `{fn_name}` expected arguments of similar dtype: " - f"found mismatched dtypes `{lhs_dtype}` and `{rhs_dtype}`." - ) - return lhs_vtype diff --git a/pymoose/pymoose/edsl/tracer.py b/pymoose/pymoose/edsl/tracer.py index 0222ce35c..78381df10 100644 --- a/pymoose/pymoose/edsl/tracer.py +++ b/pymoose/pymoose/edsl/tracer.py @@ -680,6 +680,24 @@ def visit_SliceExpression(self, slice_expression): ) ) + def visit_StridedSliceExpression(self, slice_expression): + assert isinstance(slice_expression, expr.StridedSliceExpression) + (x_expression,) = slice_expression.inputs + x_operation = self.visit(x_expression) + placement = self.visit_placement_expression(slice_expression.placement) + return self.computation.add_operation( + ops.StridedSliceOperation( + placement_name=placement.name, + name=self.get_fresh_name("strided_slice"), + inputs={"x": x_operation.name}, + slices=slice_expression.slices, + signature=ops.OpSignature( + input_types={"x": x_operation.return_type}, + return_type=slice_expression.vtype, + ), + ) + ) + def visit_ShapeExpression(self, shape_expression): assert isinstance(shape_expression, expr.ShapeExpression) (x_expression,) = shape_expression.inputs diff --git a/pymoose/pymoose/predictors/linear_predictor.py b/pymoose/pymoose/predictors/linear_predictor.py index 9f51880ed..451ea362d 100644 --- a/pymoose/pymoose/predictors/linear_predictor.py +++ b/pymoose/pymoose/predictors/linear_predictor.py @@ -30,9 +30,7 @@ def post_transform(self, y): @classmethod def bias_trick(cls, x, plc, dtype): - bias_shape = edsl.slice( - edsl.shape(x, placement=plc), begin=0, end=1, placement=plc - ) + bias_shape = edsl.shape(x, placement=plc)[0:1] bias = edsl.ones(bias_shape, dtype=edsl.float64, placement=plc) reshaped_bias = edsl.expand_dims(bias, 1, placement=plc) return edsl.cast(reshaped_bias, dtype=dtype, placement=plc) diff --git a/pymoose/rust_integration_tests/slicing_test.py b/pymoose/rust_integration_tests/slicing_test.py new file mode 100644 index 000000000..a9eeebc7e --- /dev/null +++ b/pymoose/rust_integration_tests/slicing_test.py @@ -0,0 +1,187 @@ +import argparse +import logging + +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized + +from pymoose import edsl +from pymoose.computation import types as ty +from pymoose.logger import get_logger +from pymoose.testing import LocalMooseRuntime + + +def compile_and_run(traced_slice_comp, x_arg): + storage = { + "alice": {}, + "carole": {}, + "bob": {"x_arg": x_arg}, + } + + runtime = LocalMooseRuntime(storage_mapping=storage) + _ = runtime.evaluate_computation( + computation=traced_slice_comp, + role_assignment={"alice": "alice", "bob": "bob", "carole": "carole"}, + arguments={"x_uri": "x_arg"}, + ) + + x_sliced = runtime.read_value_from_storage("bob", "sliced") + return x_sliced + + +class SliceExample(parameterized.TestCase): + def _setup_comp(self, slice_spec, to_dtype): + bob = edsl.host_placement(name="bob") + + @edsl.computation + def my_comp( + x_uri: edsl.Argument(placement=bob, vtype=ty.StringType()), + ): + with bob: + x = edsl.load(x_uri, dtype=to_dtype)[slice_spec] + res = (edsl.save("sliced", x),) + + return res + + return my_comp + + @parameterized.parameters( + ( + [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], + edsl.float32, + (slice(1, None, None), slice(1, None, None), slice(1, None, None)), + ), + ( + [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], + edsl.float64, + (slice(None, None, None), slice(None, None, None), slice(1, None, None)), + ), + ( + [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], + edsl.uint64, + (slice(None, None, None), slice(None, None, None), slice(1, None, None)), + ), + ) + def test_slice_types_execute(self, x, to_dtype, slice_spec): + comp = self._setup_comp(slice_spec, to_dtype) + traced_slice_comp = edsl.trace(comp) + + x_arg = np.array(x, dtype=to_dtype.numpy_dtype) + x_from_runtime = compile_and_run(traced_slice_comp, x_arg) + + expected_npy = x_arg[slice_spec] + np.testing.assert_equal(x_from_runtime, expected_npy) + + def test_basic(self): + def setup_basic_comp(): + bob = edsl.host_placement(name="bob") + + @edsl.computation + def my_comp( + x_uri: edsl.Argument(placement=bob, vtype=ty.StringType()), + ): + with bob: + x = edsl.load(x_uri, dtype=edsl.float64)[1:, 1:2] + res = (edsl.save("sliced", x),) + return res + + return my_comp + + comp = setup_basic_comp() + traced_slice_comp = edsl.trace(comp) + x_arg = np.array( + [[1, 23.0, 321, 30.321, 321], [32.0, 321, 5, 3.0, 32.0]], dtype=np.float64 + ) + x_from_runtime = compile_and_run(traced_slice_comp, x_arg) + np.testing.assert_equal(x_from_runtime, x_arg[1:, 1:2]) + + def test_basic_colons(self): + def setup_basic_comp(): + bob = edsl.host_placement(name="bob") + + @edsl.computation + def my_comp( + x_uri: edsl.Argument(placement=bob, vtype=ty.StringType()), + ): + with bob: + x = edsl.load(x_uri, dtype=edsl.float64)[:, 2:4] + res = (edsl.save("sliced", x),) + return res + + return my_comp + + comp = setup_basic_comp() + traced_slice_comp = edsl.trace(comp) + x_arg = np.array( + [[1, 23.0, 321, 30.321, 321], [32.0, 321, 5, 3.0, 32.0]], dtype=np.float64 + ) + x_from_runtime = compile_and_run(traced_slice_comp, x_arg) + np.testing.assert_equal(x_from_runtime, x_arg[:, 2:4]) + + def test_rep_basic(self): + def setup_basic_comp(): + alice = edsl.host_placement(name="alice") + bob = edsl.host_placement(name="bob") + carole = edsl.host_placement(name="carole") + rep = edsl.replicated_placement(name="rep", players=[alice, bob, carole]) + + @edsl.computation + def my_comp( + x_uri: edsl.Argument(placement=bob, vtype=ty.StringType()), + ): + with bob: + x = edsl.load(x_uri, dtype=edsl.float64) + x_fixed = edsl.cast(x, dtype=edsl.fixed(8, 27)) + + with rep: + x_sliced_rep = x_fixed[1:, 1:2] + + with bob: + x_sliced_host = edsl.cast(x_sliced_rep, dtype=edsl.float64) + res = edsl.save("sliced", x_sliced_host) + return res + + return my_comp + + comp = setup_basic_comp() + traced_slice_comp = edsl.trace(comp) + x_arg = np.array( + [[1, 23.0, 321, 30.321, 321], [32.0, 321, 5, 3.0, 32.0]], dtype=np.float64 + ) + x_from_runtime = compile_and_run(traced_slice_comp, x_arg) + np.testing.assert_equal(x_from_runtime, x_arg[1:, 1:2]) + + def test_shape_slice(self): + alice = edsl.host_placement("alice") + + @edsl.computation + def my_comp(x: edsl.Argument(alice, edsl.float64)): + with alice: + x_shape = edsl.shape(x) + sliced_shape = x_shape[1:3] + ones_res = edsl.ones(sliced_shape, edsl.float64) + res = edsl.save("ones", ones_res) + return res + + traced_slice_comp = edsl.trace(my_comp) + x_arg = np.ones([4, 3, 5], dtype=np.float64) + runtime = LocalMooseRuntime(storage_mapping={"alice": {}}) + _ = runtime.evaluate_computation( + computation=traced_slice_comp, + role_assignment={"alice": "alice"}, + arguments={"x": x_arg}, + ) + + y_ones = runtime.read_value_from_storage("alice", "ones") + assert y_ones.shape == (3, 5) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="comparison example") + parser.add_argument("--verbose", action="store_true") + args = parser.parse_args() + + if args.verbose: + get_logger().setLevel(level=logging.DEBUG) + + absltest.main() diff --git a/pymoose/src/computation.rs b/pymoose/src/computation.rs index bfe4ce9d1..de9cd1544 100644 --- a/pymoose/src/computation.rs +++ b/pymoose/src/computation.rs @@ -31,6 +31,7 @@ enum PyOperation { ShapeOperation(PyShapeOperation), IndexAxisOperation(PyIndexAxisOperation), SliceOperation(PySliceOperation), + StridedSliceOperation(PyStridedSliceOperation), OnesOperation(PyOnesOperation), ZerosOperation(PyZerosOperation), ConcatenateOperation(PyConcatenateOperation), @@ -118,6 +119,13 @@ struct PyOpSignature { return_type: PyValueType, } +#[derive(Deserialize, Debug)] +struct PySlice { + start: Option, + step: Option, + stop: Option, +} + #[derive(Deserialize, Debug)] struct PyAbsOperation { name: String, @@ -253,6 +261,15 @@ struct PySliceOperation { end: u32, } +#[derive(Deserialize, Debug)] +struct PyStridedSliceOperation { + name: String, + inputs: Inputs, + placement_name: String, + signature: PyOpSignature, + slices: Vec, +} + #[derive(Deserialize, Debug)] struct PyOnesOperation { name: String, @@ -979,6 +996,33 @@ impl TryFrom for Computation { name: op.name.clone(), placement: map_placement(&placements, &op.placement_name)?, }), + StridedSliceOperation(op) => { + let slices: Vec<_> = op + .slices + .iter() + .map(|slice| SliceInfoElem { + start: slice.start.unwrap_or(0), + step: slice.step, + end: slice.stop, + }) + .collect(); + Ok(Operation { + kind: SliceOp { + sig: map_signature( + &op.signature, + &placements, + &op.placement_name, + &["x"], + )?, + slice: SliceInfo(slices), + } + .into(), + inputs: map_inputs(&op.inputs, &["x"]) + .with_context(|| format!("Failed at op {:?}", op))?, + name: op.name.clone(), + placement: map_placement(&placements, &op.placement_name)?, + }) + } OnesOperation(op) => Ok(Operation { kind: OnesOp { sig: map_signature(