From 7ba42703fccf767ba711e3fba4f064d757a9c950 Mon Sep 17 00:00:00 2001 From: Eben Kadile Date: Tue, 27 Feb 2024 14:50:34 +0100 Subject: [PATCH] added equality checking and type cast node --- src/core/graph/autodiff.rs | 8 ++++++- src/core/graph/compile.rs | 32 +++++++++++++++++++++++++ src/core/graph/context.rs | 2 ++ src/core/graph/math.rs | 47 +++++++++++++++++++++++++++++++++++++ src/core/graph/operation.rs | 3 +++ 5 files changed, 91 insertions(+), 1 deletion(-) diff --git a/src/core/graph/autodiff.rs b/src/core/graph/autodiff.rs index 0e55a73..d49ee4a 100644 --- a/src/core/graph/autodiff.rs +++ b/src/core/graph/autodiff.rs @@ -41,6 +41,7 @@ impl Context { // operations mean we need to go deeper Operation::Add(a, b) | Operation::Mul(a, b) + | Operation::Equal(a, b) | Operation::LessThan(a, b) | Operation::GreaterThan(a, b) | Operation::LessThanEq(a, b) @@ -54,6 +55,9 @@ impl Context { self.autodiff(on_false, modification_limit - (r as usize)) .map(|v| v || r) } + Operation::TypeCast(node, ty) => { + self.autodiff(node, modification_limit) + } // finally a Diff node, lets distribute it Operation::Diff(outer, outer_param) => { let outer_node = &self.nodes[outer]; @@ -146,10 +150,12 @@ impl Context { self.autodiff(input, modification_limit - 1) } - Operation::LessThan(_, _) + Operation::Equal(_, _) + | Operation::LessThan(_, _) | Operation::GreaterThan(_, _) | Operation::LessThanEq(_, _) | Operation::GreaterThanEq(_, _) => Err(ContextError::NonDifferentiableError(outer_node.callsite.clone())), + Operation::TypeCast(_, _) => Err(ContextError::NonDifferentiableError(outer_node.callsite.clone())), Operation::Select { pred, diff --git a/src/core/graph/compile.rs b/src/core/graph/compile.rs index 6357499..b3702b4 100644 --- a/src/core/graph/compile.rs +++ b/src/core/graph/compile.rs @@ -1,4 +1,5 @@ use super::*; +use serde_json::de; use slotmap::SlotMap; use smallvec::SmallVec; use std::collections::{HashMap, HashSet, VecDeque}; @@ -47,6 +48,7 @@ impl Context { Operation::Diff(_, _) => Err(CompileError::DiffNode(input_node.callsite.clone()))?, Operation::Mul(node1, node2) | Operation::Add(node1, node2) + | Operation::Equal(node1, node2) | Operation::LessThan(node1, node2) | Operation::GreaterThan(node1, node2) | Operation::LessThanEq(node1, node2) @@ -62,6 +64,13 @@ impl Context { self.get_dependent_nodes(node1, dep_nodes, constants, parameters)?; self.get_dependent_nodes(node2, dep_nodes, constants, parameters) } + Operation::TypeCast(node, _) => { + dep_nodes + .entry(node) + .or_insert(Vec::new()) + .push(this_node_id); + self.get_dependent_nodes(node, dep_nodes, constants, parameters) + } Operation::Select { pred, on_true, @@ -233,6 +242,19 @@ impl Context { } } + Operation::Equal(node1, node2) => { + if xla_op_slotmap.contains_key(unda_xla_map[&node1]) + && xla_op_slotmap.contains_key(unda_xla_map[&node1]) + { + let xla_op = xla_op_slotmap[unda_xla_map[&node1]] + .eq(&xla_op_slotmap[unda_xla_map[&node2]])?; + let xla_id = xla_op_slotmap.insert(xla_op); + unda_xla_map.insert(*dependent_op, xla_id); + unda_op_queue.push_back(*dependent_op); + covered_ops.insert(*dependent_op); + } + } + Operation::LessThan(node1, node2) => { if xla_op_slotmap.contains_key(unda_xla_map[&node1]) && xla_op_slotmap.contains_key(unda_xla_map[&node1]) @@ -306,6 +328,16 @@ impl Context { covered_ops.insert(*dependent_op); } } + Operation::TypeCast(node, ty) => { + if xla_op_slotmap.contains_key(unda_xla_map[&node]) { + let xla_op = + xla_op_slotmap[unda_xla_map[&node]].convert(ty.primitive_type())?; + let xla_id = xla_op_slotmap.insert(xla_op); + unda_xla_map.insert(*dependent_op, xla_id); + unda_op_queue.push_back(*dependent_op); + covered_ops.insert(*dependent_op); + } + } } } } diff --git a/src/core/graph/context.rs b/src/core/graph/context.rs index 5bd6725..de34ada 100644 --- a/src/core/graph/context.rs +++ b/src/core/graph/context.rs @@ -69,6 +69,7 @@ impl Context { Operation::Diff(a, b) => format!("Diff ({}) {}", self.to_string(a), self.to_string(b)), Operation::Add(a, b) => format!("Add ({}) ({})", self.to_string(a), self.to_string(b)), Operation::Mul(a, b) => format!("Mul ({}) ({})", self.to_string(a), self.to_string(b)), + Operation::Equal(a, b) => format!("LessThan ({}) ({})", self.to_string(a), self.to_string(b)), Operation::LessThan(a, b) => format!("LessThan ({}) ({})", self.to_string(a), self.to_string(b)), Operation::GreaterThan(a, b) => format!("GreaterThan ({}) ({})", self.to_string(a), self.to_string(b)), Operation::LessThanEq(a, b) => format!("LessThanEq ({}) ({})", self.to_string(a), self.to_string(b)), @@ -83,6 +84,7 @@ impl Context { self.to_string(on_true), self.to_string(on_false) ), + Operation::TypeCast(a, ty) => format!("TypeCast {} {}", self.to_string(a), ty), } } } diff --git a/src/core/graph/math.rs b/src/core/graph/math.rs index 7d49bfb..7f9fc72 100644 --- a/src/core/graph/math.rs +++ b/src/core/graph/math.rs @@ -75,6 +75,42 @@ impl Context { } } + pub fn eq + Copy, B: Into + Copy>( + &mut self, + a: A, + b: B, + ) -> Result { + let a = a.into(); + let b = b.into(); + let node_a = &self.nodes[a]; + let node_b = &self.nodes[b]; + + if node_a.dtype != node_b.dtype { + Err(ContextError::IncompatibleOperandTypes( + node_a.dtype, + node_b.dtype, + callsite!(1), + )) + } else { + match node_a.shape.broadcast(&node_b.shape) { + None => Err(ContextError::IncompatibleOperandShapes( + node_a.shape.clone(), + node_b.shape.clone(), + callsite!(1), + )), + Some(s) => { + let node = Node { + callsite: callsite!(1), + shape: s, + operation: Operation::Equal(a, b), + dtype: xla::ElementType::Pred, + }; + Ok(self.nodes.insert(node)) + } + } + } + } + pub fn lt + Copy, B: Into + Copy>( &mut self, a: A, @@ -259,4 +295,15 @@ impl Context { self.maximum(const_zero, a) } + + pub fn type_cast + Copy>(&mut self, a: A, dtype: xla::ElementType) -> NodeIdentifier { + let a = a.into(); + let a_shape = self.nodes[a].shape.clone(); + self.nodes.insert(Node { + callsite: callsite!(1), + shape: a_shape, + operation: Operation::TypeCast(a, dtype), + dtype: dtype + }) + } } diff --git a/src/core/graph/operation.rs b/src/core/graph/operation.rs index a3058ff..c6b20bf 100644 --- a/src/core/graph/operation.rs +++ b/src/core/graph/operation.rs @@ -11,12 +11,15 @@ pub enum Operation { Add(NodeIdentifier, NodeIdentifier), Mul(NodeIdentifier, NodeIdentifier), + Equal(NodeIdentifier, NodeIdentifier), LessThan(NodeIdentifier, NodeIdentifier), GreaterThan(NodeIdentifier, NodeIdentifier), LessThanEq(NodeIdentifier, NodeIdentifier), GreaterThanEq(NodeIdentifier, NodeIdentifier), Select{ pred: NodeIdentifier, on_true: NodeIdentifier, on_false: NodeIdentifier }, + + TypeCast(NodeIdentifier, xla::ElementType), } impl Display for Operation {