-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement backend for Tracker #44
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
00ed3f0
Add Tracker backend
sethaxen 89463aa
Test Tracker backend
sethaxen b3ed818
Add Tracker as a test dependency
sethaxen 3933e41
Disallow nesting Tracker
sethaxen fafa239
Remove primal_value implementations
sethaxen 03892f6
Overload derivative and gradient
sethaxen 6b7f2f1
Test error when nesting Tracker
sethaxen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
using .Tracker: Tracker | ||
|
||
""" | ||
TrackerBackend | ||
|
||
AD backend that uses reverse mode with Tracker.jl. | ||
""" | ||
struct TrackerBackend <: AbstractReverseMode end | ||
|
||
function second_lowest(::TrackerBackend) | ||
return throw(ArgumentError("Tracker backend does not support nested differentiation.")) | ||
end | ||
|
||
@primitive function pullback_function(ba::TrackerBackend, f, xs...) | ||
value, back = Tracker.forward(f, xs...) | ||
function pullback(ws) | ||
if ws isa Tuple && !(value isa Tuple) | ||
@assert length(ws) == 1 | ||
map(Tracker.data, back(ws[1])) | ||
else | ||
map(Tracker.data, back(ws)) | ||
end | ||
end | ||
return pullback | ||
end | ||
|
||
function derivative(ba::TrackerBackend, f, xs::Number...) | ||
return Tracker.data.(Tracker.gradient(f, xs...)) | ||
end | ||
|
||
function gradient(ba::TrackerBackend, f, xs::AbstractVector...) | ||
return Tracker.data.(Tracker.gradient(f, xs...)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
using AbstractDifferentiation | ||
using Test | ||
using Tracker | ||
|
||
@testset "TrackerBackend" begin | ||
backends = [@inferred(AD.TrackerBackend())] | ||
@testset for backend in backends | ||
@testset "errors when nested" begin | ||
@test_throws ArgumentError AD.second_lowest(backend) | ||
@test_throws ArgumentError AD.hessian(backend, sum, randn(3)) | ||
end | ||
@testset "Derivative" begin | ||
test_derivatives(backend) | ||
end | ||
@testset "Gradient" begin | ||
test_gradients(backend) | ||
end | ||
@testset "Jacobian" begin | ||
test_jacobians(backend) | ||
end | ||
@testset "jvp" begin | ||
test_jvp(backend) | ||
end | ||
@testset "j′vp" begin | ||
test_j′vp(backend) | ||
end | ||
@testset "Lazy Derivative" begin | ||
test_lazy_derivatives(backend) | ||
end | ||
@testset "Lazy Gradient" begin | ||
test_lazy_gradients(backend) | ||
end | ||
@testset "Lazy Jacobian" begin | ||
test_lazy_jacobians(backend) | ||
end | ||
end | ||
end |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the best way I could think of to signal that Tracker should not be used for higher order AD.