Skip to content

Commit

Permalink
Merge pull request #56 from wsmoses/vc/enz_0.0.9
Browse files Browse the repository at this point in the history
adapt to C-API changes
  • Loading branch information
vchuravy authored Mar 25, 2021
2 parents ecd8b49 + a1e67d3 commit 8fef794
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 29 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
[compat]
CEnum = "0.4"
Cassette = "0.3.4"
Enzyme_jll = "0.0.8"
Enzyme_jll = "0.0.9"
GPUCompiler = "0.9, 0.10"
LLVM = "3.2"
julia = "1.6"
1 change: 1 addition & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Base.eltype(::Type{<:Annotation{T}}) where T = T
import LLVM

include("api.jl")
include("logic.jl")
include("typeanalysis.jl")
include("typetree.jl")
include("utils.jl")
Expand Down
48 changes: 26 additions & 22 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@ using Enzyme_jll
using Libdl
using CEnum

struct EnzymeAAResultsRef
a::Ptr{Cvoid}
b::Ptr{Cvoid}
c::Ptr{Cvoid}
end
const EnzymeLogicRef = Ptr{Cvoid}
const EnzymeTypeAnalysisRef = Ptr{Cvoid}
const EnzymeAugmentedReturnPtr = Ptr{Cvoid}

Expand Down Expand Up @@ -61,14 +57,6 @@ end
)


function EnzymeGetGlobalAA(mod)
ccall((:EnzymeGetGlobalAA, libEnzyme), EnzymeAAResultsRef, (LLVMModuleRef,), mod)
end

function EnzymeFreeGlobalAA(aa)
ccall((:EnzymeFreeGlobalAA, libEnzyme), Cvoid, (EnzymeAAResultsRef,), aa)
end

# Create the derivative function itself.
# \p todiff is the function to differentiate
# \p retType is the activity info of the return
Expand All @@ -84,14 +72,14 @@ end
# pass
# \p AtomicAdd is whether to perform all adjoint updates to memory in an atomic way
# \p PostOpt is whether to perform basic optimization of the function after synthesis
function EnzymeCreatePrimalAndGradient(todiff, retType, constant_args, TA, global_AA,
function EnzymeCreatePrimalAndGradient(logic, todiff, retType, constant_args, TA,
returnValue, dretUsed, topLevel, additionalArg, typeInfo,
uncacheable_args, augmented, atomicAdd, postOpt)
ccall((:EnzymeCreatePrimalAndGradient, libEnzyme), LLVMValueRef,
(LLVMValueRef, CDIFFE_TYPE, Ptr{CDIFFE_TYPE}, Csize_t, EnzymeTypeAnalysisRef,
EnzymeAAResultsRef, UInt8, UInt8, UInt8, LLVMTypeRef, CFnTypeInfo,
(EnzymeLogicRef, LLVMValueRef, CDIFFE_TYPE, Ptr{CDIFFE_TYPE}, Csize_t,
EnzymeTypeAnalysisRef, UInt8, UInt8, UInt8, LLVMTypeRef, CFnTypeInfo,
Ptr{UInt8}, Csize_t, EnzymeAugmentedReturnPtr, UInt8, UInt8),
todiff, retType, constant_args, length(constant_args), TA, global_AA, returnValue,
logic, todiff, retType, constant_args, length(constant_args), TA, returnValue,
dretUsed, topLevel, additionalArg, typeInfo, uncacheable_args, length(uncacheable_args),
augmented, atomicAdd, postOpt)
end
Expand All @@ -107,13 +95,13 @@ end
# \p forceAnonymousTape forces the tape to be an i8* rather than the true tape structure
# \p AtomicAdd is whether to perform all adjoint updates to memory in an atomic way
# \p PostOpt is whether to perform basic optimization of the function after synthesis
function EnzymeCreateAugmentedPrimal(todiff, retType, constant_args, TA, global_AA, returnUsed,
function EnzymeCreateAugmentedPrimal(logic, todiff, retType, constant_args, TA, returnUsed,
typeInfo, uncacheable_args, forceAnonymousTape, atomicAdd, postOpt)
ccall((:EnzymeCreateAugmentedPrimal, libEnzyme), EnzymeAugmentedReturnPtr,
(LLVMValueRef, CDIFFE_TYPE, Ptr{CDIFFE_TYPE}, Csize_t,
EnzymeTypeAnalysisRef, EnzymeAAResultsRef, UInt8,
(EnzymeLogicRef, LLVMValueRef, CDIFFE_TYPE, Ptr{CDIFFE_TYPE}, Csize_t,
EnzymeTypeAnalysisRef, UInt8,
CFnTypeInfo, Ptr{UInt8}, Csize_t, UInt8, UInt8, UInt8),
todiff, retType, constant_args, length(constant_args), TA, global_AA, returnUsed,
logic, todiff, retType, constant_args, length(constant_args), TA, returnUsed,
typeInfo, uncacheable_args, length(uncacheable_args), forceAnonymousTape, atomicAdd, postOpt)
end

Expand All @@ -127,8 +115,24 @@ function CreateTypeAnalysis(triple, rulenames, rules)
ccall((:CreateTypeAnalysis, libEnzyme), EnzymeTypeAnalysisRef, (Cstring, Ptr{Cstring}, Ptr{CustomRuleType}, Csize_t), triple, rulenames, rules, length(rules))
end

function ClearTypeAnalysis(ta)
ccall((:ClearTypeAnalysis, libEnzyme), Cvoid, (EnzymeTypeAnalysisRef,), ta)
end

function FreeTypeAnalysis(ta)
ccall((:FreeTypeAnalysis, libEnzyme), Cvoid, (EnzymeAAResultsRef,), ta)
ccall((:FreeTypeAnalysis, libEnzyme), Cvoid, (EnzymeTypeAnalysisRef,), ta)
end

function CreateLogic()
ccall((:CreateEnzymeLogic, libEnzyme), EnzymeLogicRef, ())
end

function ClearLogic(logic)
ccall((:ClearEnzymeLogic, libEnzyme), Cvoid, (EnzymeLogicRef,), logic)
end

function FreeLogic(logic)
ccall((:FreeEnzymeLogic, libEnzyme), Cvoid, (EnzymeLogicRef,), logic)
end

function EnzymeExtractReturnInfo(ret, data, existed)
Expand Down
11 changes: 5 additions & 6 deletions src/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module Compiler

import ..Enzyme: Const, Active, Duplicated, DuplicatedNoNeed
import ..Enzyme: API, TypeTree, typetree, TypeAnalysis, FnTypeInfo
import ..Enzyme: API, TypeTree, typetree, TypeAnalysis, FnTypeInfo, Logic

using LLVM, GPUCompiler, Libdl
import Enzyme_jll
Expand Down Expand Up @@ -160,14 +160,14 @@ function enzyme!(mod, primalf, adjoint, rt, split)
end

TA = TypeAnalysis(triple(mod))
global_AA = API.EnzymeGetGlobalAA(mod)
logic = Logic()
retTT = typetree(rt, ctx, dl)

typeInfo = FnTypeInfo(retTT, args_typeInfo, args_known_values)

if split
augmented = API.EnzymeCreateAugmentedPrimal(
primalf, retType, args_activity, TA, global_AA, #=returnUsed=# true,
logic, primalf, retType, args_activity, TA, #=returnUsed=# true,
typeInfo, uncacheable_args, #=forceAnonymousTape=# false, #=atomicAdd=# false, #=postOpt=# false)

# 2. get new_primalf
Expand All @@ -186,20 +186,19 @@ function enzyme!(mod, primalf, adjoint, rt, split)
API.EnzymeExtractReturnInfo(augmented, data, existed)

adjointf = LLVM.Function(API.EnzymeCreatePrimalAndGradient(
primalf, retType, args_activity, TA, global_AA,
logic, primalf, retType, args_activity, TA,
#=returnValue=#false, #=dretUsed=#false, #=topLevel=#false,
#=additionalArg=#tape, typeInfo,
uncacheable_args, augmented, #=atomicAdd=#false, #=postOpt=#false))
else
adjointf = LLVM.Function(API.EnzymeCreatePrimalAndGradient(
primalf, retType, args_activity, TA, global_AA,
logic, primalf, retType, args_activity, TA,
#=returnValue=#false, #=dretUsed=#false, #=topLevel=#true,
#=additionalArg=#C_NULL, typeInfo,
uncacheable_args, #=augmented=#C_NULL, #=atomicAdd=#false, #=postOpt=#false))
augmented_primalf = nothing
end

API.EnzymeFreeGlobalAA(global_AA)
return adjointf, augmented_primalf
end

Expand Down
15 changes: 15 additions & 0 deletions src/logic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import LLVM: refcheck

LLVM.@checked struct Logic
ref::API.EnzymeLogicRef
function Logic()
ref = API.CreateLogic()
new(ref)
end
end
Base.unsafe_convert(::Type{API.EnzymeLogicRef}, logic::Logic) = logic.ref
LLVM.dispose(logic::Logic) = API.FreeLogic(logic)

# typedef bool (*CustomRuleType)(int /*direction*/, CTypeTree * /*return*/,
# CTypeTree * /*args*/, size_t /*numArgs*/,
# LLVMValueRef)=T

0 comments on commit 8fef794

Please sign in to comment.