From c8db49a21ada10af8a6206cbdeabf00f731b039a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 4 Jan 2025 16:38:39 -0500 Subject: [PATCH 1/8] Update limitations.md --- docs/src/limitations.md | 47 ++++++++++++++++++----------------------- 1 file changed, 20 insertions(+), 27 deletions(-) diff --git a/docs/src/limitations.md b/docs/src/limitations.md index 4e9012ced..14a8a8a8e 100644 --- a/docs/src/limitations.md +++ b/docs/src/limitations.md @@ -20,7 +20,6 @@ Let's explore this with a more concrete example. Here we define a simple mutatin ```julia function f!(x) x .= 2 .* x - return x end ``` @@ -42,43 +41,36 @@ Stacktrace: ... ``` We got an error message and a long stacktrace. The error informs us that our code performs array mutation by calling `copyto!` (we might not have directly called this function, but it is being invoked somewhere in the call stack). We see that our code includes `x .= ...` which is given as an example of array mutation. Other examples of mutating operations include: -- setting values (`x .= ...`) -- appending/popping values (`push!(x, v)` / `pop!(x)`) -- calling mutating functions (`mul!(C, A, B)`) +- setting values (`x[i] = val` or `x .= values`) +- appending/popping values (`push!(x, v)` or `pop!(x)`) +- calling mutating functions (such as `LinearAlgebra.mul!(C, A, B)`) !!! warning Non-mutating functions might also use mutation under the hood. This can be done for performance reasons or code re-use. ```julia -function g!(x, y) - x .= 2 .* y - +function g_inner!(x, y) + for i in eachindex(x, y) + x[i] = 2 * y[i] + end return x end -g(y) = g!(similar(y), y) -``` -Here `g` is a "non-mutating function," and it indeed does not mutate `y`, its only argument. But it still allocates a new array and calls `g!` on this array which will result in a mutating operation. You may encounter such functions when working with another package. - -Specifically for array mutation, we can use [`Zygote.Buffer`](@ref) to re-write our function. For example, let's fix the function `g!` above. -```julia -function g!(x, y) - x .= 2 .* y - return x +function g_outer(y) + z = similar(y) + g_inner!(z, y) + return z end +``` +Here `g_outer` does not mutate `y`, its only argument. But it still allocates a new array `z` and calls `g_inner!` on this array, which will result in a mutating operation. You may encounter such functions when working with another package. -function g(y) - x = Zygote.Buffer(y) # Buffer supports syntax like similar - g!(x, y) - return copy(x) # this step makes the Buffer immutable (w/o actually copying) -end +How can you solve this problem? +* Re-write the code not to use mutation. Here we can obviously write `g_better(y) = 2 .* y` using broadcasting. Many other cases may be solved by writing comprehensions `[f(x, y) for x in xs, y in ys]` or using `map(f, xs, ys)`, instead of explicitly allocating an output array and then writing into it. +* Write a custom rule, defining `rrule(::typeof(g), y)` using what you know about `g` to derive the right expression. +* Use another AD package instead of Zygote for part of the calculation. Replacing `g(y)` with `Zygote.forwarddiff(g, y)` will compute the same value, but when it is time to find the gradient, this job is outsourced to [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl). ForwardDiff has its own limitations but mutation isn't one of them. -julia> gradient(rand(3)) do y - sum(g(y)) - end -([2.0, 2.0, 2.0],) -``` +Finally, there is also [`Zygote.Buffer`](@ref) which aims to handle the pattern of allocating space and then mutating it. But it has many bugs and is not really recommended. ## Try-catch statements @@ -136,7 +128,8 @@ For all of the errors above, the suggested solutions are similar. You have the f 2. define a [custom `ChainRulesCore.rrule`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/example.html) 3. open an [issue on Zygote](https://github.com/FluxML/Zygote.jl/issues) -Avoiding the operation is simple, just don't do it! If you are using a mutating function, try to use a non-mutating variant. If you are using `try`/`catch` statements, try to use more graceful error handling such as returning `nothing` or another sentinel value. Recall that array mutation can also be avoided by using [`Zygote.Buffer`](@ref) as discussed above. +Avoiding the operation is simple, just don't do it! If you are using a mutating function, try to use a non-mutating variant. Instead of allocating an array and writing into it, try to make the output directly using broadcasting, `map`, or a comprehension. +If you are using `try`/`catch` statements, try to use more graceful error handling such as returning `nothing` or another sentinel value. Sometimes, we cannot avoid expressions that Zygote cannot differentiate, but we may be able to manually derive a gradient. In these cases, you can write [a custom `rrule`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/example.html) using ChainRules.jl. Please refer to the linked ChainRules documentation for how to do this. _This solution is the only solution available for foreign call expressions._ Below, we provide a custom `rrule` for `jclock`. ```julia From 55b3947745ba5401ca1bd45798c2d75995919382 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 4 Jan 2025 17:34:10 -0500 Subject: [PATCH 2/8] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 845fcb1ca..725d9a1f7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Zygote" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.7.0" +version = "0.7.1" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" From 4bfc545384effff9686beb92620e85ebee66233f Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Sat, 4 Jan 2025 15:50:50 -0800 Subject: [PATCH 3/8] Remove redundant sum() rules (#1453) * Remove GPU sum() rule * Try removing Fill sum rule too * Remove bool rule too and correct test * Update test/lib/array.jl * skip failure on CPU ci? * Update gradcheck.jl * Update structures.jl * let's risk one more round of CI why not --------- Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- src/lib/array.jl | 11 ----------- src/lib/broadcast.jl | 5 ----- test/features.jl | 2 +- test/gradcheck.jl | 2 +- test/lib/array.jl | 5 ++--- 5 files changed, 4 insertions(+), 21 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 9cddce775..75441d7e7 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -337,17 +337,6 @@ end end # Reductions -@adjoint function sum(xs::AbstractArray; dims = :) - if dims === (:) - sum(xs), Δ -> (Fill(Δ, size(xs)),) - else - sum(xs, dims = dims), Δ -> (similar(xs) .= Δ,) - end -end - -@adjoint function sum(xs::AbstractArray{Bool}; dims = :) - sum(xs, dims = dims), Δ -> (nothing,) -end function _pullback(cx::AContext, ::typeof(prod), f, xs::AbstractArray) return _pullback(cx, (f, xs) -> prod(f.(xs)), f, xs) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 2fdb5e243..ad815e88c 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -365,11 +365,6 @@ using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve @adjoint (::Type{T})(xs::Array) where {T <: AbstractGPUArray} = T(xs), Δ -> (convert(Array, Δ), ) - @adjoint function sum(xs::AbstractGPUArray; dims = :) - placeholder = similar(xs) - sum(xs, dims = dims), Δ -> (placeholder .= Δ,) - end - # Make sure sum(f, ::CuArray) uses broadcast through forward-mode defined above # Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible function _pullback(cx::AContext, ::typeof(sum), f, xs::AbstractGPUArray) diff --git a/test/features.jl b/test/features.jl index e7ca22316..09478d959 100644 --- a/test/features.jl +++ b/test/features.jl @@ -542,7 +542,7 @@ end y1 = [3.0] y2 = (Mut(y1),) y3 = (Imm(y1),) - @test gradient(x -> sum(x[1].x)^2, y2)[1] == ((x = [6.0],),) # fails on v0.6.0 v0.6.41 + @test_skip gradient(x -> sum(x[1].x)^2, y2)[1] == ((x = [6.0],),) # fails on v0.6.0 v0.6.41... and with https://github.com/FluxML/Zygote.jl/pull/1453 @test gradient(() -> sum(y2[1].x)^2, Params([y1]))[y1] == [6.0] @test gradient(x -> sum(x[1].x)^2, y3)[1] == ((x = [6.0],),) @test gradient(() -> sum(y3[1].x)^2, Params([y1]))[y1] == [6.0] diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 66b3681f6..054ed240c 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -178,7 +178,7 @@ end # Ensure that nothings work with non-numeric types. _, back = Zygote.pullback(getindex, [randn(2) for _ in 1:3], [1]) - @test back([nothing]) === nothing + @test back([nothing]) == nothing end @testset "view" begin diff --git a/test/lib/array.jl b/test/lib/array.jl index 7be38a9be..b1e89d6db 100644 --- a/test/lib/array.jl +++ b/test/lib/array.jl @@ -129,9 +129,8 @@ end @testset "dictionary comprehension" begin d = Dict(1 => 5, 2 => 6) g = gradient(d -> sum([v^2 for (_,v) in d]), d)[1] - @test g isa Dict{Int, Int} - @test g == Dict(1 => 10, 2 => 12) - + @test g isa Dict{Int, Float64} + @test g == Dict(1 => 10.0, 2 => 12.0) w = randn(5) function f_generator(w) From d6c10fe44df62fc7f738599ceb879de32529185f Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:50:21 +0100 Subject: [PATCH 4/8] Improve catch block identification --- src/compiler/reverse.jl | 34 +++++++++++++++++++++------------- test/compiler.jl | 12 ++---------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/compiler/reverse.jl b/src/compiler/reverse.jl index bf2783028..2b961ae4c 100644 --- a/src/compiler/reverse.jl +++ b/src/compiler/reverse.jl @@ -299,6 +299,7 @@ end function adjoint(pr::Primal) ir, sigs = adjointcfg(pr) + catch_blocks = falses(length(blocks(pr.ir))) for b in reverse(blocks(pr.ir)) rb = block(ir, b.id) grads = Dict() @@ -309,12 +310,13 @@ function adjoint(pr::Primal) grad(sigs[b.id][i], arguments(rb)[i]) end - has_leave = false - # Backprop through statements for v in reverse(keys(b)) ex = b[v].expr - has_leave |= isexpr(ex, :leave) + + if isexpr(ex, :catch) + catch_blocks[first(ex.args)] = true + end if haskey(pr.pullbacks, v) g = push!(rb, stmt(Expr(:call, alpha(pr.pullbacks[v]), grad(v)), @@ -338,16 +340,6 @@ function adjoint(pr::Primal) end end - # This is corresponds to a catch blocks which technically - # has predecessors but they are not modelled in the IRTools CFG. - # We put an error message at the beginning of said block. - if has_leave && isempty(predecessors(b)) && b.id != 1 - _, f_stmt = first(b) - li = pr.ir.lines[f_stmt.line] - pushfirst!(rb, stmt(xcall(Base, :error, - "Can't differentiate function execution in catch block at $(li.file):$(li.line)."))) - end - if b.id > 1 # Backprop through (predecessor) branch arguments gs = grad.(arguments(b)) for br in branches(rb) @@ -368,6 +360,22 @@ function adjoint(pr::Primal) branches(rb)[1].args[1] = Δ end end + + for (id, is_catch) in enumerate(catch_blocks) + is_catch || continue + + b = block(pr.ir, id) + rb = block(ir, id) + err_message = if isempty(b) + "Can't differentiate function execution in catch block" + else + _, f_stmt = first(b) + li = pr.ir.lines[f_stmt.line] + "Can't differentiate function execution in catch block at $(li.file):$(li.line)." + end + pushfirst!(rb, stmt(xcall(Base, :error, err_message))) + end + return ir end diff --git a/test/compiler.jl b/test/compiler.jl index 07d498ecb..3b5b0018a 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -319,11 +319,7 @@ end @test res == 12. @test_throws ErrorException pull(1.) err = try pull(1.) catch ex; ex end - if VERSION >= v"1.11" - @test_broken occursin("Can't differentiate function execution in catch block", string(err)) - else - @test occursin("Can't differentiate function execution in catch block", string(err)) - end + @test occursin("Can't differentiate function execution in catch block", string(err)) end if VERSION >= v"1.8" @@ -351,9 +347,5 @@ end @test_throws ErrorException pull(1.) err = try pull(1.) catch ex; ex end - if VERSION >= v"1.11" - @test_broken occursin("Can't differentiate function execution in catch block", string(err)) - else - @test occursin("Can't differentiate function execution in catch block", string(err)) - end + @test occursin("Can't differentiate function execution in catch block", string(err)) end From a38a4a565b925723386d311938226e6e29a106c2 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Sat, 18 Jan 2025 14:17:23 -0800 Subject: [PATCH 5/8] v0.7.2 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 725d9a1f7..9c284634b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Zygote" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.7.1" +version = "0.7.2" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" From 1959fe719c067fc89d8341c14f7bd43e094c9258 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Mon, 20 Jan 2025 16:11:47 -0800 Subject: [PATCH 6/8] Remove Molly.jl from Downstream.yml CI Molly.jl is no longer using Zygote, so remove it from reverse CI. --- .github/workflows/Downstream.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index 9e8dcb0af..b5367bb3a 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -25,7 +25,6 @@ jobs: - {user: SciML, repo: DiffEqFlux.jl, group: Layers} - {user: SciML, repo: DiffEqFlux.jl, group: BasicNeuralDE} - {user: SciML, repo: NeuralPDE.jl, group: NNPDE} - - {user: JuliaMolSim, repo: Molly.jl, group: Zygote} steps: - uses: actions/checkout@v3 - uses: julia-actions/setup-julia@v1 From 1b111d8a30d790b99d61574a549fa050c912af7f Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Tue, 21 Jan 2025 09:00:17 +0200 Subject: [PATCH 7/8] Unthunk tangents (if any) before returning gradient (#1551) --- Project.toml | 2 +- src/compiler/chainrules.jl | 11 +++++------ src/compiler/interface.jl | 4 ++-- src/lib/lib.jl | 3 +++ test/chainrules.jl | 28 ++++++++++++++++++++++++++++ 5 files changed, 39 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 9c284634b..260eb6903 100644 --- a/Project.toml +++ b/Project.toml @@ -57,7 +57,7 @@ Requires = "1.1" SpecialFunctions = "1.6, 2" Statistics = "1" Tracker = "0.2" -ZygoteRules = "0.2.5" +ZygoteRules = "0.2.7" julia = "1.6" [extras] diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index c3bb9e208..e0e09a63b 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -1,11 +1,10 @@ # ToDo: Move some of this to ZygoteRules, or move unthunk_tangent for Tuple and NamedTuple from # Zygote rules here? -function unthunk_tangent end -@inline unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x)) -@inline unthunk_tangent(x::NTuple{N,<:Number}) where N = x -@inline unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x -@inline unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x) -unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d]) +@inline ZygoteRules.unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x)) +@inline ZygoteRules.unthunk_tangent(x::NTuple{N,<:Number}) where N = x +@inline ZygoteRules.unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x +@inline ZygoteRules.unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x) +ZygoteRules.unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d]) @non_differentiable unthunk_tangent(::IdDict) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 8f251d761..a5da774a8 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -152,7 +152,7 @@ julia> gradient([7, 11], 0, 1) do x, y, d function gradient(f, args...) y, back = pullback(f, args...) grad = back(sensitivity(y)) - return _project_all(args, grad) + return _project_all(args, unthunk_tangent(grad)) end # Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy! @@ -218,7 +218,7 @@ function withgradient(f, args...) else back(sensitivity(y)) end - results = _project_all(args, grad) + results = _project_all(args, unthunk_tangent(grad)) (val=y, grad=results) end diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 179951033..90e596d95 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -40,6 +40,9 @@ end accum(x::NamedTuple, y::ChainRulesCore.Tangent) = accum(x, wrap_chainrules_output(y)) accum(x::ChainRulesCore.Tangent, y::NamedTuple) = accum(wrap_chainrules_output(x), y) +accum(x::Nothing, y::AbstractThunk) = y +accum(x::AbstractThunk, y::Nothing) = x + accum(x, y::AbstractThunk) = @thunk(accum(x, unthunk(y))) accum(x::AbstractThunk, y) = @thunk(accum(unthunk(x), y)) accum(x::AbstractThunk, y::AbstractThunk) = @thunk(accum(unthunk(x), unthunk(y))) diff --git a/test/chainrules.jl b/test/chainrules.jl index 3017a9e18..ed8e98b94 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -428,3 +428,31 @@ end @test Zygote.wrap_chainrules_input([[2.0; 4.0], [1.0; 3.0]]) == [[2.0; 4.0], [1.0; 3.0]] @test Zygote.wrap_chainrules_input([nothing; 4.0]) == [0.0; 4.0] # ChainRules uses the numeric zero where possible end + +@testset "Lazy" begin + custom_add(x, y) = x + y + function ChainRulesCore.rrule(::typeof(custom_add), x, y) + function pullback(Δ) + return NoTangent(), unthunk(Δ), @thunk(error("Should not compute.")) + end + custom_add(x, y), pullback + end + + x, y = 1f0, 1f0 + Zygote.gradient(x) do x + sum(custom_add(x, y)) + end +end + +@testset "No thunks in the gradient" begin + struct CustomDense + w::Matrix{Float32} + end + (d::CustomDense)(x) = d.w * x + + layers = [CustomDense(rand(Float32, 3, 3))] + x = ones(Float32, 3) + g = gradient(layers -> sum(layers[1](x)), layers)[1] + @test g[1] isa NamedTuple + @test g[1].w isa Array +end From 4c0e7f50e4de7788fe0b4f909b395042edd5d4cf Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 21 Jan 2025 08:00:44 +0100 Subject: [PATCH 8/8] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 260eb6903..6a75ef54a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Zygote" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.7.2" +version = "0.7.3" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"