You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
While reading the docs I came to wonder about the following MWE in the nested-AD section:
model =Chain(Dense(2=>4, tanh), Dense(4=>2))
ps, st = Lux.setup(StableRNG(0), model)
x =randn(StableRNG(0), Float32, 2, 10)
y =randn(StableRNG(11), Float32, 2, 10)
functionloss_function_batched(model, x, ps, st, y)
# Make it a stateful layer
smodel =StatefulLuxLayer{true}(model, ps, st)
ŷ =smodel(x)
loss_emp =sum(abs2, ŷ .- y)
# You can use `AutoZygote()` as well but `AutoForwardDiff()` tends to be more efficient here
J =batched_jacobian(smodel, AutoForwardDiff(), x)
loss_reg =abs2(norm(J .*0.01f0))
return loss_emp + loss_reg
endloss_function_batched(model, x, ps, st, y)
It is noticeable that the forward pass must be done at least twice (once at ŷ = smodel(x) and at least once at J = batched_jacobian(smodel, AutoForwardDiff(), x)) here to obtain both $\hat{y}$ and $J$, which is inefficient (or is it?). This may imply some costly resources wasting in cases where the number of chunks is small. batched_jacobian will internally compute $\hat{y}$ anyway, so it would be interesting to use it.
Design ideas to fix it
Here are some ways that one may fix this by fetching the primal computation results in the internal API.
In all of those, the batched_jacobian function accepts a new parameter which defaults to false, asking whether to return the primal (new api) or not (as before).
functionbatched_jacobian(f::F, backend::AbstractADType, x::AbstractArray; primal::Bool=false) where {F}
returnbatched_jacobian_internal(f, backend, x, Val{primal})
end
FowrardDiff possible solutions
Since both Zygote and ForwardDiff must compute at some point $\hat{y}$ internally, may it be possible to extract it ?
For the ForwardDiff par for example, it seems to be related to the file src/autodiff/batched_autodiff.jl line 160:166 (main branch). One may do the same as partials_wrap line 159 for the primal "values" while computing the first chunk.
# line 124@viewsfunctionbatched_forwarddiff_jacobian_first_chunk(
f::F,
x::AbstractMatrix{T}, ::Type{Tag}, ::ForwardDiff.Chunk{CK}, ::Type{Dual},
::Type{Partials}) where {F, T, Tag, CK, Dual, Partials}
N, B =size(x)
n_idxs =min(CK, N)
idxs =1:n_idxs
idxs_next = (1+ CK):N
dev =get_device(x)
partials =map(𝒾 ->Partials(ntuple(𝒿 ->ifelse(𝒾 == 𝒿, oneunit(T), zero(T)), CK)),
dev(collect(1:n_idxs)))
x_part_duals =Dual.(x[idxs, :], partials)
iflength(idxs_next) ==0
x_part_next =similar(x_part_duals, 0, B)
else
x_part_next =Dual.(x[idxs_next, :],
map(𝒾 ->Partials(ntuple(_ ->zero(T), CK)), dev(collect(1:length(idxs_next)))))
end
x_duals =vcat(x_part_duals, x_part_next)
y_duals_ =f(x_duals)
@argcheckndims(y_duals_) >1&&size(y_duals_, ndims(y_duals_)) == B
y_duals =reshape(y_duals_, :, B)
partials_wrap(y, i) = ForwardDiff.partials(Tag, y, i)
return ForwardDiff.values(Tag, y_duals_), stack(i ->partials_wrap.(y_duals, i), 1:CK; dims=2)
end
Then the first chunck would be computed with this method by changing line 100 as follows and then returning (or not) $y$.
y, J_partial =batched_forwarddiff_jacobian_first_chunk!!(f, x, Tag, ck, dual_type, partials_type)
Zygote possible solutions
If I understood the code correctly, this one may be simpler, at ext/LuxZygoteExt/batched_autodiff.jl.
# line 31if primal
return y, J
return J
About Enzyme
When ho/ho-enzyme (#954) will be ready, using e.g. julia AutoEnzyme(; mode=Enzyme.ForwardWithPrimal) will make this issue almost trivial.
From personal experience with Reactant+Enzyme+nestedAD, this looks like it is particularly difficult for now though.
vjp/jvp
All the ideas mentioned above also seem applicable to both vector_jacobian_product and jacobian_vector_product.
Conclusion
Are those ideas applicable in the way described, or some other way? Are they useful to the API?
I could implement changes myself if needed, but I might need guidance wrt the codebase, especially the tests targeting this part of the code. If this seems relevant, I'll open a PR.
Disclaimer
(I am not proficient in Julia coding, especially regarding the autodiff libraries and GPU technicalities. I may have misunderstood some mechanisms (e.g. some inplace operation that may make my proposed solutions inapplicable, or any scalar indexing issues that may arise), in which case I would not know how to address the above issue.)
The text was updated successfully, but these errors were encountered:
floffy-f
changed the title
Add primal computations to batched jacobian computations
Append primal results to batched jacobian computations
Jan 9, 2025
Problem statement
While reading the docs I came to wonder about the following MWE in the nested-AD section:
It is noticeable that the forward pass must be done at least twice (once at$\hat{y}$ and $J$ , which is inefficient (or is it?). This may imply some costly resources wasting in cases where the number of chunks is small.
$\hat{y}$ anyway, so it would be interesting to use it.
ŷ = smodel(x)
and at least once atJ = batched_jacobian(smodel, AutoForwardDiff(), x)
) here to obtain bothbatched_jacobian
will internally computeDesign ideas to fix it
Here are some ways that one may fix this by fetching the primal computation results in the internal API.
In all of those, the
batched_jacobian
function accepts a new parameter which defaults to false, asking whether to return the primal (new api) or not (as before).FowrardDiff
possible solutionsSince both$\hat{y}$ internally, may it be possible to extract it ?
Zygote
andForwardDiff
must compute at some pointFor the
ForwardDiff
par for example, it seems to be related to the filesrc/autodiff/batched_autodiff.jl
line 160:166 (main branch). One may do the same aspartials_wrap
line 159 for the primal "values" while computing the first chunk.Then the first chunck would be computed with this method by changing line 100 as follows and then returning (or not)$y$ .
Zygote
possible solutionsIf I understood the code correctly, this one may be simpler, at
ext/LuxZygoteExt/batched_autodiff.jl
.About Enzyme
When
ho/ho-enzyme
(#954) will be ready, using e.g.julia AutoEnzyme(; mode=Enzyme.ForwardWithPrimal)
will make this issue almost trivial.From personal experience with
Reactant
+Enzyme
+nestedAD, this looks like it is particularly difficult for now though.vjp/jvp
All the ideas mentioned above also seem applicable to both
vector_jacobian_product
andjacobian_vector_product
.Conclusion
Are those ideas applicable in the way described, or some other way? Are they useful to the API?
I could implement changes myself if needed, but I might need guidance wrt the codebase, especially the tests targeting this part of the code. If this seems relevant, I'll open a PR.
Disclaimer
(I am not proficient in Julia coding, especially regarding the autodiff libraries and GPU technicalities. I may have misunderstood some mechanisms (e.g. some inplace operation that may make my proposed solutions inapplicable, or any scalar indexing issues that may arise), in which case I would not know how to address the above issue.)
The text was updated successfully, but these errors were encountered: