diff --git a/src/internal_rules.jl b/src/internal_rules.jl index b6dc8c75d6..ea33959b23 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -636,6 +636,111 @@ function EnzymeRules.reverse( return (nothing,) end +function EnzymeRules.forward( + ::Const{typeof(partialsort!)}, + RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, + xs::Duplicated{T}, + k::Const{<:Union{Integer, OrdinalRange}}; + kwargs... + ) where {T <: AbstractArray{<:AbstractFloat}} + kv = k.val + inds = collect(eachindex(xs.val)) + partialsortperm!(inds, xs.val, kv; kwargs...) + xs.val .= xs.val[inds] + xs.dval .= xs.dval[inds] + if RT <: Const + return kv isa Integer ? xs.val[kv] : view(xs.val, kv) + elseif RT <: DuplicatedNoNeed + return kv isa Integer ? xs.dval[kv] : view(xs.dval, kv) + else + if kv isa Integer + return Duplicated(xs.val[kv], xs.dval[kv]) + else + return Duplicated(view(xs.val, kv), view(xs.dval, kv)) + end + end +end + +function EnzymeRules.forward( + ::Const{typeof(partialsort!)}, + RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}}, + xs::BatchDuplicated{T, N}, + k::Const{<:Union{Integer, OrdinalRange}}; + kwargs... + ) where {T <: AbstractArray{<:AbstractFloat}, N} + kv = k.val + inds = collect(eachindex(xs.val)) + partialsortperm!(inds, xs.val, kv; kwargs...) + xs.val .= xs.val[inds] + for i in 1:N + xs.dval[i] .= xs.dval[i][inds] + end + if RT <: Const + return kv isa Integer ? xs.val[kv] : view(xs.val, kv) + elseif RT <: BatchDuplicatedNoNeed + if kv isa Integer + return ntuple(i -> xs.dval[i][kv], N) + else + return ntuple(i -> view(xs.dval[i], kv), N) + end + else + if kv isa Integer + return BatchDuplicated(xs.val[kv], ntuple(i -> xs.dval[i][kv], N)) + else + return BatchDuplicated(view(xs.val, kv), ntuple(i -> view(xs.dval[i], kv), N)) + end + end +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.ConfigWidth{1}, + ::Const{typeof(partialsort!)}, + RT::Type{<:Union{Const, Active, DuplicatedNoNeed, Duplicated}}, + xs::Duplicated{T}, + k::Const{<:Union{Integer, OrdinalRange}}; + kwargs... + ) where {T <: AbstractArray{<:AbstractFloat}} + kv = k.val + inds = collect(eachindex(xs.val)) + partialsortperm!(inds, xs.val, kv; kwargs...) + xs.val .= xs.val[inds] + xs.dval .= xs.dval[inds] + if EnzymeRules.needs_primal(config) + primal = kv isa Integer ? xs.val[kv] : view(xs.val, kv) + else + primal = nothing + end + if RT <: Const || RT <: Active + shadow = nothing + else + shadow = kv isa Integer ? xs.dval[kv] : view(xs.dval, kv) + end + return EnzymeRules.AugmentedReturn(primal, shadow, inds) +end + +function EnzymeRules.reverse( + config::EnzymeRules.ConfigWidth{1}, + ::Const{typeof(partialsort!)}, + dret::Union{Active, Type{<:Union{Const, Active, DuplicatedNoNeed, Duplicated}}}, + tape, + xs::Duplicated{T}, + k::Const{<:Union{Integer, OrdinalRange}}; + kwargs..., + ) where {T <: AbstractArray{<:AbstractFloat}} + inds = tape + kv = k.val + if dret isa Active + if kv isa Integer + xs.dval[kv] += dret.val + else + xs.dval[kv] .+= dret.val + end + end + back_inds = sortperm(inds) + xs.dval .= xs.dval[back_inds] + return (nothing, nothing) +end + function EnzymeRules.forward(::Const{typeof(cholesky)}, RT::Type, A; kwargs...) fact = cholesky(A.val; kwargs...) if RT <: Const diff --git a/test/internal_rules.jl b/test/internal_rules.jl index e325189dc1..b076a51b3e 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -45,6 +45,30 @@ end @test autodiff(Forward, f2, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=-3.0, var"2"=-6.0) @test autodiff(Reverse, f2, Active, Active(2.0))[1][1] == -3 + function f3(x) + a = [2.0, 2.5, x, 1.0] + return partialsort(a, 2) + end + + @test autodiff(Forward, f3, Duplicated(1.5, 1.0))[1] == 1.0 + @test autodiff(Forward, f3, BatchDuplicated(1.5, (1.0, 2.0)))[1] == (var"1"=1.0, var"2"=2.0) + @test autodiff(Reverse, f3, Active(1.5))[1][1] == 1.0 + @test autodiff(Reverse, f3, Active(2.5))[1][1] == 0.0 + + function f4(x) + a = [2.0, 2.5, x, x / 2] + y = partialsort(a, 1:2) + return sum(y) + end + + @test autodiff(Forward, f4, Duplicated(1.5, 1.0))[1] == 1.5 + @static if VERSION < v"1.7-" || VERSION >= v"1.8-" + @test autodiff(Forward, f4, BatchDuplicated(1.5, (1.0, 2.0)))[1] == (var"1"=1.5, var"2"=3.0) + end + @test autodiff(Reverse, f4, Active(1.5))[1][1] == 1.5 + @test autodiff(Reverse, f4, Active(4.0))[1][1] == 0.5 + @test autodiff(Reverse, f4, Active(6.0))[1][1] == 0.0 + dd = Duplicated([TPair(1, 2), TPair(2, 3), TPair(0, 1)], [TPair(0, 0), TPair(0, 0), TPair(0, 0)]) res = Enzyme.autodiff(Reverse, sorterrfn, dd, Active(1.0)) diff --git a/test/runtests.jl b/test/runtests.jl index f6100bab81..0e9c455e99 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2883,11 +2883,10 @@ end @test autodiff(Forward, f6, Duplicated(4.0, 1.0))[1] ≈ 5/3 f7(x) = median([2.0, 1.0, x]) - # Fails on Julia 1.9 due to #880 - #=@test autodiff(Reverse, f7, Active, Active(1.5))[1][1] == 1 + @test autodiff(Reverse, f7, Active, Active(1.5))[1][1] == 1 @test autodiff(Forward, f7, Duplicated(1.5, 1.0))[1] == 1 @test autodiff(Reverse, f7, Active, Active(2.5))[1][1] == 0 - @test autodiff(Forward, f7, Duplicated(2.5, 1.0))[1] == 0=# + @test autodiff(Forward, f7, Duplicated(2.5, 1.0))[1] == 0 f8(x) = middle([2.0, x, 1.0]) @test autodiff(Reverse, f8, Active, Active(2.5))[1][1] == 0.5