Skip to content

Commit

Permalink
fix iter
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jun 13, 2024
1 parent 09fc316 commit 6924d59
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/rules/jitrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,7 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween
nothing
end

function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shadowargs)
function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shadowargs, active_refs)
outs = []
for i in 1:N
for w in 1:Width
Expand Down Expand Up @@ -997,6 +997,7 @@ function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shado
push!(shadowsplat, :(($(s...),)))
end
quote
$(active_refs...)
args = ($(wrappedexexpand...),)
tt′ = Enzyme.vaTypeof(args...)
FT = Core.Typeof(f)
Expand All @@ -1006,8 +1007,8 @@ function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shado
end

function func_runtime_iterate_rev(N, Width)
primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, modbetween = setup_macro_wraps(false, N, Width, #=body=#nothing, #=iterate=#true)
body = body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs)
primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, modbetween, active_refs = setup_macro_wraps(false, N, Width, #=body=#nothing, #=iterate=#true)
body = body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs, active_refs)

quote
function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, TapeType, F, DF, $(typeargs...)}
Expand All @@ -1019,7 +1020,7 @@ end
@generated function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, TapeType, F, DF}
N = div(length(allargs)+2, Width+1)-1
primargs, _, primtypes, _, _, wrapped, batchshadowargs, modbetween, active_refs = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true)
return body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs)
return body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs, active_refs)
end

# Create specializations
Expand Down

0 comments on commit 6924d59

Please sign in to comment.