Skip to content

Commit

Permalink
don't copy non-diff arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Mar 23, 2024
1 parent 443820b commit 61f47d5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1199,6 +1199,9 @@ end
if haskey(seen, prev)
return seen[prev]
end
if guaranteed_const_nongen(RT, nothing)
return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev
end
newa = RT(undef, size(prev))
seen[prev] = newa
for I in eachindex(prev)
Expand All @@ -1216,6 +1219,7 @@ end
end



@inline function EnzymeCore.make_zero(::Type{NamedTuple{A,RT}}, seen::IdDict, prev::NamedTuple{A,RT}, ::Val{copy_if_inactive}=Val(false))::NamedTuple{A,RT} where {copy_if_inactive, A,RT}
return NamedTuple{A,RT}(EnzymeCore.make_zero(RT, seen, RT(prev), Val(copy_if_inactive)))
end
Expand Down
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2350,6 +2350,11 @@ end
@test grad.x == [3.0, 12.0]
@test grad.p 5.545177444479562

xy = (x = [1.0, 2.0], y = [3, 4]) # y is non-diff
grad = Enzyme.gradient(Reverse, z -> sum(z.x .* z.y), xy)
@test grad.x == [3.0, 4.0]
@test grad.y === xy.y # make_zero did not copy this

grad = Enzyme.gradient(Reverse, z -> (z.x * z.y), (x=5.0, y=6.0))
@test grad == (x = 6.0, y = 5.0)

Expand Down

0 comments on commit 61f47d5

Please sign in to comment.