Skip to content

Commit

Permalink
Merge branch 'master' into drop16
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott authored Jan 22, 2025
2 parents 5ad23d7 + 572eb2a commit 4471a73
Show file tree
Hide file tree
Showing 13 changed files with 87 additions and 81 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.7.0"
version = "0.7.3"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down Expand Up @@ -57,7 +57,7 @@ Requires = "1.1"
SpecialFunctions = "1.6, 2"
Statistics = "1"
Tracker = "0.2"
ZygoteRules = "0.2.5"
ZygoteRules = "0.2.7"
julia = "1.10"

[extras]
Expand Down
47 changes: 20 additions & 27 deletions docs/src/limitations.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ Let's explore this with a more concrete example. Here we define a simple mutatin
```julia
function f!(x)
x .= 2 .* x

return x
end
```
Expand All @@ -42,43 +41,36 @@ Stacktrace:
...
```
We got an error message and a long stacktrace. The error informs us that our code performs array mutation by calling `copyto!` (we might not have directly called this function, but it is being invoked somewhere in the call stack). We see that our code includes `x .= ...` which is given as an example of array mutation. Other examples of mutating operations include:
- setting values (`x .= ...`)
- appending/popping values (`push!(x, v)` / `pop!(x)`)
- calling mutating functions (`mul!(C, A, B)`)
- setting values (`x[i] = val` or `x .= values`)
- appending/popping values (`push!(x, v)` or `pop!(x)`)
- calling mutating functions (such as `LinearAlgebra.mul!(C, A, B)`)

!!! warning

Non-mutating functions might also use mutation under the hood. This can be done for performance reasons or code re-use.

```julia
function g!(x, y)
x .= 2 .* y

function g_inner!(x, y)
for i in eachindex(x, y)
x[i] = 2 * y[i]
end
return x
end
g(y) = g!(similar(y), y)
```
Here `g` is a "non-mutating function," and it indeed does not mutate `y`, its only argument. But it still allocates a new array and calls `g!` on this array which will result in a mutating operation. You may encounter such functions when working with another package.

Specifically for array mutation, we can use [`Zygote.Buffer`](@ref) to re-write our function. For example, let's fix the function `g!` above.
```julia
function g!(x, y)
x .= 2 .* y

return x
function g_outer(y)
z = similar(y)
g_inner!(z, y)
return z
end
```
Here `g_outer` does not mutate `y`, its only argument. But it still allocates a new array `z` and calls `g_inner!` on this array, which will result in a mutating operation. You may encounter such functions when working with another package.

function g(y)
x = Zygote.Buffer(y) # Buffer supports syntax like similar
g!(x, y)
return copy(x) # this step makes the Buffer immutable (w/o actually copying)
end
How can you solve this problem?
* Re-write the code not to use mutation. Here we can obviously write `g_better(y) = 2 .* y` using broadcasting. Many other cases may be solved by writing comprehensions `[f(x, y) for x in xs, y in ys]` or using `map(f, xs, ys)`, instead of explicitly allocating an output array and then writing into it.
* Write a custom rule, defining `rrule(::typeof(g), y)` using what you know about `g` to derive the right expression.
* Use another AD package instead of Zygote for part of the calculation. Replacing `g(y)` with `Zygote.forwarddiff(g, y)` will compute the same value, but when it is time to find the gradient, this job is outsourced to [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl). ForwardDiff has its own limitations but mutation isn't one of them.

julia> gradient(rand(3)) do y
sum(g(y))
end
([2.0, 2.0, 2.0],)
```
Finally, there is also [`Zygote.Buffer`](@ref) which aims to handle the pattern of allocating space and then mutating it. But it has many bugs and is not really recommended.

## Try-catch statements

Expand Down Expand Up @@ -136,7 +128,8 @@ For all of the errors above, the suggested solutions are similar. You have the f
2. define a [custom `ChainRulesCore.rrule`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/example.html)
3. open an [issue on Zygote](https://github.com/FluxML/Zygote.jl/issues)

Avoiding the operation is simple, just don't do it! If you are using a mutating function, try to use a non-mutating variant. If you are using `try`/`catch` statements, try to use more graceful error handling such as returning `nothing` or another sentinel value. Recall that array mutation can also be avoided by using [`Zygote.Buffer`](@ref) as discussed above.
Avoiding the operation is simple, just don't do it! If you are using a mutating function, try to use a non-mutating variant. Instead of allocating an array and writing into it, try to make the output directly using broadcasting, `map`, or a comprehension.
If you are using `try`/`catch` statements, try to use more graceful error handling such as returning `nothing` or another sentinel value.

Sometimes, we cannot avoid expressions that Zygote cannot differentiate, but we may be able to manually derive a gradient. In these cases, you can write [a custom `rrule`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/example.html) using ChainRules.jl. Please refer to the linked ChainRules documentation for how to do this. _This solution is the only solution available for foreign call expressions._ Below, we provide a custom `rrule` for `jclock`.
```julia
Expand Down
11 changes: 5 additions & 6 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# ToDo: Move some of this to ZygoteRules, or move unthunk_tangent for Tuple and NamedTuple from
# Zygote rules here?
function unthunk_tangent end
@inline unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x))
@inline unthunk_tangent(x::NTuple{N,<:Number}) where N = x
@inline unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x
@inline unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x)
unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d])
@inline ZygoteRules.unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x))
@inline ZygoteRules.unthunk_tangent(x::NTuple{N,<:Number}) where N = x
@inline ZygoteRules.unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x
@inline ZygoteRules.unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x)
ZygoteRules.unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d])
@non_differentiable unthunk_tangent(::IdDict)


Expand Down
4 changes: 2 additions & 2 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ julia> gradient([7, 11], 0, 1) do x, y, d
function gradient(f, args...)
y, back = pullback(f, args...)
grad = back(sensitivity(y))
return _project_all(args, grad)
return _project_all(args, unthunk_tangent(grad))
end

# Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy!
Expand Down Expand Up @@ -218,7 +218,7 @@ function withgradient(f, args...)
else
back(sensitivity(y))
end
results = _project_all(args, grad)
results = _project_all(args, unthunk_tangent(grad))
(val=y, grad=results)
end

Expand Down
34 changes: 21 additions & 13 deletions src/compiler/reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ end

function adjoint(pr::Primal)
ir, sigs = adjointcfg(pr)
catch_blocks = falses(length(blocks(pr.ir)))
for b in reverse(blocks(pr.ir))
rb = block(ir, b.id)
grads = Dict()
Expand All @@ -309,12 +310,13 @@ function adjoint(pr::Primal)
grad(sigs[b.id][i], arguments(rb)[i])
end

has_leave = false

# Backprop through statements
for v in reverse(keys(b))
ex = b[v].expr
has_leave |= isexpr(ex, :leave)

if isexpr(ex, :catch)
catch_blocks[first(ex.args)] = true
end

if haskey(pr.pullbacks, v)
g = push!(rb, stmt(Expr(:call, alpha(pr.pullbacks[v]), grad(v)),
Expand All @@ -338,16 +340,6 @@ function adjoint(pr::Primal)
end
end

# This is corresponds to a catch blocks which technically
# has predecessors but they are not modelled in the IRTools CFG.
# We put an error message at the beginning of said block.
if has_leave && isempty(predecessors(b)) && b.id != 1
_, f_stmt = first(b)
li = pr.ir.lines[f_stmt.line]
pushfirst!(rb, stmt(xcall(Base, :error,
"Can't differentiate function execution in catch block at $(li.file):$(li.line).")))
end

if b.id > 1 # Backprop through (predecessor) branch arguments
gs = grad.(arguments(b))
for br in branches(rb)
Expand All @@ -368,6 +360,22 @@ function adjoint(pr::Primal)
branches(rb)[1].args[1] = Δ
end
end

for (id, is_catch) in enumerate(catch_blocks)
is_catch || continue

b = block(pr.ir, id)
rb = block(ir, id)
err_message = if isempty(b)
"Can't differentiate function execution in catch block"
else
_, f_stmt = first(b)
li = pr.ir.lines[f_stmt.line]
"Can't differentiate function execution in catch block at $(li.file):$(li.line)."
end
pushfirst!(rb, stmt(xcall(Base, :error, err_message)))
end

return ir
end

Expand Down
11 changes: 0 additions & 11 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,17 +337,6 @@ end
end

# Reductions
@adjoint function sum(xs::AbstractArray; dims = :)
if dims === (:)
sum(xs), Δ -> (Fill(Δ, size(xs)),)
else
sum(xs, dims = dims), Δ -> (similar(xs) .= Δ,)
end
end

@adjoint function sum(xs::AbstractArray{Bool}; dims = :)
sum(xs, dims = dims), Δ -> (nothing,)
end

function _pullback(cx::AContext, ::typeof(prod), f, xs::AbstractArray)
return _pullback(cx, (f, xs) -> prod(f.(xs)), f, xs)
Expand Down
5 changes: 0 additions & 5 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -365,11 +365,6 @@ using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve
@adjoint (::Type{T})(xs::Array) where {T <: AbstractGPUArray} =
T(xs), Δ -> (convert(Array, Δ), )

@adjoint function sum(xs::AbstractGPUArray; dims = :)
placeholder = similar(xs)
sum(xs, dims = dims), Δ -> (placeholder .= Δ,)
end

# Make sure sum(f, ::CuArray) uses broadcast through forward-mode defined above
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
function _pullback(cx::AContext, ::typeof(sum), f, xs::AbstractGPUArray)
Expand Down
3 changes: 3 additions & 0 deletions src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ end
accum(x::NamedTuple, y::ChainRulesCore.Tangent) = accum(x, wrap_chainrules_output(y))
accum(x::ChainRulesCore.Tangent, y::NamedTuple) = accum(wrap_chainrules_output(x), y)

accum(x::Nothing, y::AbstractThunk) = y
accum(x::AbstractThunk, y::Nothing) = x

accum(x, y::AbstractThunk) = @thunk(accum(x, unthunk(y)))
accum(x::AbstractThunk, y) = @thunk(accum(unthunk(x), y))
accum(x::AbstractThunk, y::AbstractThunk) = @thunk(accum(unthunk(x), unthunk(y)))
Expand Down
28 changes: 28 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,31 @@ end
@test Zygote.wrap_chainrules_input([[2.0; 4.0], [1.0; 3.0]]) == [[2.0; 4.0], [1.0; 3.0]]
@test Zygote.wrap_chainrules_input([nothing; 4.0]) == [0.0; 4.0] # ChainRules uses the numeric zero where possible
end

@testset "Lazy" begin
custom_add(x, y) = x + y
function ChainRulesCore.rrule(::typeof(custom_add), x, y)
function pullback(Δ)
return NoTangent(), unthunk(Δ), @thunk(error("Should not compute."))
end
custom_add(x, y), pullback
end

x, y = 1f0, 1f0
Zygote.gradient(x) do x
sum(custom_add(x, y))
end
end

@testset "No thunks in the gradient" begin
struct CustomDense
w::Matrix{Float32}
end
(d::CustomDense)(x) = d.w * x

layers = [CustomDense(rand(Float32, 3, 3))]
x = ones(Float32, 3)
g = gradient(layers -> sum(layers[1](x)), layers)[1]
@test g[1] isa NamedTuple
@test g[1].w isa Array
end
12 changes: 2 additions & 10 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -309,11 +309,7 @@ end
@test res == 12.
@test_throws ErrorException pull(1.)
err = try pull(1.) catch ex; ex end
if VERSION >= v"1.11"
@test_broken occursin("Can't differentiate function execution in catch block", string(err))
else
@test occursin("Can't differentiate function execution in catch block", string(err))
end
@test occursin("Can't differentiate function execution in catch block", string(err))
end

@testset "try/catch/else" begin
Expand All @@ -339,9 +335,5 @@ end
@test_throws ErrorException pull(1.)

err = try pull(1.) catch ex; ex end
if VERSION >= v"1.11"
@test_broken occursin("Can't differentiate function execution in catch block", string(err))
else
@test occursin("Can't differentiate function execution in catch block", string(err))
end
@test occursin("Can't differentiate function execution in catch block", string(err))
end
2 changes: 1 addition & 1 deletion test/features.jl
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ end
y1 = [3.0]
y2 = (Mut(y1),)
y3 = (Imm(y1),)
@test gradient(x -> sum(x[1].x)^2, y2)[1] == ((x = [6.0],),) # fails on v0.6.0 v0.6.41
@test_skip gradient(x -> sum(x[1].x)^2, y2)[1] == ((x = [6.0],),) # fails on v0.6.0 v0.6.41... and with https://github.com/FluxML/Zygote.jl/pull/1453
@test gradient(() -> sum(y2[1].x)^2, Params([y1]))[y1] == [6.0]
@test gradient(x -> sum(x[1].x)^2, y3)[1] == ((x = [6.0],),)
@test gradient(() -> sum(y3[1].x)^2, Params([y1]))[y1] == [6.0]
Expand Down
2 changes: 1 addition & 1 deletion test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ end

# Ensure that nothings work with non-numeric types.
_, back = Zygote.pullback(getindex, [randn(2) for _ in 1:3], [1])
@test back([nothing]) === nothing
@test back([nothing]) == nothing
end

@testset "view" begin
Expand Down
5 changes: 2 additions & 3 deletions test/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,8 @@ end
@testset "dictionary comprehension" begin
d = Dict(1 => 5, 2 => 6)
g = gradient(d -> sum([v^2 for (_,v) in d]), d)[1]
@test g isa Dict{Int, Int}
@test g == Dict(1 => 10, 2 => 12)

@test g isa Dict{Int, Float64}
@test g == Dict(1 => 10.0, 2 => 12.0)

w = randn(5)
function f_generator(w)
Expand Down

0 comments on commit 4471a73

Please sign in to comment.