Skip to content

Commit

Permalink
Removed outer match branch
Browse files Browse the repository at this point in the history
  • Loading branch information
BradenEverson committed Mar 4, 2024
1 parent 253bc2b commit 93539f6
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 70 deletions.
4 changes: 2 additions & 2 deletions src/core/graph/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ pub enum CompileError {
}

impl Context {
pub fn compile<A: Into<NodeIdentifier> + Copy, const N: usize>(
pub fn compile<const N: usize>(
&mut self,
name: &str,
returns: [A; N],
returns: [NodeIdentifier; N],
client: &xla::PjRtClient,
) -> Result<xla::PjRtLoadedExecutable> {
// TODO: gate debug mode behind a feature flag
Expand Down
138 changes: 70 additions & 68 deletions src/core/graph/consteval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ impl Context {
.collect::<Vec<NodeIdentifier>>()
}

fn replace_index<A: Into<NodeIdentifier> + Copy>(
fn replace_index(
&mut self,
to_remove: A,
rep_with: A
to_remove: NodeIdentifier,
rep_with: NodeIdentifier
) -> Result<bool> {
let mut changed = false;

Expand Down Expand Up @@ -117,9 +117,9 @@ impl Context {
/// with a Constant of the result of the operation. All existing references to
/// the old node will still point to it once its replaced, and this process is
/// repeated until there are no more nodes whose inputs are all constants.
pub(crate) fn fold_consts<A: Into<NodeIdentifier> + Copy>(
pub(crate) fn fold_consts(
&mut self,
input: A,
input: NodeIdentifier,
modification_limit: usize,
) -> Result<bool> {
if modification_limit == 0 {
Expand All @@ -137,69 +137,18 @@ impl Context {
continue;
}
match self.nodes[node_id].operation {
Operation::Add(a, b) | Operation::Sub(a, b) | Operation::Mul(a, b) => {
if self.nodes.contains_key(a.into()) && self.nodes.contains_key(b.into()) {

match self.nodes[node_id].operation {
Operation::Add(_, _) | Operation::Sub(_, _) => {
if self.nodes[a].is_zero()? {
self.replace_index(node_id, b)?;
modifications += 1;
changed = true;

} else if self.nodes[b].is_zero()? {
self.replace_index(node_id, a)?;
modifications += 1;
changed = true;
}
},
Operation::Mul(_, _) => {
if self.nodes[a].is_zero()? {
self.replace_index(node_id, a)?;
modifications += 1;
changed = true;
} else if let Some(literal) = self.nodes[a].is_const() {
//Check for mul by 1
let floating_literal: Vec<f32> = literal.convert(xla::PrimitiveType::F32)?.to_vec()?;
let mut all_one = true;
floating_literal.iter().for_each(|elem| {
if elem != &1f32 {
all_one = false;
}
});
if all_one {
//a is all ones, replace node_id with a
self.replace_index(node_id, b)?;
modifications += 1;
changed = true;
}
} else if self.nodes[b].is_zero()?{
self.replace_index(node_id, b)?;
modifications += 1;
changed = true;
} else if let Some(literal) = self.nodes[b].is_const() {
//Check for mul by 1
let floating_literal: Vec<f32> = literal.convert(xla::PrimitiveType::F32)?.to_vec()?;
let mut all_one = true;
floating_literal.iter().for_each(|elem| {
if elem != &1f32 {
all_one = false;
}
});
if all_one {
//b is all ones, replace node_id with a
self.replace_index(node_id, a)?;
modifications += 1;
changed = true;
}
}
},
_ => {
unreachable!("Cannot fold parameters of a node that isn't mul, add or sub for now")
}

};

Operation::Add(a, b)
| Operation::Sub(a, b) => {
if self.nodes[a].is_zero()? {
self.replace_index(node_id, b)?;
modifications += 1;
changed = true;

} else if self.nodes[b].is_zero()? {
self.replace_index(node_id, a)?;
modifications += 1;
changed = true;
}
modifications += 1;
//Enqueue the dependent nodes to check both of them for constant
//mul/adding
Expand All @@ -212,7 +161,60 @@ impl Context {
if self.nodes.get(b.into()).unwrap().is_const().is_none() {
to_visit.push(b.into());
}
},
Operation::Mul(a, b) => {
if self.nodes[a].is_zero()? {
self.replace_index(node_id, a)?;
modifications += 1;
changed = true;
} else if let Some(literal) = self.nodes[a].is_const() {
//Check for mul by 1
let floating_literal: Vec<f32> = literal.convert(xla::PrimitiveType::F32)?.to_vec()?;
let mut all_one = true;
floating_literal.iter().for_each(|elem| {
if elem != &1f32 {
all_one = false;
}
});
if all_one {
//a is all ones, replace node_id with a
self.replace_index(node_id, b)?;
modifications += 1;
changed = true;
}
} else if self.nodes[b].is_zero()?{
self.replace_index(node_id, b)?;
modifications += 1;
changed = true;
} else if let Some(literal) = self.nodes[b].is_const() {
//Check for mul by 1
let floating_literal: Vec<f32> = literal.convert(xla::PrimitiveType::F32)?.to_vec()?;
let mut all_one = true;
floating_literal.iter().for_each(|elem| {
if elem != &1f32 {
all_one = false;
}
});
if all_one {
//b is all ones, replace node_id with a
self.replace_index(node_id, a)?;
modifications += 1;
changed = true;
}
}
modifications += 1;
//Enqueue the dependent nodes to check both of them for constant
//mul/adding

//TODO: Once we create a new Node based on the constant propegation,
//use insert_with_key to 'replace existant node'
if self.nodes.get(a.into()).unwrap().is_const().is_none() {
to_visit.push(a.into());
}
if self.nodes.get(b.into()).unwrap().is_const().is_none() {
to_visit.push(b.into());
}

},
Operation::GreaterThan(a, b)
| Operation::GreaterThanEq(a, b)
Expand Down

0 comments on commit 93539f6

Please sign in to comment.