Skip to content

Commit

Permalink
Removed outer match for replace and created a == b situation
Browse files Browse the repository at this point in the history
  • Loading branch information
BradenEverson committed Mar 5, 2024
1 parent 3b28b3e commit 23b8a35
Showing 1 changed file with 127 additions and 60 deletions.
187 changes: 127 additions & 60 deletions src/core/graph/consteval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,75 +26,138 @@ impl Context {

for dep_node in deps {
match self.nodes[dep_node].operation {
Operation::Add(a, b)
| Operation::Sub(a, b)
| Operation::Mul(a, b)
| Operation::GreaterThan(a, b)
| Operation::GreaterThanEq(a, b)
| Operation::Equal(a, b)
| Operation::NotEqual(a, b)
| Operation::LessThan(a, b)
| Operation::LessThanEq(a, b) => {
Operation::Add(a, b) => {
if a == b {
match self.nodes[dep_node].operation {
Operation::Add(_, _) => self.nodes[dep_node].operation = Operation::Add(rep_with, rep_with),
Operation::Sub(_, _) => self.nodes[dep_node].operation = Operation::Sub(rep_with, rep_with),
Operation::Mul(_, _) => self.nodes[dep_node].operation = Operation::Mul(rep_with, rep_with),
Operation::GreaterThan(_, _) => self.nodes[dep_node].operation = Operation::GreaterThan(rep_with, rep_with),
Operation::GreaterThanEq(_, _) => self.nodes[dep_node].operation = Operation::GreaterThanEq(rep_with, rep_with),
Operation::Equal(_, _) => self.nodes[dep_node].operation = Operation::Equal(rep_with, rep_with),
Operation::NotEqual(_, _) => self.nodes[dep_node].operation = Operation::NotEqual(rep_with, rep_with),
Operation::LessThan(_, _) => self.nodes[dep_node].operation = Operation::LessThan(rep_with, rep_with),
Operation::LessThanEq(_, _) => self.nodes[dep_node].operation = Operation::LessThanEq(rep_with, rep_with),
_ => unreachable!("Only add, sub, mul, gt, gte, e, ne, lt and lte can be reached here")
}
self.nodes[dep_node].operation = Operation::Add(rep_with, rep_with);
changed = true;
} else if a == to_remove {
self.nodes[dep_node].operation = Operation::Add(rep_with, b);
changed = true;
} else if b == to_remove {
self.nodes[dep_node].operation = Operation::Add(a, rep_with);
changed = true;
}

} if a == to_remove {
match self.nodes[dep_node].operation {
Operation::Add(_, _) => self.nodes[dep_node].operation = Operation::Add(rep_with, b),
Operation::Sub(_, _) => self.nodes[dep_node].operation = Operation::Sub(rep_with, b),
Operation::Mul(_, _) => self.nodes[dep_node].operation = Operation::Mul(rep_with, b),
Operation::GreaterThan(_, _) => self.nodes[dep_node].operation = Operation::GreaterThan(rep_with, b),
Operation::GreaterThanEq(_, _) => self.nodes[dep_node].operation = Operation::GreaterThanEq(rep_with, b),
Operation::Equal(_, _) => self.nodes[dep_node].operation = Operation::Equal(rep_with, b),
Operation::NotEqual(_, _) => self.nodes[dep_node].operation = Operation::NotEqual(rep_with, b),
Operation::LessThan(_, _) => self.nodes[dep_node].operation = Operation::LessThan(rep_with, b),
Operation::LessThanEq(_, _) => self.nodes[dep_node].operation = Operation::LessThanEq(rep_with, b),
_ => unreachable!("Only add, sub, mul, gt, gte, e, ne, lt and lte can be reached here")
}
},
Operation::Sub(a, b) => {
if a == b {
self.nodes[dep_node].operation = Operation::Sub(rep_with, rep_with);
changed = true;
} else if a == to_remove {
self.nodes[dep_node].operation = Operation::Sub(rep_with, b);
changed = true;
} else if b == to_remove {
match self.nodes[dep_node].operation {
Operation::Add(_, _) => self.nodes[dep_node].operation = Operation::Add(a, rep_with),
Operation::Sub(_, _) => self.nodes[dep_node].operation = Operation::Sub(a, rep_with),
Operation::Mul(_, _) => self.nodes[dep_node].operation = Operation::Mul(a, rep_with),
Operation::GreaterThan(_, _) => self.nodes[dep_node].operation = Operation::GreaterThan(a, rep_with),
Operation::GreaterThanEq(_, _) => self.nodes[dep_node].operation = Operation::GreaterThanEq(a, rep_with),
Operation::Equal(_, _) => self.nodes[dep_node].operation = Operation::Equal(a, rep_with),
Operation::NotEqual(_, _) => self.nodes[dep_node].operation = Operation::NotEqual(a, rep_with),
Operation::LessThan(_, _) => self.nodes[dep_node].operation = Operation::LessThan(a, rep_with),
Operation::LessThanEq(_, _) => self.nodes[dep_node].operation = Operation::LessThanEq(a, rep_with),
_ => unreachable!("Only add, sub, mul, gt, gte, e, ne, lt and lte can be reached here")
}
self.nodes[dep_node].operation = Operation::Sub(a, rep_with);
changed = true;
}

},
Operation::Mul(a, b) => {
if a == b {
self.nodes[dep_node].operation = Operation::Mul(rep_with, rep_with);
changed = true;
} else if a == to_remove {
self.nodes[dep_node].operation = Operation::Mul(rep_with, b);
changed = true;
} else if b == to_remove {
self.nodes[dep_node].operation = Operation::Mul(a, rep_with);
changed = true;
}

},
Operation::GreaterThan(a, b) => {
if a == b {
self.nodes[dep_node].operation = Operation::GreaterThan(rep_with, rep_with);
changed = true;
} else if a == to_remove {
self.nodes[dep_node].operation = Operation::GreaterThan(rep_with, b);
changed = true;
} else if b == to_remove {
self.nodes[dep_node].operation = Operation::GreaterThan(a, rep_with);
changed = true;
}
},
},

Operation::GreaterThanEq(a, b) => {
if a == b {
self.nodes[dep_node].operation = Operation::GreaterThanEq(rep_with, rep_with);
changed = true;
} else if a == to_remove {
self.nodes[dep_node].operation = Operation::GreaterThanEq(rep_with, b);
changed = true;
} else if b == to_remove {
self.nodes[dep_node].operation = Operation::GreaterThanEq(a, rep_with);
changed = true;
}
},
Operation::Equal(a, b) => {
if a == b {
self.nodes[dep_node].operation = Operation::Equal(rep_with, rep_with);
changed = true;
} else if a == to_remove {
self.nodes[dep_node].operation = Operation::Equal(rep_with, b);
changed = true;
} else if b == to_remove {
self.nodes[dep_node].operation = Operation::Equal(a, rep_with);
changed = true;
}
},
Operation::NotEqual(a, b) => {
if a == b {
self.nodes[dep_node].operation = Operation::NotEqual(rep_with, rep_with);
changed = true;
} else if a == to_remove {
self.nodes[dep_node].operation = Operation::NotEqual(rep_with, b);
changed = true;
} else if b == to_remove {
self.nodes[dep_node].operation = Operation::NotEqual(a, rep_with);
changed = true;
}
},
Operation::LessThan(a, b) => {
if a == b {
self.nodes[dep_node].operation = Operation::LessThan(rep_with, rep_with);
changed = true;
} else if a == to_remove {
self.nodes[dep_node].operation = Operation::LessThan(rep_with, b);
changed = true;
} else if b == to_remove {
self.nodes[dep_node].operation = Operation::LessThan(a, rep_with);
changed = true;
}
},

Operation::LessThanEq(a, b) => {
if a == b {
self.nodes[dep_node].operation = Operation::LessThanEq(rep_with, rep_with);
changed = true;
} else if a == to_remove {
self.nodes[dep_node].operation = Operation::LessThanEq(rep_with, b);
changed = true;
} else if b == to_remove {
self.nodes[dep_node].operation = Operation::LessThanEq(a, rep_with);
changed = true;
}
},
Operation::Constant(_)
| Operation::Parameter(_) => {
unreachable!("Constants or Parameters cannot depend on nodes");
},
Operation::StopGradient(a)
| Operation::Neg(a)
| Operation::ZerosLike(a) => {
Operation::StopGradient(a) => {
if a == to_remove {
match self.nodes[dep_node].operation {
Operation::Neg(_) => self.nodes[dep_node].operation = Operation::Neg(rep_with),
Operation::ZerosLike(_) => self.nodes[dep_node].operation = Operation::ZerosLike(rep_with),
Operation::StopGradient(_) => self.nodes[dep_node].operation = Operation::StopGradient(rep_with),
_ => unreachable!("Only Neg, StopGradient and ZerosLike get this far")
}
self.nodes[dep_node].operation = Operation::StopGradient(rep_with);
changed = true;
}
},
Operation::Neg(a) => {
if a == to_remove {
self.nodes[dep_node].operation = Operation::Neg(rep_with);
changed = true;
}
},
Operation::ZerosLike(a) => {
if a == to_remove {
self.nodes[dep_node].operation = Operation::ZerosLike(rep_with);
changed = true;
}
},
Expand All @@ -121,12 +184,16 @@ impl Context {
}
},
Operation::ReduceMax { node, dim, keepdims } => {
changed = true;
self.nodes[dep_node].operation = Operation::ReduceMax { node: rep_with, dim, keepdims }
if node == to_remove {
changed = true;
self.nodes[dep_node].operation = Operation::ReduceMax { node: rep_with, dim, keepdims }
}
},
Operation::SliceInDim { node, start, stop, stride, dim } => {
changed = true;
self.nodes[dep_node].operation = Operation::SliceInDim { node: rep_with, start, stop, stride, dim }
if node == to_remove {
changed = true;
self.nodes[dep_node].operation = Operation::SliceInDim { node: rep_with, start, stop, stride, dim }
}
}
}
}
Expand Down

0 comments on commit 23b8a35

Please sign in to comment.