diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index e2da730..5bc6c2e 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -204,13 +204,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, end if cons !== nothing && cons_j == true && f.cons_j === nothing - if num_cons > length(x) - seeds = Enzyme.onehot(x) - Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x)) - else - seeds = Enzyme.onehot(zeros(eltype(x), num_cons)) - Jaccache = Tuple(zero(x) for i in 1:num_cons) - end + # if num_cons > length(x) + seeds = Enzyme.onehot(x) + Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x)) + # else + # seeds = Enzyme.onehot(zeros(eltype(x), num_cons)) + # Jaccache = Tuple(zero(x) for i in 1:num_cons) + # end y = zeros(eltype(x), num_cons) @@ -219,27 +219,26 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, Enzyme.make_zero!(Jaccache[i]) end Enzyme.make_zero!(y) - if num_cons > length(θ) - Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache), - BatchDuplicated(θ, seeds), Const(p)) - for i in eachindex(θ) - if J isa Vector - J[i] = Jaccache[i][1] - else - copyto!(@view(J[:, i]), Jaccache[i]) - end - end - else - Enzyme.autodiff(Enzyme.Reverse, f.cons, BatchDuplicated(y, seeds), - BatchDuplicated(θ, Jaccache), Const(p)) - for i in 1:num_cons - if J isa Vector - J .= Jaccache[1] - else - copyto!(@view(J[i, :]), Jaccache[i]) - end + Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache), + BatchDuplicated(θ, seeds), Const(p)) + for i in eachindex(θ) + if J isa Vector + J[i] = Jaccache[i][1] + else + copyto!(@view(J[:, i]), Jaccache[i]) end end + # else + # Enzyme.autodiff(Enzyme.Reverse, f.cons, BatchDuplicated(y, seeds), + # BatchDuplicated(θ, Jaccache), Const(p)) + # for i in 1:num_cons + # if J isa Vector + # J .= Jaccache[1] + # else + # J[i, :] = Jaccache[i] + # end + # end + # end end elseif cons_j == true && cons !== nothing cons_j! = (J, θ) -> f.cons_j(J, θ, p) @@ -397,11 +396,11 @@ end function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, adtype::AutoEnzyme, - num_cons = 0) + num_cons = 0; kwargs...) p = cache.p x = cache.u0 - return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons) + return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons; kwargs...) end function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, @@ -676,11 +675,11 @@ end function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, cache::OptimizationBase.ReInitCache, adtype::AutoEnzyme, - num_cons = 0) + num_cons = 0; kwargs...) p = cache.p x = cache.u0 - return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons) + return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons; kwargs...) end end diff --git a/ext/OptimizationMTKExt.jl b/ext/OptimizationMTKExt.jl index ff1dce2..dbc08c4 100644 --- a/ext/OptimizationMTKExt.jl +++ b/ext/OptimizationMTKExt.jl @@ -57,7 +57,7 @@ end function OptimizationBase.instantiate_function( f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, - adtype::AutoSparse{<:AutoSymbolics}, num_cons = 0, + adtype::AutoSparse{<:AutoSymbolics}, num_cons = 0; g = false, h = false, hv = false, fg = false, fgh = false, cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, lag_h = false) @@ -107,7 +107,7 @@ end function OptimizationBase.instantiate_function( f::OptimizationFunction{true}, x, adtype::AutoSymbolics, p, - num_cons = 0, g = false, h = false, hv = false, fg = false, fgh = false, + num_cons = 0; g = false, h = false, hv = false, fg = false, fgh = false, cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, lag_h = false) p = isnothing(p) ? SciMLBase.NullParameters() : p @@ -155,7 +155,7 @@ end function OptimizationBase.instantiate_function( f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, - adtype::AutoSymbolics, num_cons = 0, + adtype::AutoSymbolics, num_cons = 0; g = false, h = false, hv = false, fg = false, fgh = false, cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, lag_h = false) diff --git a/ext/OptimizationZygoteExt.jl b/ext/OptimizationZygoteExt.jl index 4916321..d830d3a 100644 --- a/ext/OptimizationZygoteExt.jl +++ b/ext/OptimizationZygoteExt.jl @@ -220,7 +220,7 @@ function OptimizationBase.instantiate_function( if f.lag_h === nothing && cons !== nothing && lag_h == true lag_extras = prepare_hessian( lagrangian, soadtype, vcat(x, [one(eltype(x))], ones(eltype(x), num_cons))) - lag_hess_prototype = zeros(Bool, length(x), length(x)) + lag_hess_prototype = zeros(Bool, length(x) + num_cons + 1, length(x) + num_cons + 1) function lag_h!(H::AbstractMatrix, θ, σ, λ) if σ == zero(eltype(θ)) @@ -232,13 +232,11 @@ function OptimizationBase.instantiate_function( end end - function lag_h!(h, θ, σ, λ) - H = eltype(θ).(lag_hess_prototype) - hessian!(x -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras) + function lag_h!(h::AbstractVector, θ, σ, λ) + H = hessian(lagrangian, soadtype, vcat(θ, [σ], λ), lag_extras) k = 0 - rows, cols, _ = findnz(H) - for (i, j) in zip(rows, cols) - if i <= j + for i in 1:length(θ) + for j in 1:i k += 1 h[k] = H[i, j] end @@ -256,7 +254,7 @@ function OptimizationBase.instantiate_function( 1:length(θ), 1:length(θ)]) end end - + function lag_h!(h::AbstractVector, θ, σ, λ, p) global _p = p H = hessian(lagrangian, soadtype, vcat(θ, [σ], λ), lag_extras) @@ -294,21 +292,20 @@ end function OptimizationBase.instantiate_function( f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, - adtype::ADTypes.AutoZygote, num_cons = 0; - g = false, h = false, hv = false, fg = false, fgh = false, - cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false) + adtype::ADTypes.AutoZygote, num_cons = 0; kwargs...) x = cache.u0 p = cache.p return OptimizationBase.instantiate_function( - f, x, adtype, p, num_cons; g, h, hv, fg, fgh, cons_j, cons_vjp, cons_jvp, cons_h) + f, x, adtype, p, num_cons; kwargs...) end function OptimizationBase.instantiate_function( f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse{<:AutoZygote}, p = SciMLBase.NullParameters(), num_cons = 0; g = false, h = false, hv = false, fg = false, fgh = false, - cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false) + cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, + lag_h = false) function _f(θ) return f.f(θ, p)[1] end @@ -335,7 +332,7 @@ function OptimizationBase.instantiate_function( grad = nothing end - if fg == true && f.fg !== nothing + if fg == true && f.fg === nothing if g == false extras_grad = prepare_gradient(_f, adtype.dense_ad, x) end @@ -361,7 +358,7 @@ function OptimizationBase.instantiate_function( hess_sparsity = f.hess_prototype hess_colors = f.hess_colorvec - if f.hess === nothing + if h == true && f.hess === nothing extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better function hess(res, θ) hessian!(_f, res, soadtype, θ, extras_hess) @@ -384,7 +381,7 @@ function OptimizationBase.instantiate_function( hess = nothing end - if fgh == true && f.fgh !== nothing + if fgh == true && f.fgh === nothing function fgh!(G, H, θ) (y, _, _) = value_derivative_and_second_derivative!(_f, G, H, θ, extras_hess) return y @@ -406,7 +403,7 @@ function OptimizationBase.instantiate_function( fgh! = nothing end - if hv == true && f.hv !== nothing + if hv == true && f.hv === nothing extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, zeros(eltype(x), size(x))) function hv!(H, θ, v) hvp!(_f, H, soadtype.dense_ad, θ, v, extras_hvp) @@ -443,7 +440,7 @@ function OptimizationBase.instantiate_function( θ = augvars[1:length(x)] σ = augvars[length(x) + 1] λ = augvars[(length(x) + 2):end] - return σ * _f(θ) + dot(λ, cons(θ)) + return σ * _f(θ) + dot(λ, cons_oop(θ)) end end @@ -466,7 +463,8 @@ function OptimizationBase.instantiate_function( end if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing - extras_pullback = prepare_pullback(cons_oop, adtype, x) + extras_pullback = prepare_pullback( + cons_oop, adtype.dense_ad, x, ones(eltype(x), num_cons)) function cons_vjp!(J, θ, v) pullback!(cons_oop, J, adtype.dense_ad, θ, v, extras_pullback) end @@ -477,7 +475,8 @@ function OptimizationBase.instantiate_function( end if f.cons_jvp === nothing && cons_jvp == true && cons !== nothing - extras_pushforward = prepare_pushforward(cons_oop, adtype, x) + extras_pushforward = prepare_pushforward( + cons_oop, adtype.dense_ad, x, ones(eltype(x), length(x))) function cons_jvp!(J, θ, v) pushforward!(cons_oop, J, adtype.dense_ad, θ, v, extras_pushforward) end @@ -510,10 +509,11 @@ function OptimizationBase.instantiate_function( end lag_hess_prototype = f.lag_hess_prototype - if cons !== nothing && cons_h == true && f.lag_h === nothing + lag_hess_colors = f.lag_hess_colorvec + if cons !== nothing && f.lag_h === nothing && lag_h == true lag_extras = prepare_hessian( lagrangian, soadtype, vcat(x, [one(eltype(x))], ones(eltype(x), num_cons))) - lag_hess_prototype = lag_extras.coloring_result.S[1:length(θ), 1:length(θ)] + lag_hess_prototype = lag_extras.coloring_result.S[1:length(x), 1:length(x)] lag_hess_colors = lag_extras.coloring_result.color function lag_h!(H::AbstractMatrix, θ, σ, λ) @@ -587,14 +587,11 @@ end function OptimizationBase.instantiate_function( f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, - adtype::ADTypes.AutoSparse{<:AutoZygote}, num_cons = 0; - g = false, h = false, hv = false, fg = false, fgh = false, - cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false) + adtype::ADTypes.AutoSparse{<:AutoZygote}, num_cons = 0; kwargs...) x = cache.u0 p = cache.p - return OptimizationBase.instantiate_function( - f, x, adtype, p, num_cons; g, h, hv, fg, fgh, cons_j, cons_vjp, cons_jvp, cons_h) + return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons; kwargs...) end end diff --git a/src/OptimizationDIExt.jl b/src/OptimizationDIExt.jl index 2f346b5..82826aa 100644 --- a/src/OptimizationDIExt.jl +++ b/src/OptimizationDIExt.jl @@ -104,7 +104,7 @@ function instantiate_function( hess = nothing end - if fgh == true && f.fgh !== nothing + if fgh == true && f.fgh === nothing function fgh!(G, H, θ) (y, _, _) = value_derivative_and_second_derivative!( _f, G, H, soadtype, θ, extras_hess) @@ -229,7 +229,7 @@ function instantiate_function( if cons !== nothing && lag_h == true && f.lag_h === nothing lag_extras = prepare_hessian( lagrangian, soadtype, vcat(x, [one(eltype(x))], ones(eltype(x), num_cons))) - lag_hess_prototype = zeros(Bool, length(x), length(x)) + lag_hess_prototype = zeros(Bool, length(x) + num_cons + 1, length(x) + num_cons + 1) function lag_h!(H::AbstractMatrix, θ, σ, λ) if σ == zero(eltype(θ)) @@ -263,7 +263,7 @@ function instantiate_function( 1:length(θ), 1:length(θ)]) end end - + function lag_h!(h::AbstractVector, θ, σ, λ, p) global _p = p H = hessian(lagrangian, soadtype, vcat(θ, [σ], λ), lag_extras) @@ -301,16 +301,12 @@ end function instantiate_function( f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, - adtype::ADTypes.AbstractADType, num_cons = 0, - g = false, h = false, hv = false, fg = false, fgh = false, - cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, - lag_h = false) + adtype::ADTypes.AbstractADType, num_cons = 0; + kwargs...) x = cache.u0 p = cache.p - return instantiate_function(f, x, adtype, p, num_cons; g = g, h = h, hv = hv, - fg = fg, fgh = fgh, cons_j = cons_j, cons_vjp = cons_vjp, cons_jvp = cons_jvp, - cons_h = cons_h, lag_h = lag_h) + return instantiate_function(f, x, adtype, p, num_cons; kwargs...) end function instantiate_function( @@ -392,7 +388,7 @@ function instantiate_function( hess = nothing end - if fgh == true && f.fgh !== nothing + if fgh == true && f.fgh === nothing function fgh!(θ) (y, G, H) = value_derivative_and_second_derivative(_f, adtype, θ, extras_hess) return y, G, H @@ -511,7 +507,7 @@ function instantiate_function( if cons !== nothing && lag_h == true && f.lag_h === nothing lag_extras = prepare_hessian( lagrangian, soadtype, vcat(x, [one(eltype(x))], ones(eltype(x), num_cons))) - lag_hess_prototype = zeros(Bool, length(x), length(x)) + lag_hess_prototype = zeros(Bool, length(x) + num_cons + 1, length(x) + num_cons + 1) function lag_h!(θ, σ, λ) if σ == zero(eltype(θ)) @@ -558,9 +554,9 @@ end function instantiate_function( f::OptimizationFunction{false}, cache::OptimizationBase.ReInitCache, - adtype::ADTypes.AbstractADType, num_cons = 0) + adtype::ADTypes.AbstractADType, num_cons = 0; kwargs...) x = cache.u0 p = cache.p - return instantiate_function(f, x, adtype, p, num_cons) + return instantiate_function(f, x, adtype, p, num_cons; kwargs...) end diff --git a/src/OptimizationDISparseExt.jl b/src/OptimizationDISparseExt.jl index ccd8836..b0ec48b 100644 --- a/src/OptimizationDISparseExt.jl +++ b/src/OptimizationDISparseExt.jl @@ -135,7 +135,7 @@ function instantiate_function( grad = nothing end - if fg == true && f.fg !== nothing + if fg == true && f.fg === nothing if g == false extras_grad = prepare_gradient(_f, adtype.dense_ad, x) end @@ -184,7 +184,7 @@ function instantiate_function( hess = nothing end - if fgh == true && f.fgh !== nothing + if fgh == true && f.fgh === nothing function fgh!(G, H, θ) (y, _, _) = value_derivative_and_second_derivative!(_f, G, H, θ, extras_hess) return y @@ -192,7 +192,8 @@ function instantiate_function( if p !== SciMLBase.NullParameters() && p !== nothing function fgh!(G, H, θ, p) global _p = p - (y, _, _) = value_derivative_and_second_derivative!(_f, G, H, θ, extras_hess) + (y, _, _) = value_derivative_and_second_derivative!( + _f, G, H, θ, extras_hess) return y end end @@ -264,8 +265,9 @@ function instantiate_function( cons_j! = nothing end - if f.cons_vjp === nothing && cons_vjp == true - extras_pullback = prepare_pullback(cons_oop, adtype, x, ones(eltype(x), num_cons)) + if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing + extras_pullback = prepare_pullback( + cons_oop, adtype.dense_ad, x, ones(eltype(x), num_cons)) function cons_vjp!(J, θ, v) pullback!(cons_oop, J, adtype.dense_ad, θ, v, extras_pullback) end @@ -275,9 +277,9 @@ function instantiate_function( cons_vjp! = nothing end - if f.cons_jvp === nothing && cons_jvp == true + if f.cons_jvp === nothing && cons_jvp == true && cons !== nothing extras_pushforward = prepare_pushforward( - cons_oop, adtype, x, ones(eltype(x), length(x))) + cons_oop, adtype.dense_ad, x, ones(eltype(x), length(x))) function cons_jvp!(J, θ, v) pushforward!(cons_oop, J, adtype.dense_ad, θ, v, extras_pushforward) end @@ -351,7 +353,7 @@ function instantiate_function( 1:length(θ), 1:length(θ)] end end - + function lag_h!(h, θ, σ, λ, p) global _p = p H = hessian(lagrangian, soadtype, vcat(θ, [σ], λ), lag_extras)[ @@ -394,16 +396,11 @@ end function instantiate_function( f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, - adtype::ADTypes.AutoSparse{<:AbstractADType}, num_cons = 0; - g = false, h = false, hv = false, fg = false, fgh = false, - cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, - lag_h = false) + adtype::ADTypes.AutoSparse{<:AbstractADType}, num_cons = 0; kwargs...) x = cache.u0 p = cache.p - return instantiate_function(f, x, adtype, p, num_cons; g = g, h = h, hv = hv, fg = fg, - fgh = fgh, cons_j = cons_j, cons_vjp = cons_vjp, cons_jvp = cons_jvp, cons_h = cons_h, - lag_h = lag_h) + return instantiate_function(f, x, adtype, p, num_cons; kwargs...) end function instantiate_function( @@ -436,7 +433,7 @@ function instantiate_function( grad = nothing end - if fg == true && f.fg !== nothing + if fg == true && f.fg === nothing if g == false extras_grad = prepare_gradient(_f, adtype.dense_ad, x) end @@ -457,7 +454,7 @@ function instantiate_function( fg! = nothing end - if fgh == true && f.fgh !== nothing + if fgh == true && f.fgh === nothing function fgh!(θ) (y, G, H) = value_derivative_and_second_derivative(_f, soadtype, θ, extras_hess) return y, G, H @@ -466,7 +463,8 @@ function instantiate_function( if p !== SciMLBase.NullParameters() && p !== nothing function fgh!(θ, p) global _p = p - (y, G, H) = value_derivative_and_second_derivative(_f, soadtype, θ, extras_hess) + (y, G, H) = value_derivative_and_second_derivative( + _f, soadtype, θ, extras_hess) return y, G, H end end @@ -559,10 +557,11 @@ function instantiate_function( cons_j! = nothing end - if f.cons_vjp === nothing && cons_vjp == true - extras_pullback = prepare_pullback(cons, adtype, x, ones(eltype(x), num_cons)) + if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing + extras_pullback = prepare_pullback( + cons, adtype.dense_ad, x, ones(eltype(x), num_cons)) function cons_vjp!(θ, v) - pullback(cons, adtype, θ, v, extras_pullback) + pullback(cons, adtype.dense_ad, θ, v, extras_pullback) end elseif cons_vjp === true && cons !== nothing cons_vjp! = (θ, v) -> f.cons_vjp(θ, v, p) @@ -570,11 +569,11 @@ function instantiate_function( cons_vjp! = nothing end - if f.cons_jvp === nothing && cons_jvp == true + if f.cons_jvp === nothing && cons_jvp == true && cons !== nothing extras_pushforward = prepare_pushforward( - cons, adtype, x, ones(eltype(x), length(x))) + cons, adtype.dense_ad, x, ones(eltype(x), length(x))) function cons_jvp!(θ, v) - pushforward(cons, adtype, θ, v, extras_pushforward) + pushforward(cons, adtype.dense_ad, θ, v, extras_pushforward) end elseif cons_jvp === true && cons !== nothing cons_jvp! = (θ, v) -> f.cons_jvp(θ, v, p) @@ -660,9 +659,9 @@ end function instantiate_function( f::OptimizationFunction{false}, cache::OptimizationBase.ReInitCache, - adtype::ADTypes.AutoSparse{<:AbstractADType}, num_cons = 0) + adtype::ADTypes.AutoSparse{<:AbstractADType}, num_cons = 0; kwargs...) x = cache.u0 p = cache.p - return instantiate_function(f, x, adtype, p, num_cons) + return instantiate_function(f, x, adtype, p, num_cons; kwargs...) end diff --git a/src/cache.jl b/src/cache.jl index 6dd196a..e7722f5 100644 --- a/src/cache.jl +++ b/src/cache.jl @@ -5,7 +5,7 @@ struct AnalysisResults constraints::Union{Nothing, Vector{AnalysisResult}} end -struct OptimizationCache{F, RC, LB, UB, LC, UC, S, O, D, P, C, M} <: +struct OptimizationCache{F, RC, LB, UB, LC, UC, S, O, P, C, M} <: SciMLBase.AbstractOptimizationCache f::F reinit_cache::RC @@ -15,7 +15,6 @@ struct OptimizationCache{F, RC, LB, UB, LC, UC, S, O, D, P, C, M} <: ucons::UC sense::S opt::O - data::D progress::P callback::C manifold::M @@ -23,7 +22,7 @@ struct OptimizationCache{F, RC, LB, UB, LC, UC, S, O, D, P, C, M} <: solver_args::NamedTuple end -function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt, data = DEFAULT_DATA; +function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt; callback = DEFAULT_CALLBACK, maxiters::Union{Number, Nothing} = nothing, maxtime::Union{Number, Nothing} = nothing, @@ -36,8 +35,9 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt, data = DEFA reinit_cache = OptimizationBase.ReInitCache(prob.u0, prob.p) num_cons = prob.ucons === nothing ? 0 : length(prob.ucons) f = OptimizationBase.instantiate_function( - prob.f, reinit_cache, prob.f.adtype, num_cons, - g = SciMLBase.requiresgradient(opt), h = SciMLBase.requireshessian(opt), fg = SciMLBase.allowsfg(opt), + prob.f, reinit_cache, prob.f.adtype, num_cons; + g = SciMLBase.requiresgradient(opt), h = SciMLBase.requireshessian(opt), + hv = SciMLBase.requireshessian(opt), fg = SciMLBase.allowsfg(opt), fgh = SciMLBase.allowsfgh(opt), cons_j = SciMLBase.requiresconsjac(opt), cons_h = SciMLBase.requiresconshess(opt), cons_vjp = SciMLBase.allowsconsjvp(opt), cons_jvp = SciMLBase.allowsconsjvp(opt), lag_h = SciMLBase.requireslagh(opt)) @@ -149,13 +149,12 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt, data = DEFA return OptimizationCache(f, reinit_cache, prob.lb, prob.ub, prob.lcons, prob.ucons, prob.sense, - opt, data, progress, callback, manifold, AnalysisResults(obj_res, cons_res), + opt, progress, callback, manifold, AnalysisResults(obj_res, cons_res), merge((; maxiters, maxtime, abstol, reltol), NamedTuple(kwargs))) end -function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, opt, - data = DEFAULT_DATA; +function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, opt; callback = DEFAULT_CALLBACK, maxiters::Union{Number, Nothing} = nothing, maxtime::Union{Number, Nothing} = nothing, @@ -163,7 +162,7 @@ function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, opt, reltol::Union{Number, Nothing} = nothing, progress = false, kwargs...) - return OptimizationCache(prob, opt, data; maxiters, maxtime, abstol, callback, + return OptimizationCache(prob, opt; maxiters, maxtime, abstol, callback, reltol, progress, kwargs...) end diff --git a/src/function.jl b/src/function.jl index 1343900..c5d3e94 100644 --- a/src/function.jl +++ b/src/function.jl @@ -43,7 +43,8 @@ function that is not defined, an error is thrown. For more information on the use of automatic differentiation, see the documentation of the `AbstractADType` types. """ -function instantiate_function(f::MultiObjectiveOptimizationFunction, x, ::SciMLBase.NoAD, +function OptimizationBase.instantiate_function( + f::MultiObjectiveOptimizationFunction, x, ::SciMLBase.NoAD, p, num_cons = 0) jac = f.jac === nothing ? nothing : (J, x, args...) -> f.jac(J, x, p, args...) hess = f.hess === nothing ? nothing : @@ -76,7 +77,7 @@ function instantiate_function(f::MultiObjectiveOptimizationFunction, x, ::SciMLB observed = f.observed) end -function instantiate_function( +function OptimizationBase.instantiate_function( f::MultiObjectiveOptimizationFunction, cache::ReInitCache, ::SciMLBase.NoAD, num_cons = 0) jac = f.jac === nothing ? nothing : (J, x, args...) -> f.jac(J, x, cache.p, args...) @@ -110,19 +111,90 @@ function instantiate_function( observed = f.observed) end -function instantiate_function(f::OptimizationFunction{true}, x, ::SciMLBase.NoAD, - p, num_cons = 0, kwargs...) - grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, p, args...) - fg = f.fg === nothing ? nothing : (G, x, args...) -> f.fg(G, x, p, args...) - hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, p, args...) - fgh = f.fgh === nothing ? nothing : (G, H, x, args...) -> f.fgh(G, H, x, p, args...) - hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, p, args...) +function OptimizationBase.instantiate_function( + f::OptimizationFunction{true}, x, ::SciMLBase.NoAD, + p, num_cons = 0; kwargs...) + if f.grad === nothing + grad = nothing + else + function grad(G, x) + return f.grad(G, x, p) + end + if p != SciMLBase.NullParameters() + function grad(G, x, p) + return f.grad(G, x, p) + end + end + end + if f.fg === nothing + fg = nothing + else + function fg(G, x) + return f.fg(G, x, p) + end + if p != SciMLBase.NullParameters() + function fg(G, x, p) + return f.fg(G, x, p) + end + end + end + if f.hess === nothing + hess = nothing + else + function hess(H, x) + return f.hess(H, x, p) + end + if p != SciMLBase.NullParameters() + function hess(H, x, p) + return f.hess(H, x, p) + end + end + end + + if f.fgh === nothing + fgh = nothing + else + function fgh(G, H, x) + return f.fgh(G, H, x, p) + end + if p != SciMLBase.NullParameters() + function fgh(G, H, x, p) + return f.fgh(G, H, x, p) + end + end + end + + if f.hv === nothing + hv = nothing + else + function hv(H, x, v) + return f.hv(H, x, v, p) + end + if p != SciMLBase.NullParameters() + function hv(H, x, v, p) + return f.hv(H, x, v, p) + end + end + end + cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, p) cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, p) cons_vjp = f.cons_vjp === nothing ? nothing : (res, x) -> f.cons_vjp(res, x, p) cons_jvp = f.cons_jvp === nothing ? nothing : (res, x) -> f.cons_jvp(res, x, p) cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, p) - lag_h = f.lag_h === nothing ? nothing : (res, x) -> f.lag_h(res, x, p) + + if f.lag_h === nothing + lag_h = nothing + else + function lag_h(res, x) + return f.lag_h(res, x, p) + end + if p != SciMLBase.NullParameters() + function lag_h(res, x, p) + return f.lag_h(res, x, p) + end + end + end hess_prototype = f.hess_prototype === nothing ? nothing : convert.(eltype(x), f.hess_prototype) cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing : @@ -146,17 +218,17 @@ function instantiate_function(f::OptimizationFunction{true}, x, ::SciMLBase.NoAD observed = f.observed) end -function instantiate_function( +function OptimizationBase.instantiate_function( f::OptimizationFunction{true}, cache::ReInitCache, ::SciMLBase.NoAD, - num_cons = 0, kwargs...) + num_cons = 0; kwargs...) x = cache.u0 p = cache.p - return instantiate_function(f, x, SciMLBase.NoAD(), p, num_cons, kwargs...) + return instantiate_function(f, x, SciMLBase.NoAD(), p, num_cons; kwargs...) end function instantiate_function(f::OptimizationFunction, x, adtype::ADTypes.AbstractADType, - p, num_cons = 0, kwargs...) + p, num_cons = 0; kwargs...) adtypestr = string(adtype) _strtind = findfirst('.', adtypestr) strtind = isnothing(_strtind) ? 5 : _strtind + 5