Skip to content

Commit

Permalink
require GPUCompiler 0.14 (#256)
Browse files Browse the repository at this point in the history
* require GPUCompiler 0.14

* add ctx

* adjust to context changes

* fixup! adjust to context changes

* fix pfor

* take gpucompiler

Co-authored-by: William S. Moses <[email protected]>
  • Loading branch information
vchuravy and wsmoses authored Mar 15, 2022
1 parent 7bd2ab8 commit 0742289
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 19 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Adapt = "3.3"
CEnum = "0.4"
Enzyme_jll = "0.0.28"
GPUCompiler = "0.13.14"
GPUCompiler = "0.14"
LLVM = "4.1"
ObjectFile = "0.3"
julia = "1.6"
22 changes: 15 additions & 7 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1366,11 +1366,12 @@ end
# - GPU support
# - When OrcV2 only use a MaterializationUnit to avoid mutation of the module here


target = GPUCompiler.NativeCompilerTarget()
params = Compiler.PrimalCompilerParams()
job = CompilerJob(target, funcspec, params)

otherMod, meta = GPUCompiler.codegen(:llvm, job, optimize=false, validate=false)
otherMod, meta = GPUCompiler.codegen(:llvm, job; optimize=false, validate=false, ctx)
entry = name(meta.entry)

# 4) Link the corresponding module
Expand Down Expand Up @@ -3065,7 +3066,7 @@ function adim(::Array{T, N}) where {T, N}
end

function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
libraries::Bool=true, deferred_codegen::Bool=true, optimize::Bool=true,
libraries::Bool=true, deferred_codegen::Bool=true, optimize::Bool=true, ctx = nothing,
strip::Bool=false, validate::Bool=true, only_entry::Bool=false, parent_job::Union{Nothing, CompilerJob} = nothing)
params = job.params
mode = params.mode
Expand All @@ -3081,11 +3082,11 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
else
primal_job = similar(parent_job, job.source)
end
mod, meta = GPUCompiler.codegen(:llvm, primal_job, optimize=false, validate=false, parent_job=parent_job)
mod, meta = GPUCompiler.codegen(:llvm, primal_job; optimize=false, validate=false, parent_job=parent_job, ctx)
primalf = meta.entry
check_ir(job, mod)

ctx = context(mod)
@assert ctx == context(mod)
custom = Dict{String, LLVM.API.LLVMLinkage}()
must_wrap = false

Expand Down Expand Up @@ -3611,7 +3612,7 @@ end
# JIT
##

function _link(job, (mod, adjoint_name, primal_name))
function _link(job, (mod, adjoint_name, primal_name, ctx))
params = job.params
adjoint = params.adjoint

Expand All @@ -3620,6 +3621,11 @@ function _link(job, (mod, adjoint_name, primal_name))

# Now invoke the JIT
jitted_mod = JIT.add!(mod)
if VERSION >= v"1.9.0-DEV.115"
LLVM.dispose(ctx)
else
# we cannot dispose of the global unique context
end
adjoint_addr = JIT.lookup(jitted_mod, adjoint_name)

adjoint_ptr = pointer(adjoint_addr)
Expand All @@ -3643,7 +3649,9 @@ end
function _thunk(job)
params = job.params

mod, meta = codegen(:llvm, job, optimize=false)
# TODO: on 1.9, this actually creates a context. cache those.
ctx = JuliaContext()
mod, meta = codegen(:llvm, job; optimize=false, ctx)

adjointf, augmented_primalf = meta.adjointf, meta.augmented_primalf

Expand All @@ -3662,7 +3670,7 @@ function _thunk(job)

# Run post optimization pipeline
post_optimze!(mod, JIT.get_tm())
return (mod, adjoint_name, primal_name)
return (mod, adjoint_name, primal_name, ctx)
end

const cache = Dict{UInt, Dict{UInt, Any}}()
Expand Down
23 changes: 23 additions & 0 deletions src/compiler/orcv2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,24 @@ const tm = Ref{TargetMachine}() # for opt pipeline

get_tm() = tm[]

function absolute_symbol_materialization(name, ptr)
address = LLVM.API.LLVMOrcJITTargetAddress(reinterpret(UInt, ptr))
flags = LLVM.API.LLVMJITSymbolFlags(LLVM.API.LLVMJITSymbolGenericFlagsExported, 0)
symbol = LLVM.API.LLVMJITEvaluatedSymbol(address, flags)
gv = LLVM.API.LLVMJITCSymbolMapPair(name, symbol)

return LLVM.absolute_symbols(Ref(gv))
end

function define_absolute_symbol(jd, name)
ptr = LLVM.find_symbol(name)
if ptr !== C_NULL
LLVM.define(jd, absolute_symbol_materialization(name, ptr))
return true
end
return false
end

function __init__()
opt_level = Base.JLOptions().opt_level
if opt_level < 2
Expand Down Expand Up @@ -71,6 +89,11 @@ function __init__()
dg = LLVM.CreateDynamicLibrarySearchGeneratorForProcess(prefix)
LLVM.add!(jd_main, dg)

if Sys.iswindows() && Int === Int64
# TODO can we check isGNU?
define_absolute_symbol(jd_main, mangle(lljit, "___chkstk_ms"))
end

es = ExecutionSession(lljit)
try
lctm = LLVM.LocalLazyCallThroughManager(triple(lljit), es)
Expand Down
24 changes: 14 additions & 10 deletions src/compiler/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ end


function reflect(@nospecialize(func), @nospecialize(A), @nospecialize(types);
optimize::Bool=true, second_stage::Bool=true, kwargs...)
optimize::Bool=true, second_stage::Bool=true, ctx=nothing, kwargs...)

job = get_job(func, A, types; kwargs...)
# Codegen the primal function and all its dependency in one module
mod, meta = Compiler.codegen(:llvm, job, optimize=optimize, #= validate=false =#)
mod, meta = Compiler.codegen(:llvm, job; optimize, ctx #= validate=false =#)

if second_stage
post_optimze!(mod, JIT.get_tm())
Expand All @@ -31,19 +31,23 @@ end
function enzyme_code_llvm(io::IO, @nospecialize(func), @nospecialize(A), @nospecialize(types);
optimize::Bool=true, run_enzyme::Bool=true, second_stage::Bool=true,
raw::Bool=false, debuginfo::Symbol=:default, dump_module::Bool=false)
llvmf, _ = reflect(func, A, types; optimize,run_enzyme, second_stage)
JuliaContext() do ctx
llvmf, _ = reflect(func, A, types; optimize, run_enzyme, second_stage, ctx)

str = ccall(:jl_dump_function_ir, Ref{String},
(LLVM.API.LLVMValueRef, Bool, Bool, Ptr{UInt8}),
llvmf, !raw, dump_module, debuginfo)
print(io, str)
str = ccall(:jl_dump_function_ir, Ref{String},
(LLVM.API.LLVMValueRef, Bool, Bool, Ptr{UInt8}),
llvmf, !raw, dump_module, debuginfo)
print(io, str)
end
end
enzyme_code_llvm(@nospecialize(func), @nospecialize(A), @nospecialize(types); kwargs...) = enzyme_code_llvm(stdout, func, A, types; kwargs...)

function enzyme_code_native(io::IO, @nospecialize(func), @nospecialize(A), @nospecialize(types))
_, mod = reflect(func, A, types)
str = String(LLVM.emit(JIT.get_tm(), mod, LLVM.API.LLVMAssemblyFile))
print(io, str)
JuliaContext() do ctx
_, mod = reflect(func, A, types; ctx)
str = String(LLVM.emit(JIT.get_tm(), mod, LLVM.API.LLVMAssemblyFile))
print(io, str)
end
end
enzyme_code_native(@nospecialize(func), @nospecialize(A), @nospecialize(types); kwargs...) = enzyme_code_native(stdout, func, A, types; kwargs...)

Expand Down
2 changes: 1 addition & 1 deletion src/pmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ function commonInnerCompile(runtime_fn, B, orig, gutils, tape)
params = Compiler.PrimalCompilerParams()
job = CompilerJob(target, funcspec, params)

otherMod, meta = GPUCompiler.codegen(:llvm, job, optimize=false, validate=false)
otherMod, meta = GPUCompiler.codegen(:llvm, job; optimize=false, validate=false, ctx)
entry = name(meta.entry)
optimize!(otherMod, JIT.get_tm())

Expand Down

0 comments on commit 0742289

Please sign in to comment.