From d6c10fe44df62fc7f738599ceb879de32529185f Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:50:21 +0100 Subject: [PATCH] Improve catch block identification --- src/compiler/reverse.jl | 34 +++++++++++++++++++++------------- test/compiler.jl | 12 ++---------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/compiler/reverse.jl b/src/compiler/reverse.jl index bf2783028..2b961ae4c 100644 --- a/src/compiler/reverse.jl +++ b/src/compiler/reverse.jl @@ -299,6 +299,7 @@ end function adjoint(pr::Primal) ir, sigs = adjointcfg(pr) + catch_blocks = falses(length(blocks(pr.ir))) for b in reverse(blocks(pr.ir)) rb = block(ir, b.id) grads = Dict() @@ -309,12 +310,13 @@ function adjoint(pr::Primal) grad(sigs[b.id][i], arguments(rb)[i]) end - has_leave = false - # Backprop through statements for v in reverse(keys(b)) ex = b[v].expr - has_leave |= isexpr(ex, :leave) + + if isexpr(ex, :catch) + catch_blocks[first(ex.args)] = true + end if haskey(pr.pullbacks, v) g = push!(rb, stmt(Expr(:call, alpha(pr.pullbacks[v]), grad(v)), @@ -338,16 +340,6 @@ function adjoint(pr::Primal) end end - # This is corresponds to a catch blocks which technically - # has predecessors but they are not modelled in the IRTools CFG. - # We put an error message at the beginning of said block. - if has_leave && isempty(predecessors(b)) && b.id != 1 - _, f_stmt = first(b) - li = pr.ir.lines[f_stmt.line] - pushfirst!(rb, stmt(xcall(Base, :error, - "Can't differentiate function execution in catch block at $(li.file):$(li.line)."))) - end - if b.id > 1 # Backprop through (predecessor) branch arguments gs = grad.(arguments(b)) for br in branches(rb) @@ -368,6 +360,22 @@ function adjoint(pr::Primal) branches(rb)[1].args[1] = Δ end end + + for (id, is_catch) in enumerate(catch_blocks) + is_catch || continue + + b = block(pr.ir, id) + rb = block(ir, id) + err_message = if isempty(b) + "Can't differentiate function execution in catch block" + else + _, f_stmt = first(b) + li = pr.ir.lines[f_stmt.line] + "Can't differentiate function execution in catch block at $(li.file):$(li.line)." + end + pushfirst!(rb, stmt(xcall(Base, :error, err_message))) + end + return ir end diff --git a/test/compiler.jl b/test/compiler.jl index 07d498ecb..3b5b0018a 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -319,11 +319,7 @@ end @test res == 12. @test_throws ErrorException pull(1.) err = try pull(1.) catch ex; ex end - if VERSION >= v"1.11" - @test_broken occursin("Can't differentiate function execution in catch block", string(err)) - else - @test occursin("Can't differentiate function execution in catch block", string(err)) - end + @test occursin("Can't differentiate function execution in catch block", string(err)) end if VERSION >= v"1.8" @@ -351,9 +347,5 @@ end @test_throws ErrorException pull(1.) err = try pull(1.) catch ex; ex end - if VERSION >= v"1.11" - @test_broken occursin("Can't differentiate function execution in catch block", string(err)) - else - @test occursin("Can't differentiate function execution in catch block", string(err)) - end + @test occursin("Can't differentiate function execution in catch block", string(err)) end