-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #65 from BradenEverson/common_term_extraction
Implement Common Term Extraction and make tests
- Loading branch information
Showing
4 changed files
with
107 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,90 @@ | ||
use super::*; | ||
use std::collections::HashMap; | ||
use std::collections::{HashMap, HashSet}; | ||
|
||
impl Context { | ||
/// Traverses graph context, building a hashmap of Node -> NodeIdentifier pairs | ||
/// If a duplicate Node is found, we can reference the other NodeIdentifier with | ||
/// the already existant node instead of having duplicates | ||
/// make sure to update entry for the modified node, as the hash will change. | ||
/// do not include callsite when calculating the hash. | ||
pub(crate) fn extract_subterms<A: Into<NodeIdentifier> + Copy>( | ||
pub(crate) fn extract_subterms( | ||
&mut self, | ||
_input: A, | ||
outputs: &[NodeIdentifier], | ||
modification_limit: usize, | ||
) -> Result<bool> { | ||
if modification_limit == 0 { | ||
return Ok(true); | ||
} | ||
let mut _node_map: HashMap<String, NodeIdentifier> = HashMap::new(); | ||
for (mut _identifier, _node) in self.nodes.iter_mut() { | ||
//TODO: Build a HashMap out of all nodes, check if a node already 'exists' | ||
//If node exists, remove all references to its NodeIdentifier and replace with the | ||
//prexisting NodeIdentifier | ||
let mut node_map: HashMap<Node, NodeIdentifier> = HashMap::new(); | ||
|
||
let mut modifications = 0; | ||
let mut changed = false; | ||
|
||
let mut to_visit: Vec<NodeIdentifier> = outputs.to_vec(); | ||
let mut visited: HashSet<NodeIdentifier> = HashSet::new(); | ||
|
||
while let Some(node_id) = to_visit.pop() { | ||
if visited.contains(&node_id) || modifications >= modification_limit { | ||
continue; | ||
} | ||
|
||
if node_map.contains_key(&self.nodes[node_id]) { | ||
self.replace_index(node_id, node_map[&self.nodes[node_id]])?; | ||
modifications += 1; | ||
changed = true; | ||
} else { | ||
node_map.insert(self.nodes[node_id].clone(), node_id); | ||
} | ||
|
||
visited.insert(node_id); | ||
//Add operation nodes to the queue | ||
match self.nodes[node_id].operation { | ||
Operation::Add(a, b) | ||
| Operation::Sub(a, b) | ||
| Operation::Mul(a, b) | ||
| Operation::Div(a, b) | ||
| Operation::NotEqual(a, b) | ||
| Operation::Equal(a, b) | ||
| Operation::LessThan(a, b) | ||
| Operation::GreaterThan(a, b) | ||
| Operation::GreaterThanEq(a, b) | ||
| Operation::LessThanEq(a, b) | ||
| Operation::MatMul(a, b) | ||
| Operation::Pow(a, b) => { | ||
to_visit.push(a); | ||
to_visit.push(b); | ||
} | ||
Operation::Neg(a) | ||
| Operation::StopGradient(a) | ||
| Operation::Log(a) | ||
| Operation::Exp(a) | ||
| Operation::TypeCast(a, _) | ||
| Operation::Transpose(a, _) | ||
| Operation::SliceInDim { node: a, start: _, stop: _, stride: _, dim: _ } | ||
| Operation::TileInDim { node: a, n_tiles: _, dim: _ } | ||
| Operation::Reshape(a) | ||
| Operation::ZerosLike(a) => { | ||
to_visit.push(a); | ||
} | ||
Operation::ReduceMax { node, dim: _ } | ||
| Operation::ReduceMean { node, dim: _ } | ||
| Operation::ReduceSum { node, dim: _ } => { | ||
to_visit.push(node); | ||
} | ||
Operation::Select { pred, on_true, on_false } => { | ||
to_visit.push(pred); | ||
to_visit.push(on_true); | ||
to_visit.push(on_false); | ||
} | ||
Operation::Constant(_) | Operation::Parameter(_) => {} | ||
} | ||
} | ||
|
||
//Recursive recall if we changed something and modifications are still available | ||
if !changed { | ||
return Ok(false); | ||
} else { | ||
return Ok(changed || self.extract_subterms(outputs, modification_limit - modifications)?); | ||
} | ||
Ok(false) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters