Skip to content

Commit

Permalink
Merge pull request #68 from BradenEverson/cte_test
Browse files Browse the repository at this point in the history
Create new test for CTE and prevent it from trying to replace a node if it already exists in the slotmap under the same NodeIdentifier
  • Loading branch information
BradenEverson authored Mar 28, 2024
2 parents ed8729d + 53e99b3 commit fe6af89
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 42 deletions.
2 changes: 1 addition & 1 deletion src/core/graph/constant.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::*;
use std::{path::Path, hash::Hash};
use std::{path::Path};
use xla::FromRawBytes;

#[derive(Debug, Clone)]
Expand Down
4 changes: 2 additions & 2 deletions src/core/graph/consteval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ use super::*;
impl Context {
fn collect_deps(&self, node: NodeIdentifier) -> Vec<NodeIdentifier> {
if self.dependent_nodes.contains_key(&node) {
return self.dependent_nodes[&node].to_vec();
self.dependent_nodes[&node].to_vec()
} else {
return vec![];
vec![]
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/core/graph/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ impl Hash for Operation {

impl PartialEq for Operation {
fn eq(&self, other: &Self) -> bool {
return match (&self, &other) {
match (&self, &other) {
//Order not matering. Ex: 1 + 2 equals 2 + 1, but 1 / 2 doesnt equal 2 /1 so we can
//check these separately
(&Self::Mul(a, b), &Self::Mul(c, d))
Expand Down
76 changes: 38 additions & 38 deletions src/core/graph/subterm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,56 +28,56 @@ impl Context {
continue;
}

if node_map.contains_key(&self.nodes[node_id]) {
if node_map.contains_key(&self.nodes[node_id]) && node_map[&self.nodes[node_id]] != node_id {
self.replace_index(node_id, node_map[&self.nodes[node_id]])?;
modifications += 1;
changed = true;
} else {
node_map.insert(self.nodes[node_id].clone(), node_id);
}

visited.insert(node_id);
//Add operation nodes to the queue
match self.nodes[node_id].operation {
Operation::Add(a, b)
| Operation::Sub(a, b)
| Operation::Mul(a, b)
| Operation::Div(a, b)
| Operation::NotEqual(a, b)
| Operation::Equal(a, b)
| Operation::LessThan(a, b)
| Operation::GreaterThan(a, b)
| Operation::GreaterThanEq(a, b)
| Operation::LessThanEq(a, b)
| Operation::MatMul(a, b)
| Operation::Pow(a, b) => {
to_visit.push(a);
to_visit.push(b);
}
Operation::Neg(a)
| Operation::StopGradient(a)
| Operation::Log(a)
| Operation::Exp(a)
| Operation::TypeCast(a, _)
| Operation::Transpose(a, _)
| Operation::SliceInDim { node: a, start: _, stop: _, stride: _, dim: _ }
visited.insert(node_id);
//Add operation nodes to the queue
match self.nodes[node_id].operation {
Operation::Add(a, b)
| Operation::Sub(a, b)
| Operation::Mul(a, b)
| Operation::Div(a, b)
| Operation::NotEqual(a, b)
| Operation::Equal(a, b)
| Operation::LessThan(a, b)
| Operation::GreaterThan(a, b)
| Operation::GreaterThanEq(a, b)
| Operation::LessThanEq(a, b)
| Operation::MatMul(a, b)
| Operation::Pow(a, b) => {
to_visit.push(a);
to_visit.push(b);
}
Operation::Neg(a)
| Operation::StopGradient(a)
| Operation::Log(a)
| Operation::Exp(a)
| Operation::TypeCast(a, _)
| Operation::Transpose(a, _)
| Operation::SliceInDim { node: a, start: _, stop: _, stride: _, dim: _ }
| Operation::TileInDim { node: a, n_tiles: _, dim: _ }
| Operation::Reshape(a)
| Operation::ZerosLike(a) => {
to_visit.push(a);
}
Operation::ReduceMax { node, dim: _ }
| Operation::ZerosLike(a) => {
to_visit.push(a);
}
Operation::ReduceMax { node, dim: _ }
| Operation::ReduceMean { node, dim: _ }
| Operation::ReduceSum { node, dim: _ } => {
to_visit.push(node);
}
Operation::Select { pred, on_true, on_false } => {
to_visit.push(pred);
to_visit.push(on_true);
to_visit.push(on_false);
Operation::Select { pred, on_true, on_false } => {
to_visit.push(pred);
to_visit.push(on_true);
to_visit.push(on_false);
}
Operation::Constant(_) | Operation::Parameter(_) => {}
}
Operation::Constant(_) | Operation::Parameter(_) => {}
node_map.insert(self.nodes[node_id].clone(), node_id);
}

}

//Recursive recall if we changed something and modifications are still available
Expand Down
30 changes: 30 additions & 0 deletions src/core/graph/tests_cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,36 @@ mod tests {
create_test!(test_add_1_2, add, F32, 1f32, 2f32, 3f32);
create_test!(test_sub_1_2, sub, F32, 1f32, 2f32, -1f32);

#[test]
fn test_large_cte() {
let mut ctx = Context::new();
let a = ctx.parameter("a", [], xla::ElementType::F32).expect("a");
let two = ctx.scalar(2, xla::ElementType::F32).expect("2");

let a_2 = ctx.pow(a, two).expect("a^2");
let a_21 = ctx.pow(a, two).expect("a^2");
let a_22 = ctx.pow(a, two).expect("a^2");
let a_23 = ctx.pow(a, two).expect("a^2");
let a_24 = ctx.pow(a, two).expect("a^2");
let a_25 = ctx.pow(a, two).expect("a^2");
let a_26 = ctx.pow(a, two).expect("a^2");
let a_27 = ctx.pow(a, two).expect("a^2");


let sum1 = ctx.add(a_2, a_21).expect("a^2 + a^2");
let sum2 = ctx.add(a_22, a_23).expect("a^2 + a^2");
let sum3 = ctx.add(a_24, a_25).expect("a^2 + a^2");
let sum4 = ctx.add(a_26, a_27).expect("a^2 + a^2");

let nest_sum1 = ctx.add(sum1, sum2).expect("(a^2 + a^2) + (a^2 + a^2)");
let nest_sum2 = ctx.add(sum3, sum4).expect("(a^2 + a^2) + (a^2 + a^2)");

let res = ctx.add(nest_sum1, nest_sum2).expect("((a^2 + a^2) + (a^2 + a^2)) + ((a^2 + a^2) + (a^2 + a^2))");
let subterm_extract = ctx.extract_subterms(&[res], usize::MAX).expect("CTE");

assert!(subterm_extract);
}

#[test]
fn test_cte_happened() {
let mut ctx = Context::new();
Expand Down

0 comments on commit fe6af89

Please sign in to comment.