diff --git a/src/absint.jl b/src/absint.jl index 50282e745c..b8ccf86050 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -1,7 +1,7 @@ # Abstractly interpret julia from LLVM # Return (bool if could interpret, julia object interpreted to) -function absint(@nospecialize(arg::LLVM.Value), partial::Bool = false)::Tuple{Bool,Any} +Base.@nospecializeinfer function absint(@nospecialize(arg::LLVM.Value), partial::Bool = false)::Tuple{Bool,Any} if isa(arg, LLVM.BitCastInst) || isa(arg, LLVM.AddrSpaceCastInst) return absint(operands(arg)[1], partial) end @@ -165,7 +165,7 @@ function absint(@nospecialize(arg::LLVM.Value), partial::Bool = false)::Tuple{Bo return (false, nothing) end -function actual_size(@nospecialize(typ2))::Int +Base.@nospecializeinfer function actual_size(@nospecialize(typ2))::Int @static if VERSION < v"1.11-" if typ2 <: Array return sizeof(Ptr{Cvoid}) + 2 + 2 + 4 + 2 * sizeof(Csize_t) + sizeof(Csize_t) @@ -184,7 +184,7 @@ function actual_size(@nospecialize(typ2))::Int end end -@inline function first_non_ghost(@nospecialize(typ2))::Tuple{Int, Int} +Base.@nospecializeinfer @inline function first_non_ghost(@nospecialize(typ2))::Tuple{Int, Int} @static if VERSION < v"1.11-" if typ2 <: Array return (1, 0) @@ -204,7 +204,7 @@ end return (-1, 0) end -function should_recurse(@nospecialize(typ2), @nospecialize(arg_t::LLVM.LLVMType), byref::GPUCompiler.ArgumentCC, dl::LLVM.DataLayout)::Bool +Base.@nospecializeinfer function should_recurse(@nospecialize(typ2), @nospecialize(arg_t::LLVM.LLVMType), byref::GPUCompiler.ArgumentCC, dl::LLVM.DataLayout)::Bool sz = if arg_t == LLVM.IntType(1) 1 else @@ -232,7 +232,7 @@ function should_recurse(@nospecialize(typ2), @nospecialize(arg_t::LLVM.LLVMType) end end -function get_base_and_offset(@nospecialize(larg::LLVM.Value); offsetAllowed::Bool=true, inttoptr::Bool=false)::Tuple{LLVM.Value, Int} +Base.@nospecializeinfer function get_base_and_offset(@nospecialize(larg::LLVM.Value); offsetAllowed::Bool=true, inttoptr::Bool=false)::Tuple{LLVM.Value, Int} offset = 0 while true if isa(larg, LLVM.ConstantExpr) @@ -280,7 +280,7 @@ function get_base_and_offset(@nospecialize(larg::LLVM.Value); offsetAllowed::Boo return larg, offset end -function abs_typeof( +Base.@nospecializeinfer function abs_typeof( @nospecialize(arg::LLVM.Value), partial::Bool = false, seenphis=Set{LLVM.PHIInst}() )::Union{Tuple{Bool,Type,GPUCompiler.ArgumentCC},Tuple{Bool,Nothing,Nothing}} @@ -758,7 +758,7 @@ end return false end -function abs_cstring(@nospecialize(arg::LLVM.Value))::Tuple{Bool,String} +Base.@nospecializeinfer function abs_cstring(@nospecialize(arg::LLVM.Value))::Tuple{Bool,String} if isa(arg, ConstantExpr) ce = arg while isa(ce, ConstantExpr) diff --git a/src/compiler.jl b/src/compiler.jl index ddccde1a24..ee9c0230ca 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -106,7 +106,7 @@ const known_ops = Dict{DataType,Tuple{Symbol,Int,Union{Nothing,Tuple{Symbol,Data typeof(Base.FastMath.tanh_fast) => (:tanh, 1, nothing), typeof(Base.fma_emulated) => (:fma, 3, nothing), ) -@inline function find_math_method(@nospecialize(func::Type), sparam_vals::Core.SimpleVector) +@inline Base.@nospecializeinfer function find_math_method(@nospecialize(func::Type), sparam_vals::Core.SimpleVector) if func ∈ keys(known_ops) name, arity, toinject = known_ops[func] Tys = (Float32, Float64) @@ -317,7 +317,8 @@ include("llvm/transforms.jl") include("llvm/passes.jl") include("typeutils/make_zero.jl") -function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type), world::UInt) + +Base.@nospecializeinfer function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type), world::UInt) funcspec = my_methodinstance(typeof(f), tt, world) nested_codegen!(mode, mod, funcspec, world) end @@ -1345,7 +1346,7 @@ include("rules/activityrules.jl") const DumpPreEnzyme = Ref(false) const DumpPostWrap = Ref(false) -function enzyme!( +Base.@nospecializeinfer function enzyme!( job::CompilerJob, mod::LLVM.Module, primalf::LLVM.Function, @@ -1685,7 +1686,7 @@ function set_subprogram!(f::LLVM.Function, sp) end end -function create_abi_wrapper( +Base.@nospecializeinfer function create_abi_wrapper( enzymefn::LLVM.Function, @nospecialize(TT::Type), @nospecialize(rettype::Type), @@ -2167,7 +2168,7 @@ function create_abi_wrapper( metadata(val)[LLVM.MD_dbg] = DILocation(0, 0, get_subprogram(llvm_f)) end - @inline function fixup_abi(index::Int, @nospecialize(value::LLVM.Value)) + @inline Base.@nospecializeinfer function fixup_abi(index::Int, @nospecialize(value::LLVM.Value)) valty = sret_types[index] # Union becoming part of a tuple needs to be adjusted # See https://github.com/JuliaLang/julia/blob/81afdbc36b365fcbf3ae25b7451c6cb5798c0c3d/src/cgutils.cpp#L3795C1-L3801C121 @@ -2505,7 +2506,7 @@ function fixup_metadata!(f::LLVM.Function) end # Modified from GPUCompiler/src/irgen.jl:365 lower_byval -function lower_convention( +Base.@nospecializeinfer function lower_convention( @nospecialize(functy::Type), mod::LLVM.Module, entry_f::LLVM.Function, @@ -3206,7 +3207,7 @@ end using Random # returns arg, return -function no_type_setting(@nospecialize(specTypes::Type{<:Tuple}); world = nothing) +Base.@nospecializeinfer function no_type_setting(@nospecialize(specTypes::Type{<:Tuple}); world = nothing) # Even though the julia type here is ptr{int8}, the actual data can be something else if specTypes.parameters[1] == typeof(Random.XoshiroSimd.xoshiro_bulk_simd) return (true, false) @@ -5226,7 +5227,7 @@ end # JIT ## -function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType), prepost::String) +Base.@nospecializeinfer function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType), prepost::String) if job.config.params.ABI <: InlineABI return CompileResult( Val((Symbol(mod), Symbol(adjoint_name))), @@ -5337,7 +5338,7 @@ const cache_lock = ReentrantLock() end end -@inline function thunkbase( +Base.@nospecializeinfer @inline function thunkbase( mi::Core.MethodInstance, World::Union{UInt, Nothing}, @nospecialize(FA::Type{<:Annotation}), diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index bd57ec92dd..ff18f54aec 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -228,35 +228,36 @@ EnzymeInterpreter( handler = nothing ) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, broadcast_rewrite, handler) -Core.Compiler.InferenceParams(@nospecialize(interp::EnzymeInterpreter)) = interp.inf_params -Core.Compiler.OptimizationParams(@nospecialize(interp::EnzymeInterpreter)) = interp.opt_params -get_inference_world(@nospecialize(interp::EnzymeInterpreter)) = interp.world -Core.Compiler.get_inference_cache(@nospecialize(interp::EnzymeInterpreter)) = interp.local_cache + +Base.@nospecializeinfer Core.Compiler.InferenceParams(@nospecialize(interp::EnzymeInterpreter)) = interp.inf_params +Base.@nospecializeinfer Core.Compiler.OptimizationParams(@nospecialize(interp::EnzymeInterpreter)) = interp.opt_params +Base.@nospecializeinfer get_inference_world(@nospecialize(interp::EnzymeInterpreter)) = interp.world +Base.@nospecializeinfer Core.Compiler.get_inference_cache(@nospecialize(interp::EnzymeInterpreter)) = interp.local_cache @static if HAS_INTEGRATED_CACHE - Core.Compiler.cache_owner(@nospecialize(interp::EnzymeInterpreter)) = interp.token + Base.@nospecializeinfer Core.Compiler.cache_owner(@nospecialize(interp::EnzymeInterpreter)) = interp.token else - Core.Compiler.code_cache(@nospecialize(interp::EnzymeInterpreter)) = + Base.@nospecializeinfer Core.Compiler.code_cache(@nospecialize(interp::EnzymeInterpreter)) = WorldView(interp.code_cache, interp.world) end # No need to do any locking since we're not putting our results into the runtime cache -Core.Compiler.lock_mi_inference(@nospecialize(::EnzymeInterpreter), ::MethodInstance) = nothing -Core.Compiler.unlock_mi_inference(@nospecialize(::EnzymeInterpreter), ::MethodInstance) = nothing +Base.@nospecializeinfer Core.Compiler.lock_mi_inference(@nospecialize(::EnzymeInterpreter), ::MethodInstance) = nothing +Base.@nospecializeinfer Core.Compiler.unlock_mi_inference(@nospecialize(::EnzymeInterpreter), ::MethodInstance) = nothing -Core.Compiler.may_optimize(@nospecialize(::EnzymeInterpreter)) = true -Core.Compiler.may_compress(@nospecialize(::EnzymeInterpreter)) = true +Base.@nospecializeinfer Core.Compiler.may_optimize(@nospecialize(::EnzymeInterpreter)) = true +Base.@nospecializeinfer Core.Compiler.may_compress(@nospecialize(::EnzymeInterpreter)) = true # From @aviatesk: # `may_discard_trees = true`` means a complicated (in terms of inlineability) source will be discarded, # but as far as I understand Enzyme wants "always inlining, except special cased functions", # so I guess we really don't want to discard sources? -Core.Compiler.may_discard_trees(@nospecialize(::EnzymeInterpreter)) = false -Core.Compiler.verbose_stmt_info(@nospecialize(::EnzymeInterpreter)) = false +Base.@nospecializeinfer Core.Compiler.may_discard_trees(@nospecialize(::EnzymeInterpreter)) = false +Base.@nospecializeinfer Core.Compiler.verbose_stmt_info(@nospecialize(::EnzymeInterpreter)) = false -Core.Compiler.method_table(@nospecialize(interp::EnzymeInterpreter), sv::InferenceState) = +Base.@nospecializeinfer Core.Compiler.method_table(@nospecialize(interp::EnzymeInterpreter), sv::InferenceState) = Core.Compiler.OverlayMethodTable(interp.world, interp.method_table) -function is_alwaysinline_func(@nospecialize(TT))::Bool +Base.@nospecializeinfer function is_alwaysinline_func(@nospecialize(TT))::Bool isa(TT, DataType) || return false @static if VERSION ≥ v"1.11-" if TT.parameters[1] == typeof(Core.memoryref) @@ -266,7 +267,7 @@ function is_alwaysinline_func(@nospecialize(TT))::Bool return false end -function is_primitive_func(@nospecialize(TT))::Bool +Base.@nospecializeinfer function is_primitive_func(@nospecialize(TT))::Bool isa(TT, DataType) || return false ft = TT.parameters[1] if ft == typeof(Enzyme.pmap) @@ -289,7 +290,7 @@ function is_primitive_func(@nospecialize(TT))::Bool return false end -function isKWCallSignature(@nospecialize(TT))::Bool +Base.@nospecializeinfer function isKWCallSignature(@nospecialize(TT))::Bool return TT <: Tuple{typeof(Core.kwcall),Any,Any,Vararg} end @@ -329,7 +330,7 @@ Core.Compiler.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) = import .EnzymeRules: FwdConfig, RevConfig, Annotation using Core.Compiler: ArgInfo, StmtInfo, AbsIntState -function Core.Compiler.abstract_call_gf_by_type( +Base.@nospecializeinfer function Core.Compiler.abstract_call_gf_by_type( @nospecialize(interp::EnzymeInterpreter), @nospecialize(f), arginfo::ArgInfo, @@ -424,7 +425,7 @@ let # overload `inlining_policy` ) end @static if isdefined(Core.Compiler, :inlining_policy) - @eval function Core.Compiler.inlining_policy($(sigs_ex.args...)) + @eval Base.@nospecializeinfer function Core.Compiler.inlining_policy($(sigs_ex.args...)) if info isa NoInlineCallInfo if info.kind === :primitive @safe_debug "Blocking inlining for primitive func" info.tt @@ -444,7 +445,7 @@ let # overload `inlining_policy` return @invoke Core.Compiler.inlining_policy($(args_ex.args...)) end else - @eval function Core.Compiler.src_inlining_policy($(sigs_ex.args...)) + @eval Base.@nospecializeinfer function Core.Compiler.src_inlining_policy($(sigs_ex.args...)) if info isa NoInlineCallInfo if info.kind === :primitive @safe_debug "Blocking inlining for primitive func" info.tt @@ -903,7 +904,7 @@ end end end -function abstract_call_known( +Base.@nospecializeinfer function abstract_call_known( interp::EnzymeInterpreter{Handler}, @nospecialize(f), arginfo::ArgInfo, diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index 7f004ea379..9bac2cfaaf 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -400,7 +400,7 @@ Base.@assume_effects :removable :foldable :nothrow function has_fn_attr(fn::LLVM return false end -function eraseInst(bb::LLVM.BasicBlock, @nospecialize(inst::LLVM.Instruction)) +Base.@nospecializeinfer function eraseInst(bb::LLVM.BasicBlock, @nospecialize(inst::LLVM.Instruction)) @static if isdefined(LLVM, Symbol("erase!")) LLVM.erase!(inst) else @@ -446,7 +446,7 @@ end NamedTuple{ntuple(Symbol, Val(length(U.parameters))),U} # recursively compute the eltype type indexed by idx[0], idx[1], ... -Base.@assume_effects :removable :foldable :nothrow function recursive_eltype(@nospecialize(val::LLVM.Value), idxs::Vector{Cuint})::LLVM.LLVMType +Base.@nospecializeinfer Base.@assume_effects :removable :foldable :nothrow function recursive_eltype(@nospecialize(val::LLVM.Value), idxs::Vector{Cuint})::LLVM.LLVMType ty = LLVM.value_type(val)::LLVM.LLVMType for i in idxs if isa(ty, LLVM.ArrayType) @@ -461,7 +461,7 @@ end # Fix calling convention within julia that Tuple{Float,Float} ->[2 x float] rather than {float, float} # and that Bool -> i8, not i1 -function calling_conv_fixup( +Base.@nospecializeinfer function calling_conv_fixup( builder::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(tape::LLVM.LLVMType), diff --git a/src/jlrt.jl b/src/jlrt.jl index 59acd1d231..acda0bbd90 100644 --- a/src/jlrt.jl +++ b/src/jlrt.jl @@ -1,6 +1,6 @@ # For julia runtime function emission -function emit_allocobj!( +Base.@nospecializeinfer function emit_allocobj!( B::LLVM.IRBuilder, @nospecialize(tag::LLVM.Value), @nospecialize(Size::LLVM.Value), @@ -58,7 +58,7 @@ function emit_allocobj!( return call!(B, alty, alloc_obj, LLVM.Value[ct, Size, tag], name) end -function emit_allocobj!(B::LLVM.IRBuilder, @nospecialize(T::DataType), name::String = "") +Base.@nospecializeinfer function emit_allocobj!(B::LLVM.IRBuilder, @nospecialize(T::DataType), name::String = "") curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -83,7 +83,7 @@ declare_pointerfromobjref!(mod::LLVM.Module) = LLVM.FunctionType(T_pjlvalue, [T_prjlvalue]) end -function emit_pointerfromobjref!(B::LLVM.IRBuilder, @nospecialize(T::LLVM.Value)) +Base.@nospecializeinfer function emit_pointerfromobjref!(B::LLVM.IRBuilder, @nospecialize(T::LLVM.Value)) curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -113,7 +113,7 @@ declare_juliacall!(mod::LLVM.Module) = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue]; vararg = true) end -function emit_jl!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value +Base.@nospecializeinfer function emit_jl!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -124,7 +124,7 @@ function emit_jl!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value call!(B, FT, fn, LLVM.Value[val]) end -function emit_jl_isa!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(ty::LLVM.Value))::LLVM.Value +Base.@nospecializeinfer function emit_jl_isa!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(ty::LLVM.Value))::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -136,11 +136,11 @@ function emit_jl_isa!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospec call!(B, FT, fn, LLVM.Value[val, ty]) end -function emit_jl_isa!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(ty::Type))::LLVM.Value +Base.@nospecializeinfer function emit_jl_isa!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(ty::Type))::LLVM.Value emit_jl_isa!(B, val, unsafe_to_llvm(B, ty)) end -function emit_getfield!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(fld::LLVM.Value))::LLVM.Value +Base.@nospecializeinfer function emit_getfield!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(fld::LLVM.Value))::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -170,7 +170,7 @@ function emit_getfield!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nosp end -function emit_nthfield!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(fld::LLVM.Value))::LLVM.Value +Base.@nospecializeinfer function emit_nthfield!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(fld::LLVM.Value))::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -186,11 +186,11 @@ function emit_nthfield!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nosp call!(B, gen_FT, inv, args) end -function emit_nthfield!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), fld::Int)::LLVM.Value +Base.@nospecializeinfer function emit_nthfield!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), fld::Int)::LLVM.Value emit_nthfield!(B, val, LLVM.ConstantInt(fld)) end -function emit_jl_throw!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value +Base.@nospecializeinfer function emit_jl_throw!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -202,7 +202,7 @@ function emit_jl_throw!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM call!(B, FT, fn, LLVM.Value[val]) end -function emit_box_int32!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value +Base.@nospecializeinfer function emit_box_int32!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -216,7 +216,7 @@ function emit_box_int32!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLV call!(B, FT, box_int32, LLVM.Value[val]) end -function emit_box_int64!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value +Base.@nospecializeinfer function emit_box_int64!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -326,7 +326,7 @@ function load_if_mixed(oval::OT, val::VT) where {OT, VT} end end -function val_from_byref_if_mixed(B::LLVM.IRBuilder, gutils::GradientUtils, @nospecialize(oval::LLVM.Value), @nospecialize(val::LLVM.Value))::LLVM.Value +Base.@nospecializeinfer function val_from_byref_if_mixed(B::LLVM.IRBuilder, gutils::GradientUtils, @nospecialize(oval::LLVM.Value), @nospecialize(val::LLVM.Value))::LLVM.Value world = enzyme_extract_world(LLVM.parent(position(B))) legal, TT, _ = abs_typeof(oval) if !legal @@ -374,7 +374,7 @@ function ref_if_mixed(val::VT) where {VT} end end -function byref_from_val_if_mixed(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value +Base.@nospecializeinfer function byref_from_val_if_mixed(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value world = enzyme_extract_world(LLVM.parent(position(B))) legal, TT, _ = abs_typeof(val) if !legal @@ -401,7 +401,7 @@ function byref_from_val_if_mixed(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Valu end end -function emit_apply_type!(B::LLVM.IRBuilder, @nospecialize(Ty::Type), args::Vector{LLVM.Value})::LLVM.Value +Base.@nospecializeinfer function emit_apply_type!(B::LLVM.IRBuilder, @nospecialize(Ty::Type), args::Vector{LLVM.Value})::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -510,7 +510,7 @@ function emit_tuple!(B::LLVM.IRBuilder, args::Vector{LLVM.Value})::LLVM.Value return tag end -function emit_jltypeof!(B::LLVM.IRBuilder, @nospecialize(arg::LLVM.Value))::LLVM.Value +Base.@nospecializeinfer function emit_jltypeof!(B::LLVM.IRBuilder, @nospecialize(arg::LLVM.Value))::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -527,7 +527,7 @@ function emit_jltypeof!(B::LLVM.IRBuilder, @nospecialize(arg::LLVM.Value))::LLVM call!(B, FT, fn, [arg]) end -function emit_methodinstance!(B::LLVM.IRBuilder, @nospecialize(func), args::Vector{LLVM.Value})::LLVM.Value +Base.@nospecializeinfer function emit_methodinstance!(B::LLVM.IRBuilder, @nospecialize(func), args::Vector{LLVM.Value})::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -714,7 +714,7 @@ function get_memory_struct() return LLVM.StructType([sizeT, ptrty]; packed = true) end -function get_memory_data(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) +Base.@nospecializeinfer function get_memory_data(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) mty = get_memory_struct() array = LLVM.pointercast!( B, @@ -795,7 +795,7 @@ function get_datatype_struct() return LLVM.StructType([jlvaluet, jlvaluet, jlvaluet, jlvaluet, jlvaluet, jlvaluet, i32, i16]; packed = true) end -function get_array_data(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) +Base.@nospecializeinfer function get_array_data(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) i8 = LLVM.IntType(8) ptrty = LLVM.PointerType(i8, 13) array = LLVM.pointercast!( @@ -806,7 +806,7 @@ function get_array_data(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) return LLVM.load!(B, ptrty, array) end -function get_array_elsz(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) +Base.@nospecializeinfer function get_array_elsz(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) ST = get_array_struct() elsz = LLVM.IntType(16) array = LLVM.pointercast!( @@ -823,7 +823,7 @@ function get_array_elsz(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) return LLVM.load!(B, elsz, v) end -function emit_layout_of_type!(B::LLVM.IRBuilder, @nospecialize(ty::LLVM.Value)) +Base.@nospecializeinfer function emit_layout_of_type!(B::LLVM.IRBuilder, @nospecialize(ty::LLVM.Value)) legal, JTy = absint(ty) ls = get_layout_struct() lptr = LLVM.PointerType(ls, 10) @@ -843,7 +843,7 @@ function emit_layout_of_type!(B::LLVM.IRBuilder, @nospecialize(ty::LLVM.Value)) return layout end -function emit_memorytype_elsz!(B::LLVM.IRBuilder, @nospecialize(ty::LLVM.Value)) +Base.@nospecializeinfer function emit_memorytype_elsz!(B::LLVM.IRBuilder, @nospecialize(ty::LLVM.Value)) legal, JTy = absint(ty) if legal res = unsafe_load(reinterpret(Ptr{UInt32}, JTy.layout)) @@ -857,12 +857,12 @@ function emit_memorytype_elsz!(B::LLVM.IRBuilder, @nospecialize(ty::LLVM.Value)) return load!(B, i32, lty) end -function get_memory_elsz(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) +Base.@nospecializeinfer function get_memory_elsz(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) ty = emit_jltypeof!(B, array) return emit_memorytype_elsz!(B, ty) end -function get_array_len(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) +Base.@nospecializeinfer function get_array_len(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) if isa(array, LLVM.CallInst) fn = LLVM.called_operand(array) nm = "" @@ -903,7 +903,7 @@ function get_array_len(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) return LLVM.load!(B, sizeT, v) end -function get_memory_len(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) +Base.@nospecializeinfer function get_memory_len(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) if isa(array, LLVM.CallInst) fn = LLVM.called_operand(array) nm = "" @@ -940,7 +940,7 @@ function get_memory_len(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) return LLVM.load!(B, sizeT, v) end -function get_array_nrows(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) +Base.@nospecializeinfer function get_array_nrows(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) ST = get_array_struct() array = LLVM.pointercast!( B, @@ -971,7 +971,7 @@ function emit_gc_preserve_begin(B::LLVM.IRBuilder, args::Vector{LLVM.Value} = LL return token end -function emit_gc_preserve_end(B::LLVM.IRBuilder, @nospecialize(token::LLVM.Value)) +Base.@nospecializeinfer function emit_gc_preserve_end(B::LLVM.IRBuilder, @nospecialize(token::LLVM.Value)) curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -986,20 +986,20 @@ function emit_gc_preserve_end(B::LLVM.IRBuilder, @nospecialize(token::LLVM.Value return end -function allocate_sret!(B::LLVM.IRBuilder, @nospecialize(N::LLVM.LLVMType)) +Base.@nospecializeinfer function allocate_sret!(B::LLVM.IRBuilder, @nospecialize(N::LLVM.LLVMType)) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) al = LLVM.alloca!(B, LLVM.ArrayType(T_prjlvalue, N)) return al end -function allocate_sret!(gutils::API.EnzymeGradientUtilsRef, @nospecialize(N::LLVM.LLVMType)) +Base.@nospecializeinfer function allocate_sret!(gutils::API.EnzymeGradientUtilsRef, @nospecialize(N::LLVM.LLVMType)) B = LLVM.IRBuilder() position!(B, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) allocate_sret!(B, N) end -function emit_error(B::LLVM.IRBuilder, @nospecialize(orig::Union{Nothing, LLVM.Instruction}), string::String, @nospecialize(errty::Type) = EnzymeRuntimeException) +Base.@nospecializeinfer function emit_error(B::LLVM.IRBuilder, @nospecialize(orig::Union{Nothing, LLVM.Instruction}), string::String, @nospecialize(errty::Type) = EnzymeRuntimeException) curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) diff --git a/src/utils.jl b/src/utils.jl index 0f92fc1f5d..db5bbac42b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -5,7 +5,7 @@ Assumes that `val` is globally rooted and pointer to it can be leaked. Prefer `pointer_from_objref`. Only use inside Enzyme.jl should be for Types. """ -@inline unsafe_to_pointer(@nospecialize(val::Type)) = @static if sizeof(Int) == sizeof(Int64) +@inline Base.@nospecializeinfer unsafe_to_pointer(@nospecialize(val::Type)) = @static if sizeof(Int) == sizeof(Int64) Base.llvmcall(( """ declare nonnull {}* @julia.pointer_from_objref({} addrspace(11)*) @@ -65,7 +65,7 @@ function unsafe_nothing_to_llvm(mod::LLVM.Module) return gv end -function unsafe_to_ptr(@nospecialize(val)) +Base.@nospecializeinfer function unsafe_to_ptr(@nospecialize(val)) if !Base.ismutable(val) val = Core.Box(val) # FIXME many objects could be leaked here @assert Base.ismutable(val) @@ -81,7 +81,7 @@ end export unsafe_to_ptr # This mimicks literal_pointer_val / literal_pointer_val_slot -function unsafe_to_llvm(B::LLVM.IRBuilder, @nospecialize(val))::LLVM.Value +Base.@nospecializeinfer function unsafe_to_llvm(B::LLVM.IRBuilder, @nospecialize(val))::LLVM.Value T_jlvalue = LLVM.StructType(LLVM.LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) T_prjlvalue_UT = LLVM.PointerType(T_jlvalue) @@ -141,7 +141,7 @@ function unsafe_to_llvm(B::LLVM.IRBuilder, @nospecialize(val))::LLVM.Value end export unsafe_to_llvm, unsafe_nothing_to_llvm -function makeInstanceOf(B::LLVM.IRBuilder, @nospecialize(T::Type)) +Base.@nospecializeinfer function makeInstanceOf(B::LLVM.IRBuilder, @nospecialize(T::Type)) if !Core.Compiler.isconstType(T) throw(AssertionError("Tried to make instance of non constant type $T")) end @@ -151,7 +151,7 @@ end export makeInstanceOf -function hasfieldcount(@nospecialize(dt))::Bool +Base.@nospecializeinfer function hasfieldcount(@nospecialize(dt))::Bool try fieldcount(dt) catch