-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Const Folding for Mul, Add and Sub #42
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Merge Eben's changes into const_fold
Bring autodiff into const_fold
Closed
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Nodes:
is_zero()
method, I have it working for all of the normal primative types, but I'll have to check with y'all how bf16 and the complex numbers should be tested.is_const()
method that returns a Literal, used in const folding to iterate through the literal and check that each element is 1 for folding multiplication.Const Folding:
replace_index()
that takes in twoNodeIdentifier
s and replaces all references to the first one with references to the second. This effectively writes the first node out of the compute graph and sets the correct dependents all at the same time, so it's very nice. If we ever add more operations though we'll have to add it to the huge match tree, there's probably a better way with macros or something to do this but it works for now.fold_consts()
: Builds a queue of Nodes connected to the node we start at and traverses through them while making sure we don't pass the mod limit. For every node traveled, check the operation. If the operation is mul, sub or add, we can do const folding. For sub and add, if one node is a constant zero we usereplace_index()
and remove all references to the current node and replace it with the non-zero node. For multiplication, if one node is zeroes, we replace the current node with the zeroed node. If one node is ones, replace the current node with the non-one node. This will work as long as any dimension altering multiplication is handled by a different operation,mat_mul
will be its own operation, correct?Implemented a few unit tests to check that
fold_consts()
returns true when it should and doesn't when it shouldn't. Also compiled several compute graphs in tests for all of the use cases I've talked about. Feel free to let me know what you all think!