Skip to content

Commit

Permalink
Remove julia level type rules (#2130)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Nov 28, 2024
1 parent 45f01bd commit e9d303b
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 264 deletions.
6 changes: 3 additions & 3 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,9 @@ function CreateTypeAnalysis(logic, rulenames, rules)
EnzymeTypeAnalysisRef,
(EnzymeLogicRef, Ptr{Cstring}, Ptr{CustomRuleType}, Csize_t),
logic,
rulenames,
rules,
length(rules),
rulenames isa Tuple{} ? C_NULL : rulenames,
rules isa Tuple{} ? C_NULL : rules,
rulenames isa Tuple{} ? 0 : length(rules),
)
end

Expand Down
102 changes: 1 addition & 101 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3940,7 +3940,6 @@ function enzyme_extract_parm_type(fn::LLVM.Function, idx::Int, error::Bool = tru
return ty, byref
end

include("rules/typerules.jl")
include("rules/activityrules.jl")

@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:Const} = API.DFT_CONSTANT
Expand Down Expand Up @@ -4073,107 +4072,8 @@ function enzyme!(
convert(API.CDIFFE_TYPE, rt)
end

rules = Dict{String,API.CustomRuleType}(
"jl_array_copy" => @cfunction(
inout_rule,
UInt8,
(
Cint,
API.CTypeTreeRef,
Ptr{API.CTypeTreeRef},
Ptr{API.IntList},
Csize_t,
LLVM.API.LLVMValueRef,
)
),
"ijl_array_copy" => @cfunction(
inout_rule,
UInt8,
(
Cint,
API.CTypeTreeRef,
Ptr{API.CTypeTreeRef},
Ptr{API.IntList},
Csize_t,
LLVM.API.LLVMValueRef,
)
),
"jl_genericmemory_copy_slice" => @cfunction(
inoutcopyslice_rule,
UInt8,
(
Cint,
API.CTypeTreeRef,
Ptr{API.CTypeTreeRef},
Ptr{API.IntList},
Csize_t,
LLVM.API.LLVMValueRef,
)
),
"ijl_genericmemory_copy_slice" => @cfunction(
inoutcopyslice_rule,
UInt8,
(
Cint,
API.CTypeTreeRef,
Ptr{API.CTypeTreeRef},
Ptr{API.IntList},
Csize_t,
LLVM.API.LLVMValueRef,
)
),
"jl_inactive_inout" => @cfunction(
inout_rule,
UInt8,
(
Cint,
API.CTypeTreeRef,
Ptr{API.CTypeTreeRef},
Ptr{API.IntList},
Csize_t,
LLVM.API.LLVMValueRef,
)
),
"jl_excstack_state" => @cfunction(
int_return_rule,
UInt8,
(
Cint,
API.CTypeTreeRef,
Ptr{API.CTypeTreeRef},
Ptr{API.IntList},
Csize_t,
LLVM.API.LLVMValueRef,
)
),
"ijl_excstack_state" => @cfunction(
int_return_rule,
UInt8,
(
Cint,
API.CTypeTreeRef,
Ptr{API.CTypeTreeRef},
Ptr{API.IntList},
Csize_t,
LLVM.API.LLVMValueRef,
)
),
"julia.except_enter" => @cfunction(
int_return_rule,
UInt8,
(
Cint,
API.CTypeTreeRef,
Ptr{API.CTypeTreeRef},
Ptr{API.IntList},
Csize_t,
LLVM.API.LLVMValueRef,
)
),
)

logic = Logic()
TA = TypeAnalysis(logic, rules)
TA = TypeAnalysis(logic)

retTT = if !isa(actualRetType, Union) &&
actualRetType <: Tuple &&
Expand Down
14 changes: 1 addition & 13 deletions src/rules/allocrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,6 @@ function array_shadow_handler(
return ref
end

function null_free_handler(
B::LLVM.API.LLVMBuilderRef,
ToFree::LLVM.API.LLVMValueRef,
Fn::LLVM.API.LLVMValueRef,
)::LLVM.API.LLVMValueRef
return C_NULL
end

function register_alloc_handler!(variants, alloc_handler, free_handler)
for variant in variants
API.EnzymeRegisterAllocationHandler(variant, alloc_handler, free_handler)
Expand All @@ -120,10 +112,6 @@ end
API.EnzymeGradientUtilsRef,
)
),
@cfunction(
null_free_handler,
LLVM.API.LLVMValueRef,
(LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef)
)
C_NULL
)
end
140 changes: 0 additions & 140 deletions src/rules/typerules.jl

This file was deleted.

18 changes: 11 additions & 7 deletions src/typeanalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,19 @@ LLVM.dispose(ta::TypeAnalysis) = API.FreeTypeAnalysis(ta)

function TypeAnalysis(
logic,
typerules::Dict{String,CustomRuleType} = Dict{String,CustomRuleType}(),
typerules::Union{Dict{String,CustomRuleType}, Nothing} = nothing,
)
rulenames = String[]
rules = CustomRuleType[]
for (rulename, rule) in typerules
push!(rulenames, rulename)
push!(rules, rule)
if typerules isa Nothing
ref = API.CreateTypeAnalysis(logic, (), ())
else
rulenames = String[]
rules = CustomRuleType[]
for (rulename, rule) in typerules
push!(rulenames, rulename)
push!(rules, rule)
end
ref = API.CreateTypeAnalysis(logic, rulenames, rules)
end
ref = API.CreateTypeAnalysis(logic, rulenames, rules)
TypeAnalysis(ref)
end

Expand Down

0 comments on commit e9d303b

Please sign in to comment.