From c8df9b799f397228c95a1f7c0082d48d30bc69b8 Mon Sep 17 00:00:00 2001 From: ExpandingMan Date: Tue, 1 Oct 2024 19:28:15 -0400 Subject: [PATCH 1/3] inelegant solution to #1821 --- src/Enzyme.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 17a7c6ff5d..92a93c96e8 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1566,6 +1566,26 @@ end return (one(x),) end +@inline function _herm_sym_onehot(type, x::AbstractMatrix, start=1, endl=length(x)) + idxs = CartesianIndices(x) + ntuple(Val(endl - start + 1)) do i0 + Base.@_inline_meta + i = start + i0 - 1 + idx = idxs[i] + res = zeros(eltype(x), size(x)) + res[idx] = 1 + res[idx[2],idx[1]] = 1 + type(res) + end +end + +@inline onehot(x::Hermitian) = _herm_sym_onehot(Hermitian, x) +@inline onehot(x::Symmetric) = _herm_sym_onehot(Symmetric, x) + +@inline onehot(x::Hermitian, start, endl) = _herm_sym_onehot(Hermitian, x, start, endl) +@inline onehot(x::Symmetric, start, endl) = _herm_sym_onehot(Symmetric, x, start, endl) + + """ gradient(::ReverseMode, f, args...) From e69d1e95ce0e03d32e2d66f423866dcee611c978 Mon Sep 17 00:00:00 2001 From: ExpandingMan Date: Tue, 1 Oct 2024 19:39:56 -0400 Subject: [PATCH 2/3] be sure to still use similar --- src/Enzyme.jl | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 92a93c96e8..c82196faac 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1566,15 +1566,20 @@ end return (one(x),) end +function _is_symmed_index(a::CartesianIndex, b::CartesianIndex) + (a[1] == b[1] && a[2] == b[2]) || (a[1] == b[2] && a[2] == b[1]) +end + @inline function _herm_sym_onehot(type, x::AbstractMatrix, start=1, endl=length(x)) idxs = CartesianIndices(x) ntuple(Val(endl - start + 1)) do i0 Base.@_inline_meta i = start + i0 - 1 idx = idxs[i] - res = zeros(eltype(x), size(x)) - res[idx] = 1 - res[idx[2],idx[1]] = 1 + res = similar(parent(x)) + for idx2 ∈ CartesianIndices(x) + @inbounds res[idx2] = _is_symmed_index(idx, idx2) ? 1 : 0 + end type(res) end end From 2606cd172543a8f6509f23f7cffde2154dd08351 Mon Sep 17 00:00:00 2001 From: ExpandingMan Date: Wed, 2 Oct 2024 18:49:12 -0400 Subject: [PATCH 3/3] add unit tests --- test/runtests.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 902b9e4f65..daf8555bd2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3191,6 +3191,23 @@ mkarray(sz, args...) = reshape(vcat(args...), sz) @test_broken Enzyme.gradient(Enzyme.Reverse, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] @test_broken Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] + + f0 = x -> x[1,1]^2 + 2*x[1,2]^2 + 3*x[2,1]^2 - x[2,2]^2 + f1 = x -> [x[1,1]^2, 2*x[1,2]^2 + 3*x[2,1]^2, x[1,2]^2 + x[2,1]^2, x[2,2]^2] + + for x ∈ (Float64[1 2; 0 3], Float64[1 2; 2 3]) # both are [1 2; 2 3] + for T ∈ (Hermitian, Symmetric) + x = T(x) + df = Enzyme.gradient(Enzyme.Forward, f0, x)[1] + @test df ≈ Float64[2 20; 20 -6] + + df = Enzyme.gradient(Enzyme.Forward, f1, x)[1] + @test df[:,1,1] ≈ [2,0,0,0] + @test df[:,1,2] ≈ [0,20,8,0] + @test df[:,2,1] ≈ [0,20,8,0] + @test df[:,2,2] ≈ [0,0,0,6] + end + end end @testset "Simple Jacobian" begin