-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reverse mode AD with forward mode backend for conditions (and vice versa) #69
Comments
can you give an example of what you are trying to do? |
I am looking for an implementation of this (from example 4): implicit_cstr_optim = ImplicitFunction(forward_cstr_optim, conditions_cstr_optim, backend = :Forward) which in the background would lead to the following chain rule being implemented: function ChainRulesCore.rrule(
rc::RuleConfig,
implicit::ImplicitFunction,
x::AbstractArray{R},
::Val{return_byproduct};
backend::Symbol = :ReverseDiff,
kwargs...,
) where {R,return_byproduct}
@unpack conditions, linear_solver = implicit
y, z = implicit(x, Val(true); kwargs...)
n, m = length(x), length(y)
if backend == :Forward
backend = ForwardDiffBackend()
elseif backend == :Reverse
backend = ReverseRuleConfigBackend(rc)
end
pbA = pullback_function(backend, _y -> conditions(x, _y, z; kwargs...), y)
pbB = pullback_function(backend, _x -> conditions(_x, y, z; kwargs...), x)
pbmA = PullbackMul!(pbA, size(y))
pbmB = PullbackMul!(pbB, size(y))
Aᵀ_op = LinearOperator(R, m, m, false, false, pbmA)
Bᵀ_op = LinearOperator(R, n, m, false, false, pbmB)
implicit_pullback = ImplicitPullback(
Aᵀ_op, Bᵀ_op, linear_solver, x, Val(return_byproduct)
)
if return_byproduct
return (y, z), implicit_pullback
else
return y, implicit_pullback
end
end this is a crude demonstration but I hope you get the point |
In your specific case, this does the job: https://fluxml.ai/Zygote.jl/stable/utils/#Zygote.forwarddiff |
In the general case, it makes sense mathematically to differentiate conditions "in the same way" as the general function, but I agree that there is no reason not to be more generic. We can probably do this within the upcoming |
indeed |
@mohamed82008 this would be breaking. Do we try to put it in 0.5? |
I also think we would lose efficiency that way, cause if we differentiate the implicit function in reverse mode but the conditions in forward mode, we are wasting a lot of passes. Basically for every partial pullback on the conditions we do n pushforwards. |
why breaking? you could set the defaults to what it is right now but have the option to change the backend. @wasted passes: how so? the problems are separate and come together as a product in either case (JVP or VJP), or am I missing something? given that the AD of the implicit function and the AD of the conditions are separate, it makes sense to choose the most efficient AD backend for the conditions and use the result in your JVP/VJP |
They are not completely separate: for efficiency they should use the same mode. Because one JVP for the implicit function means 2 partial JVPs for the conditions, and one VJP for the implicit function means 2 partial VJPs for the conditions. On the other hand, if we want to mix modes, doing one JVP for the implicit function using VJPs for the conditions, we're gonna need |
I don't think we should support this feature. This is outside of the scope of the package. An API like this https://julianonconvex.github.io/Nonconvex.jl/stable/gradients/other_ad/ makes more sense to me. If you want to mix modes, define an rrule for the condition function to use forward-mode AD to compute the Jacobian and then do the pullback. |
Makes sense to me too. Basically it's up to the user how conditions will be differentiated. The only thing is, this API comes with a lot of nonconvex-related deps. Do you plan to move it to AbstractDifferentiation? |
I would be happy to review and approve a PR if someone moves it from NonconvexUtils. |
Oh crap it still has that flattening part I hate so much, because it's not restricted to arrays. I think we'd be better off putting the array-only version of |
Nice, I'll have a goat it later today or tomorrow |
the forward mode backend gave me a significant boost on my type of problem. well done! FYI: the byproduct implementation in this branch is a bit confusing |
Happy to hear it! I guess a merge is back on the table if @mohamed82008 approves of the implementation. What do you mean by confusing? |
I managed with 0.4.4 but not with this branch. The additional arguments HandleByproduct and ReturnByproduct were not working as expected. so in the end I opted for no byproduct |
That is important for us to know before we release it ^^ Could you maybe provide an MWE? What didn't work as expected? |
I got it to run after looking at the examples. One thing I need to take some time to understand is why I need to write the function like this: riccati_(∇₁;T, explosive) = ImplicitFunction(∇₁ -> riccati_forward(∇₁, T=T, explosive=explosive), (x,y,z)->riccati_conditions(x,y,T=T,explosive=explosive), HandleByproduct()) in order to have additional inputs (not to be diffed) of various types (Bool, Int, Float...) be passed to the ImplicitFunction. Any hints most welcome. it does not work if I write: later in the code I call the ImplicitDifferetiation like this: riccati = riccati_(∇₁, T = T, explosive = explosive)
A, solved = riccati(∇₁, ReturnByproduct()) while this works I would advocate for a solution without handle and return byproduct arguments. I understand handling of byproducts can be automatised (recognise outputs and pass them on) and it would make the function more user friendly. from a user perspective I don't see why you would not have the byproduct returned of a function you wrote. If you don't want it, you can remove it at the level of the function itself. |
Do you mean additional outputs of the forward mapping? Cause in the general case, keyword arguments are always supported by
I'm not sure what it is that doesn't work. Does your
instead of
It's useful to ensure type stability, but it leads to a lot of complications I agree.
If it is so complicated to get right from the user perspective (and you are our only user ^^) I have half a mind to walk back the changes and leave the byproduct always. @mohamed82008 ? |
basically we wanted to simplify it because byproduct uses are not the norm |
here is a more nasty example: block_solver_AD(parameters_and_solved_vars::Vector{<: Real},
n_block::Int,
ss_solve_blocks::Function,
# ss_solve_blocks_no_transform::Function,
# f::OptimizationFunction,
guess::Vector{Float64},
lbs::Vector{Float64},
ubs::Vector{Float64};
tol::AbstractFloat = eps(Float64),
# timeout = 120,
starting_points::Vector{Float64} = [0.897, 1.2, .9, .75, 1.5, -.5, 2.0, .25],
# fail_fast_solvers_only = true,
verbose::Bool = false) = ImplicitFunction(x -> block_solver(x,
n_block,
ss_solve_blocks,
# f,
guess,
lbs,
ubs;
tol = tol,
# timeout = timeout,
starting_points = starting_points,
# fail_fast_solvers_only = fail_fast_solvers_only,
verbose = verbose)[1],
(x,y) -> ss_solve_blocks(x,y)) in the end its not only known arguments and thats what my question is about. how to handle those? i mean i have a way of handling them but is that the best/only way? thumbsup for always having the byproduct. leaving it for the package to figure out makes for a more user friendly experience |
You can also write f(x, y; kwargs...) = y, z
c(x, y, z; kwargs...) = w as long as
The question is whether we can figure it out in a type-stable way. But I think we can, let's see |
ok but does that cover the case where i have arguments which are to be diffed and passed to the function |
The only argument which is diffed is Neglecting the fact that we only take one positional argument, this is the convention adopted by ChainRules.jl (diff <=> arg, non-diff <=> kwarg). |
I have a case where I need to autodiff through scalar functions and matrix operations. I understand forward mode is better for the former and reverse mode more efficient for the latter.
what might solve this "problem" is using reverse mode on the overall chain and forward mode as a backend on the conditions involving the scalar functions.
looking at the current implementation I would suggest to make
backend
an argument similar toreturn_byproduct
The text was updated successfully, but these errors were encountered: