Skip to content

Commit

Permalink
Better slicing (#1059)
Browse files Browse the repository at this point in the history
* added better slice draft

* added slicing at the logical level

* first slice example works

* from sliced to slice again

* rename better_slice to strided_slice

* small cleanup of edsl/base.py

* first stab at pythonic/numpy-style slicing of Expressions

* replaced old slice with strided slice

* added slice for host fixed point at the logical layer

* added boolean kernels, but still need to do the implementation on the bit array level

* added rep kernels

* added slice support for uint64, added few missing kernels for loading constants

* more tests for slicing

* add ShapeType handling in Expression.__getitem__

* added missing kernel

* linting

* fixed flake8 issues

Co-authored-by: jvmncs <[email protected]>
Co-authored-by: jvmncs <[email protected]>
  • Loading branch information
3 people authored Apr 28, 2022
1 parent 92abffb commit 6b6b89e
Show file tree
Hide file tree
Showing 19 changed files with 892 additions and 89 deletions.
56 changes: 55 additions & 1 deletion moose/src/boolean/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::computation::*;
use crate::error::{Error, Result};
use crate::execution::Session;
use crate::floatingpoint::FloatTensor;
use crate::host::HostPlacement;
use crate::host::{HostPlacement, SliceInfo};
use crate::integer::AbstractUint64Tensor;
use crate::kernels::*;
use crate::replicated::ReplicatedPlacement;
Expand Down Expand Up @@ -281,6 +281,60 @@ impl IndexAxisOp {
}
}

impl SliceOp {
pub(crate) fn bool_rep_kernel<S: Session, HostT, RepT>(
sess: &S,
plc: &ReplicatedPlacement,
info: SliceInfo,
x: BoolTensor<HostT, RepT>,
) -> Result<BoolTensor<HostT, RepT>>
where
ReplicatedPlacement: PlacementSlice<S, RepT, RepT>,
ReplicatedPlacement: PlacementShare<S, HostT, RepT>,
{
let x = match x {
BoolTensor::Host(v) => plc.share(sess, &v),
BoolTensor::Replicated(v) => v,
};
let z = plc.slice(sess, info, &x);
Ok(BoolTensor::Replicated(z))
}

pub(crate) fn bool_host_kernel<S: Session, HostT, RepT>(
sess: &S,
plc: &HostPlacement,
info: SliceInfo,
x: BoolTensor<HostT, RepT>,
) -> Result<BoolTensor<HostT, RepT>>
where
HostPlacement: PlacementSlice<S, HostT, HostT>,
HostPlacement: PlacementReveal<S, RepT, HostT>,
{
let x = match x {
BoolTensor::Replicated(v) => plc.reveal(sess, &v),
BoolTensor::Host(v) => v,
};
let z = plc.slice(sess, info, &x);
Ok(BoolTensor::Host(z))
}
}

impl LoadOp {
pub(crate) fn bool_kernel<S: Session, HostT, RepT>(
sess: &S,
plc: &HostPlacement,
key: m!(HostString),
query: m!(HostString),
) -> Result<BoolTensor<HostT, RepT>>
where
HostString: KnownType<S>,
HostPlacement: PlacementLoad<S, m!(HostString), m!(HostString), HostT>,
{
let z = plc.load(sess, &key, &query);
Ok(BoolTensor::Host(z))
}
}

impl SaveOp {
pub(crate) fn bool_kernel<S: Session, HostT, RepT>(
sess: &S,
Expand Down
59 changes: 59 additions & 0 deletions moose/src/fixedpoint/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1161,6 +1161,65 @@ impl IndexAxisOp {
}
}

impl SliceOp {
pub(crate) fn fixed_host_kernel<S: Session, HostFixedT, MirFixedT, RepFixedT>(
sess: &S,
plc: &HostPlacement,
info: SliceInfo,
x: FixedTensor<HostFixedT, MirFixedT, RepFixedT>,
) -> Result<FixedTensor<HostFixedT, MirFixedT, RepFixedT>>
where
HostPlacement: PlacementReveal<S, RepFixedT, HostFixedT>,
HostPlacement: PlacementDemirror<S, MirFixedT, HostFixedT>,
HostPlacement: PlacementSlice<S, HostFixedT, HostFixedT>,
{
let x = match x {
FixedTensor::Replicated(v) => plc.reveal(sess, &v),
FixedTensor::Mirrored3(v) => plc.demirror(sess, &v),
FixedTensor::Host(v) => v,
};
let z = plc.slice(sess, info, &x);
Ok(FixedTensor::Host(z))
}

pub(crate) fn fixed_rep_kernel<S: Session, HostFixedT, MirFixedT, RepFixedT>(
sess: &S,
plc: &ReplicatedPlacement,
info: SliceInfo,
x: FixedTensor<HostFixedT, MirFixedT, RepFixedT>,
) -> Result<FixedTensor<HostFixedT, MirFixedT, RepFixedT>>
where
ReplicatedPlacement: PlacementShare<S, HostFixedT, RepFixedT>,
ReplicatedPlacement: PlacementShare<S, MirFixedT, RepFixedT>,
ReplicatedPlacement: PlacementSlice<S, RepFixedT, RepFixedT>,
{
let x = match x {
FixedTensor::Host(v) => plc.share(sess, &v),
FixedTensor::Mirrored3(v) => plc.share(sess, &v),
FixedTensor::Replicated(v) => v,
};
let z = plc.slice(sess, info, &x);
Ok(FixedTensor::Replicated(z))
}

pub(crate) fn repfixed_kernel<S: Session, RepRingT>(
sess: &S,
plc: &ReplicatedPlacement,
info: SliceInfo,
x: RepFixedTensor<RepRingT>,
) -> Result<RepFixedTensor<RepRingT>>
where
ReplicatedPlacement: PlacementSlice<S, RepRingT, RepRingT>,
{
let y = plc.slice(sess, info, &x.tensor);
Ok(RepFixedTensor {
tensor: y,
fractional_precision: x.fractional_precision,
integral_precision: x.integral_precision,
})
}
}

impl ShapeOp {
pub(crate) fn host_fixed_kernel<S: Session, HostFixedT, MirFixedT, RepFixedT, HostShapeT>(
sess: &S,
Expand Down
25 changes: 24 additions & 1 deletion moose/src/floatingpoint/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::computation::*;
use crate::error::Error;
use crate::error::Result;
use crate::execution::Session;
use crate::host::HostPlacement;
use crate::host::{HostPlacement, SliceInfo};
use crate::kernels::*;
use crate::mirrored::{Mir3Tensor, Mirrored3Placement};
use crate::types::*;
Expand Down Expand Up @@ -707,3 +707,26 @@ impl MuxOp {
Ok(FloatTensor::Host(z))
}
}

impl SliceOp {
pub(crate) fn float_host_kernel<S: Session, HostFloatT, MirroredT>(
sess: &S,
plc: &HostPlacement,
slice: SliceInfo,
x: FloatTensor<HostFloatT, MirroredT>,
) -> Result<FloatTensor<HostFloatT, MirroredT>>
where
HostPlacement: PlacementSlice<S, HostFloatT, HostFloatT>,
{
let x = match x {
FloatTensor::Host(v) => v,
FloatTensor::Mirrored3(_v) => {
return Err(Error::UnimplementedOperator(
"SliceOp @ Mirrored3Placement".to_string(),
))
}
};
let z = plc.slice(sess, slice, &x);
Ok(FloatTensor::Host(z))
}
}
6 changes: 6 additions & 0 deletions moose/src/host/bitarray.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::host::RawShape;
use crate::host::SliceInfo;
use anyhow::anyhow;
use bitvec::prelude::*;
use ndarray::{prelude::*, RemoveAxis};
Expand Down Expand Up @@ -125,6 +126,11 @@ impl BitArrayRepr {
dim: Arc::new(IxDyn(&[len])),
}
}

pub(crate) fn slice(&self, _info: SliceInfo) -> anyhow::Result<BitArrayRepr> {
// TODO(Dragos) Implement this in future
Err(anyhow::anyhow!("slicing not implemented for BitArray yet"))
}
}

impl std::ops::BitXor for &BitArrayRepr {
Expand Down
112 changes: 100 additions & 12 deletions moose/src/host/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,31 +473,71 @@ impl AtLeast2DOp {
}

impl SliceOp {
pub(crate) fn host_kernel<S: RuntimeSession, T>(
_sess: &S,
pub(crate) fn host_fixed_kernel<S: Session, HostRingT>(
sess: &S,
plc: &HostPlacement,
info: SliceInfo,
x: HostFixedTensor<HostRingT>,
) -> Result<HostFixedTensor<HostRingT>>
where
HostPlacement: PlacementSlice<S, HostRingT, HostRingT>,
{
let tensor = plc.slice(sess, info, &x.tensor);
Ok(HostFixedTensor::<HostRingT> {
tensor,
fractional_precision: x.fractional_precision,
integral_precision: x.integral_precision,
})
}

pub(crate) fn host_bit_kernel<S: RuntimeSession>(
sess: &S,
plc: &HostPlacement,
slice_info: SliceInfo,
info: SliceInfo,
x: HostBitTensor,
) -> Result<HostBitTensor>
where
HostPlacement: PlacementPlace<S, HostBitTensor>,
{
let x = plc.place(sess, x);
x.slice(info)
}

pub(crate) fn host_generic_kernel<S: RuntimeSession, T: Clone>(
sess: &S,
plc: &HostPlacement,
info: SliceInfo,
x: HostTensor<T>,
) -> Result<HostTensor<T>>
where
HostPlacement: PlacementPlace<S, HostTensor<T>>,
{
let x = plc.place(sess, x);
x.slice(info)
}

pub(crate) fn host_ring_kernel<S: RuntimeSession, T>(
sess: &S,
plc: &HostPlacement,
info: SliceInfo,
x: HostRingTensor<T>,
) -> Result<HostRingTensor<T>>
where
T: Clone,
HostPlacement: PlacementPlace<S, HostRingTensor<T>>,
{
let slice_info =
ndarray::SliceInfo::<Vec<ndarray::SliceInfoElem>, IxDyn, IxDyn>::from(slice_info);
let sliced = x.0.slice(slice_info).to_owned();
Ok(HostRingTensor(sliced.to_shared(), plc.clone()))
let x = plc.place(sess, x);
x.slice(info)
}

pub(crate) fn shape_kernel<S: RuntimeSession>(
_sess: &S,
plc: &HostPlacement,
slice_info: SliceInfo,
info: SliceInfo,
x: HostShape,
) -> Result<HostShape> {
let slice = x.0.slice(
slice_info.0[0].start as usize,
slice_info.0[0].end.unwrap() as usize,
);
let slice =
x.0.slice(info.0[0].start as usize, info.0[0].end.unwrap() as usize);
Ok(HostShape(slice, plc.clone()))
}
}
Expand Down Expand Up @@ -559,6 +599,21 @@ impl<T: LinalgScalar> HostTensor<T> {
}
}

impl<T: Clone> HostTensor<T> {
fn slice(&self, info: SliceInfo) -> Result<HostTensor<T>> {
if info.0.len() != self.0.ndim() {
return Err(Error::InvalidArgument(format!(
"The input dimension of `info` must match the array to be sliced. Used slice info dim {}, tensor had dim {}",
info.0.len(),
self.0.ndim()
)));
}
let info = ndarray::SliceInfo::<Vec<ndarray::SliceInfoElem>, IxDyn, IxDyn>::from(info);
let result = self.0.slice(info);
Ok(HostTensor(result.to_owned().into_shared(), self.1.clone()))
}
}

impl<T: Clone> HostRingTensor<T> {
fn index_axis(self, axis: usize, index: usize) -> Result<HostRingTensor<T>> {
if axis >= self.0.ndim() {
Expand All @@ -581,6 +636,24 @@ impl<T: Clone> HostRingTensor<T> {
}
}

impl<T: Clone> HostRingTensor<T> {
fn slice(&self, info: SliceInfo) -> Result<HostRingTensor<T>> {
if info.0.len() != self.0.ndim() {
return Err(Error::InvalidArgument(format!(
"The input dimension of `info` must match the array to be sliced. Used slice info dim {}, tensor had dim {}",
info.0.len(),
self.0.ndim()
)));
}
let info = ndarray::SliceInfo::<Vec<ndarray::SliceInfoElem>, IxDyn, IxDyn>::from(info);
let result = self.0.slice(info);
Ok(HostRingTensor(
result.to_owned().into_shared(),
self.1.clone(),
))
}
}

impl HostBitTensor {
fn index_axis(self, axis: usize, index: usize) -> Result<HostBitTensor> {
if axis >= self.0.ndim() {
Expand All @@ -600,6 +673,21 @@ impl HostBitTensor {
let result = self.0.index_axis(axis, index);
Ok(HostBitTensor(result, self.1))
}

fn slice(&self, info: SliceInfo) -> Result<HostBitTensor> {
if info.0.len() != self.0.ndim() {
return Err(Error::InvalidArgument(format!(
"The input dimension of `info` must match the array to be sliced. Used slice info dim {}, tensor had dim {}",
info.0.len(),
self.0.ndim()
)));
}
let result = self
.0
.slice(info)
.map_err(|e| Error::KernelError(e.to_string()))?;
Ok(HostBitTensor(result, self.1.clone()))
}
}

impl IndexAxisOp {
Expand Down
Loading

0 comments on commit 6b6b89e

Please sign in to comment.