diff --git a/src/core/graph/tests_cpu.rs b/src/core/graph/tests_cpu.rs index 999e26f..b387f8c 100644 --- a/src/core/graph/tests_cpu.rs +++ b/src/core/graph/tests_cpu.rs @@ -68,6 +68,36 @@ mod tests { create_test!(test_add_1_2, add, F32, 1f32, 2f32, 3f32); create_test!(test_sub_1_2, sub, F32, 1f32, 2f32, -1f32); + #[test] + fn test_cte_happened() { + let mut ctx = Context::new(); + let a = ctx.parameter("a", [], xla::ElementType::F32).expect("a"); + let b = ctx.parameter("b", [], xla::ElementType::F32).expect("b"); + + let a_p_b = ctx.add(a, b).expect("a + b"); + let a_p_b_again = ctx.add(a, b).expect("a + b again"); + + let res = ctx.mul(a_p_b, a_p_b_again).expect("(a + b) * (a + b)"); + let subterm_extract = ctx.extract_subterms(&[res], 10).expect("CTE"); + + assert!(subterm_extract); + } + + #[test] + fn test_cte_no_false_positives() { + let mut ctx = Context::new(); + let a = ctx.parameter("a", [], xla::ElementType::F32).expect("a"); + let b = ctx.parameter("b", [], xla::ElementType::F32).expect("b"); + + let a_p_b = ctx.add(a, b).expect("a + b"); + + let c = ctx.parameter("c", [], xla::ElementType::F32).expect("c"); + let res = ctx.mul(a_p_b, c).expect("(a+b) * c"); + let subterm_extract = ctx.extract_subterms(&[res], 10).expect("CTE"); + + assert!(!subterm_extract); + } + #[test] fn test_hash_node() { let mut ctx = Context::new();