diff --git a/src/lib/array.jl b/src/lib/array.jl index 6d914d272..633aeaf06 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -168,7 +168,8 @@ _reverse(x::Symmetric) = Symmetric(_reverse(x.data), x.uplo == 'U' ? :L : :U) # With mismatched lengths, map stops early. With mismatched shapes, it makes a vector. # So we keep axes(x) to restore gradient dx to its full length & correct shape. -_tryaxes(x) = axes(x) +_tryaxes(x) = (s = Base.IteratorSize(x); s isa Base.HasShape ? axes(x) : s isa Base.HasLength ? (Base.OneTo(length(x)),) : throw(ArgumentError("iterator size must be finite"))) +_tryaxes(x::AbstractArray) = axes(x) _tryaxes(x::Tuple) = Val(length(x)) _tryaxes(x::Number) = x _restore(dx::AbstractArray{Nothing}, ax::Tuple) = similar(dx, ax) @@ -319,6 +320,21 @@ end collect(z), collect_zip_pullback end +takefunc(itr, dy) = _restore(dy, _tryaxes(itr)) + +@adjoint function Iterators.take(itr, n) + take_pullback(::AbstractArray{Nothing}) = nothing + take_pullback(dy::NamedTuple{(:xs,:n)}) = (dy.xs, dy.n) + take_pullback(dy::NamedTuple{(:n,:xs)}) = (dy.xs, dy.n) + take_pullback(dy::AbstractArray) = (takefunc(itr, dy), nothing) + Iterators.take(itr, n), take_pullback +end + +@adjoint function Base.collect(t::Iterators.Take) + collect_take_pullback(dy) = ((xs=takefunc(t.xs, dy), n=nothing),) + collect(t), collect_take_pullback +end + # Reductions @adjoint function sum(xs::AbstractArray; dims = :) if dims === (:) diff --git a/src/lib/buffer.jl b/src/lib/buffer.jl index 83fceb713..0666172b4 100644 --- a/src/lib/buffer.jl +++ b/src/lib/buffer.jl @@ -7,15 +7,16 @@ grad_mut(cx::Context, b::Buffer{T}, ::Type{S}=Union{}) where {T<:Number, S<:Numb @non_differentiable Buffer(::Any...) @adjoint function getindex(b::Buffer, i...) - b[i...], function (Δ) + function getindex_buffer_pullback(Δ) grad = grad_mut(__context__, b, eltype(Δ)) grad[i...] = accum(grad[i...], Δ) return end + b[i...], getindex_buffer_pullback end @adjoint! function setindex!(b::Buffer, v, i...) - setindex!(b, v, i...), function (_) + function setindex!_buffer_pullback(_) grad = grad_mut(__context__, b) v̄ = grad[i...] zero = eltype(grad) <: Number ? 0 : nothing @@ -26,26 +27,49 @@ end end (nothing, v̄, map(_->nothing, i)...) end + setindex!(b, v, i...), setindex!_buffer_pullback end -@adjoint! function copyto!(b::Buffer, xs) - copyto!(b, xs), function (_) +@adjoint! function copyto!(b::Buffer, src::AbstractArray) + function copyto!_buffer_array_pullback(_) grad = grad_mut(__context__, b) - x̄s = copy(grad) - grad .= eltype(grad) <: Number ? 0 : nothing - return (nothing, x̄s) + xs = copy(grad) + grad .= eltype(grad) <: Number ? zero(eltype(grad)) : nothing + return (nothing, xs) end + copyto!(b, src), copyto!_buffer_array_pullback end +@adjoint! function copyto!(b::Buffer, bc::Base.Broadcast.Broadcasted) + xs, map_pullback = ∇map(__context__, i -> bc[i], eachindex(bc)) + function copyto!_buffer_broadcast_pullback(_) + grad = grad_mut(__context__, b) + d, = map_pullback(reshape(first(grad, length(xs)), size(xs))) + grad .= eltype(grad) <: Number ? zero(eltype(grad)) : nothing + return (nothing, d.bc) + end + copyto!(b, xs), copyto!_buffer_broadcast_pullback +end + +function _pullback(cx::AContext, ::typeof(copyto!), b::Buffer, g::Base.Generator) + xs, collect_pullback = _pullback(cx, collect, g) + function copyto!_buffer_generator_pullback(_) + grad = grad_mut(cx, b) + _, dg = collect_pullback(reshape(first(grad, length(xs)), size(xs))) + grad .= eltype(grad) <: Number ? zero(eltype(grad)) : nothing + return (nothing, nothing, dg) + end + copyto!(b, xs), copyto!_buffer_generator_pullback + end + @adjoint! function push!(b::Buffer, x) - push!(b, x), function (y) + function push!_buffer_pullback(_) grad = grad_mut(__context__, b) return (nothing, pop!(grad)) end + push!(b, x), push!_buffer_pullback end -_pullback(cx::AContext, ::typeof(Broadcast.materialize!), b::Buffer, x::AbstractArray) = - _pullback(cx, copyto!, b, x) @adjoint function copy(b::Buffer) res = copy(b) diff --git a/src/tools/buffer.jl b/src/tools/buffer.jl index 9409a74bc..d8c9a82d3 100644 --- a/src/tools/buffer.jl +++ b/src/tools/buffer.jl @@ -72,8 +72,8 @@ function Base.deleteat!(b::Buffer, i) return b end -@forward Buffer.data Base.eltype, Base.length, Base.ndims, Base.size, Base.axes, - Base.eachindex, Base.stride, Base.strides, Base.findfirst, +@forward Buffer.data Base.eltype, Base.length, Base.ndims, Base.size, Base.axes, + Base.eachindex, Base.stride, Base.strides, Base.findfirst, Base.keys Base.IteratorSize(::Type{<:Buffer{<:Any, A}}) where {A} = Base.IteratorSize(A) @@ -84,3 +84,5 @@ function Base.iterate(b::Buffer, state=(eachindex(b),)) y === nothing && return nothing b[y[1]], (state[1], tail(y)...) end + +Base.BroadcastStyle(::Type{Buffer{T,A}}) where {T,A} = Base.BroadcastStyle(A) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 8cb7e6e1a..133383048 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -214,6 +214,23 @@ end @test gradient(x -> sum(inv, collect(view(x', 1,:))), ones(2,2)) == ([-1 0; -1 0],) @test gradient(xs -> sum(inv, [x^2 for x in xs]), ones(2)) == ([-2, -2],) + + # adjoint of generators is available and should support generic arrays and iterators + # generator of array + @test gradient(p -> sum(collect(p*i for i in [1.0, 2.0, 3.0])), 2.0) == (6.0,) + # generator of iterator with HasShape + @test gradient(p -> sum(collect(p*i for (i,) in zip([1.0, 2.0, 3.0]))), 2.0) == (6.0,) + # generator of iterator with HasLength + @test gradient(p -> sum(collect(p*i for i in Iterators.take([1.0, 2.0, 3.0], 3))), 2.0) == (6.0,) + @test gradient(p -> sum(collect(p*i for i in Iterators.take(p*[1.0, 2.0, 3.0], 2))), 2.0) == (12.0,) + # generator 0-d behavior handled incorrectly + @test_broken gradient(p -> sum(collect(p*i for i in 1.0)), 2.0) + @test_broken gradient(p -> sum(collect(p*i for i in fill(1.0))), 2.0) + + # adjoints for iterators + @test gradient(x -> sum(collect(Iterators.take([x*i for i in 1:5], 4))), 1.0) == (10.0,) + @test gradient(x -> sum(collect(Iterators.take([x*i for i in 1:5], 5))), 1.0) == (15.0,) + @test_broken gradient(sum∘collect, 1.0) == (1.0,) # broken since no generic adjoint end @test gradtest(x -> reverse(x), rand(17)) @@ -523,7 +540,7 @@ end @test gradtest(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4)) @test gradient(x -> 1 / maximum(x), [1., 2, 3])[1] == [0, 0, -1/9] - + # issue 1224, second order f1244(w, x) = sum(maximum((w * x).^2, dims=1)) g1244(w, x) = sum(gradient(f1244, w, x)[2].^2) @@ -1538,6 +1555,36 @@ using Zygote: Buffer return sum(copy(b)) end == ([2,2,2],) + @test gradient([1, 2, 3]) do xs + b = Zygote.Buffer(xs) + b .= 2 + return sum(copy(b)) + end == (nothing,) + + @test gradient(1.1) do p + b = Zygote.Buffer(zeros(3)) + b .= (p*i for i in eachindex(b)) + return sum(copy(b) .* (2:4)) + end[1] ≈ 1*2 + 2*3 + 3*4 + + @test gradient(1.1) do p + b = Zygote.Buffer(zeros(3)) + copyto!(b, [p*i for i in eachindex(b)]) + return sum(copy(b) .* (2:4)) + end[1] ≈ 1*2 + 2*3 + 3*4 + + @test gradient(1.1) do p + b = Zygote.Buffer(zeros(3)) + copyto!(b, (p*i for i in eachindex(b))) + return sum(copy(b) .* (2:4)) + end[1] ≈ 1*2 + 2*3 + 3*4 + + @test_broken gradient(1.1) do p + b = Zygote.Buffer(zeros(3)) + copyto!(b, p) + return sum(copy(b) .* (2:4)) + end[1] ≈ 1*2 + @test gradient(2) do x b = Zygote.Buffer([]) push!(b, x) @@ -1701,7 +1748,7 @@ end end @testset "FillArrays" begin - + @test gradcheck(x->sum(Fill(x[], (2, 2))), [0.1]) @test first(Zygote.gradient(sz->sum(Ones(sz)), 6)) === nothing @test first(Zygote.gradient(sz->sum(Zeros(sz)), 6)) === nothing diff --git a/test/lib/array.jl b/test/lib/array.jl index 8016c9541..64ecfa9b5 100644 --- a/test/lib/array.jl +++ b/test/lib/array.jl @@ -50,6 +50,23 @@ end @test @inferred back(collect(y)) == (nothing, [1.0, 0.0, 0.0, 0.0, 0.0], fill(1.0)) end +@testset "adjoints of Iterators.take" begin + y, back = _pullback(Iterators.take, 1:5, 3) + @test back(collect(y)) == (nothing, [1.0, 2.0, 3.0, 0.0, 0.0], nothing) + @test back([nothing for i in 1:3]) === nothing + + @test gradient(x -> sum([2y for y in Iterators.take(x, 4)]), [1,2,3,4])[1] ≈ [2, 2, 2, 2] + @test gradient(x -> sum(2y for y in Iterators.take(x, 4)), [1,2,3,4])[1] ≈ [2, 2, 2, 2] + + for p in (1.0, fill(1.0), [1.0]) + @test gradient(p_ -> sum(map(prod, Iterators.take(p_, 1))), p) == (p,) + @test gradient(p_ -> sum(x for x in Iterators.take(p_, 1)), p) == (p,) + end + + y, back = _pullback(Iterators.take, ones(2, 2), 3) + @test @inferred back(collect(y)) == (nothing, [1.0 1.0; 1.0 0.0], nothing) +end + @testset "collect" begin @testset "Dict" begin d = Dict(1 => 5, 2 => 6) @@ -97,6 +114,16 @@ end @test gradient(x -> sum(broadcast(prod, Iterators.zip(x,x.^2))), ones(4)) == (3ones(4),) @test gradient(x -> sum(broadcast(prod, Iterators.zip(x.^2,x.^2))), ones(4)) == (4ones(4),) end + + + @testset "Iterators.Take" begin + z = Iterators.take(1:3, 2) + g = gradient(z -> sum(collect(z)), z)[1] + @test g == (xs=[1.0, 1.0, 0.0], n=nothing) + + @test gradient(x -> sum(broadcast(prod, Iterators.take(x,2))), ones(4)) == ([1.0,1.0,0.0,0.0],) + @test gradient(x -> sum(broadcast(prod, Iterators.take(x.^2,2))), ones(4)) == (2*[1.0,1.0,0.0,0.0],) + end end @testset "dictionary comprehension" begin