Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create new test for CTE and prevent it from trying to replace a node if it already exists in the slotmap under the same NodeIdentifier #68

Merged
merged 2 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/core/graph/constant.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::*;
use std::{path::Path, hash::Hash};
use std::{path::Path};
use xla::FromRawBytes;

#[derive(Debug, Clone)]
Expand Down
4 changes: 2 additions & 2 deletions src/core/graph/consteval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ use super::*;
impl Context {
fn collect_deps(&self, node: NodeIdentifier) -> Vec<NodeIdentifier> {
if self.dependent_nodes.contains_key(&node) {
return self.dependent_nodes[&node].to_vec();
self.dependent_nodes[&node].to_vec()
} else {
return vec![];
vec![]
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/core/graph/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ impl Hash for Operation {

impl PartialEq for Operation {
fn eq(&self, other: &Self) -> bool {
return match (&self, &other) {
match (&self, &other) {
//Order not matering. Ex: 1 + 2 equals 2 + 1, but 1 / 2 doesnt equal 2 /1 so we can
//check these separately
(&Self::Mul(a, b), &Self::Mul(c, d))
Expand Down
76 changes: 38 additions & 38 deletions src/core/graph/subterm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,56 +28,56 @@ impl Context {
continue;
}

if node_map.contains_key(&self.nodes[node_id]) {
if node_map.contains_key(&self.nodes[node_id]) && node_map[&self.nodes[node_id]] != node_id {
self.replace_index(node_id, node_map[&self.nodes[node_id]])?;
modifications += 1;
changed = true;
} else {
node_map.insert(self.nodes[node_id].clone(), node_id);
}

visited.insert(node_id);
//Add operation nodes to the queue
match self.nodes[node_id].operation {
Operation::Add(a, b)
| Operation::Sub(a, b)
| Operation::Mul(a, b)
| Operation::Div(a, b)
| Operation::NotEqual(a, b)
| Operation::Equal(a, b)
| Operation::LessThan(a, b)
| Operation::GreaterThan(a, b)
| Operation::GreaterThanEq(a, b)
| Operation::LessThanEq(a, b)
| Operation::MatMul(a, b)
| Operation::Pow(a, b) => {
to_visit.push(a);
to_visit.push(b);
}
Operation::Neg(a)
| Operation::StopGradient(a)
| Operation::Log(a)
| Operation::Exp(a)
| Operation::TypeCast(a, _)
| Operation::Transpose(a, _)
| Operation::SliceInDim { node: a, start: _, stop: _, stride: _, dim: _ }
visited.insert(node_id);
//Add operation nodes to the queue
match self.nodes[node_id].operation {
Operation::Add(a, b)
| Operation::Sub(a, b)
| Operation::Mul(a, b)
| Operation::Div(a, b)
| Operation::NotEqual(a, b)
| Operation::Equal(a, b)
| Operation::LessThan(a, b)
| Operation::GreaterThan(a, b)
| Operation::GreaterThanEq(a, b)
| Operation::LessThanEq(a, b)
| Operation::MatMul(a, b)
| Operation::Pow(a, b) => {
to_visit.push(a);
to_visit.push(b);
}
Operation::Neg(a)
| Operation::StopGradient(a)
| Operation::Log(a)
| Operation::Exp(a)
| Operation::TypeCast(a, _)
| Operation::Transpose(a, _)
| Operation::SliceInDim { node: a, start: _, stop: _, stride: _, dim: _ }
| Operation::TileInDim { node: a, n_tiles: _, dim: _ }
| Operation::Reshape(a)
| Operation::ZerosLike(a) => {
to_visit.push(a);
}
Operation::ReduceMax { node, dim: _ }
| Operation::ZerosLike(a) => {
to_visit.push(a);
}
Operation::ReduceMax { node, dim: _ }
| Operation::ReduceMean { node, dim: _ }
| Operation::ReduceSum { node, dim: _ } => {
to_visit.push(node);
}
Operation::Select { pred, on_true, on_false } => {
to_visit.push(pred);
to_visit.push(on_true);
to_visit.push(on_false);
Operation::Select { pred, on_true, on_false } => {
to_visit.push(pred);
to_visit.push(on_true);
to_visit.push(on_false);
}
Operation::Constant(_) | Operation::Parameter(_) => {}
}
Operation::Constant(_) | Operation::Parameter(_) => {}
node_map.insert(self.nodes[node_id].clone(), node_id);
}

}

//Recursive recall if we changed something and modifications are still available
Expand Down
30 changes: 30 additions & 0 deletions src/core/graph/tests_cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,36 @@ mod tests {
create_test!(test_add_1_2, add, F32, 1f32, 2f32, 3f32);
create_test!(test_sub_1_2, sub, F32, 1f32, 2f32, -1f32);

#[test]
fn test_large_cte() {
let mut ctx = Context::new();
let a = ctx.parameter("a", [], xla::ElementType::F32).expect("a");
let two = ctx.scalar(2, xla::ElementType::F32).expect("2");

let a_2 = ctx.pow(a, two).expect("a^2");
let a_21 = ctx.pow(a, two).expect("a^2");
let a_22 = ctx.pow(a, two).expect("a^2");
let a_23 = ctx.pow(a, two).expect("a^2");
let a_24 = ctx.pow(a, two).expect("a^2");
let a_25 = ctx.pow(a, two).expect("a^2");
let a_26 = ctx.pow(a, two).expect("a^2");
let a_27 = ctx.pow(a, two).expect("a^2");


let sum1 = ctx.add(a_2, a_21).expect("a^2 + a^2");
let sum2 = ctx.add(a_22, a_23).expect("a^2 + a^2");
let sum3 = ctx.add(a_24, a_25).expect("a^2 + a^2");
let sum4 = ctx.add(a_26, a_27).expect("a^2 + a^2");

let nest_sum1 = ctx.add(sum1, sum2).expect("(a^2 + a^2) + (a^2 + a^2)");
let nest_sum2 = ctx.add(sum3, sum4).expect("(a^2 + a^2) + (a^2 + a^2)");

let res = ctx.add(nest_sum1, nest_sum2).expect("((a^2 + a^2) + (a^2 + a^2)) + ((a^2 + a^2) + (a^2 + a^2))");
let subterm_extract = ctx.extract_subterms(&[res], usize::MAX).expect("CTE");

assert!(subterm_extract);
}

#[test]
fn test_cte_happened() {
let mut ctx = Context::new();
Expand Down
Loading