Skip to content

Commit

Permalink
Add tests(and test false positives cause I was nervous lol)
Browse files Browse the repository at this point in the history
  • Loading branch information
BradenEverson committed Mar 24, 2024
1 parent 9d18fe3 commit 83d31ce
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions src/core/graph/tests_cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,36 @@ mod tests {
create_test!(test_add_1_2, add, F32, 1f32, 2f32, 3f32);
create_test!(test_sub_1_2, sub, F32, 1f32, 2f32, -1f32);

#[test]
fn test_cte_happened() {
let mut ctx = Context::new();
let a = ctx.parameter("a", [], xla::ElementType::F32).expect("a");
let b = ctx.parameter("b", [], xla::ElementType::F32).expect("b");

let a_p_b = ctx.add(a, b).expect("a + b");
let a_p_b_again = ctx.add(a, b).expect("a + b again");

let res = ctx.mul(a_p_b, a_p_b_again).expect("(a + b) * (a + b)");
let subterm_extract = ctx.extract_subterms(&[res], 10).expect("CTE");

assert!(subterm_extract);
}

#[test]
fn test_cte_no_false_positives() {
let mut ctx = Context::new();
let a = ctx.parameter("a", [], xla::ElementType::F32).expect("a");
let b = ctx.parameter("b", [], xla::ElementType::F32).expect("b");

let a_p_b = ctx.add(a, b).expect("a + b");

let c = ctx.parameter("c", [], xla::ElementType::F32).expect("c");
let res = ctx.mul(a_p_b, c).expect("(a+b) * c");
let subterm_extract = ctx.extract_subterms(&[res], 10).expect("CTE");

assert!(!subterm_extract);
}

#[test]
fn test_hash_node() {
let mut ctx = Context::new();
Expand Down

0 comments on commit 83d31ce

Please sign in to comment.