From a72efb1ba163158c199a4811d4839b2640e4dd91 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 25 Sep 2024 22:26:13 -0500 Subject: [PATCH] Concrete type assertion (#1890) * Concrete type assertion * fix * fix * fix --- src/absint.jl | 72 ++++++++++++++++++++++++++++++------------------- src/compiler.jl | 4 ++- 2 files changed, 48 insertions(+), 28 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index 27f66d2c06..9eec24bcf3 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -167,6 +167,43 @@ function actual_size(@nospecialize(typ2)) end end +@inline function first_non_ghost(@nospecialize(typ2)) + fc = fieldcount(typ2) + for i in 1:fc + if i == fc + return (i, sizeof(typ2)) + else + fo = fieldoffset(typ2, i+1) + if fo != 0 + return (i, fo) + end + end + end + return (-1, 0) +end + +function should_recurse(@nospecialize(typ2), arg_t, byref, dl) + sz = sizeof(dl, arg_t) + if byref != GPUCompiler.BITS_VALUE + @assert sz == sizeof(Int) + return false + else + if actual_size(typ2) != sz + return true + else + if Base.isconcretetype(typ2) + idx, sz2 = first_non_ghost(typ2) + if idx != -1 + if sz2 == sz + return true + end + end + end + return false + end + end +end + function abs_typeof( arg::LLVM.Value, partial::Bool = false, @@ -346,7 +383,7 @@ function abs_typeof( if !error legal, typ, byref = abs_typeof(larg) - if legal && (byref == GPUCompiler.MUT_REF || byref == GPUCompiler.BITS_REF) + if legal && (byref == GPUCompiler.MUT_REF || byref == GPUCompiler.BITS_REF) && Base.isconcretetype(typ) @static if VERSION < v"1.11-" if typ <: Array && Base.isconcretetype(typ) T = eltype(typ) @@ -359,31 +396,11 @@ function abs_typeof( end if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF dl = LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(arg)))) - function should_recurse(typ2, arg_t, byref) - sz = sizeof(dl, arg_t) - if byref != GPUCompiler.BITS_VALUE - @assert sz == sizeof(Int) - return false - else - if actual_size(typ2) != sz - return true - else - if Base.isconcretetype(typ2) - if fieldcount(typ2) > 0 - if actual_size(fieldtype(typ2,1)) == sz - return true - end - end - end - return false - end - end - end byref = GPUCompiler.BITS_VALUE legal = true - while offset !== nothing && legal + while (offset !== nothing && offset != 0) && legal @assert Base.isconcretetype(typ) seen = false lasti = 1 @@ -403,6 +420,7 @@ function abs_typeof( elseif fieldoffset(typ, i) > offset offset = offset - fieldoffset(typ, lasti) typ = fieldtype(typ, lasti) + @assert Base.isconcretetype(typ) if !Base.allocatedinline(typ) legal = false end @@ -420,9 +438,10 @@ function abs_typeof( end typ2 = typ - while should_recurse(typ2, value_type(arg), byref) - if fieldcount(typ2) > 0 - typ2 = fieldtype(typ2, 1) + while should_recurse(typ2, value_type(arg), byref, dl) + idx, _ = first_non_ghost(typ2) + if idx != -1 + typ2 = fieldtype(typ2, idx) if !Base.allocatedinline(typ2) if byref != GPUCompiler.BITS_VALUE legal = false @@ -439,10 +458,9 @@ function abs_typeof( return (true, typ2, byref) end end - elseif legal && if typ <: Ptr && Base.isconcretetype(typ) + elseif legal && typ <: Ptr && Base.isconcretetype(typ) return (true, eltype(typ), GPUCompiler.BITS_VALUE) end - end end end diff --git a/src/compiler.jl b/src/compiler.jl index f3680cc4c0..32dd293ece 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -7936,8 +7936,10 @@ function GPUCompiler.codegen( elseif byref == GPUCompiler.MUT_REF || byref == GPUCompiler.BITS_REF Ptr{source_typ} else - println(string(mod)) + # println(string(mod)) + println(string(f)) @show legal, source_typ, byref, llvm_source_typ, codegen_typ, string(inst) + @show enzyme_custom_extract_mi(f) @assert false end else