Skip to content

Commit

Permalink
added select and zeros like ops
Browse files Browse the repository at this point in the history
  • Loading branch information
Ebanflo42 committed Feb 28, 2024
1 parent 7ba4270 commit 215632d
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 9 deletions.
37 changes: 33 additions & 4 deletions src/core/graph/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,18 @@ impl Context {
self.autodiff(b, modification_limit - (r as usize))
.map(|v| v || r)
}
Operation::Select{ pred: _, on_true, on_false } => {
Operation::Select {
pred: _,
on_true,
on_false,
} => {
let r = self.autodiff(on_true, modification_limit)?;
self.autodiff(on_false, modification_limit - (r as usize))
.map(|v| v || r)
}
Operation::TypeCast(node, ty) => {
self.autodiff(node, modification_limit)
}
Operation::TypeCast(node, ty) => self.autodiff(node, modification_limit),
Operation::SliceInDim { node, .. } => self.autodiff(node, modification_limit),
Operation::ZerosLike(node) => self.autodiff(node, modification_limit),
// finally a Diff node, lets distribute it
Operation::Diff(outer, outer_param) => {
let outer_node = &self.nodes[outer];
Expand Down Expand Up @@ -190,6 +194,31 @@ impl Context {
// rerun autodiff on the node we replaced
self.autodiff(input, modification_limit - 1)
}

/*
Operation::SliceInDim { node, start, stop, stride, dim } => {
let diff_node = Node {
callsite: input_node.callsite.clone(),
shape: self.nodes[node].shape.clone(),
operation: Operation::Diff(node, outer_param),
dtype: self.nodes[node].dtype
};
self.nodes[input].callsite = outer_node.callsite.clone();
let diff_node = self.nodes.insert(diff_node);
let zero_node = self.nodes.insert(Node {
callsite: input_node.callsite.clone(),
shape: node.shape.clone(),
operation: Operation::ZerosLike(node)
});
}
*/
Operation::SliceInDim { node, start, stop, stride, dim } => panic!("Differentiating SliceInDim not yet supported, xla-rs must implement scatter operation."),

Operation::ZerosLike(node) => {
self.nodes[input].operation = Operation::ZerosLike(node);
self.autodiff(input, modification_limit - 1)
}

Operation::Diff(inner, _) => {
// derivative of a derivative, apply the inner one first then try again on the outer.
let r = self.autodiff(inner, modification_limit)?;
Expand Down
34 changes: 34 additions & 0 deletions src/core/graph/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ impl Context {
.push(this_node_id);
self.get_dependent_nodes(node, dep_nodes, constants, parameters)
}
Operation::SliceInDim{ node, .. } => {
dep_nodes
.entry(node)
.or_insert(Vec::new())
.push(this_node_id);
self.get_dependent_nodes(node, dep_nodes, constants, parameters)
}
Operation::Select {
pred,
on_true,
Expand All @@ -92,6 +99,13 @@ impl Context {
self.get_dependent_nodes(on_true, dep_nodes, constants, parameters)?;
self.get_dependent_nodes(on_false, dep_nodes, constants, parameters)
}
Operation::ZerosLike(node) => {
dep_nodes
.entry(node)
.or_insert(Vec::new())
.push(this_node_id);
self.get_dependent_nodes(node, dep_nodes, constants, parameters)
}
}
}

Expand Down Expand Up @@ -338,6 +352,26 @@ impl Context {
covered_ops.insert(*dependent_op);
}
}
Operation::SliceInDim{ node, start, stop, stride, dim } => {
if xla_op_slotmap.contains_key(unda_xla_map[&node]) {
let xla_op =
xla_op_slotmap[unda_xla_map[&node]].slice_in_dim(start, stop, stride, dim)?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}
Operation::ZerosLike(node) => {
if xla_op_slotmap.contains_key(unda_xla_map[&node]) {
let xla_op =
xla_op_slotmap[unda_xla_map[&node]].zeros_like()?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}
}
}
}
Expand Down
35 changes: 30 additions & 5 deletions src/core/graph/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,25 @@ impl Context {
Operation::Diff(a, b) => format!("Diff ({}) {}", self.to_string(a), self.to_string(b)),
Operation::Add(a, b) => format!("Add ({}) ({})", self.to_string(a), self.to_string(b)),
Operation::Mul(a, b) => format!("Mul ({}) ({})", self.to_string(a), self.to_string(b)),
Operation::Equal(a, b) => format!("LessThan ({}) ({})", self.to_string(a), self.to_string(b)),
Operation::LessThan(a, b) => format!("LessThan ({}) ({})", self.to_string(a), self.to_string(b)),
Operation::GreaterThan(a, b) => format!("GreaterThan ({}) ({})", self.to_string(a), self.to_string(b)),
Operation::LessThanEq(a, b) => format!("LessThanEq ({}) ({})", self.to_string(a), self.to_string(b)),
Operation::GreaterThanEq(a, b) => format!("GreaterThanEq ({}) ({})", self.to_string(a), self.to_string(b)),
Operation::Equal(a, b) => {
format!("LessThan ({}) ({})", self.to_string(a), self.to_string(b))
}
Operation::LessThan(a, b) => {
format!("LessThan ({}) ({})", self.to_string(a), self.to_string(b))
}
Operation::GreaterThan(a, b) => format!(
"GreaterThan ({}) ({})",
self.to_string(a),
self.to_string(b)
),
Operation::LessThanEq(a, b) => {
format!("LessThanEq ({}) ({})", self.to_string(a), self.to_string(b))
}
Operation::GreaterThanEq(a, b) => format!(
"GreaterThanEq ({}) ({})",
self.to_string(a),
self.to_string(b)
),
Operation::Select {
pred,
on_true,
Expand All @@ -85,6 +99,17 @@ impl Context {
self.to_string(on_false)
),
Operation::TypeCast(a, ty) => format!("TypeCast {} {}", self.to_string(a), ty),
Operation::SliceInDim {
node,
start,
stop,
stride,
dim,
} => format!(
"SliceInDim {} {} {} {} {}",
self.to_string(node), start, stop, stride, dim
),
Operation::ZerosLike(node) => format!("ZerosLike {}", self.to_string(node))
}
}
}
28 changes: 28 additions & 0 deletions src/core/graph/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,4 +306,32 @@ impl Context {
dtype: dtype
})
}

pub fn slice_in_dim<A: Into<NodeIdentifier> + Copy>(&mut self, a: A, start: i64, stop: i64, stride: i64, dim: i64) {
let a = a.into();
let mut s = Shape::new();
for d in (0..self.nodes[a].shape.ndims()).rev() {
if d as i64 == dim {
s.sizes.push(1)
} else {
s.sizes.push(self.nodes[a].shape.sizes[d])
}
}
self.nodes.insert(Node {
callsite: callsite!(1),
shape: s,
operation: Operation::SliceInDim { node: a, start: start, stop: stop, stride: stride, dim: dim },
dtype: self.nodes[a].dtype
});
}

pub fn zeros_like<A: Into<NodeIdentifier> + Copy>(&mut self, a: A) {
let a = a.into();
self.nodes.insert(Node {
callsite: callsite!(1),
shape: self.nodes[a].shape.clone(),
operation: Operation::ZerosLike(a),
dtype: self.nodes[a].dtype
});
}
}
4 changes: 4 additions & 0 deletions src/core/graph/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ pub enum Operation {
Select{ pred: NodeIdentifier, on_true: NodeIdentifier, on_false: NodeIdentifier },

TypeCast(NodeIdentifier, xla::ElementType),

SliceInDim{ node: NodeIdentifier, start: i64, stop: i64, stride: i64, dim: i64 },

ZerosLike(NodeIdentifier),
}

impl Display for Operation {
Expand Down
21 changes: 21 additions & 0 deletions src/core/graph/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,4 +177,25 @@ mod tests {
println!("{:?}", rust_result);
assert_eq!(rust_result.as_slice(), &[0, 4, 4, 0]);
}

#[test]
fn test_slice_in_dim() {
let mut ctx = Context::new();

let test_const = ctx.const_from_npy("test2.npy").expect("test_const");
let relu = ctx.relu(test_const).expect("relu");

let client = xla::PjRtClient::gpu(0.7, false).expect("client");
let name = "test";
let executable = ctx.compile(&name, [relu], &client).expect("executable");

let device_result = executable.execute::<xla::Literal>(&[]).expect("execute");
let host_result = device_result[0][0]
.to_literal_sync()
.expect("to_literal_sync");
let untupled_result = host_result.to_tuple1().expect("untuple");
let rust_result = untupled_result.to_vec::<i64>().expect("to_vec");
println!("{:?}", rust_result);
assert_eq!(rust_result.as_slice(), &[0, 4, 4, 0]);
}
}

0 comments on commit 215632d

Please sign in to comment.