Skip to content

Commit

Permalink
Implement caching
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Oct 30, 2023
1 parent cbb3f3b commit 60af10b
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 29 deletions.
1 change: 1 addition & 0 deletions stdlib/TapirOffloading/examples/saxpy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ function saxpy(Z, X, Y, a)
Tapir.foreach(eachindex(Z, Y, X)) do I
@inbounds Z[I] = a*X[I] + Y[I]
end
# FIXME: This should not be needed, bug in SyncInst detection.
TapirOffloading.sync()
Z
end
Expand Down
25 changes: 16 additions & 9 deletions stdlib/TapirOffloading/ext/TapirOffloadingCUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,24 @@ module TapirOffloadingCUDA
end

TapirOffloading.runtime(::CUDABackend) = OffloadingRuntime
function TapirOffloading.compiler_config(::CUDABackend)
cuda = CUDA.active_state()
return CUDA.compiler_config(cuda.device)::CUDA.CUDACompilerConfig
end

const ROOTED_FUNCTIONS = Vector{CuFunction}()
function TapirOffloading.link_kernel(::CUDABackend, image, name)
function link_kernel(image, name)
cu_mod = CuModule(image)
cu_func = CuFunction(cu_mod, name)
push!(ROOTED_FUNCTIONS, cu_func)
return Base.unsafe_convert(Ptr{Cvoid}, cu_func.handle)
return CuFunction(cu_mod, name)
end

const _compiler_caches = Dict{CuContext, Dict{Any, CuFunction}}()
function TapirOffloading.lookup_or_compile(backend::CUDABackend, mod, nbytes, name)
cuda = CUDA.active_state()
cfg = CUDA.compiler_config(cuda.device)

key = (mod, name, cfg)
cache = CUDA.compiler_cache(cuda.context)
cuF = get!(cache, key) do
image, name = TapirOffloading.codegen(backend, mod, nbytes, name, cfg)
return link_kernel(image, name)
end::CuFunction
Base.unsafe_convert(Ptr{Cvoid}, cuF.handle)
end

function TapirOffloading.launch(::CUDABackend, func, args, args_sz, N)
Expand Down
33 changes: 13 additions & 20 deletions stdlib/TapirOffloading/src/TapirOffloading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ function register(name, backend)
BACKENDS[name] = backend
end

function lookup_or_compile end
function runtime end
function compiler_config end
function link_kernel end
function launch end
function sync end
function pin end
Expand All @@ -30,7 +29,6 @@ using GPUCompiler

function initialize()
if !GPUCompiler.__llvm_initialized[]
@info "Initializing LLVM" libLLVMExtra = LLVM.API.libLLVMExtra
LLVM.InitializeAllTargets()
LLVM.InitializeAllTargetInfos()
LLVM.InitializeAllAsmPrinters()
Expand Down Expand Up @@ -73,15 +71,21 @@ function build_runtime(backend, config)
return mod
end

const Key = Tuple{Ptr{Int8}, Ptr{Int8}}
# TODO: memoize on device
# See CUDA.jl -- src/compiler/compilation.jl
function codegen(backend, mod, nbytes, name, cfg)
name = Base.unsafe_string(name)
mod = unsafe_wrap(Vector{Int8}, mod, nbytes)
return ThreadSafeContext() do ts_ctx
LLVM.context!(LLVM.context(ts_ctx)) do
mod = parse(LLVM.Module, mod)
codegen(backend, mod, name, cfg)
end
end
end

function placeholder() end

function codegen(backend, ir::LLVM.Module, entry_fn)
function codegen(backend, ir::LLVM.Module, entry_fn, cfg)
initialize()
cfg = compiler_config(backend)

@info "Compiling for" target=cfg.target

Expand Down Expand Up @@ -124,8 +128,6 @@ function codegen(backend, ir::LLVM.Module, entry_fn)
GPUCompiler.check_ir(job, ir)
verify(ir)

@info "Backend" ir

# 2. Emission
asm, asm_meta = GPUCompiler.emit_asm(job, ir; strip=true, validate=true, format=LLVM.API.LLVMAssemblyFile)
return asm, entry_fn
Expand All @@ -140,16 +142,7 @@ import Base: @ccallable

@ccallable function __chi_lookup_or_compile(mod::Ptr{Int8}, nbytes::Csize_t, name::Ptr{Int8})::Ptr{Cvoid}
backend = current_backend()
name = Base.unsafe_string(name)
mod = unsafe_wrap(Vector{Int8}, mod, nbytes)
@info "Lookup or Compile called" name nbytes
image, name = ThreadSafeContext() do ts_ctx
LLVM.context!(LLVM.context(ts_ctx)) do
mod = parse(LLVM.Module, mod)
codegen(backend, mod, name)
end
end
return link_kernel(backend, image, name)::Ptr{Cvoid}
lookup_or_compile(backend, mod, nbytes, name)::Ptr{Cvoid}
end

@ccallable function __chi_launch(func::Ptr{Int8}, args::Ptr{Int8}, args_sz::Csize_t, N::Csize_t)::Cvoid
Expand Down

0 comments on commit 60af10b

Please sign in to comment.