From 82352c278404606fd161a95f83ef8f8bc3ee6651 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 28 Nov 2024 01:34:13 -0500 Subject: [PATCH 1/6] Fewer anonymous funcs --- src/compiler.jl | 55 ++++++++++++++++++++++------------------- src/rules/allocrules.jl | 7 +----- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index ad9367fca3..b3faa47470 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -8012,9 +8012,6 @@ end ::Type{TapeType}, args::Vararg{Any,N}, ) where {RawCall,PT,FA,T,RT,TapeType,N,CC,width,returnPrimal} - - JuliaContext() do ctx - Base.@_inline_meta F = eltype(FA) is_forward = CC <: AugmentedForwardThunk || CC <: ForwardModeThunk || CC <: PrimalErrorThunk @@ -8263,6 +8260,10 @@ end i += 1 end + ts_ctx = JuliaContext() + ctx = context(ts_ctx) + activate(ctx) + (ir, fn, combinedReturn) = try if is_adjoint NT = Tuple{ActiveRetTypes...} @@ -8441,31 +8442,35 @@ end ir = string(mod) fn = LLVM.name(llvm_f) + (ir, fn, combinedReturn) + finally + deactivate(ctx) + dispose(ts_ctx) + end - @assert length(types) == length(ccexprs) + @assert length(types) == length(ccexprs) - if !(GPUCompiler.isghosttype(PT) || Core.Compiler.isconstType(PT)) - return quote - Base.@_inline_meta - Base.llvmcall( - ($ir, $fn), - $combinedReturn, - Tuple{$PT,$(types...)}, - fptr, - $(ccexprs...), - ) - end - else - return quote - Base.@_inline_meta - Base.llvmcall( - ($ir, $fn), - $combinedReturn, - Tuple{$(types...)}, - $(ccexprs...), - ) - end + if !(GPUCompiler.isghosttype(PT) || Core.Compiler.isconstType(PT)) + return quote + Base.@_inline_meta + Base.llvmcall( + ($ir, $fn), + $combinedReturn, + Tuple{$PT,$(types...)}, + fptr, + $(ccexprs...), + ) + end + else + return quote + Base.@_inline_meta + Base.llvmcall( + ($ir, $fn), + $combinedReturn, + Tuple{$(types...)}, + $(ccexprs...), + ) end end end diff --git a/src/rules/allocrules.jl b/src/rules/allocrules.jl index 7c611b6c85..bc732809a3 100644 --- a/src/rules/allocrules.jl +++ b/src/rules/allocrules.jl @@ -1,7 +1,4 @@ - -function array_inner(::Type{<:Array{T}}) where {T} - return T -end +LLT_ALIGN(x, sz) = (((x) + (sz) - 1) & ~((sz) - 1)) function array_shadow_handler( B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, @@ -52,8 +49,6 @@ function array_shadow_handler( isunion = typ isa Union - LLT_ALIGN(x, sz) = (((x) + (sz) - 1) & ~((sz) - 1)) - if !isunboxed elsz = sizeof(Ptr{Cvoid}) al = elsz From 1ef3b5667397b4d49f9de2806ae58b2006943b9f Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 28 Nov 2024 01:36:17 -0500 Subject: [PATCH 2/6] minor cleanup --- src/rules/allocrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules/allocrules.jl b/src/rules/allocrules.jl index bc732809a3..7cd3c59b4c 100644 --- a/src/rules/allocrules.jl +++ b/src/rules/allocrules.jl @@ -1,4 +1,4 @@ -LLT_ALIGN(x, sz) = (((x) + (sz) - 1) & ~((sz) - 1)) +@inline LLT_ALIGN(x::Int, sz::Int) = (((x) + (sz) - 1) & ~((sz) - 1)) function array_shadow_handler( B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, From 30401f7c311274ac1b9d7edff6a82ef7d9dcc8ef Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 28 Nov 2024 02:26:26 -0500 Subject: [PATCH 3/6] cleanup --- src/absint.jl | 8 ++++---- src/utils.jl | 34 +++++++++++++++++++++++++++++++--- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index c4866fe4da..6302c9cc2f 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -1,7 +1,7 @@ # Abstractly interpret julia from LLVM # Return (bool if could interpret, julia object interpreted to) -function absint(arg::LLVM.Value, partial::Bool = false) +function absint(@nospecialize(arg::LLVM.Value), partial::Bool = false) if isa(arg, LLVM.BitCastInst) || isa(arg, LLVM.AddrSpaceCastInst) return absint(operands(arg)[1], partial) end @@ -228,7 +228,7 @@ function should_recurse(@nospecialize(typ2), arg_t, byref, dl) end end -function get_base_and_offset(larg::LLVM.Value; offsetAllowed=true, inttoptr=false)::Tuple{LLVM.Value, Int} +function get_base_and_offset(@nospecialize(larg::LLVM.Value); offsetAllowed=true, inttoptr=false)::Tuple{LLVM.Value, Int} offset = 0 while true if isa(larg, LLVM.ConstantExpr) @@ -277,7 +277,7 @@ function get_base_and_offset(larg::LLVM.Value; offsetAllowed=true, inttoptr=fals end function abs_typeof( - arg::LLVM.Value, + @nospecialize(arg::LLVM.Value), partial::Bool = false, seenphis=Set{LLVM.PHIInst}() )::Union{Tuple{Bool,Type,GPUCompiler.ArgumentCC},Tuple{Bool,Nothing,Nothing}} if isa(arg, LLVM.BitCastInst) || isa(arg, LLVM.AddrSpaceCastInst) @@ -729,7 +729,7 @@ function abs_typeof( return (false, nothing, nothing) end -function abs_cstring(arg::LLVM.Value)::Tuple{Bool,String} +function abs_cstring(@nospecialize(arg::LLVM.Value))::Tuple{Bool,String} if isa(arg, ConstantExpr) ce = arg while isa(ce, ConstantExpr) diff --git a/src/utils.jl b/src/utils.jl index d5d0ed733a..d1f1fabb72 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -5,12 +5,40 @@ Assumes that `val` is globally rooted and pointer to it can be leaked. Prefer `pointer_from_objref`. Only use inside Enzyme.jl should be for Types. """ -@inline unsafe_to_pointer(val::Type{T}) where {T} = ccall( - Base.@cfunction(Base.identity, Ptr{Cvoid}, (Ptr{Cvoid},)), +@inline unsafe_to_pointer(val::Type{T}) where {T} = @static if sizeof(Int) == sizeof(Int64) + Base.llvmcall(( +""" +declare nonnull {}* @julia.pointer_from_objref({} addrspace(11)*) + +define i64 @f({} addrspace(10)* %obj) readnone alwaysinline { + %c = addrspacecast {} addrspace(10)* %obj to {} addrspace(11)* + %r = call {}* @julia.pointer_from_objref({} addrspace(11)* %c) + %e = ptrtoint {}* %r to i64 + ret i64 %e +} +""", "f"), + Ptr{Cvoid}, + Tuple{Any}, + val, +) +else + Base.llvmcall(( +""" +declare nonnull {}* @julia.pointer_from_objref({} addrspace(11)*) + +define i32 @f({} addrspace(10)* %obj) readnone alwaysinline { + %c = addrspacecast {} addrspace(10)* %obj to {} addrspace(11)* + %r = call {}* @julia.pointer_from_objref({} addrspace(11)* %c) + %e = ptrtoint {}* %r to i32 + ret i32 %e +} +""", "f"), Ptr{Cvoid}, - (Any,), + Tuple{Any}, val, ) +end + export unsafe_to_pointer @inline is_concrete_tuple(x::Type{T2}) where {T2} = From bdb69baf43eabb4756d1cdc4b172566db4e35bb9 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 28 Nov 2024 02:44:41 -0500 Subject: [PATCH 4/6] fix --- src/compiler.jl | 12 ++++++++++-- src/compiler/validation.jl | 14 ++++++++++++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index b3faa47470..6ef999862a 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1125,7 +1125,11 @@ struct Return2 end function force_recompute!(mod::LLVM.Module) - for f in functions(mod), bb in blocks(f), inst in collect(instructions(bb)) + for f in functions(mod), bb in blocks(f) + iter = LLVM.API.LLVMGetFirstInstruction(bb) + while iter != C_NULL + inst = LLVM.Instruction(iter) + iter = LLVM.API.LLVMGetNextInstruction(iter) if isa(inst, LLVM.LoadInst) has_loaded = false for u in LLVM.uses(inst) @@ -1170,6 +1174,7 @@ function force_recompute!(mod::LLVM.Module) end end end + end end function permit_inlining!(f::LLVM.Function) @@ -9076,7 +9081,10 @@ include("compiler/reflection.jl") ) copysetfn = meta.entry blk = first(blocks(copysetfn)) - for inst in collect(instructions(blk)) + iter = LLVM.API.LLVMGetFirstInstruction(blk) + while iter != C_NULL + inst = LLVM.Instruction(iter) + iter = LLVM.API.LLVMGetNextInstruction(iter) if isa(inst, LLVM.FenceInst) eraseInst(blk, inst) end diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 2a5c860f64..d5bad35607 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -421,7 +421,11 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp calls = LLVM.CallInst[] isInline = API.EnzymeGetCLBool(cglobal((:EnzymeInline, API.libEnzyme))) != 0 mod = LLVM.parent(f) - for bb in blocks(f), inst in collect(instructions(bb)) + for bb in blocks(f) + iter = LLVM.API.LLVMGetFirstInstruction(bb) + while iter != C_NULL + inst = LLVM.Instruction(iter) + iter = LLVM.API.LLVMGetNextInstruction(iter) if isa(inst, LLVM.CallInst) push!(calls, inst) # remove illegal invariant.load and jtbaa_const invariants @@ -489,7 +493,11 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp newf, _ = get_function!(mod, fname, FT) else found = nothing - for lbb in blocks(initfn), linst in collect(instructions(lbb)) + for lbb in blocks(initfn) + liter = LLVM.API.LLVMGetFirstInstruction(lbb) + while liter != C_NULL + linst = LLVM.Instruction(liter) + liter = LLVM.API.LLVMGetNextInstruction(liter) if !isa(linst, LLVM.CallInst) continue end @@ -502,6 +510,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp break end end + end if found == nothing msg = sprint() do io::IO println( @@ -630,6 +639,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp end end end + end while length(calls) > 0 inst = pop!(calls) From 083757b5b9470dca87ac0592dcf26ac4b928212d Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 28 Nov 2024 03:18:32 -0500 Subject: [PATCH 5/6] Aggressively noinfer --- src/absint.jl | 16 +++++----- src/compiler.jl | 60 ++++++++++++++++++------------------- src/compiler/interpreter.jl | 48 ++++++++++++++--------------- src/compiler/utils.jl | 6 ++-- src/compiler/validation.jl | 6 ++-- src/jlrt.jl | 60 ++++++++++++++++++------------------- src/utils.jl | 24 +++++++-------- 7 files changed, 110 insertions(+), 110 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index 6302c9cc2f..6fbd0bcd7b 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -1,7 +1,7 @@ # Abstractly interpret julia from LLVM # Return (bool if could interpret, julia object interpreted to) -function absint(@nospecialize(arg::LLVM.Value), partial::Bool = false) +Base.@nospecializeinfer function absint(@nospecialize(arg::LLVM.Value), partial::Bool = false)::Tuple{Bool,Any} if isa(arg, LLVM.BitCastInst) || isa(arg, LLVM.AddrSpaceCastInst) return absint(operands(arg)[1], partial) end @@ -165,7 +165,7 @@ function absint(@nospecialize(arg::LLVM.Value), partial::Bool = false) return (false, nothing) end -function actual_size(@nospecialize(typ2)) +Base.@nospecializeinfer function actual_size(@nospecialize(typ2))::Int @static if VERSION < v"1.11-" if typ2 <: Array return sizeof(Ptr{Cvoid}) + 2 + 2 + 4 + 2 * sizeof(Csize_t) + sizeof(Csize_t) @@ -184,10 +184,10 @@ function actual_size(@nospecialize(typ2)) end end -@inline function first_non_ghost(@nospecialize(typ2)) +Base.@nospecializeinfer @inline function first_non_ghost(@nospecialize(typ2))::Tuple{Int, Int} @static if VERSION < v"1.11-" if typ2 <: Array - return (1, typed_fieldtype(typ2, 1)) + return (1, 0) end end fc = fieldcount(typ2) @@ -204,7 +204,7 @@ end return (-1, 0) end -function should_recurse(@nospecialize(typ2), arg_t, byref, dl) +Base.@nospecializeinfer function should_recurse(@nospecialize(typ2), @nospecialize(arg_t::LLVM.LLVMType), byref::GPUCompiler.ArgumentCC, dl::LLVM.DataLayout)::Bool sz = sizeof(dl, arg_t) if byref != GPUCompiler.BITS_VALUE if sz != sizeof(Int) @@ -228,7 +228,7 @@ function should_recurse(@nospecialize(typ2), arg_t, byref, dl) end end -function get_base_and_offset(@nospecialize(larg::LLVM.Value); offsetAllowed=true, inttoptr=false)::Tuple{LLVM.Value, Int} +Base.@nospecializeinfer function get_base_and_offset(@nospecialize(larg::LLVM.Value); offsetAllowed::Bool=true, inttoptr::Bool=false)::Tuple{LLVM.Value, Int} offset = 0 while true if isa(larg, LLVM.ConstantExpr) @@ -276,7 +276,7 @@ function get_base_and_offset(@nospecialize(larg::LLVM.Value); offsetAllowed=true return larg, offset end -function abs_typeof( +Base.@nospecializeinfer function abs_typeof( @nospecialize(arg::LLVM.Value), partial::Bool = false, seenphis=Set{LLVM.PHIInst}() )::Union{Tuple{Bool,Type,GPUCompiler.ArgumentCC},Tuple{Bool,Nothing,Nothing}} @@ -729,7 +729,7 @@ function abs_typeof( return (false, nothing, nothing) end -function abs_cstring(@nospecialize(arg::LLVM.Value))::Tuple{Bool,String} +Base.@nospecializeinfer function abs_cstring(@nospecialize(arg::LLVM.Value))::Tuple{Bool,String} if isa(arg, ConstantExpr) ce = arg while isa(ce, ConstantExpr) diff --git a/src/compiler.jl b/src/compiler.jl index 6ef999862a..d9823e5dd1 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -106,7 +106,7 @@ const known_ops = Dict{DataType,Tuple{Symbol,Int,Union{Nothing,Tuple{Symbol,Data typeof(Base.FastMath.tanh_fast) => (:tanh, 1, nothing), typeof(Base.fma_emulated) => (:fma, 3, nothing), ) -@inline function find_math_method(@nospecialize(func::Type), sparam_vals::Core.SimpleVector) +@inline Base.@nospecializeinfer function find_math_method(@nospecialize(func::Type), sparam_vals::Core.SimpleVector) if func ∈ keys(known_ops) name, arity, toinject = known_ops[func] Tys = (Float32, Float64) @@ -1207,7 +1207,7 @@ end include("make_zero.jl") -function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type), world::UInt) +Base.@nospecializeinfer function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type), world::UInt) funcspec = my_methodinstance(typeof(f), tt, world) nested_codegen!(mode, mod, funcspec, world) end @@ -1411,7 +1411,7 @@ end parent_scope(val::LLVM.Function, depth = 0) = depth == 0 ? LLVM.parent(val) : val parent_scope(val::LLVM.Module, depth = 0) = val -parent_scope(@nospecialize(val::LLVM.Value), depth = 0) = parent_scope(LLVM.parent(val), depth + 1) +Base.@nospecializeinfer parent_scope(@nospecialize(val::LLVM.Value), depth = 0) = parent_scope(LLVM.parent(val), depth + 1) parent_scope(val::LLVM.Argument, depth = 0) = parent_scope(LLVM.Function(LLVM.API.LLVMGetParamParent(val)), depth + 1) @@ -2209,7 +2209,7 @@ current_task_offset() = current_ptls_offset() = unsafe_load(cglobal(:jl_task_ptls_offset, Cint)) ÷ sizeof(Ptr{Cvoid}) -function store_nonjl_types!(B::LLVM.IRBuilder, @nospecialize(startval::LLVM.Value), @nospecialize(p::LLVM.Value)) +Base.@nospecializeinfer function store_nonjl_types!(B::LLVM.IRBuilder, @nospecialize(startval::LLVM.Value), @nospecialize(p::LLVM.Value)) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) vals = LLVM.Value[] @@ -2253,7 +2253,7 @@ function store_nonjl_types!(B::LLVM.IRBuilder, @nospecialize(startval::LLVM.Valu return end -function get_julia_inner_types(B::LLVM.IRBuilder, @nospecialize(p::Union{Nothing, LLVM.Value}), @nospecialize(startvals::Vararg{LLVM.Value}); added = LLVM.API.LLVMValueRef[]) +Base.@nospecializeinfer function get_julia_inner_types(B::LLVM.IRBuilder, @nospecialize(p::Union{Nothing, LLVM.Value}), @nospecialize(startvals::Vararg{LLVM.Value}); added = LLVM.API.LLVMValueRef[]) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) vals = LLVM.Value[] @@ -2498,7 +2498,7 @@ function zero_allocation(B::LLVM.API.LLVMBuilderRef, LLVMType::LLVM.API.LLVMType return nothing end -function zero_single_allocation(builder::LLVM.IRBuilder, @nospecialize(jlType::DataType), @nospecialize(LLVMType::LLVM.LLVMType), @nospecialize(nobj::LLVM.Value), zeroAll::Bool, @nospecialize(idx::LLVM.Value)) +Base.@nospecializeinfer function zero_single_allocation(builder::LLVM.IRBuilder, @nospecialize(jlType::DataType), @nospecialize(LLVMType::LLVM.LLVMType), @nospecialize(nobj::LLVM.Value), zeroAll::Bool, @nospecialize(idx::LLVM.Value)) T_jlvalue = LLVM.StructType(LLVM.LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) T_prjlvalue_UT = LLVM.PointerType(T_jlvalue) @@ -2573,7 +2573,7 @@ function zero_single_allocation(builder::LLVM.IRBuilder, @nospecialize(jlType::D end -function zero_allocation( +Base.@nospecializeinfer function zero_allocation( B::LLVM.IRBuilder, @nospecialize(jlType::DataType), @nospecialize(LLVMType::LLVM.LLVMType), @@ -2662,7 +2662,7 @@ function zero_allocation( ).ref end -function julia_allocator(B::LLVM.IRBuilder, @nospecialize(LLVMType::LLVM.LLVMType), @nospecialize(Count::LLVM.Value), @nospecialize(AlignedSize::LLVM.Value), IsDefault::UInt8, ZI::Ptr{LLVM.API.LLVMValueRef}) +Base.@nospecializeinfer function julia_allocator(B::LLVM.IRBuilder, @nospecialize(LLVMType::LLVM.LLVMType), @nospecialize(Count::LLVM.Value), @nospecialize(AlignedSize::LLVM.Value), IsDefault::UInt8, ZI::Ptr{LLVM.API.LLVMValueRef}) func = LLVM.parent(position(B)) mod = LLVM.parent(func) @@ -2787,7 +2787,7 @@ function julia_deallocator(B::LLVM.API.LLVMBuilderRef, Obj::LLVM.API.LLVMValueRe julia_deallocator(B, Obj) end -function julia_deallocator(B::LLVM.IRBuilder, @nospecialize(Obj::LLVM.Value)) +Base.@nospecializeinfer function julia_deallocator(B::LLVM.IRBuilder, @nospecialize(Obj::LLVM.Value)) mod = LLVM.parent(LLVM.parent(position(B))) T_void = LLVM.VoidType() @@ -3087,7 +3087,7 @@ import .Interpreter: isKWCallSignature """ Create the methodinstance pair, and lookup the primal return type. """ -@inline function fspec( +@inline Base.@nospecializeinfer function fspec( @nospecialize(F::Type), @nospecialize(TT::Type), world::Union{UInt,Nothing} = nothing, @@ -3179,7 +3179,7 @@ primal_return_type_world( @nospecialize(TT::Type), ) = primal_return_type_world(mode, world, Tuple{FT, TT.parameters...}) -function primal_return_type_generator(world::UInt, source, self, @nospecialize(mode::Type), @nospecialize(ft::Type), @nospecialize(tt::Type)) +Base.@nospecializeinfer function primal_return_type_generator(world::UInt, source, self, @nospecialize(mode::Type), @nospecialize(ft::Type), @nospecialize(tt::Type)) @nospecialize @assert Core.Compiler.isType(ft) && Core.Compiler.isType(tt) @assert mode <: Mode @@ -3280,7 +3280,7 @@ end # Enzyme compiler step ## -function annotate!(mod, mode) +function annotate!(mod::LLVM.Module) inactive = LLVM.StringAttribute("enzyme_inactive", "") active = LLVM.StringAttribute("enzyme_active", "") no_escaping_alloc = LLVM.StringAttribute("enzyme_no_escaping_allocation") @@ -3896,7 +3896,7 @@ function enzyme_extract_world(fn::LLVM.Function)::UInt throw(AssertionError("Enzyme: could not find world in $(string(fn))")) end -function enzyme_custom_extract_mi(orig::LLVM.Instruction, error::Bool = true) +function enzyme_custom_extract_mi(orig::LLVM.CallInst, error::Bool = true) operand = LLVM.called_operand(orig) if isa(operand, LLVM.Function) return enzyme_custom_extract_mi(operand::LLVM.Function, error) @@ -3968,7 +3968,7 @@ include("rules/activityrules.jl") const DumpPreEnzyme = Ref(false) const DumpPostWrap = Ref(false) -function enzyme!( +Base.@nospecializeinfer function enzyme!( job::CompilerJob, mod::LLVM.Module, primalf::LLVM.Function, @@ -4315,7 +4315,7 @@ function set_subprogram!(f::LLVM.Function, sp) end end -function create_abi_wrapper( +Base.@nospecializeinfer function create_abi_wrapper( enzymefn::LLVM.Function, @nospecialize(TT::Type), @nospecialize(rettype::Type), @@ -4802,7 +4802,7 @@ function create_abi_wrapper( metadata(val)[LLVM.MD_dbg] = DILocation(0, 0, get_subprogram(llvm_f)) end - @inline function fixup_abi(index::Int, @nospecialize(value::LLVM.Value)) + @inline Base.@nospecializeinfer function fixup_abi(index::Int, @nospecialize(value::LLVM.Value)) valty = sret_types[index] # Union becoming part of a tuple needs to be adjusted # See https://github.com/JuliaLang/julia/blob/81afdbc36b365fcbf3ae25b7451c6cb5798c0c3d/src/cgutils.cpp#L3795C1-L3801C121 @@ -5142,7 +5142,7 @@ end struct RemovedParam end # Modified from GPUCompiler classify_arguments -function classify_arguments( +Base.@nospecializeinfer function classify_arguments( @nospecialize(source_sig::Type), codegen_ft::LLVM.FunctionType, has_sret::Bool, @@ -5245,7 +5245,7 @@ function classify_arguments( return args end -function isSpecialPtr(@nospecialize(Ty::LLVM.LLVMType)) +Base.@nospecializeinfer function isSpecialPtr(@nospecialize(Ty::LLVM.LLVMType)) if !isa(Ty, LLVM.PointerType) return false end @@ -5259,7 +5259,7 @@ mutable struct CountTrackedPointers derived::Bool end -function CountTrackedPointers(@nospecialize(T::LLVM.LLVMType)) +Base.@nospecializeinfer function CountTrackedPointers(@nospecialize(T::LLVM.LLVMType)) res = CountTrackedPointers(0, true, false) if isa(T, LLVM.PointerType) @@ -5296,7 +5296,7 @@ function CountTrackedPointers(@nospecialize(T::LLVM.LLVMType)) end # must deserve sret -function deserves_rooting(@nospecialize(T::LLVM.LLVMType)) +Base.@nospecializeinfer function deserves_rooting(@nospecialize(T::LLVM.LLVMType)) tracked = CountTrackedPointers(T) @assert !tracked.derived if tracked.count != 0 && !tracked.all @@ -5307,7 +5307,7 @@ end # https://github.com/JuliaLang/julia/blob/64378db18b512677fc6d3b012e6d1f02077af191/src/cgutils.cpp#L823 # returns if all unboxed -function for_each_uniontype_small(@nospecialize(f), @nospecialize(ty::Type), counter::Base.RefValue{Int} = Ref(0)) +Base.@nospecializeinfer function for_each_uniontype_small(@nospecialize(f), @nospecialize(ty::Type), counter::Base.RefValue{Int} = Ref(0)) if counter[] > 127 return false end @@ -5326,7 +5326,7 @@ function for_each_uniontype_small(@nospecialize(f), @nospecialize(ty::Type), cou end # From https://github.com/JuliaLang/julia/blob/038d31463f0ef744c8308bdbe87339b9c3f0b890/src/cgutils.cpp#L3108 -function union_alloca_type(@nospecialize(UT::Type)) +Base.@nospecializeinfer function union_alloca_type(@nospecialize(UT::Type)) nbytes = 0 function inner(@nospecialize(jlrettype::Type)) if !(Base.issingletontype(jlrettype) && isa(jlrettype, DataType)) @@ -5338,7 +5338,7 @@ function union_alloca_type(@nospecialize(UT::Type)) end # From https://github.com/JuliaLang/julia/blob/e6bf81f39a202eedc7bd4f310c1ab60b5b86c251/src/codegen.cpp#L6447 -function is_sret(@nospecialize(jlrettype::Type)) +Base.@nospecializeinfer function is_sret(@nospecialize(jlrettype::Type)) if jlrettype === Union{} # jlrettype == (jl_value_t*)jl_bottom_type return false @@ -5361,7 +5361,7 @@ function is_sret(@nospecialize(jlrettype::Type)) end return false end -function is_sret_union(@nospecialize(jlrettype::Type)) +Base.@nospecializeinfer function is_sret_union(@nospecialize(jlrettype::Type)) if jlrettype === Union{} # jlrettype == (jl_value_t*)jl_bottom_type return false @@ -5380,7 +5380,7 @@ function is_sret_union(@nospecialize(jlrettype::Type)) end # https://github.com/JuliaLang/julia/blob/0a696a3842750fcedca8832bc0aabe9096c7658f/src/codegen.cpp#L6812 -function get_return_info( +Base.@nospecializeinfer function get_return_info( @nospecialize(jlrettype::Type), )::Tuple{Union{Nothing,Type},Union{Nothing,Type},Union{Nothing,Type}} sret = nothing @@ -5431,7 +5431,7 @@ function get_return_info( end # Modified from GPUCompiler/src/irgen.jl:365 lower_byval -function lower_convention( +Base.@nospecializeinfer function lower_convention( @nospecialize(functy::Type), mod::LLVM.Module, entry_f::LLVM.Function, @@ -6149,7 +6149,7 @@ end using Random # returns arg, return -function no_type_setting(@nospecialize(specTypes); world = nothing) +Base.@nospecializeinfer function no_type_setting(@nospecialize(specTypes::Type{<:Tuple}); world = nothing) # Even though the julia type here is ptr{int8}, the actual data can be something else if specTypes.parameters[1] == typeof(Random.XoshiroSimd.xoshiro_bulk_simd) return (true, false) @@ -7042,7 +7042,7 @@ end end # annotate - annotate!(mod, mode) + annotate!(mod) for name in ("gpu_report_exception", "report_exception") if haskey(functions(mod), name) exc = functions(mod)[name] @@ -8484,7 +8484,7 @@ end # JIT ## -function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType)) +Base.@nospecializeinfer function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType)) if job.config.params.ABI <: InlineABI return CompileResult( Val((Symbol(mod), Symbol(adjoint_name))), @@ -8588,7 +8588,7 @@ end @inline remove_innerty(::Type{<:MixedDuplicated}) = MixedDuplicated @inline remove_innerty(::Type{<:BatchMixedDuplicated}) = MixedDuplicated -@inline function thunkbase( +Base.@nospecializeinfer @inline function thunkbase( mi::Core.MethodInstance, World::Union{UInt, Nothing}, @nospecialize(FA::Type{<:Annotation}), diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 7648236ffb..4df93528e2 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -96,34 +96,34 @@ EnzymeInterpreter( handler = nothing ) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, deferred_lower, broadcast_rewrite, handler) -Core.Compiler.InferenceParams(@nospecialize(interp::EnzymeInterpreter)) = interp.inf_params -Core.Compiler.OptimizationParams(@nospecialize(interp::EnzymeInterpreter)) = interp.opt_params -get_inference_world(@nospecialize(interp::EnzymeInterpreter)) = interp.world -Core.Compiler.get_inference_cache(@nospecialize(interp::EnzymeInterpreter)) = interp.local_cache +Base.@nospecializeinfer Core.Compiler.InferenceParams(@nospecialize(interp::EnzymeInterpreter)) = interp.inf_params +Base.@nospecializeinfer Core.Compiler.OptimizationParams(@nospecialize(interp::EnzymeInterpreter)) = interp.opt_params +Base.@nospecializeinfer get_inference_world(@nospecialize(interp::EnzymeInterpreter)) = interp.world +Base.@nospecializeinfer Core.Compiler.get_inference_cache(@nospecialize(interp::EnzymeInterpreter)) = interp.local_cache @static if HAS_INTEGRATED_CACHE - Core.Compiler.cache_owner(@nospecialize(interp::EnzymeInterpreter)) = interp.token + Base.@nospecializeinfer Core.Compiler.cache_owner(@nospecialize(interp::EnzymeInterpreter)) = interp.token else - Core.Compiler.code_cache(@nospecialize(interp::EnzymeInterpreter)) = + Base.@nospecializeinfer Core.Compiler.code_cache(@nospecialize(interp::EnzymeInterpreter)) = WorldView(interp.code_cache, interp.world) end # No need to do any locking since we're not putting our results into the runtime cache -Core.Compiler.lock_mi_inference(@nospecialize(::EnzymeInterpreter), ::MethodInstance) = nothing -Core.Compiler.unlock_mi_inference(@nospecialize(::EnzymeInterpreter), ::MethodInstance) = nothing +Base.@nospecializeinfer Core.Compiler.lock_mi_inference(@nospecialize(::EnzymeInterpreter), ::MethodInstance) = nothing +Base.@nospecializeinfer Core.Compiler.unlock_mi_inference(@nospecialize(::EnzymeInterpreter), ::MethodInstance) = nothing -Core.Compiler.may_optimize(@nospecialize(::EnzymeInterpreter)) = true -Core.Compiler.may_compress(@nospecialize(::EnzymeInterpreter)) = true +Base.@nospecializeinfer Core.Compiler.may_optimize(@nospecialize(::EnzymeInterpreter)) = true +Base.@nospecializeinfer Core.Compiler.may_compress(@nospecialize(::EnzymeInterpreter)) = true # From @aviatesk: # `may_discard_trees = true`` means a complicated (in terms of inlineability) source will be discarded, # but as far as I understand Enzyme wants "always inlining, except special cased functions", # so I guess we really don't want to discard sources? -Core.Compiler.may_discard_trees(@nospecialize(::EnzymeInterpreter)) = false -Core.Compiler.verbose_stmt_info(@nospecialize(::EnzymeInterpreter)) = false +Base.@nospecializeinfer Core.Compiler.may_discard_trees(@nospecialize(::EnzymeInterpreter)) = false +Base.@nospecializeinfer Core.Compiler.verbose_stmt_info(@nospecialize(::EnzymeInterpreter)) = false -Core.Compiler.method_table(@nospecialize(interp::EnzymeInterpreter), sv::InferenceState) = +Base.@nospecializeinfer Core.Compiler.method_table(@nospecialize(interp::EnzymeInterpreter), sv::InferenceState) = Core.Compiler.OverlayMethodTable(interp.world, interp.method_table) -function is_alwaysinline_func(@nospecialize(TT)) +Base.@nospecializeinfer function is_alwaysinline_func(@nospecialize(TT))::Bool isa(TT, DataType) || return false @static if VERSION ≥ v"1.11-" if TT.parameters[1] == typeof(Core.memoryref) @@ -133,7 +133,7 @@ function is_alwaysinline_func(@nospecialize(TT)) return false end -function is_primitive_func(@nospecialize(TT)) +Base.@nospecializeinfer function is_primitive_func(@nospecialize(TT))::Bool isa(TT, DataType) || return false ft = TT.parameters[1] if ft == typeof(Enzyme.pmap) @@ -156,11 +156,11 @@ function is_primitive_func(@nospecialize(TT)) return false end -function isKWCallSignature(@nospecialize(TT)) +Base.@nospecializeinfer function isKWCallSignature(@nospecialize(TT))::Bool return TT <: Tuple{typeof(Core.kwcall),Any,Any,Vararg} end -function simplify_kw(@nospecialize specTypes) +function simplify_kw(@nospecialize(specTypes)) if isKWCallSignature(specTypes) return Base.tuple_type_tail(Base.tuple_type_tail(specTypes)) else @@ -193,7 +193,7 @@ Core.Compiler.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) = Core.Compiler.getresult(info.info, idx) using Core.Compiler: ArgInfo, StmtInfo, AbsIntState -function Core.Compiler.abstract_call_gf_by_type( +Base.@nospecializeinfer function Core.Compiler.abstract_call_gf_by_type( @nospecialize(interp::EnzymeInterpreter), @nospecialize(f), arginfo::ArgInfo, @@ -279,7 +279,7 @@ let # overload `inlining_policy` ) end @static if isdefined(Core.Compiler, :inlining_policy) - @eval function Core.Compiler.inlining_policy($(sigs_ex.args...)) + @eval Base.@nospecializeinfer function Core.Compiler.inlining_policy($(sigs_ex.args...)) if info isa NoInlineCallInfo if info.kind === :primitive @safe_debug "Blocking inlining for primitive func" info.tt @@ -299,7 +299,7 @@ let # overload `inlining_policy` return @invoke Core.Compiler.inlining_policy($(args_ex.args...)) end else - @eval function Core.Compiler.src_inlining_policy($(sigs_ex.args...)) + @eval Base.@nospecializeinfer function Core.Compiler.src_inlining_policy($(sigs_ex.args...)) if info isa NoInlineCallInfo if info.kind === :primitive @safe_debug "Blocking inlining for primitive func" info.tt @@ -742,15 +742,15 @@ end end end -@inline function array_or_number(@nospecialize(Ty)) +@inline function array_or_number(@nospecialize(Ty))::Bool return Ty <: AbstractArray || Ty <: Number end -@inline function isa_array_or_number(@nospecialize(x)) +@inline function isa_array_or_number(@nospecialize(x))::Bool return x isa AbstractArray || x isa Number end -@inline function num_or_eltype(@nospecialize(Ty)) +@inline function num_or_eltype(@nospecialize(Ty))::Type if Ty <: AbstractArray eltype(Ty) else @@ -758,7 +758,7 @@ end end end -function abstract_call_known( +Base.@nospecializeinfer function abstract_call_known( @nospecialize(interp::EnzymeInterpreter), @nospecialize(f), arginfo::ArgInfo, diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index e58e574dd2..73c68ea172 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -361,7 +361,7 @@ function reinsert_gcmarker!(func::LLVM.Function, @nospecialize(PB::Union{Nothing end end -function eraseInst(bb::LLVM.BasicBlock, @nospecialize(inst::LLVM.Instruction)) +Base.@nospecializeinfer function eraseInst(bb::LLVM.BasicBlock, @nospecialize(inst::LLVM.Instruction)) @static if isdefined(LLVM, Symbol("erase!")) LLVM.erase!(inst) else @@ -407,7 +407,7 @@ end NamedTuple{ntuple(i -> Symbol(i), Val(length(U.parameters))),U} # recursively compute the eltype type indexed by idx[0], idx[1], ... -function recursive_eltype(@nospecialize(val::LLVM.Value), idxs::Vector{Cuint}) +Base.@nospecializeinfer function recursive_eltype(@nospecialize(val::LLVM.Value), idxs::Vector{Cuint}) ty = LLVM.value_type(val) for i in idxs if isa(ty, LLVM.ArrayType) @@ -422,7 +422,7 @@ end # Fix calling convention within julia that Tuple{Float,Float} ->[2 x float] rather than {float, float} # and that Bool -> i8, not i1 -function calling_conv_fixup( +Base.@nospecializeinfer function calling_conv_fixup( builder::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(tape::LLVM.LLVMType), diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index d5bad35607..3e833324b6 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -20,7 +20,7 @@ function get_blas_symbols() return symbols end -function lookup_blas_symbol(name) +function lookup_blas_symbol(name::String) Libdl.dlsym(blas_handle::Ptr{Cvoid}, name; throw_error = false) end end @@ -127,7 +127,7 @@ function __init__() end end -function memoize!(ptr, fn) +function memoize!(ptr::Ptr{Cvoid}, fn::String)::String fn = get(ptr_map, ptr, fn) if !haskey(ptr_map, ptr) ptr_map[ptr] = fn @@ -140,7 +140,7 @@ end import GPUCompiler: IRError, InvalidIRError -function restore_lookups(mod::LLVM.Module) +function restore_lookups(mod::LLVM.Module)::Nothing T_size_t = convert(LLVM.LLVMType, Int) for (v, k) in FFI.ptr_map if haskey(functions(mod), k) diff --git a/src/jlrt.jl b/src/jlrt.jl index 300fdc5515..acda0bbd90 100644 --- a/src/jlrt.jl +++ b/src/jlrt.jl @@ -1,6 +1,6 @@ # For julia runtime function emission -function emit_allocobj!( +Base.@nospecializeinfer function emit_allocobj!( B::LLVM.IRBuilder, @nospecialize(tag::LLVM.Value), @nospecialize(Size::LLVM.Value), @@ -58,7 +58,7 @@ function emit_allocobj!( return call!(B, alty, alloc_obj, LLVM.Value[ct, Size, tag], name) end -function emit_allocobj!(B::LLVM.IRBuilder, @nospecialize(T::DataType), name::String = "") +Base.@nospecializeinfer function emit_allocobj!(B::LLVM.IRBuilder, @nospecialize(T::DataType), name::String = "") curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -83,7 +83,7 @@ declare_pointerfromobjref!(mod::LLVM.Module) = LLVM.FunctionType(T_pjlvalue, [T_prjlvalue]) end -function emit_pointerfromobjref!(B::LLVM.IRBuilder, @nospecialize(T::LLVM.Value)) +Base.@nospecializeinfer function emit_pointerfromobjref!(B::LLVM.IRBuilder, @nospecialize(T::LLVM.Value)) curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -113,7 +113,7 @@ declare_juliacall!(mod::LLVM.Module) = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue]; vararg = true) end -function emit_jl!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value +Base.@nospecializeinfer function emit_jl!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -124,7 +124,7 @@ function emit_jl!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value call!(B, FT, fn, LLVM.Value[val]) end -function emit_jl_isa!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(ty::LLVM.Value))::LLVM.Value +Base.@nospecializeinfer function emit_jl_isa!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(ty::LLVM.Value))::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -136,11 +136,11 @@ function emit_jl_isa!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospec call!(B, FT, fn, LLVM.Value[val, ty]) end -function emit_jl_isa!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(ty::Type))::LLVM.Value +Base.@nospecializeinfer function emit_jl_isa!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(ty::Type))::LLVM.Value emit_jl_isa!(B, val, unsafe_to_llvm(B, ty)) end -function emit_getfield!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(fld::LLVM.Value))::LLVM.Value +Base.@nospecializeinfer function emit_getfield!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(fld::LLVM.Value))::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -170,7 +170,7 @@ function emit_getfield!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nosp end -function emit_nthfield!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(fld::LLVM.Value))::LLVM.Value +Base.@nospecializeinfer function emit_nthfield!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(fld::LLVM.Value))::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -186,11 +186,11 @@ function emit_nthfield!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nosp call!(B, gen_FT, inv, args) end -function emit_nthfield!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), fld::Int)::LLVM.Value +Base.@nospecializeinfer function emit_nthfield!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), fld::Int)::LLVM.Value emit_nthfield!(B, val, LLVM.ConstantInt(fld)) end -function emit_jl_throw!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value +Base.@nospecializeinfer function emit_jl_throw!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -202,7 +202,7 @@ function emit_jl_throw!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM call!(B, FT, fn, LLVM.Value[val]) end -function emit_box_int32!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value +Base.@nospecializeinfer function emit_box_int32!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -216,7 +216,7 @@ function emit_box_int32!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLV call!(B, FT, box_int32, LLVM.Value[val]) end -function emit_box_int64!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value +Base.@nospecializeinfer function emit_box_int64!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -326,7 +326,7 @@ function load_if_mixed(oval::OT, val::VT) where {OT, VT} end end -function val_from_byref_if_mixed(B::LLVM.IRBuilder, gutils::GradientUtils, @nospecialize(oval::LLVM.Value), @nospecialize(val::LLVM.Value)) +Base.@nospecializeinfer function val_from_byref_if_mixed(B::LLVM.IRBuilder, gutils::GradientUtils, @nospecialize(oval::LLVM.Value), @nospecialize(val::LLVM.Value))::LLVM.Value world = enzyme_extract_world(LLVM.parent(position(B))) legal, TT, _ = abs_typeof(oval) if !legal @@ -374,7 +374,7 @@ function ref_if_mixed(val::VT) where {VT} end end -function byref_from_val_if_mixed(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value)) +Base.@nospecializeinfer function byref_from_val_if_mixed(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value world = enzyme_extract_world(LLVM.parent(position(B))) legal, TT, _ = abs_typeof(val) if !legal @@ -401,7 +401,7 @@ function byref_from_val_if_mixed(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Valu end end -function emit_apply_type!(B::LLVM.IRBuilder, @nospecialize(Ty::Type), args::Vector{LLVM.Value})::LLVM.Value +Base.@nospecializeinfer function emit_apply_type!(B::LLVM.IRBuilder, @nospecialize(Ty::Type), args::Vector{LLVM.Value})::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -510,7 +510,7 @@ function emit_tuple!(B::LLVM.IRBuilder, args::Vector{LLVM.Value})::LLVM.Value return tag end -function emit_jltypeof!(B::LLVM.IRBuilder, @nospecialize(arg::LLVM.Value))::LLVM.Value +Base.@nospecializeinfer function emit_jltypeof!(B::LLVM.IRBuilder, @nospecialize(arg::LLVM.Value))::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -527,7 +527,7 @@ function emit_jltypeof!(B::LLVM.IRBuilder, @nospecialize(arg::LLVM.Value))::LLVM call!(B, FT, fn, [arg]) end -function emit_methodinstance!(B::LLVM.IRBuilder, @nospecialize(func), args::Vector{LLVM.Value})::LLVM.Value +Base.@nospecializeinfer function emit_methodinstance!(B::LLVM.IRBuilder, @nospecialize(func), args::Vector{LLVM.Value})::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -714,7 +714,7 @@ function get_memory_struct() return LLVM.StructType([sizeT, ptrty]; packed = true) end -function get_memory_data(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) +Base.@nospecializeinfer function get_memory_data(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) mty = get_memory_struct() array = LLVM.pointercast!( B, @@ -795,7 +795,7 @@ function get_datatype_struct() return LLVM.StructType([jlvaluet, jlvaluet, jlvaluet, jlvaluet, jlvaluet, jlvaluet, i32, i16]; packed = true) end -function get_array_data(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) +Base.@nospecializeinfer function get_array_data(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) i8 = LLVM.IntType(8) ptrty = LLVM.PointerType(i8, 13) array = LLVM.pointercast!( @@ -806,7 +806,7 @@ function get_array_data(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) return LLVM.load!(B, ptrty, array) end -function get_array_elsz(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) +Base.@nospecializeinfer function get_array_elsz(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) ST = get_array_struct() elsz = LLVM.IntType(16) array = LLVM.pointercast!( @@ -823,7 +823,7 @@ function get_array_elsz(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) return LLVM.load!(B, elsz, v) end -function emit_layout_of_type!(B::LLVM.IRBuilder, @nospecialize(ty::LLVM.Value)) +Base.@nospecializeinfer function emit_layout_of_type!(B::LLVM.IRBuilder, @nospecialize(ty::LLVM.Value)) legal, JTy = absint(ty) ls = get_layout_struct() lptr = LLVM.PointerType(ls, 10) @@ -843,7 +843,7 @@ function emit_layout_of_type!(B::LLVM.IRBuilder, @nospecialize(ty::LLVM.Value)) return layout end -function emit_memorytype_elsz!(B::LLVM.IRBuilder, @nospecialize(ty::LLVM.Value)) +Base.@nospecializeinfer function emit_memorytype_elsz!(B::LLVM.IRBuilder, @nospecialize(ty::LLVM.Value)) legal, JTy = absint(ty) if legal res = unsafe_load(reinterpret(Ptr{UInt32}, JTy.layout)) @@ -857,12 +857,12 @@ function emit_memorytype_elsz!(B::LLVM.IRBuilder, @nospecialize(ty::LLVM.Value)) return load!(B, i32, lty) end -function get_memory_elsz(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) +Base.@nospecializeinfer function get_memory_elsz(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) ty = emit_jltypeof!(B, array) return emit_memorytype_elsz!(B, ty) end -function get_array_len(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) +Base.@nospecializeinfer function get_array_len(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) if isa(array, LLVM.CallInst) fn = LLVM.called_operand(array) nm = "" @@ -903,7 +903,7 @@ function get_array_len(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) return LLVM.load!(B, sizeT, v) end -function get_memory_len(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) +Base.@nospecializeinfer function get_memory_len(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) if isa(array, LLVM.CallInst) fn = LLVM.called_operand(array) nm = "" @@ -940,7 +940,7 @@ function get_memory_len(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) return LLVM.load!(B, sizeT, v) end -function get_array_nrows(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) +Base.@nospecializeinfer function get_array_nrows(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) ST = get_array_struct() array = LLVM.pointercast!( B, @@ -971,7 +971,7 @@ function emit_gc_preserve_begin(B::LLVM.IRBuilder, args::Vector{LLVM.Value} = LL return token end -function emit_gc_preserve_end(B::LLVM.IRBuilder, @nospecialize(token::LLVM.Value)) +Base.@nospecializeinfer function emit_gc_preserve_end(B::LLVM.IRBuilder, @nospecialize(token::LLVM.Value)) curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -986,20 +986,20 @@ function emit_gc_preserve_end(B::LLVM.IRBuilder, @nospecialize(token::LLVM.Value return end -function allocate_sret!(B::LLVM.IRBuilder, @nospecialize(N::LLVM.LLVMType)) +Base.@nospecializeinfer function allocate_sret!(B::LLVM.IRBuilder, @nospecialize(N::LLVM.LLVMType)) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) al = LLVM.alloca!(B, LLVM.ArrayType(T_prjlvalue, N)) return al end -function allocate_sret!(gutils::API.EnzymeGradientUtilsRef, @nospecialize(N::LLVM.LLVMType)) +Base.@nospecializeinfer function allocate_sret!(gutils::API.EnzymeGradientUtilsRef, @nospecialize(N::LLVM.LLVMType)) B = LLVM.IRBuilder() position!(B, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) allocate_sret!(B, N) end -function emit_error(B::LLVM.IRBuilder, @nospecialize(orig::Union{Nothing, LLVM.Instruction}), string::String, @nospecialize(errty::Type) = EnzymeRuntimeException) +Base.@nospecializeinfer function emit_error(B::LLVM.IRBuilder, @nospecialize(orig::Union{Nothing, LLVM.Instruction}), string::String, @nospecialize(errty::Type) = EnzymeRuntimeException) curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) diff --git a/src/utils.jl b/src/utils.jl index d1f1fabb72..678945d5b7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -5,7 +5,7 @@ Assumes that `val` is globally rooted and pointer to it can be leaked. Prefer `pointer_from_objref`. Only use inside Enzyme.jl should be for Types. """ -@inline unsafe_to_pointer(val::Type{T}) where {T} = @static if sizeof(Int) == sizeof(Int64) +@inline Base.@nospecializeinfer unsafe_to_pointer(@nospecialize(val::Type)) = @static if sizeof(Int) == sizeof(Int64) Base.llvmcall(( """ declare nonnull {}* @julia.pointer_from_objref({} addrspace(11)*) @@ -65,7 +65,7 @@ function unsafe_nothing_to_llvm(mod::LLVM.Module) return gv end -function unsafe_to_ptr(@nospecialize(val)) +Base.@nospecializeinfer function unsafe_to_ptr(@nospecialize(val)) if !Base.ismutable(val) val = Core.Box(val) # FIXME many objects could be leaked here @assert Base.ismutable(val) @@ -81,7 +81,7 @@ end export unsafe_to_ptr # This mimicks literal_pointer_val / literal_pointer_val_slot -function unsafe_to_llvm(B::LLVM.IRBuilder, @nospecialize(val)) +Base.@nospecializeinfer function unsafe_to_llvm(B::LLVM.IRBuilder, @nospecialize(val))::LLVM.Value T_jlvalue = LLVM.StructType(LLVM.LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) T_prjlvalue_UT = LLVM.PointerType(T_jlvalue) @@ -141,7 +141,7 @@ function unsafe_to_llvm(B::LLVM.IRBuilder, @nospecialize(val)) end export unsafe_to_llvm, unsafe_nothing_to_llvm -function makeInstanceOf(B::LLVM.IRBuilder, @nospecialize(T)) +Base.@nospecializeinfer function makeInstanceOf(B::LLVM.IRBuilder, @nospecialize(T::Type)) if !Core.Compiler.isconstType(T) throw(AssertionError("Tried to make instance of non constant type $T")) end @@ -151,7 +151,7 @@ end export makeInstanceOf -function hasfieldcount(@nospecialize(dt)) +Base.@nospecializeinfer function hasfieldcount(@nospecialize(dt))::Bool try fieldcount(dt) catch @@ -268,7 +268,7 @@ export my_methodinstance # # // followed by alignment padding and inline data, or owner pointer # } jl_array_t; -@inline function typed_fieldtype(@nospecialize(T::Type), i::Int) +@inline function typed_fieldtype(@nospecialize(T::Type), i::Int)::Type if T <: Array eT = eltype(T) PT = Ptr{eT} @@ -278,7 +278,7 @@ export my_methodinstance end end -@inline function typed_fieldcount(@nospecialize(T::Type)) +@inline function typed_fieldcount(@nospecialize(T::Type))::Int if T <: Array return 7 else @@ -286,7 +286,7 @@ end end end -@inline function typed_fieldoffset(@nospecialize(T::Type), i::Int) +@inline function typed_fieldoffset(@nospecialize(T::Type), i::Int)::Int if T <: Array tys = (Ptr, Csize_t, UInt16, UInt16, UInt32, Csize_t, Csize_t) sum = 0 @@ -303,7 +303,7 @@ end else -@inline function typed_fieldtype(@nospecialize(T::Type), i::Int) +@inline function typed_fieldtype(@nospecialize(T::Type), i::Int)::Type if T <: GenericMemoryRef && i == 1 || T <: GenericMemory && i == 2 eT = eltype(T) Ptr{eT} @@ -312,11 +312,11 @@ else end end -@inline function typed_fieldcount(@nospecialize(T::Type)) +@inline function typed_fieldcount(@nospecialize(T::Type))::Int fieldcount(T) end -@inline function typed_fieldoffset(@nospecialize(T::Type), i::Int) +@inline function typed_fieldoffset(@nospecialize(T::Type), i::Int)::Int fieldoffset(T, i) end @@ -327,7 +327,7 @@ export typed_fieldcount export typed_fieldoffset # returns the inner type of an sret/enzyme_sret/enzyme_sret_v -function sret_ty(fn::LLVM.Function, idx::Int) +function sret_ty(fn::LLVM.Function, idx::Int)::LLVM.LLVMType return eltype(LLVM.value_type(LLVM.parameters(fn)[idx])) end From d32a06f1ed818d9192b6ea88ff67a9e518423045 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 7 Dec 2024 11:31:33 -0600 Subject: [PATCH 6/6] Update utils.jl --- src/compiler/utils.jl | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index 5fdc077b7b..9bac2cfaaf 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -360,6 +360,46 @@ function reinsert_gcmarker!(func::LLVM.Function, @nospecialize(PB::Union{Nothing end end +@inline enum_attr_kind(kind::String) = LLVM.API.LLVMGetEnumAttributeKindForName(kind, Csize_t(length(kind))) + +const swiftself_kind = enum_attr_kind("swiftself") + +Base.@assume_effects :removable :foldable :nothrow function has_swiftself(fn::LLVM.Function)::Bool + for i in 1:length(LLVM.parameters(fn)) + for attr in collect(LLVM.parameter_attributes(fn, i)) + if attr isa LLVM.EnumAttribute + if kind(attr) == swiftself_kind + return true + end + end + end + end + return false +end +Base.@assume_effects :removable :foldable :nothrow function has_fn_attr(fn::LLVM.Function, attr::LLVM.EnumAttribute)::Bool + ekind = LLVM.kind(attr) + for attr in collect(function_attributes(fn)) + if attr isa LLVM.EnumAttribute + if kind(attr) == ekind + return true + end + end + end + return false +end + +Base.@assume_effects :removable :foldable :nothrow function has_fn_attr(fn::LLVM.Function, attr::LLVM.StringAttribute)::Bool + ekind = LLVM.kind(attr) + for attr in collect(function_attributes(fn)) + if attr isa LLVM.StringAttribute + if kind(attr) == ekind + return true + end + end + end + return false +end + Base.@nospecializeinfer function eraseInst(bb::LLVM.BasicBlock, @nospecialize(inst::LLVM.Instruction)) @static if isdefined(LLVM, Symbol("erase!")) LLVM.erase!(inst)