diff --git a/Project.toml b/Project.toml index 61dc74a41..748290911 100644 --- a/Project.toml +++ b/Project.toml @@ -38,8 +38,8 @@ ZygoteTrackerExt = "Tracker" [compat] AbstractFFTs = "1.3.1" -ChainRules = "1.44.1" -ChainRulesCore = "1.9" +ChainRules = "1.72.2" +ChainRulesCore = "1.25.1" ChainRulesTestUtils = "1" Colors = "0.12, 0.13" DiffRules = "1.4" diff --git a/src/Zygote.jl b/src/Zygote.jl index 64564ad7f..218ab1348 100644 --- a/src/Zygote.jl +++ b/src/Zygote.jl @@ -3,11 +3,12 @@ module Zygote using LinearAlgebra, Statistics using LinearAlgebra: copytri!, AbstractTriangular +import ZygoteRules import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback, literal_getproperty, literal_getfield, unthunk_tangent using ChainRulesCore -using ChainRules: ChainRules, rrule, unthunk, canonicalize +using ChainRules: ChainRules, AbstractThunk, rrule, unthunk, canonicalize using IRTools using MacroTools, Requires using MacroTools: @forward diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 7b070f730..c3bb9e208 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -1,3 +1,14 @@ +# ToDo: Move some of this to ZygoteRules, or move unthunk_tangent for Tuple and NamedTuple from +# Zygote rules here? +function unthunk_tangent end +@inline unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x)) +@inline unthunk_tangent(x::NTuple{N,<:Number}) where N = x +@inline unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x +@inline unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x) +unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d]) +@non_differentiable unthunk_tangent(::IdDict) + + struct ZygoteRuleConfig{CTX<:AContext} <: RuleConfig{Union{HasReverseMode,NoForwardsMode}} context::CTX end @@ -107,7 +118,6 @@ is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k===Core.kwftype(f) Convert `x` from the differentials types ChainRules uses to the format Zygote uses internally. """ @inline wrap_chainrules_output(x) = x -@inline wrap_chainrules_output(x::AbstractThunk) = wrap_chainrules_output(unthunk(x)) # For now we are just not going to deal with thunks @inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x) # Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing. @inline wrap_chainrules_output(x::Tuple{Vararg{ChainRules.AbstractZero}}) = nothing @@ -261,7 +271,9 @@ function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig, f_args...; kwargs _pullback(config.context, f_args...) end - ad_pullback(Δ) = zygote2differential(pb(wrap_chainrules_output(Δ)), f_args) + ad_pullback(Δ) = zygote2differential( + pb(wrap_chainrules_output(unthunk_tangent(Δ))), + f_args) return y, ad_pullback end diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 80fd9b477..8f251d761 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -37,7 +37,13 @@ end _pullback(f, args...) = _pullback(Context(), f, args...) tailmemaybe(::Nothing) = nothing -tailmemaybe(x::Tuple) = Base.tail(x) +tailmemaybe(x::Tuple) = unthunk_tangent(Base.tail(x)) + +# unthunking is essentially an identity operation on a lazy value, but +# `@adjoint unthunk_tangent(x) = unthunk_tangent(x), ȳ -> (ȳ,)` is not enough to make +# nested AD work, so define +@adjoint tailmemaybe(xs::Tuple) = tailmemaybe(xs), x̄s -> ((nothing, x̄s...),) + """ pullback(f, args...) @@ -351,6 +357,9 @@ function copy!(x::AbstractVector, ps::Params) x end +_maybe_unthunk(x::AbstractThunk) = unthunk(x) +_maybe_unthunk(x) = x + """ Grads(...) @@ -385,7 +394,7 @@ end function Base.getindex(gs::Grads, x) isbits(x) && error("Only reference types can be differentiated with `Params`.") - return gs.grads[x] + return _maybe_unthunk(gs.grads[x]) end """ @@ -468,7 +477,7 @@ function pullback(f, ps::Params) cache(cx)[p] = nothing end back(Δ) - Grads(cx.cache, ps) # TODO make a copy + Grads(_maybe_unthunk(cx.cache), ps) end end diff --git a/src/compiler/reverse.jl b/src/compiler/reverse.jl index 060ac970d..bf2783028 100644 --- a/src/compiler/reverse.jl +++ b/src/compiler/reverse.jl @@ -3,6 +3,18 @@ using IRTools: IR, Variable, Pipe, xcall, var, prewalk, postwalk, insertafter!, finish, expand!, prune!, substitute!, substitute, block, block!, branch!, return!, stmt, meta + +# TODO: Temporary, to be removed when ChainRulesCore rrules are required to +# support thunks as an input and all instances of _adjoint_keepthunks in +# Zygote have been replaces by rrules: +macro _adjoint_keepthunks(ex) + ZygoteRules.gradm(ex, false, true) +end +macro _adjoint_keepthunks!(ex) + ZygoteRules.gradm(ex, true, true) +end + + @inline tuple_va(N, xs) = xs @inline tuple_va(N, x, xs...) = (x, tuple_va(N, xs...)...) @inline tuple_va(::Val{N}, ::Nothing) where N = ntuple(_ -> nothing, Val(N)) diff --git a/src/lib/array.jl b/src/lib/array.jl index 7d3f37839..9cddce775 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -237,7 +237,7 @@ function _pullback(cx::AContext, ::typeof(collect), g::Base.Generator) x̄ = reconstruct_if_dict(x̄, _keys) # return a dictionary if needed (nothing, (f = f̄, iter = x̄),) end - y, collect_pullback + y, collect_pullback ∘ unthunk_tangent end collect_if_dict(x::Dict) = collect(x), collect(keys(x)) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 504ef614d..2fdb5e243 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -53,7 +53,8 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr Base.reducedim_initarray(A, region, nothing, Union{Nothing,eltype(A)}) end -function unbroadcast(x::AbstractArray, x̄) +function unbroadcast(x::AbstractArray, maybethunked_x̄) + x̄ = unthunk_tangent(maybethunked_x̄) N = ndims(x̄) if length(x) == length(x̄) _project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 1d9678807..179951033 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -37,16 +37,23 @@ function accum(x::RefValue, y::RefValue) return x end +accum(x::NamedTuple, y::ChainRulesCore.Tangent) = accum(x, wrap_chainrules_output(y)) +accum(x::ChainRulesCore.Tangent, y::NamedTuple) = accum(wrap_chainrules_output(x), y) + +accum(x, y::AbstractThunk) = @thunk(accum(x, unthunk(y))) +accum(x::AbstractThunk, y) = @thunk(accum(unthunk(x), y)) +accum(x::AbstractThunk, y::AbstractThunk) = @thunk(accum(unthunk(x), unthunk(y))) + # Core functions -@adjoint deepcopy(x) = deepcopy(x), ȳ -> (ȳ,) +@_adjoint_keepthunks deepcopy(x) = deepcopy(x), ȳ -> (ȳ,) -@adjoint (::Type{V})(x...) where V<:Val = V(x...), _ -> nothing +@_adjoint_keepthunks (::Type{V})(x...) where V<:Val = V(x...), _ -> nothing -@adjoint ifelse(cond::Bool, t, f) = +@_adjoint_keepthunks ifelse(cond::Bool, t, f) = ifelse(cond, t, f), Δ -> cond ? (nothing, Δ, zero(Δ)) : (nothing, zero(Δ), Δ) -@adjoint Base.typeassert(x, T) = Base.typeassert(x, T), Δ -> (Δ, nothing) +@_adjoint_keepthunks Base.typeassert(x, T) = Base.typeassert(x, T), Δ -> (Δ, nothing) accum_param(::Context{false}, _, Δ) = Δ @generated function accum_param(cx::Context, x, Δ) @@ -70,11 +77,11 @@ end unwrap(x) = x -@adjoint unwrap(x) = unwrap(x), x̄ -> (accum_param(__context__, x, x̄),) +@_adjoint_keepthunks unwrap(x) = unwrap(x), x̄ -> (accum_param(__context__, x, x̄),) unwrap(ref, x) = x -@adjoint unwrap(ref, x) = unwrap(x), function (x̄) +@_adjoint_keepthunks unwrap(ref, x) = unwrap(x), function (x̄) accum_global(__context__, ref, x̄) (accum_param(__context__, x, x̄),) end @@ -88,7 +95,7 @@ function global_set(ref, val) end end -@adjoint! function global_set(ref, x) +@_adjoint_keepthunks! function global_set(ref, x) global_set(ref, x), function (x̄) gs = cache(__context__) x̄ = accum(get(gs, ref, nothing), x̄) @@ -101,9 +108,9 @@ end using Base: tail -@adjoint tuple(xs...) = xs, identity +@_adjoint_keepthunks tuple(xs...) = xs, identity -@adjoint function literal_getindex(xs::NTuple{N,Any}, ::Val{i}) where {N,i} +@_adjoint_keepthunks function literal_getindex(xs::NTuple{N,Any}, ::Val{i}) where {N,i} val = xs[i] function back(Δ) accum_param(__context__, val, Δ) === nothing && return @@ -112,7 +119,7 @@ using Base: tail val, back end -@adjoint function getindex(xs::NTuple{N,Any}, i::Integer) where N +@_adjoint_keepthunks function getindex(xs::NTuple{N,Any}, i::Integer) where N val = xs[i] function back(Δ) accum_param(__context__, val, Δ) === nothing && return @@ -121,10 +128,10 @@ end return val, back end -@adjoint getindex(xs::NTuple{N,Any}, r::AbstractUnitRange) where N = +@_adjoint_keepthunks getindex(xs::NTuple{N,Any}, r::AbstractUnitRange) where N = (xs[r], Δ -> (ntuple(j -> j in r ? Δ[findfirst(isequal(j), r)] : nothing, Val(N)), nothing)) -@adjoint function getindex(xs::NTuple{N,Any}, r::AbstractVector) where N +@_adjoint_keepthunks function getindex(xs::NTuple{N,Any}, r::AbstractVector) where N val = xs[r] function back(Δ) dxs = ntuple(Val(length(xs))) do x @@ -155,18 +162,18 @@ function _pullback(cx::AContext, ::typeof(literal_indexed_iterate), xs::Tuple, : end # Needed for iteration lowering -@adjoint Core.getfield(xs::NTuple{N,Any}, i::Int) where N = +@_adjoint_keepthunks Core.getfield(xs::NTuple{N,Any}, i::Int) where N = (xs[i], Δ -> (ntuple(j -> i == j ? Δ : nothing, Val(N)), nothing)) -@adjoint Core.getfield(xs::NamedTuple{K,<:NTuple{N,Any}}, i::Int) where {K,N} = +@_adjoint_keepthunks Core.getfield(xs::NamedTuple{K,<:NTuple{N,Any}}, i::Int) where {K,N} = (xs[i], Δ -> (NamedTuple{K}(ntuple(j -> i == j ? Δ : nothing, Val(N))), nothing)) -@adjoint function Base.first(xs::Tuple) +@_adjoint_keepthunks function Base.first(xs::Tuple) drest = map(_->nothing, tail(xs)) first(xs), Δ -> ((Δ, drest...),) end -@adjoint Base.tail(xs::Tuple) = tail(xs), x̄s -> ((nothing, x̄s...),) +@_adjoint_keepthunks Base.tail(xs::Tuple) = tail(xs), x̄s -> ((nothing, x̄s...),) _empty(x) = length(x) _empty(x::Union{Tuple,NamedTuple}) = map(_->nothing, x) @@ -188,7 +195,7 @@ end unapply(t, xs) = _unapply(t, xs)[1] -@adjoint! function Core._apply(f, args...) +@_adjoint_keepthunks! function Core._apply(f, args...) y, back = Core._apply(_pullback, (__context__, f), args...) st = map(_empty, args) y, function (Δ) @@ -198,7 +205,7 @@ unapply(t, xs) = _unapply(t, xs)[1] end end -@adjoint! function Core._apply_iterate(::typeof(iterate), f, args...) +@_adjoint_keepthunks! function Core._apply_iterate(::typeof(iterate), f, args...) y, back = Core._apply(_pullback, (__context__, f), args...) st = map(_empty, args) y, function (Δ) @@ -223,7 +230,7 @@ end @generated pair(::Val{k}, v, _=nothing) where k = :($k = v,) @generated pair(::Val{k}, v, ::NamedTuple{keys}) where {k,keys} = k isa Int ? :($(getfield(keys, k)) = v,) : :($k = v,) -@adjoint function literal_getfield(x, ::Val{f}) where f +@_adjoint_keepthunks function literal_getfield(x, ::Val{f}) where f val = getfield(x, f) function back(Δ) accum_param(__context__, val, Δ) === nothing && return @@ -273,8 +280,7 @@ function _get!(default::Base.Callable, ch, x) end end - -@adjoint! function setfield!(x, f, val) +@_adjoint_keepthunks! function setfield!(x, f, val) y = setfield!(x, f, val) g = grad_mut(__context__, x) y, function (_) @@ -290,13 +296,13 @@ end Jnew{T}(g) where T = Jnew{T,typeof(g)}(g) -@adjoint! function __new__(T, args...) +@_adjoint_keepthunks! function __new__(T, args...) x = __new__(T, args...) g = !ismutabletype(T) || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x) x, Jnew{T,typeof(g),false}(g) end -@adjoint! function __splatnew__(T, args) +@_adjoint_keepthunks! function __splatnew__(T, args) x = __splatnew__(T, args) g = !ismutabletype(T) || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x) x, Jnew{T,typeof(g),true}(g) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 561163c6b..66b3681f6 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -20,14 +20,14 @@ function ngradient(f, xs::AbstractArray...) return grads end -function gradcheck(f, xs...) +function gradcheck(f, xs...; rtol = 1e-5, atol = 1e-5) grad_zygote = gradient(f, xs...) grad_finite_difference = ngradient(f, xs...) - return all(isapprox.(grad_zygote, grad_finite_difference; rtol = 1e-5, atol = 1e-5)) + return all(isapprox.(grad_zygote, grad_finite_difference; rtol = rtol, atol = atol)) end -gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...) -gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...) +gradtest(f, xs::AbstractArray...; kwargs...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...; kwargs...) +gradtest(f, dims...; kwargs...) = gradtest(f, rand.(Float64, dims)...; kwargs...) # utilities for using gradcheck with complex matrices _splitreim(A) = (real(A),) @@ -160,8 +160,8 @@ end @test gradient(y, x, z) == ([1, 1, 2], nothing) # https://github.com/FluxML/Zygote.jl/issues/376 - _, back = Zygote._pullback(x->x[1]*im, randn(2)) - @test back(1.0)[2] == real([-im, 0]) == [0, 0] + _, back = Zygote.pullback(x -> x[1] * im, randn(2)) + @test back(1.0)[1] == real([-im, 0]) == [0, 0] # _droplike @test gradient(x -> sum(inv, x[1, :]'), ones(2, 2)) == ([-1 -1; 0 0],) @@ -949,8 +949,8 @@ end _hermsymtype(::Type{<:Symmetric}) = Symmetric _hermsymtype(::Type{<:Hermitian}) = Hermitian -function _gradtest_hermsym(f, ST, A) - gradtest(_splitreim(collect(A))...) do (args...) +function _gradtest_hermsym(f, ST, A; kwargs...) + gradtest(_splitreim(collect(A))...; kwargs...) do (args...) B = f(ST(_joinreim(_dropimaggrad.(args)...))) return sum(_splitreim(B)) end diff --git a/test/interface.jl b/test/interface.jl index 23afdfb1b..3328233e8 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -269,5 +269,4 @@ end @test sgs[d.b] ≈ fill(1.f0, size(d.b)) end - end