Skip to content

Commit

Permalink
Fix const-only apply iterate (#1526)
Browse files Browse the repository at this point in the history
* Fix const-only apply iterate

* fix ct

* Fix mixed activity for type unstable

* Update jitrules.jl

* Update jitrules.jl

* wip tuple

* fix batch tuple generation

* Ensure runtime store error

* fix

* cleanup

* ignore 1.8

* newstructv

* ignore test
  • Loading branch information
wsmoses authored Jun 10, 2024
1 parent 86da3cd commit ffcc20c
Show file tree
Hide file tree
Showing 7 changed files with 808 additions and 159 deletions.
43 changes: 33 additions & 10 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ end
end

@assert !Base.isabstracttype(T)
if !(Base.isconcretetype(T) || is_concrete_tuple(T) || T isa UnionAll)
if !(Base.isconcretetype(T) || (T <: Tuple && T != Tuple) || T isa UnionAll)
throw(AssertionError("Type $T is not concrete type or concrete tuple"))
end

Expand Down Expand Up @@ -515,7 +515,7 @@ end
return active_reg_inner(T, (), world)
end

@inline function active_reg(::Type{T}, world::Union{Nothing, UInt}=nothing)::Bool where {T}
Base.@pure @inline function active_reg(::Type{T}, world::Union{Nothing, UInt}=nothing)::Bool where {T}
seen = ()

# check if it could contain an active
Expand Down Expand Up @@ -3342,6 +3342,8 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr
world = job.world
interp = GPUCompiler.get_interpreter(job)
rt = job.config.params.rt
@assert eltype(rt) != Union{}

shadow_init = job.config.params.shadowInit
ctx = context(mod)
dl = string(LLVM.datalayout(mod))
Expand Down Expand Up @@ -3546,6 +3548,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
pactualRetType = actualRetType
sret_union = is_sret_union(actualRetType)
literal_rt = eltype(rettype)
@assert literal_rt != Union{}
sret_union_rt = is_sret_union(literal_rt)
@assert sret_union == sret_union_rt
if sret_union
Expand Down Expand Up @@ -3684,9 +3687,10 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
end
end

combinedReturn = Tuple{sret_types...}
if any(any_jltypes(convert(LLVM.LLVMType, T; allow_boxed=true)) for T in sret_types)
combinedReturn = AnonymousStruct(combinedReturn)
combinedReturn = if any(any_jltypes(convert(LLVM.LLVMType, T; allow_boxed=true)) for T in sret_types)
AnonymousStruct(Tuple{sret_types...})
else
Tuple{sret_types...}
end

uses_sret = is_sret(combinedReturn)
Expand Down Expand Up @@ -4794,14 +4798,19 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
libraries::Bool=true, deferred_codegen::Bool=true, optimize::Bool=true, toplevel::Bool=true,
strip::Bool=false, validate::Bool=true, only_entry::Bool=false, parent_job::Union{Nothing, CompilerJob} = nothing)
params = job.config.params
if params.run_enzyme
@assert eltype(params.rt) != Union{}
end
expectedTapeType = params.expectedTapeType
mode = params.mode
TT = params.TT
width = params.width
abiwrap = params.abiwrap
primal = job.source
modifiedBetween = params.modifiedBetween
@assert length(modifiedBetween) == length(TT.parameters)
if length(modifiedBetween) != length(TT.parameters)
throw(AssertionError("length(modifiedBetween) [aka $(length(modifiedBetween))] != length(TT.parameters) [aka $(length(TT.parameters))] at TT=$TT"))
end
returnPrimal = params.returnPrimal

if !(params.rt <: Const)
Expand Down Expand Up @@ -5297,6 +5306,9 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
end

@assert actualRetType !== nothing
if params.run_enzyme
@assert actualRetType != Union{}
end

if must_wrap
llvmfn = primalf
Expand Down Expand Up @@ -5838,7 +5850,11 @@ end
end

push!(ccexprs, argexpr)
if !(FA <: Const)
if (FA <: Active)
return quote
error("Cannot have function with Active annotation, $FA")
end
elseif !(FA <: Const)
argexpr = :(fn.dval)
if isboxed
push!(types, Any)
Expand Down Expand Up @@ -6274,9 +6290,16 @@ end
compile_result = cached_compilation(job)
if !run_enzyme
ErrT = PrimalErrorThunk{typeof(compile_result.adjoint), FA, rt2, TT, width, ReturnPrimal, World}
return quote
Base.@_inline_meta
$ErrT($(compile_result.adjoint))
if Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient
return quote
Base.@_inline_meta
($ErrT($(compile_result.adjoint)), $ErrT($(compile_result.adjoint)))
end
else
return quote
Base.@_inline_meta
$ErrT($(compile_result.adjoint))
end
end
elseif Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient
TapeType = compile_result.TapeType
Expand Down
Loading

0 comments on commit ffcc20c

Please sign in to comment.