Skip to content

Commit

Permalink
Merge pull request #42 from BradenEverson/const_fold
Browse files Browse the repository at this point in the history
Implement Const Folding for Mul, Add and Sub
  • Loading branch information
BradenEverson authored Mar 5, 2024
2 parents 2d55fb8 + 562c5d6 commit 23a9a76
Show file tree
Hide file tree
Showing 5 changed files with 641 additions and 40 deletions.
1 change: 0 additions & 1 deletion src/core/graph/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ impl Context {
self.nodes[with_respect_to].callsite.clone(),
));
}

if node == with_respect_to {
return self.scalar(1, wrt_dtype);
}
Expand Down
6 changes: 3 additions & 3 deletions src/core/graph/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ pub enum CompileError {
}

impl Context {
pub fn compile<A: Into<NodeIdentifier> + Copy, const N: usize>(
pub fn compile<const N: usize>(
&mut self,
name: &str,
returns: [A; N],
returns: [NodeIdentifier; N],
client: &xla::PjRtClient,
) -> Result<xla::PjRtLoadedExecutable> {
// TODO: gate debug mode behind a feature flag
Expand All @@ -33,7 +33,7 @@ impl Context {
}

for a in returns.iter() {
self.foldconsts(*a, usize::MAX)?;
self.fold_consts(*a, usize::MAX)?;
}
//while self.foldconsts(a, 1)? {
// println!("{}", self.to_string(a));
Expand Down
Loading

0 comments on commit 23a9a76

Please sign in to comment.