Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test accepting thunks #449

Merged
merged 18 commits into from
Jun 28, 2021
Merged

Test accepting thunks #449

merged 18 commits into from
Jun 28, 2021

Conversation

mzgubic
Copy link
Member

@mzgubic mzgubic commented Jun 17, 2021

Closes #408

Needs JuliaDiff/ChainRulesCore.jl#371

to do:
remove examples where the thunks can be unthunked more than once

  • Base
  • Core
  • LinearAlgebra
  • Random
  • Statistics

@github-actions github-actions bot added the needs version bump Version needs to be incremented or set to -DEV in Project.toml label Jun 17, 2021
@mzgubic mzgubic changed the title Accept thunks Test accepting thunks Jun 17, 2021
@github-actions github-actions bot removed the needs version bump Version needs to be incremented or set to -DEV in Project.toml label Jun 18, 2021
@codecov-commenter
Copy link

codecov-commenter commented Jun 18, 2021

Codecov Report

Merging #449 (08013e6) into master (16fbc76) will decrease coverage by 0.03%.
The diff coverage is 97.23%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #449      +/-   ##
==========================================
- Coverage   98.52%   98.48%   -0.04%     
==========================================
  Files          21       21              
  Lines        2095     2175      +80     
==========================================
+ Hits         2064     2142      +78     
- Misses         31       33       +2     
Impacted Files Coverage Δ
src/rulesets/LinearAlgebra/factorization.jl 97.07% <92.00%> (-1.10%) ⬇️
src/rulesets/Base/array.jl 100.00% <100.00%> (ø)
src/rulesets/Base/arraymath.jl 98.47% <100.00%> (+0.11%) ⬆️
src/rulesets/Base/base.jl 100.00% <100.00%> (ø)
src/rulesets/Base/evalpoly.jl 97.84% <100.00%> (+0.02%) ⬆️
src/rulesets/Base/fastmath_able.jl 98.26% <100.00%> (+0.01%) ⬆️
src/rulesets/Base/mapreduce.jl 99.01% <100.00%> (+0.01%) ⬆️
src/rulesets/Base/sort.jl 100.00% <100.00%> (ø)
src/rulesets/LinearAlgebra/blas.jl 94.33% <100.00%> (+0.10%) ⬆️
src/rulesets/LinearAlgebra/dense.jl 98.95% <100.00%> (+0.07%) ⬆️
... and 6 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 16fbc76...08013e6. Read the comment docs.

@@ -26,7 +26,7 @@ function rrule(::typeof(sort), xs::AbstractVector; kwargs...)
return Δxs
end

Δxs = InplaceableThunk(@thunk(sort_add!(zero(Δys))), sort_add!)
Δxs = InplaceableThunk(@thunk(sort_add!(zero(xs))), sort_add!)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems wrong.
The zero of the primal doesn't have to exist.
(If it does, it is probably also the zero of the differential)

Consider xs::Vector{DateTime}
then Δys is either a Vector{<:Period} or a Vector{<:Tangent{DateTime}}

In this case it is bugged.
JuliaLang/julia#41348
but if it was a a Vector{<:Period} it wouldn't be

test/runtests.jl Outdated Show resolved Hide resolved
Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only 1 comment of any real meaning.
Once addressed merge when ready

Δfactors = ΔF.factors
Δfactors isa AbstractZero && return (NoTangent(), Δfactors, NoTangent())
factors = F.factors
∂factors = eltype(A) <: Real ? real(Δfactors) : Δfactors
∂factors = eltypeA <: Real ? real(Δfactors) : Δfactors
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check that this works right.
I suspect you need to do

_lu_pullback(ΔF::Tangent, m, n, :Type{eltypeA}, pivot, F) where eltypeA
or julia will not specialize on this and it will result in it being type-unstable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does seem to work right

julia> y, pb = rrule(lu, randn(T, m, n), pivot)

julia> @code_warntype pb(ChainRulesTestUtils.rand_tangent(y))
Variables
  #self#::ChainRules.var"#lu_pullback#1571"{Matrix{Float64}, Val{true}, LU{Float64, Matrix{Float64}}, Int64, Int64}
  ȳ::Tangent{LU{Float64, Matrix{Float64}}, NamedTuple{(:factors, :ipiv, :info), Tuple{Matrix{Float64}, Vector{NoTangent}, NoTangent}}}

Body::Tuple{NoTangent, Matrix{Float64}, NoTangent}
1%1 = Core.getfield(#self#, :m)::Int64%2 = Core.getfield(#self#, :n)::Int64%3 = Core.getfield(#self#, :A)::Matrix{Float64}%4 = ChainRules.eltype(%3)::Core.Const(Float64)
│   %5 = Core.getfield(#self#, :pivot)::Core.Const(Val{true}())%6 = Core.getfield(#self#, :F)::LU{Float64, Matrix{Float64}}%7 = ChainRules._lu_pullback(ȳ, %1, %2, %4, %5, %6)::Core.PartialStruct(Tuple{NoTangent, Matrix{Float64}, NoTangent}, Any[Core.Const(NoTangent()), Matrix{Float64}, Core.Const(NoTangent())])
└──      return %7

but also the @inferred tests should have caught a regression, right?

Copy link
Member

@oxinabox oxinabox Jun 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, good. It constant folds
Yes they should have

@mzgubic mzgubic merged commit 87458d0 into master Jun 28, 2021
@mzgubic mzgubic deleted the mz/accept_thunks branch June 28, 2021 13:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

frules/rrules should support receiving Thunks
3 participants