Skip to content

Commit

Permalink
added equality checking and type cast node
Browse files Browse the repository at this point in the history
  • Loading branch information
Eben Kadile committed Feb 27, 2024
1 parent 3dbe406 commit 7ba4270
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/core/graph/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ impl Context {
// operations mean we need to go deeper
Operation::Add(a, b)
| Operation::Mul(a, b)
| Operation::Equal(a, b)
| Operation::LessThan(a, b)
| Operation::GreaterThan(a, b)
| Operation::LessThanEq(a, b)
Expand All @@ -54,6 +55,9 @@ impl Context {
self.autodiff(on_false, modification_limit - (r as usize))
.map(|v| v || r)
}
Operation::TypeCast(node, ty) => {
self.autodiff(node, modification_limit)
}
// finally a Diff node, lets distribute it
Operation::Diff(outer, outer_param) => {
let outer_node = &self.nodes[outer];
Expand Down Expand Up @@ -146,10 +150,12 @@ impl Context {
self.autodiff(input, modification_limit - 1)
}

Operation::LessThan(_, _)
Operation::Equal(_, _)
| Operation::LessThan(_, _)
| Operation::GreaterThan(_, _)
| Operation::LessThanEq(_, _)
| Operation::GreaterThanEq(_, _) => Err(ContextError::NonDifferentiableError(outer_node.callsite.clone())),
Operation::TypeCast(_, _) => Err(ContextError::NonDifferentiableError(outer_node.callsite.clone())),

Operation::Select {
pred,
Expand Down
32 changes: 32 additions & 0 deletions src/core/graph/compile.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::*;
use serde_json::de;
use slotmap::SlotMap;
use smallvec::SmallVec;
use std::collections::{HashMap, HashSet, VecDeque};
Expand Down Expand Up @@ -47,6 +48,7 @@ impl Context {
Operation::Diff(_, _) => Err(CompileError::DiffNode(input_node.callsite.clone()))?,
Operation::Mul(node1, node2)
| Operation::Add(node1, node2)
| Operation::Equal(node1, node2)
| Operation::LessThan(node1, node2)
| Operation::GreaterThan(node1, node2)
| Operation::LessThanEq(node1, node2)
Expand All @@ -62,6 +64,13 @@ impl Context {
self.get_dependent_nodes(node1, dep_nodes, constants, parameters)?;
self.get_dependent_nodes(node2, dep_nodes, constants, parameters)
}
Operation::TypeCast(node, _) => {
dep_nodes
.entry(node)
.or_insert(Vec::new())
.push(this_node_id);
self.get_dependent_nodes(node, dep_nodes, constants, parameters)
}
Operation::Select {
pred,
on_true,
Expand Down Expand Up @@ -233,6 +242,19 @@ impl Context {
}
}

Operation::Equal(node1, node2) => {
if xla_op_slotmap.contains_key(unda_xla_map[&node1])
&& xla_op_slotmap.contains_key(unda_xla_map[&node1])
{
let xla_op = xla_op_slotmap[unda_xla_map[&node1]]
.eq(&xla_op_slotmap[unda_xla_map[&node2]])?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}

Operation::LessThan(node1, node2) => {
if xla_op_slotmap.contains_key(unda_xla_map[&node1])
&& xla_op_slotmap.contains_key(unda_xla_map[&node1])
Expand Down Expand Up @@ -306,6 +328,16 @@ impl Context {
covered_ops.insert(*dependent_op);
}
}
Operation::TypeCast(node, ty) => {
if xla_op_slotmap.contains_key(unda_xla_map[&node]) {
let xla_op =
xla_op_slotmap[unda_xla_map[&node]].convert(ty.primitive_type())?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/core/graph/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ impl Context {
Operation::Diff(a, b) => format!("Diff ({}) {}", self.to_string(a), self.to_string(b)),
Operation::Add(a, b) => format!("Add ({}) ({})", self.to_string(a), self.to_string(b)),
Operation::Mul(a, b) => format!("Mul ({}) ({})", self.to_string(a), self.to_string(b)),
Operation::Equal(a, b) => format!("LessThan ({}) ({})", self.to_string(a), self.to_string(b)),
Operation::LessThan(a, b) => format!("LessThan ({}) ({})", self.to_string(a), self.to_string(b)),
Operation::GreaterThan(a, b) => format!("GreaterThan ({}) ({})", self.to_string(a), self.to_string(b)),
Operation::LessThanEq(a, b) => format!("LessThanEq ({}) ({})", self.to_string(a), self.to_string(b)),
Expand All @@ -83,6 +84,7 @@ impl Context {
self.to_string(on_true),
self.to_string(on_false)
),
Operation::TypeCast(a, ty) => format!("TypeCast {} {}", self.to_string(a), ty),
}
}
}
47 changes: 47 additions & 0 deletions src/core/graph/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,42 @@ impl Context {
}
}

pub fn eq<A: Into<NodeIdentifier> + Copy, B: Into<NodeIdentifier> + Copy>(
&mut self,
a: A,
b: B,
) -> Result<NodeIdentifier> {
let a = a.into();
let b = b.into();
let node_a = &self.nodes[a];
let node_b = &self.nodes[b];

if node_a.dtype != node_b.dtype {
Err(ContextError::IncompatibleOperandTypes(
node_a.dtype,
node_b.dtype,
callsite!(1),
))
} else {
match node_a.shape.broadcast(&node_b.shape) {
None => Err(ContextError::IncompatibleOperandShapes(
node_a.shape.clone(),
node_b.shape.clone(),
callsite!(1),
)),
Some(s) => {
let node = Node {
callsite: callsite!(1),
shape: s,
operation: Operation::Equal(a, b),
dtype: xla::ElementType::Pred,
};
Ok(self.nodes.insert(node))
}
}
}
}

pub fn lt<A: Into<NodeIdentifier> + Copy, B: Into<NodeIdentifier> + Copy>(
&mut self,
a: A,
Expand Down Expand Up @@ -259,4 +295,15 @@ impl Context {

self.maximum(const_zero, a)
}

pub fn type_cast<A: Into<NodeIdentifier> + Copy>(&mut self, a: A, dtype: xla::ElementType) -> NodeIdentifier {
let a = a.into();
let a_shape = self.nodes[a].shape.clone();
self.nodes.insert(Node {
callsite: callsite!(1),
shape: a_shape,
operation: Operation::TypeCast(a, dtype),
dtype: dtype
})
}
}
3 changes: 3 additions & 0 deletions src/core/graph/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@ pub enum Operation {
Add(NodeIdentifier, NodeIdentifier),
Mul(NodeIdentifier, NodeIdentifier),

Equal(NodeIdentifier, NodeIdentifier),
LessThan(NodeIdentifier, NodeIdentifier),
GreaterThan(NodeIdentifier, NodeIdentifier),
LessThanEq(NodeIdentifier, NodeIdentifier),
GreaterThanEq(NodeIdentifier, NodeIdentifier),

Select{ pred: NodeIdentifier, on_true: NodeIdentifier, on_false: NodeIdentifier },

TypeCast(NodeIdentifier, xla::ElementType),
}

impl Display for Operation {
Expand Down

0 comments on commit 7ba4270

Please sign in to comment.