Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reverse mode apply iterate #1485

Merged
merged 25 commits into from
Jun 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Enzyme"
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
version = "0.12.11"
version = "0.12.12"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Expand All @@ -20,7 +20,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
CEnum = "0.4, 0.5"
ChainRulesCore = "1"
EnzymeCore = "0.7.4"
Enzyme_jll = "0.0.119"
Enzyme_jll = "0.0.121"
GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26"
LLVM = "6.1, 7"
ObjectFile = "0.4"
Expand Down
7 changes: 7 additions & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ end
end)...}
end

@inline function vaEltypes(args::Type{Ty}) where {Ty <: Tuple}
return Tuple{(ntuple(Val(length(Ty.parameters))) do i
Base.@_inline_meta
eltype(Ty.parameters[i])
end)...}
end

@inline function same_or_one_helper(current, next)
if current == -1
return next
Expand Down
11 changes: 7 additions & 4 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,6 @@ end
end

@inline function active_reg_inner(::Type{T}, seen::ST, world::Union{Nothing, UInt}, ::Val{justActive}=Val(false), ::Val{UnionSret}=Val(false))::ActivityState where {ST,T, justActive, UnionSret}

if T === Any
return DupState
end
Expand Down Expand Up @@ -422,7 +421,9 @@ end
else
inmi = GPUCompiler.methodinstance(typeof(EnzymeCore.EnzymeRules.inactive_type), Tuple{Type{T}}, world)
args = Any[EnzymeCore.EnzymeRules.inactive_type, T];
ccall(:jl_invoke, Any, (Any, Ptr{Any}, Cuint, Any), EnzymeCore.EnzymeRules.inactive_type, args, length(args), inmi)
GC.@preserve T begin
ccall(:jl_invoke, Any, (Any, Ptr{Any}, Cuint, Any), EnzymeCore.EnzymeRules.inactive_type, args, length(args), inmi)
end
end

if inactivety
Expand Down Expand Up @@ -480,11 +481,13 @@ end
@static if VERSION < v"1.7.0"
nT = T
else
nT = if is_concrete_tuple(T)
nT = if T <: Tuple && T != Tuple && !(T isa UnionAll)
Tuple{(ntuple(length(T.parameters)) do i
Base.@_inline_meta
sT = T.parameters[i]
if sT isa Core.TypeofVararg
if sT isa TypeVar
Any
elseif sT isa Core.TypeofVararg
Any
else
sT
Expand Down
16 changes: 10 additions & 6 deletions src/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -743,19 +743,18 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width
end
end

seen = Dict{LLVM.Value,Tuple}()
seen = Set{Tuple{LLVM.Value,Tuple}}()
while length(todo) != 0
cur, off = pop!(todo)

while isa(cur, LLVM.AddrSpaceCastInst) # || isa(cur, LLVM.BitCastInst)
cur = operands(cur)[1]
end

if cur in keys(seen)
@assert seen[cur] == off
if cur in seen
continue
end
seen[cur] = off
push!(seen, (cur, off))

if isa(cur, LLVM.PHIInst)
for (v, _) in LLVM.incoming(cur)
Expand All @@ -781,7 +780,7 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width

# if inserting at the current desired offset, we have found the value we need
if ind == off[1]
push!(todo, (operands(cur)[2], -1))
push!(todo, (operands(cur)[2], off[2:end]))
# otherwise it must be inserted at a different point
else
push!(todo, (operands(cur)[1], off))
Expand Down Expand Up @@ -880,10 +879,15 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width
end
end

if isa(cur, LLVM.ConstantArray)
push!(todo, (cur[off[1]], off[2:end]))
continue
end

msg = sprint() do io::IO
println(io, "Enzyme Internal Error (rewrite_union_returns_as_ref[2])")
println(io, string(enzymefn))
println(io, "cur=", cur)
println(io, "cur=", string(cur))
println(io, "off=", off)
end
throw(AssertionError(msg))
Expand Down
Loading
Loading