From 215632d063a72b42e2734143600ab7c9f8a69a39 Mon Sep 17 00:00:00 2001 From: Eben Kadile Date: Wed, 28 Feb 2024 12:56:53 +0100 Subject: [PATCH] added select and zeros like ops --- src/core/graph/autodiff.rs | 37 +++++++++++++++++++++++++++++++++---- src/core/graph/compile.rs | 34 ++++++++++++++++++++++++++++++++++ src/core/graph/context.rs | 35 ++++++++++++++++++++++++++++++----- src/core/graph/math.rs | 28 ++++++++++++++++++++++++++++ src/core/graph/operation.rs | 4 ++++ src/core/graph/tests.rs | 21 +++++++++++++++++++++ 6 files changed, 150 insertions(+), 9 deletions(-) diff --git a/src/core/graph/autodiff.rs b/src/core/graph/autodiff.rs index d49ee4a..1c7297a 100644 --- a/src/core/graph/autodiff.rs +++ b/src/core/graph/autodiff.rs @@ -50,14 +50,18 @@ impl Context { self.autodiff(b, modification_limit - (r as usize)) .map(|v| v || r) } - Operation::Select{ pred: _, on_true, on_false } => { + Operation::Select { + pred: _, + on_true, + on_false, + } => { let r = self.autodiff(on_true, modification_limit)?; self.autodiff(on_false, modification_limit - (r as usize)) .map(|v| v || r) } - Operation::TypeCast(node, ty) => { - self.autodiff(node, modification_limit) - } + Operation::TypeCast(node, ty) => self.autodiff(node, modification_limit), + Operation::SliceInDim { node, .. } => self.autodiff(node, modification_limit), + Operation::ZerosLike(node) => self.autodiff(node, modification_limit), // finally a Diff node, lets distribute it Operation::Diff(outer, outer_param) => { let outer_node = &self.nodes[outer]; @@ -190,6 +194,31 @@ impl Context { // rerun autodiff on the node we replaced self.autodiff(input, modification_limit - 1) } + + /* + Operation::SliceInDim { node, start, stop, stride, dim } => { + let diff_node = Node { + callsite: input_node.callsite.clone(), + shape: self.nodes[node].shape.clone(), + operation: Operation::Diff(node, outer_param), + dtype: self.nodes[node].dtype + }; + self.nodes[input].callsite = outer_node.callsite.clone(); + let diff_node = self.nodes.insert(diff_node); + let zero_node = self.nodes.insert(Node { + callsite: input_node.callsite.clone(), + shape: node.shape.clone(), + operation: Operation::ZerosLike(node) + }); + } + */ + Operation::SliceInDim { node, start, stop, stride, dim } => panic!("Differentiating SliceInDim not yet supported, xla-rs must implement scatter operation."), + + Operation::ZerosLike(node) => { + self.nodes[input].operation = Operation::ZerosLike(node); + self.autodiff(input, modification_limit - 1) + } + Operation::Diff(inner, _) => { // derivative of a derivative, apply the inner one first then try again on the outer. let r = self.autodiff(inner, modification_limit)?; diff --git a/src/core/graph/compile.rs b/src/core/graph/compile.rs index b3702b4..773ebd4 100644 --- a/src/core/graph/compile.rs +++ b/src/core/graph/compile.rs @@ -71,6 +71,13 @@ impl Context { .push(this_node_id); self.get_dependent_nodes(node, dep_nodes, constants, parameters) } + Operation::SliceInDim{ node, .. } => { + dep_nodes + .entry(node) + .or_insert(Vec::new()) + .push(this_node_id); + self.get_dependent_nodes(node, dep_nodes, constants, parameters) + } Operation::Select { pred, on_true, @@ -92,6 +99,13 @@ impl Context { self.get_dependent_nodes(on_true, dep_nodes, constants, parameters)?; self.get_dependent_nodes(on_false, dep_nodes, constants, parameters) } + Operation::ZerosLike(node) => { + dep_nodes + .entry(node) + .or_insert(Vec::new()) + .push(this_node_id); + self.get_dependent_nodes(node, dep_nodes, constants, parameters) + } } } @@ -338,6 +352,26 @@ impl Context { covered_ops.insert(*dependent_op); } } + Operation::SliceInDim{ node, start, stop, stride, dim } => { + if xla_op_slotmap.contains_key(unda_xla_map[&node]) { + let xla_op = + xla_op_slotmap[unda_xla_map[&node]].slice_in_dim(start, stop, stride, dim)?; + let xla_id = xla_op_slotmap.insert(xla_op); + unda_xla_map.insert(*dependent_op, xla_id); + unda_op_queue.push_back(*dependent_op); + covered_ops.insert(*dependent_op); + } + } + Operation::ZerosLike(node) => { + if xla_op_slotmap.contains_key(unda_xla_map[&node]) { + let xla_op = + xla_op_slotmap[unda_xla_map[&node]].zeros_like()?; + let xla_id = xla_op_slotmap.insert(xla_op); + unda_xla_map.insert(*dependent_op, xla_id); + unda_op_queue.push_back(*dependent_op); + covered_ops.insert(*dependent_op); + } + } } } } diff --git a/src/core/graph/context.rs b/src/core/graph/context.rs index de34ada..a2b3360 100644 --- a/src/core/graph/context.rs +++ b/src/core/graph/context.rs @@ -69,11 +69,25 @@ impl Context { Operation::Diff(a, b) => format!("Diff ({}) {}", self.to_string(a), self.to_string(b)), Operation::Add(a, b) => format!("Add ({}) ({})", self.to_string(a), self.to_string(b)), Operation::Mul(a, b) => format!("Mul ({}) ({})", self.to_string(a), self.to_string(b)), - Operation::Equal(a, b) => format!("LessThan ({}) ({})", self.to_string(a), self.to_string(b)), - Operation::LessThan(a, b) => format!("LessThan ({}) ({})", self.to_string(a), self.to_string(b)), - Operation::GreaterThan(a, b) => format!("GreaterThan ({}) ({})", self.to_string(a), self.to_string(b)), - Operation::LessThanEq(a, b) => format!("LessThanEq ({}) ({})", self.to_string(a), self.to_string(b)), - Operation::GreaterThanEq(a, b) => format!("GreaterThanEq ({}) ({})", self.to_string(a), self.to_string(b)), + Operation::Equal(a, b) => { + format!("LessThan ({}) ({})", self.to_string(a), self.to_string(b)) + } + Operation::LessThan(a, b) => { + format!("LessThan ({}) ({})", self.to_string(a), self.to_string(b)) + } + Operation::GreaterThan(a, b) => format!( + "GreaterThan ({}) ({})", + self.to_string(a), + self.to_string(b) + ), + Operation::LessThanEq(a, b) => { + format!("LessThanEq ({}) ({})", self.to_string(a), self.to_string(b)) + } + Operation::GreaterThanEq(a, b) => format!( + "GreaterThanEq ({}) ({})", + self.to_string(a), + self.to_string(b) + ), Operation::Select { pred, on_true, @@ -85,6 +99,17 @@ impl Context { self.to_string(on_false) ), Operation::TypeCast(a, ty) => format!("TypeCast {} {}", self.to_string(a), ty), + Operation::SliceInDim { + node, + start, + stop, + stride, + dim, + } => format!( + "SliceInDim {} {} {} {} {}", + self.to_string(node), start, stop, stride, dim + ), + Operation::ZerosLike(node) => format!("ZerosLike {}", self.to_string(node)) } } } diff --git a/src/core/graph/math.rs b/src/core/graph/math.rs index 7f9fc72..a03d37b 100644 --- a/src/core/graph/math.rs +++ b/src/core/graph/math.rs @@ -306,4 +306,32 @@ impl Context { dtype: dtype }) } + + pub fn slice_in_dim + Copy>(&mut self, a: A, start: i64, stop: i64, stride: i64, dim: i64) { + let a = a.into(); + let mut s = Shape::new(); + for d in (0..self.nodes[a].shape.ndims()).rev() { + if d as i64 == dim { + s.sizes.push(1) + } else { + s.sizes.push(self.nodes[a].shape.sizes[d]) + } + } + self.nodes.insert(Node { + callsite: callsite!(1), + shape: s, + operation: Operation::SliceInDim { node: a, start: start, stop: stop, stride: stride, dim: dim }, + dtype: self.nodes[a].dtype + }); + } + + pub fn zeros_like + Copy>(&mut self, a: A) { + let a = a.into(); + self.nodes.insert(Node { + callsite: callsite!(1), + shape: self.nodes[a].shape.clone(), + operation: Operation::ZerosLike(a), + dtype: self.nodes[a].dtype + }); + } } diff --git a/src/core/graph/operation.rs b/src/core/graph/operation.rs index c6b20bf..d2bbbea 100644 --- a/src/core/graph/operation.rs +++ b/src/core/graph/operation.rs @@ -20,6 +20,10 @@ pub enum Operation { Select{ pred: NodeIdentifier, on_true: NodeIdentifier, on_false: NodeIdentifier }, TypeCast(NodeIdentifier, xla::ElementType), + + SliceInDim{ node: NodeIdentifier, start: i64, stop: i64, stride: i64, dim: i64 }, + + ZerosLike(NodeIdentifier), } impl Display for Operation { diff --git a/src/core/graph/tests.rs b/src/core/graph/tests.rs index b4bfd7c..3edf6dc 100644 --- a/src/core/graph/tests.rs +++ b/src/core/graph/tests.rs @@ -177,4 +177,25 @@ mod tests { println!("{:?}", rust_result); assert_eq!(rust_result.as_slice(), &[0, 4, 4, 0]); } + + #[test] + fn test_slice_in_dim() { + let mut ctx = Context::new(); + + let test_const = ctx.const_from_npy("test2.npy").expect("test_const"); + let relu = ctx.relu(test_const).expect("relu"); + + let client = xla::PjRtClient::gpu(0.7, false).expect("client"); + let name = "test"; + let executable = ctx.compile(&name, [relu], &client).expect("executable"); + + let device_result = executable.execute::(&[]).expect("execute"); + let host_result = device_result[0][0] + .to_literal_sync() + .expect("to_literal_sync"); + let untupled_result = host_result.to_tuple1().expect("untuple"); + let rust_result = untupled_result.to_vec::().expect("to_vec"); + println!("{:?}", rust_result); + assert_eq!(rust_result.as_slice(), &[0, 4, 4, 0]); + } }