Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reverse mode AD with forward mode backend for conditions (and vice versa) #69

Closed
thorek1 opened this issue Jul 27, 2023 · 27 comments · Fixed by #87
Closed

Reverse mode AD with forward mode backend for conditions (and vice versa) #69

thorek1 opened this issue Jul 27, 2023 · 27 comments · Fixed by #87
Labels
discussion The right solution is not yet clear feature New feature or request
Milestone

Comments

@thorek1
Copy link
Contributor

thorek1 commented Jul 27, 2023

I have a case where I need to autodiff through scalar functions and matrix operations. I understand forward mode is better for the former and reverse mode more efficient for the latter.

what might solve this "problem" is using reverse mode on the overall chain and forward mode as a backend on the conditions involving the scalar functions.

looking at the current implementation I would suggest to make backend an argument similar to return_byproduct

@mohamed82008
Copy link
Collaborator

can you give an example of what you are trying to do?

@thorek1
Copy link
Contributor Author

thorek1 commented Jul 28, 2023

I am looking for an implementation of this (from example 4):

implicit_cstr_optim = ImplicitFunction(forward_cstr_optim, conditions_cstr_optim, backend = :Forward)

which in the background would lead to the following chain rule being implemented:

function ChainRulesCore.rrule(
    rc::RuleConfig,
    implicit::ImplicitFunction,
    x::AbstractArray{R},
    ::Val{return_byproduct};
    backend::Symbol = :ReverseDiff,
    kwargs...,
) where {R,return_byproduct}
    @unpack conditions, linear_solver = implicit

    y, z = implicit(x, Val(true); kwargs...)
    n, m = length(x), length(y)

    if backend == :Forward
        backend = ForwardDiffBackend()
    elseif backend == :Reverse
        backend = ReverseRuleConfigBackend(rc)
    end

    pbA = pullback_function(backend, _y -> conditions(x, _y, z; kwargs...), y)
    pbB = pullback_function(backend, _x -> conditions(_x, y, z; kwargs...), x)
    pbmA = PullbackMul!(pbA, size(y))
    pbmB = PullbackMul!(pbB, size(y))
    Aᵀ_op = LinearOperator(R, m, m, false, false, pbmA)
    Bᵀ_op = LinearOperator(R, n, m, false, false, pbmB)
    implicit_pullback = ImplicitPullback(
        Aᵀ_op, Bᵀ_op, linear_solver, x, Val(return_byproduct)
    )

    if return_byproduct
        return (y, z), implicit_pullback
    else
        return y, implicit_pullback
    end
end

this is a crude demonstration but I hope you get the point

@gdalle
Copy link
Member

gdalle commented Jul 28, 2023

In your specific case, this does the job: https://fluxml.ai/Zygote.jl/stable/utils/#Zygote.forwarddiff

@gdalle
Copy link
Member

gdalle commented Jul 28, 2023

In the general case, it makes sense mathematically to differentiate conditions "in the same way" as the general function, but I agree that there is no reason not to be more generic. We can probably do this within the upcoming Conditions structure

@gdalle gdalle changed the title reverse mode autodiff with forward mode backend for conditions (and vice versa) Reverse mode AD with forward mode backend for conditions (and vice versa) Jul 30, 2023
@gdalle gdalle added feature New feature or request discussion The right solution is not yet clear labels Jul 30, 2023
@thorek1
Copy link
Contributor Author

thorek1 commented Jul 30, 2023

indeed Zygote.forwarddiff is helpful but a Conditions structure containing a backend field would simplify things a lot in my case

@gdalle
Copy link
Member

gdalle commented Jul 30, 2023

@mohamed82008 this would be breaking. Do we try to put it in 0.5?

@gdalle
Copy link
Member

gdalle commented Jul 30, 2023

I also think we would lose efficiency that way, cause if we differentiate the implicit function in reverse mode but the conditions in forward mode, we are wasting a lot of passes. Basically for every partial pullback on the conditions we do n pushforwards.

@thorek1
Copy link
Contributor Author

thorek1 commented Jul 30, 2023

why breaking? you could set the defaults to what it is right now but have the option to change the backend.

@wasted passes: how so? the problems are separate and come together as a product in either case (JVP or VJP), or am I missing something? given that the AD of the implicit function and the AD of the conditions are separate, it makes sense to choose the most efficient AD backend for the conditions and use the result in your JVP/VJP

@gdalle
Copy link
Member

gdalle commented Jul 30, 2023

given that the AD of the implicit function and the AD of the conditions are separate

They are not completely separate: for efficiency they should use the same mode. Because one JVP for the implicit function means 2 partial JVPs for the conditions, and one VJP for the implicit function means 2 partial VJPs for the conditions.

On the other hand, if we want to mix modes, doing one JVP for the implicit function using VJPs for the conditions, we're gonna need $2n$ (or $2m$, not sure) VJPs to get the same result.

@mohamed82008
Copy link
Collaborator

I don't think we should support this feature. This is outside of the scope of the package. An API like this https://julianonconvex.github.io/Nonconvex.jl/stable/gradients/other_ad/ makes more sense to me. If you want to mix modes, define an rrule for the condition function to use forward-mode AD to compute the Jacobian and then do the pullback.

@gdalle
Copy link
Member

gdalle commented Jul 31, 2023

Makes sense to me too. Basically it's up to the user how conditions will be differentiated. The only thing is, this API comes with a lot of nonconvex-related deps. Do you plan to move it to AbstractDifferentiation?

@mohamed82008
Copy link
Collaborator

I would be happy to review and approve a PR if someone moves it from NonconvexUtils.

@gdalle
Copy link
Member

gdalle commented Jul 31, 2023

Oh crap it still has that flattening part I hate so much, because it's not restricted to arrays. I think we'd be better off putting the array-only version of abstractdiffy in ImplicitDifferentiation.jl.

@gdalle
Copy link
Member

gdalle commented Aug 1, 2023

On second thought I think my implementation in #80 is much simpler and more natural. But I'd be curious to see if it actually brings a speedup in @thorek1's case

@thorek1
Copy link
Contributor Author

thorek1 commented Aug 1, 2023

Nice, I'll have a goat it later today or tomorrow

@thorek1
Copy link
Contributor Author

thorek1 commented Aug 3, 2023

the forward mode backend gave me a significant boost on my type of problem. well done!

FYI: the byproduct implementation in this branch is a bit confusing

@gdalle
Copy link
Member

gdalle commented Aug 3, 2023

Happy to hear it! I guess a merge is back on the table if @mohamed82008 approves of the implementation.

What do you mean by confusing?

@thorek1
Copy link
Contributor Author

thorek1 commented Aug 3, 2023

I managed with 0.4.4 but not with this branch. The additional arguments HandleByproduct and ReturnByproduct were not working as expected. so in the end I opted for no byproduct

@gdalle
Copy link
Member

gdalle commented Aug 3, 2023

That is important for us to know before we release it ^^ Could you maybe provide an MWE? What didn't work as expected?
Theoretically the tests are much more thorough for this version so I'm reasonably confident that if something is wrong it's in the docs

@gdalle gdalle added this to the v0.5 milestone Aug 4, 2023
@thorek1
Copy link
Contributor Author

thorek1 commented Aug 4, 2023

I got it to run after looking at the examples. One thing I need to take some time to understand is why I need to write the function like this:

riccati_(∇₁;T, explosive) = ImplicitFunction(∇₁ -> riccati_forward(∇₁, T=T, explosive=explosive), (x,y,z)->riccati_conditions(x,y,T=T,explosive=explosive), HandleByproduct())

in order to have additional inputs (not to be diffed) of various types (Bool, Int, Float...) be passed to the ImplicitFunction. Any hints most welcome.

it does not work if I write: riccati_ = ImplicitFunction(∇₁ -> riccati_forward(∇₁, T=T, explosive=explosive)[1], (x,y)->riccati_conditions(x,y,T=T,explosive=explosive))

later in the code I call the ImplicitDifferetiation like this:

riccati = riccati_(∇₁, T = T, explosive = explosive)
A, solved = riccati(∇₁, ReturnByproduct())

while this works I would advocate for a solution without handle and return byproduct arguments. I understand handling of byproducts can be automatised (recognise outputs and pass them on) and it would make the function more user friendly. from a user perspective I don't see why you would not have the byproduct returned of a function you wrote. If you don't want it, you can remove it at the level of the function itself.

@gdalle
Copy link
Member

gdalle commented Aug 4, 2023

in order to have additional inputs (not to be diffed) of various types (Bool, Int, Float...) be passed to the ImplicitFunction.

Do you mean additional outputs of the forward mapping? Cause in the general case, keyword arguments are always supported by implicit(x; kwargs...) and never differentiated

it does not work if I write:

I'm not sure what it is that doesn't work. Does your ricatti_forward return the byproduct solved? If so, you need to tell it to the ImplicitFunction structure so that it expects

f(x) -> (y, z) and c(x, y, z) -> w

instead of

f(x) -> y and c(x, y) -> w

It's useful to ensure type stability, but it leads to a lot of complications I agree.

I would advocate for a solution without handle and return byproduct arguments

If it is so complicated to get right from the user perspective (and you are our only user ^^) I have half a mind to walk back the changes and leave the byproduct always. @mohamed82008 ?

@gdalle
Copy link
Member

gdalle commented Aug 4, 2023

basically we wanted to simplify it because byproduct uses are not the norm

@thorek1
Copy link
Contributor Author

thorek1 commented Aug 4, 2023

here is a more nasty example:

block_solver_AD(parameters_and_solved_vars::Vector{<: Real}, 
    n_block::Int, 
    ss_solve_blocks::Function, 
    # ss_solve_blocks_no_transform::Function, 
    # f::OptimizationFunction, 
    guess::Vector{Float64}, 
    lbs::Vector{Float64}, 
    ubs::Vector{Float64};
    tol::AbstractFloat = eps(Float64),
    # timeout = 120,
    starting_points::Vector{Float64} = [0.897, 1.2, .9, .75, 1.5, -.5, 2.0, .25],
    # fail_fast_solvers_only = true,
    verbose::Bool = false) = ImplicitFunction(x -> block_solver(x,
                                                            n_block, 
                                                            ss_solve_blocks,
                                                            # f,
                                                            guess,
                                                            lbs,
                                                            ubs;
                                                            tol = tol,
                                                            # timeout = timeout,
                                                            starting_points = starting_points,
                                                            # fail_fast_solvers_only = fail_fast_solvers_only,
                                                            verbose = verbose)[1],  
                                        (x,y) -> ss_solve_blocks(x,y))

in the end its not only known arguments and thats what my question is about. how to handle those? i mean i have a way of handling them but is that the best/only way?

thumbsup for always having the byproduct. leaving it for the package to figure out makes for a more user friendly experience

@gdalle
Copy link
Member

gdalle commented Aug 4, 2023

in the end its not only known arguments and thats what my question is about. how to handle those? i mean i have a way of handling them but is that the best/only way?

You can also write

f(x, y; kwargs...) = y, z
c(x, y, z; kwargs...) = w

as long as f and c accept the same kwargs. So you don't have to close over the keyword arguments as in your example. Does that answer your question?

thumbsup for always having the byproduct. leaving it for the package to figure out makes for a more user friendly experience

The question is whether we can figure it out in a type-stable way. But I think we can, let's see

@thorek1
Copy link
Contributor Author

thorek1 commented Aug 4, 2023

ok but does that cover the case where i have arguments which are to be diffed and passed to the function f, then there are arguments which are not to be diffed and passed to f, then there are arguments which are not to be diffed and are passed to f and c, and then the same for known arguments (kwargs) minus the diffable arguments.

@gdalle
Copy link
Member

gdalle commented Aug 4, 2023

The only argument which is diffed is x and it has to be an array. Every argument which is not diffed must be a kwarg, and the conditions must accept it too. The only exception is the byproduct.

Neglecting the fact that we only take one positional argument, this is the convention adopted by ChainRules.jl (diff <=> arg, non-diff <=> kwarg).

@gdalle
Copy link
Member

gdalle commented Aug 4, 2023

See #86, @thorek1 your prayers have been answered

@gdalle gdalle linked a pull request Aug 4, 2023 that will close this issue
@gdalle gdalle closed this as completed in #87 Aug 5, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion The right solution is not yet clear feature New feature or request
Projects
None yet
3 participants