-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Test accepting thunks #449
Conversation
Codecov Report
@@ Coverage Diff @@
## master #449 +/- ##
==========================================
- Coverage 98.52% 98.48% -0.04%
==========================================
Files 21 21
Lines 2095 2175 +80
==========================================
+ Hits 2064 2142 +78
- Misses 31 33 +2
Continue to review full report at Codecov.
|
src/rulesets/Base/sort.jl
Outdated
@@ -26,7 +26,7 @@ function rrule(::typeof(sort), xs::AbstractVector; kwargs...) | |||
return Δxs | |||
end | |||
|
|||
Δxs = InplaceableThunk(@thunk(sort_add!(zero(Δys))), sort_add!) | |||
Δxs = InplaceableThunk(@thunk(sort_add!(zero(xs))), sort_add!) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems wrong.
The zero of the primal doesn't have to exist.
(If it does, it is probably also the zero of the differential)
Consider xs::Vector{DateTime}
then Δys
is either a Vector{<:Period}
or a Vector{<:Tangent{DateTime}}
In this case it is bugged.
JuliaLang/julia#41348
but if it was a a Vector{<:Period}
it wouldn't be
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only 1 comment of any real meaning.
Once addressed merge when ready
Δfactors = ΔF.factors | ||
Δfactors isa AbstractZero && return (NoTangent(), Δfactors, NoTangent()) | ||
factors = F.factors | ||
∂factors = eltype(A) <: Real ? real(Δfactors) : Δfactors | ||
∂factors = eltypeA <: Real ? real(Δfactors) : Δfactors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check that this works right.
I suspect you need to do
_lu_pullback(ΔF::Tangent, m, n, :Type{eltypeA}, pivot, F) where eltypeA
or julia will not specialize on this and it will result in it being type-unstable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does seem to work right
julia> y, pb = rrule(lu, randn(T, m, n), pivot)
julia> @code_warntype pb(ChainRulesTestUtils.rand_tangent(y))
Variables
#self#::ChainRules.var"#lu_pullback#1571"{Matrix{Float64}, Val{true}, LU{Float64, Matrix{Float64}}, Int64, Int64}
ȳ::Tangent{LU{Float64, Matrix{Float64}}, NamedTuple{(:factors, :ipiv, :info), Tuple{Matrix{Float64}, Vector{NoTangent}, NoTangent}}}
Body::Tuple{NoTangent, Matrix{Float64}, NoTangent}
1 ─ %1 = Core.getfield(#self#, :m)::Int64
│ %2 = Core.getfield(#self#, :n)::Int64
│ %3 = Core.getfield(#self#, :A)::Matrix{Float64}
│ %4 = ChainRules.eltype(%3)::Core.Const(Float64)
│ %5 = Core.getfield(#self#, :pivot)::Core.Const(Val{true}())
│ %6 = Core.getfield(#self#, :F)::LU{Float64, Matrix{Float64}}
│ %7 = ChainRules._lu_pullback(ȳ, %1, %2, %4, %5, %6)::Core.PartialStruct(Tuple{NoTangent, Matrix{Float64}, NoTangent}, Any[Core.Const(NoTangent()), Matrix{Float64}, Core.Const(NoTangent())])
└── return %7
but also the @inferred
tests should have caught a regression, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice, good. It constant folds
Yes they should have
Closes #408
Needs JuliaDiff/ChainRulesCore.jl#371
to do:
remove examples where the thunks can be unthunked more than once