Skip to content

Commit

Permalink
Merge pull request #37 from Ebanflo42/ebanflo/xla/logic_ops
Browse files Browse the repository at this point in the history
Fixed shape checking, added comparison and select operations
  • Loading branch information
BradenEverson authored Feb 29, 2024
2 parents 0e0c847 + 215632d commit da45e34
Show file tree
Hide file tree
Showing 10 changed files with 767 additions and 44 deletions.
100 changes: 95 additions & 5 deletions src/core/graph/autodiff.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
use super::*;

impl Context {
pub fn stop_gradient(&mut self, node: NodeIdentifier) -> NodeIdentifier {
self.nodes.insert(Node {
callsite: callsite!(1),
shape: self.nodes[node].shape.clone(),
operation: Operation::StopGradient(node),
dtype: self.nodes[node].dtype,
})
}

pub fn diff(&mut self, node: NodeIdentifier, with_respect_to: Parameter) -> NodeIdentifier {
self.nodes.insert(Node {
callsite: callsite!(1),
Expand Down Expand Up @@ -28,17 +37,31 @@ impl Context {
// leaf nodes mean no further processing
Operation::Constant(_) => Ok(false),
Operation::Parameter(_) => Ok(false),
Operation::StopGradient(_) => Ok(false),
// operations mean we need to go deeper
Operation::Add(a, b) => {
Operation::Add(a, b)
| Operation::Mul(a, b)
| Operation::Equal(a, b)
| Operation::LessThan(a, b)
| Operation::GreaterThan(a, b)
| Operation::LessThanEq(a, b)
| Operation::GreaterThanEq(a, b) => {
let r = self.autodiff(a, modification_limit)?;
self.autodiff(b, modification_limit - (r as usize))
.map(|v| v || r)
}
Operation::Mul(a, b) => {
let r = self.autodiff(a, modification_limit)?;
self.autodiff(b, modification_limit - (r as usize))
Operation::Select {
pred: _,
on_true,
on_false,
} => {
let r = self.autodiff(on_true, modification_limit)?;
self.autodiff(on_false, modification_limit - (r as usize))
.map(|v| v || r)
}
Operation::TypeCast(node, ty) => self.autodiff(node, modification_limit),
Operation::SliceInDim { node, .. } => self.autodiff(node, modification_limit),
Operation::ZerosLike(node) => self.autodiff(node, modification_limit),
// finally a Diff node, lets distribute it
Operation::Diff(outer, outer_param) => {
let outer_node = &self.nodes[outer];
Expand All @@ -56,13 +79,14 @@ impl Context {
// derivative of a parameter with respect to itself is one, and otherwise zero
self.nodes[input].operation = Operation::Constant(ConstantBinding {
value: xla::Literal::scalar(
(outer == outer_param.into()) as u32 as f32,
(outer == outer_param.into()) as u32,
)
.convert(outer_dtype)?,
});
self.nodes[input].shape = [].into();
Ok(true)
}
Operation::StopGradient(_) => Ok(false),
Operation::Add(a, b) => {
// derivative of a sum is the sum of derivatives
// Diff (Sum a b) x = Sum (Diff a x) (Diff b x)
Expand Down Expand Up @@ -129,6 +153,72 @@ impl Context {
// rerun autodiff on the node we replaced
self.autodiff(input, modification_limit - 1)
}

Operation::Equal(_, _)
| Operation::LessThan(_, _)
| Operation::GreaterThan(_, _)
| Operation::LessThanEq(_, _)
| Operation::GreaterThanEq(_, _) => Err(ContextError::NonDifferentiableError(outer_node.callsite.clone())),
Operation::TypeCast(_, _) => Err(ContextError::NonDifferentiableError(outer_node.callsite.clone())),

Operation::Select {
pred,
on_true,
on_false,
} => {
// derivative of select is select of derivatives
let diff_true_node = Node {
// propagate original Diff callsite to the new Diff node
callsite: input_node.callsite.clone(),
shape: self.nodes[on_true].shape.clone(),
operation: Operation::Diff(on_true, outer_param),
dtype: self.nodes[on_true].dtype,
};
let diff_false_node = Node {
// propagate original Diff callsite to the new Diff node
callsite: input_node.callsite.clone(),
shape: self.nodes[on_false].shape.clone(),
operation: Operation::Diff(on_false, outer_param),
dtype: self.nodes[on_false].dtype,
};
// propagate original Mul callsite to the new Add node
self.nodes[input].callsite = outer_node.callsite.clone();
let diff_true = self.nodes.insert(diff_true_node);
let diff_false = self.nodes.insert(diff_false_node);

self.nodes[input].operation = Operation::Select {
pred: pred,
on_true: diff_true,
on_false: diff_false,
};
// rerun autodiff on the node we replaced
self.autodiff(input, modification_limit - 1)
}

/*
Operation::SliceInDim { node, start, stop, stride, dim } => {
let diff_node = Node {
callsite: input_node.callsite.clone(),
shape: self.nodes[node].shape.clone(),
operation: Operation::Diff(node, outer_param),
dtype: self.nodes[node].dtype
};
self.nodes[input].callsite = outer_node.callsite.clone();
let diff_node = self.nodes.insert(diff_node);
let zero_node = self.nodes.insert(Node {
callsite: input_node.callsite.clone(),
shape: node.shape.clone(),
operation: Operation::ZerosLike(node)
});
}
*/
Operation::SliceInDim { node, start, stop, stride, dim } => panic!("Differentiating SliceInDim not yet supported, xla-rs must implement scatter operation."),

Operation::ZerosLike(node) => {
self.nodes[input].operation = Operation::ZerosLike(node);
self.autodiff(input, modification_limit - 1)
}

Operation::Diff(inner, _) => {
// derivative of a derivative, apply the inner one first then try again on the outer.
let r = self.autodiff(inner, modification_limit)?;
Expand Down
186 changes: 184 additions & 2 deletions src/core/graph/compile.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::*;
use serde_json::de;
use slotmap::SlotMap;
use smallvec::SmallVec;
use std::collections::{HashMap, HashSet, VecDeque};
Expand Down Expand Up @@ -37,8 +38,21 @@ impl Context {
parameters.insert(this_node_id);
Ok(())
}
Operation::StopGradient(node) => {
dep_nodes
.entry(node)
.or_insert(Vec::new())
.push(this_node_id);
self.get_dependent_nodes(node, dep_nodes, constants, parameters)
}
Operation::Diff(_, _) => Err(CompileError::DiffNode(input_node.callsite.clone()))?,
Operation::Mul(node1, node2) | Operation::Add(node1, node2) => {
Operation::Mul(node1, node2)
| Operation::Add(node1, node2)
| Operation::Equal(node1, node2)
| Operation::LessThan(node1, node2)
| Operation::GreaterThan(node1, node2)
| Operation::LessThanEq(node1, node2)
| Operation::GreaterThanEq(node1, node2) => {
dep_nodes
.entry(node1)
.or_insert(Vec::new())
Expand All @@ -50,6 +64,48 @@ impl Context {
self.get_dependent_nodes(node1, dep_nodes, constants, parameters)?;
self.get_dependent_nodes(node2, dep_nodes, constants, parameters)
}
Operation::TypeCast(node, _) => {
dep_nodes
.entry(node)
.or_insert(Vec::new())
.push(this_node_id);
self.get_dependent_nodes(node, dep_nodes, constants, parameters)
}
Operation::SliceInDim{ node, .. } => {
dep_nodes
.entry(node)
.or_insert(Vec::new())
.push(this_node_id);
self.get_dependent_nodes(node, dep_nodes, constants, parameters)
}
Operation::Select {
pred,
on_true,
on_false,
} => {
dep_nodes
.entry(pred)
.or_insert(Vec::new())
.push(this_node_id);
dep_nodes
.entry(on_true)
.or_insert(Vec::new())
.push(this_node_id);
dep_nodes
.entry(on_false)
.or_insert(Vec::new())
.push(this_node_id);
self.get_dependent_nodes(pred, dep_nodes, constants, parameters)?;
self.get_dependent_nodes(on_true, dep_nodes, constants, parameters)?;
self.get_dependent_nodes(on_false, dep_nodes, constants, parameters)
}
Operation::ZerosLike(node) => {
dep_nodes
.entry(node)
.or_insert(Vec::new())
.push(this_node_id);
self.get_dependent_nodes(node, dep_nodes, constants, parameters)
}
}
}

Expand Down Expand Up @@ -167,6 +223,12 @@ impl Context {
}
Operation::Constant(_) => unreachable!("Constants can't depend on other nodes"),
Operation::Diff(_, _) => Err(CompileError::DiffNode(node.callsite.clone()))?,
Operation::StopGradient(node) => {
let xla_id = unda_xla_map[&node];
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}

Operation::Mul(node1, node2) => {
if xla_op_slotmap.contains_key(unda_xla_map[&node1])
Expand All @@ -193,11 +255,131 @@ impl Context {
covered_ops.insert(*dependent_op);
}
}

Operation::Equal(node1, node2) => {
if xla_op_slotmap.contains_key(unda_xla_map[&node1])
&& xla_op_slotmap.contains_key(unda_xla_map[&node1])
{
let xla_op = xla_op_slotmap[unda_xla_map[&node1]]
.eq(&xla_op_slotmap[unda_xla_map[&node2]])?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}

Operation::LessThan(node1, node2) => {
if xla_op_slotmap.contains_key(unda_xla_map[&node1])
&& xla_op_slotmap.contains_key(unda_xla_map[&node1])
{
let xla_op = xla_op_slotmap[unda_xla_map[&node1]]
.lt(&xla_op_slotmap[unda_xla_map[&node2]])?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}

Operation::GreaterThan(node1, node2) => {
if xla_op_slotmap.contains_key(unda_xla_map[&node1])
&& xla_op_slotmap.contains_key(unda_xla_map[&node1])
{
let xla_op = xla_op_slotmap[unda_xla_map[&node1]]
.gt(&xla_op_slotmap[unda_xla_map[&node2]])?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}

Operation::LessThanEq(node1, node2) => {
if xla_op_slotmap.contains_key(unda_xla_map[&node1])
&& xla_op_slotmap.contains_key(unda_xla_map[&node1])
{
let xla_op = xla_op_slotmap[unda_xla_map[&node1]]
.le(&xla_op_slotmap[unda_xla_map[&node2]])?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}

Operation::GreaterThanEq(node1, node2) => {
if xla_op_slotmap.contains_key(unda_xla_map[&node1])
&& xla_op_slotmap.contains_key(unda_xla_map[&node1])
{
let xla_op = xla_op_slotmap[unda_xla_map[&node1]]
.ge(&xla_op_slotmap[unda_xla_map[&node2]])?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}
Operation::Select {
pred,
on_true,
on_false,
} => {
if unda_xla_map.contains_key(&pred)
&& unda_xla_map.contains_key(&on_true)
&& unda_xla_map.contains_key(&on_false)
&& xla_op_slotmap.contains_key(unda_xla_map[&pred])
&& xla_op_slotmap.contains_key(unda_xla_map[&on_true])
&& xla_op_slotmap.contains_key(unda_xla_map[&on_false])
{
let xla_op = xla_op_slotmap[unda_xla_map[&pred]].select(
&xla_op_slotmap[unda_xla_map[&on_true]],
&xla_op_slotmap[unda_xla_map[&on_false]],
)?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}
Operation::TypeCast(node, ty) => {
if xla_op_slotmap.contains_key(unda_xla_map[&node]) {
let xla_op =
xla_op_slotmap[unda_xla_map[&node]].convert(ty.primitive_type())?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}
Operation::SliceInDim{ node, start, stop, stride, dim } => {
if xla_op_slotmap.contains_key(unda_xla_map[&node]) {
let xla_op =
xla_op_slotmap[unda_xla_map[&node]].slice_in_dim(start, stop, stride, dim)?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}
Operation::ZerosLike(node) => {
if xla_op_slotmap.contains_key(unda_xla_map[&node]) {
let xla_op =
xla_op_slotmap[unda_xla_map[&node]].zeros_like()?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}
}
}
}

let xla_return_vec: Vec<&xla::XlaOp> = returns.into_iter().map(|i| &xla_op_slotmap[unda_xla_map[&i.into()]]).collect();
let xla_return_vec: Vec<&xla::XlaOp> = returns
.into_iter()
.map(|i| &xla_op_slotmap[unda_xla_map[&i.into()]])
.collect();
let xla_return_tuple = builder.tuple(&xla_return_vec.as_slice())?;

let xla_computation = xla_return_tuple.build()?;
Expand Down
Loading

0 comments on commit da45e34

Please sign in to comment.