Skip to content

Commit

Permalink
Merge pull request #65 from BradenEverson/common_term_extraction
Browse files Browse the repository at this point in the history
Implement Common Term Extraction and make tests
  • Loading branch information
BradenEverson authored Mar 25, 2024
2 parents 89b6858 + 83d31ce commit c583aec
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 13 deletions.
6 changes: 3 additions & 3 deletions src/core/graph/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ impl Context {
//while self.foldconsts(a, 1)? {
// println!("{}", self.to_string(a));
//}

for a in returns.iter() {
self.extract_subterms(&returns, usize::MAX)?;
/*for a in returns.iter() {
self.extract_subterms(*a, usize::MAX)?;
}
}*/
//while self.extract_subterms(a, 1)? {
// println!("{}", self.to_string(a));
//}
Expand Down
2 changes: 1 addition & 1 deletion src/core/graph/consteval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ impl Context {
self.dependent_nodes[&node].to_vec()
}

fn replace_index(
pub(crate) fn replace_index(
&mut self,
to_remove: NodeIdentifier,
rep_with: NodeIdentifier,
Expand Down
82 changes: 73 additions & 9 deletions src/core/graph/subterm.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,90 @@
use super::*;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

impl Context {
/// Traverses graph context, building a hashmap of Node -> NodeIdentifier pairs
/// If a duplicate Node is found, we can reference the other NodeIdentifier with
/// the already existant node instead of having duplicates
/// make sure to update entry for the modified node, as the hash will change.
/// do not include callsite when calculating the hash.
pub(crate) fn extract_subterms<A: Into<NodeIdentifier> + Copy>(
pub(crate) fn extract_subterms(
&mut self,
_input: A,
outputs: &[NodeIdentifier],
modification_limit: usize,
) -> Result<bool> {
if modification_limit == 0 {
return Ok(true);
}
let mut _node_map: HashMap<String, NodeIdentifier> = HashMap::new();
for (mut _identifier, _node) in self.nodes.iter_mut() {
//TODO: Build a HashMap out of all nodes, check if a node already 'exists'
//If node exists, remove all references to its NodeIdentifier and replace with the
//prexisting NodeIdentifier
let mut node_map: HashMap<Node, NodeIdentifier> = HashMap::new();

let mut modifications = 0;
let mut changed = false;

let mut to_visit: Vec<NodeIdentifier> = outputs.to_vec();
let mut visited: HashSet<NodeIdentifier> = HashSet::new();

while let Some(node_id) = to_visit.pop() {
if visited.contains(&node_id) || modifications >= modification_limit {
continue;
}

if node_map.contains_key(&self.nodes[node_id]) {
self.replace_index(node_id, node_map[&self.nodes[node_id]])?;
modifications += 1;
changed = true;
} else {
node_map.insert(self.nodes[node_id].clone(), node_id);
}

visited.insert(node_id);
//Add operation nodes to the queue
match self.nodes[node_id].operation {
Operation::Add(a, b)
| Operation::Sub(a, b)
| Operation::Mul(a, b)
| Operation::Div(a, b)
| Operation::NotEqual(a, b)
| Operation::Equal(a, b)
| Operation::LessThan(a, b)
| Operation::GreaterThan(a, b)
| Operation::GreaterThanEq(a, b)
| Operation::LessThanEq(a, b)
| Operation::MatMul(a, b)
| Operation::Pow(a, b) => {
to_visit.push(a);
to_visit.push(b);
}
Operation::Neg(a)
| Operation::StopGradient(a)
| Operation::Log(a)
| Operation::Exp(a)
| Operation::TypeCast(a, _)
| Operation::Transpose(a, _)
| Operation::SliceInDim { node: a, start: _, stop: _, stride: _, dim: _ }
| Operation::TileInDim { node: a, n_tiles: _, dim: _ }
| Operation::Reshape(a)
| Operation::ZerosLike(a) => {
to_visit.push(a);
}
Operation::ReduceMax { node, dim: _ }
| Operation::ReduceMean { node, dim: _ }
| Operation::ReduceSum { node, dim: _ } => {
to_visit.push(node);
}
Operation::Select { pred, on_true, on_false } => {
to_visit.push(pred);
to_visit.push(on_true);
to_visit.push(on_false);
}
Operation::Constant(_) | Operation::Parameter(_) => {}
}
}

//Recursive recall if we changed something and modifications are still available
if !changed {
return Ok(false);
} else {
return Ok(changed || self.extract_subterms(outputs, modification_limit - modifications)?);
}
Ok(false)
}
}
30 changes: 30 additions & 0 deletions src/core/graph/tests_cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,36 @@ mod tests {
create_test!(test_add_1_2, add, F32, 1f32, 2f32, 3f32);
create_test!(test_sub_1_2, sub, F32, 1f32, 2f32, -1f32);

#[test]
fn test_cte_happened() {
let mut ctx = Context::new();
let a = ctx.parameter("a", [], xla::ElementType::F32).expect("a");
let b = ctx.parameter("b", [], xla::ElementType::F32).expect("b");

let a_p_b = ctx.add(a, b).expect("a + b");
let a_p_b_again = ctx.add(a, b).expect("a + b again");

let res = ctx.mul(a_p_b, a_p_b_again).expect("(a + b) * (a + b)");
let subterm_extract = ctx.extract_subterms(&[res], 10).expect("CTE");

assert!(subterm_extract);
}

#[test]
fn test_cte_no_false_positives() {
let mut ctx = Context::new();
let a = ctx.parameter("a", [], xla::ElementType::F32).expect("a");
let b = ctx.parameter("b", [], xla::ElementType::F32).expect("b");

let a_p_b = ctx.add(a, b).expect("a + b");

let c = ctx.parameter("c", [], xla::ElementType::F32).expect("c");
let res = ctx.mul(a_p_b, c).expect("(a+b) * c");
let subterm_extract = ctx.extract_subterms(&[res], 10).expect("CTE");

assert!(!subterm_extract);
}

#[test]
fn test_hash_node() {
let mut ctx = Context::new();
Expand Down

0 comments on commit c583aec

Please sign in to comment.