Skip to content

Commit

Permalink
Merge pull request #39 from Ebanflo42/graph/backprop
Browse files Browse the repository at this point in the history
  • Loading branch information
BradenEverson authored Mar 1, 2024
2 parents da45e34 + ea510fd commit 76bb657
Show file tree
Hide file tree
Showing 10 changed files with 731 additions and 476 deletions.
341 changes: 143 additions & 198 deletions src/core/graph/autodiff.rs

Large diffs are not rendered by default.

219 changes: 92 additions & 127 deletions src/core/graph/compile.rs

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions src/core/graph/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ impl Context {
operation: Operation::Constant(ConstantBinding { value }),
dtype,
});
self.constants.push(node_id);
Ok(node_id)
}

Expand All @@ -73,6 +74,7 @@ impl Context {
operation: Operation::Constant(ConstantBinding { value }),
dtype: T::TY,
});
self.constants.push(node_id);
Ok(node_id)
}

Expand All @@ -94,6 +96,7 @@ impl Context {
operation: Operation::Constant(ConstantBinding { value: reshaped }),
dtype: T::TY,
});
self.constants.push(node_id);
Ok(node_id)
}

Expand All @@ -108,6 +111,7 @@ impl Context {
operation: Operation::Constant(ConstantBinding { value: l }),
dtype: t,
});
self.constants.push(node_id);
Ok(node_id)
}

Expand All @@ -134,6 +138,7 @@ impl Context {
operation: Operation::Constant(ConstantBinding { value: new_value }),
dtype: self.nodes[const_id].dtype,
});
self.constants.push(node_id);
Ok(node_id)
}

Expand All @@ -153,6 +158,7 @@ impl Context {
operation: Operation::Constant(ConstantBinding { value: new_value }),
dtype: new_type,
});
self.constants.push(node_id);
Ok(node_id)
}
}
37 changes: 29 additions & 8 deletions src/core/graph/context.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
use std::collections::HashMap;

use super::*;
use slotmap::SlotMap;

/// XLA computation graph context.
// TODO: rename this to something meaningful
pub struct Context {
pub(crate) nodes: SlotMap<NodeIdentifier, Node>,
pub(crate) param_indices: Vec<NodeIdentifier>,
pub(crate) constants: Vec<NodeIdentifier>,
pub(crate) parameters: Vec<NodeIdentifier>,
pub(crate) dependent_nodes: HashMap<NodeIdentifier, Vec<NodeIdentifier>>,
}

impl Default for Context {
Expand Down Expand Up @@ -44,7 +48,10 @@ pub enum ContextError {
MultipleReturns(),

#[error("Operation is not differentiable, to use it as a constant in a differentiable computation, wrap it with Context::stop_gradient.")]
NonDifferentiableError(Callsite),
NonDifferentiableOpError(Callsite),

#[error("Type is not differentiable, differentiable types are F16, Bf16, F32, F64, C64, C128")]
NonDifferentiableTypeError(Callsite),
}

pub type Result<T> = std::result::Result<T, ContextError>;
Expand All @@ -53,7 +60,9 @@ impl Context {
pub fn new() -> Self {
Self {
nodes: SlotMap::with_key(),
param_indices: Vec::new(),
constants: Vec::new(),
parameters: Vec::new(),
dependent_nodes: HashMap::new(),
}
}

Expand All @@ -64,14 +73,18 @@ impl Context {
Operation::Constant(a) => format!("Constant {} {}", input_node.shape, a),
Operation::Parameter(a) => format!("Parameter {} {}", input_node.shape, a),
Operation::StopGradient(a) => {
format!("StopGradient {} {}", input_node.shape, self.to_string(a))
format!("StopGradient ({})", self.to_string(a))
}
Operation::Diff(a, b) => format!("Diff ({}) {}", self.to_string(a), self.to_string(b)),
Operation::Add(a, b) => format!("Add ({}) ({})", self.to_string(a), self.to_string(b)),
Operation::Sub(a, b) => format!("Sub ({}) ({})", self.to_string(a), self.to_string(b)),
Operation::Mul(a, b) => format!("Mul ({}) ({})", self.to_string(a), self.to_string(b)),
Operation::Neg(a) => format!("Neg ({})", self.to_string(a)),
Operation::Equal(a, b) => {
format!("LessThan ({}) ({})", self.to_string(a), self.to_string(b))
}
Operation::NotEqual(a, b) => {
format!("NotEqual ({}) ({})", self.to_string(a), self.to_string(b))
}
Operation::LessThan(a, b) => {
format!("LessThan ({}) ({})", self.to_string(a), self.to_string(b))
}
Expand All @@ -98,18 +111,26 @@ impl Context {
self.to_string(on_true),
self.to_string(on_false)
),
Operation::TypeCast(a, ty) => format!("TypeCast {} {}", self.to_string(a), ty),
Operation::TypeCast(a, ty) => format!("TypeCast ({}) {}", self.to_string(a), ty),
Operation::SliceInDim {
node,
start,
stop,
stride,
dim,
} => format!(
"SliceInDim {} {} {} {} {}",
"SliceInDim ({}) {} {} {} {}",
self.to_string(node), start, stop, stride, dim
),
Operation::ZerosLike(node) => format!("ZerosLike {}", self.to_string(node))
Operation::ZerosLike(node) => format!("ZerosLike {}", self.to_string(node)),
Operation::ReduceMax {
node,
dim,
keepdims,
} => format!(
"SliceInDim {} {} {}",
self.to_string(node), dim, keepdims
),
}
}
}
62 changes: 40 additions & 22 deletions src/core/graph/logic.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
use super::*;

impl Context {

pub fn select<A: Into<NodeIdentifier> + Copy, B: Into<NodeIdentifier> + Copy, C: Into<NodeIdentifier> + Copy>(
/// TODO: Typecheck pred
pub fn select(
&mut self,
pred: A,
on_true: B,
on_false: C
pred: NodeIdentifier,
on_true: NodeIdentifier,
on_false: NodeIdentifier,
) -> Result<NodeIdentifier> {
let pred = pred.into();
let pred = self.stop_gradient(pred);
let on_true = on_true.into();
let on_false = on_false.into();
let node_pred = &self.nodes[pred];
Expand All @@ -28,25 +29,42 @@ impl Context {
node_false.shape.clone(),
callsite!(1),
)),
Some(s) => {
match s.broadcast(&node_pred.shape) {
None => Err(ContextError::IncompatibleOperandShapes(
s,
node_pred.shape.clone(),
callsite!(1),
)),
Some(sh) => {
let node = Node {
callsite: callsite!(1),
shape: sh,
operation: Operation::Select{ pred: pred, on_true: on_true, on_false: on_false },
dtype: node_true.dtype,
};
Ok(self.nodes.insert(node))
Some(s) => match s.broadcast(&node_pred.shape) {
None => Err(ContextError::IncompatibleOperandShapes(
s,
node_pred.shape.clone(),
callsite!(1),
)),
Some(sh) => {
let node = Node {
callsite: callsite!(1),
shape: sh,
operation: Operation::Select {
pred: pred,
on_true: on_true,
on_false: on_false,
},
dtype: node_true.dtype,
};
let node_id = self.nodes.insert(node);
self.dependent_nodes
.entry(pred)
.or_insert(Vec::new())
.push(node_id);
self.dependent_nodes
.entry(on_true)
.or_insert(Vec::new())
.push(node_id);
if on_true != on_false {
self.dependent_nodes
.entry(on_false)
.or_insert(Vec::new())
.push(node_id);
}
Ok(node_id)
}
}
},
}
}
}
}
}
Loading

0 comments on commit 76bb657

Please sign in to comment.