Skip to content

Commit

Permalink
Merge pull request #368 from SciML/complex-u0-for-OptimizationOptimisers
Browse files Browse the repository at this point in the history
error with complex number u0 in OptimizationOptimisers
  • Loading branch information
Vaibhavdixit02 authored Sep 2, 2022
2 parents becceb8 + d52ae70 commit 9b0ab2c
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 1 deletion.
1 change: 1 addition & 0 deletions lib/OptimizationOptimisers/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
julia = "1"
Expand Down
2 changes: 1 addition & 1 deletion lib/OptimizationOptimisers/src/OptimizationOptimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function SciMLBase.__solve(prob::OptimizationProblem, opt::OptimisersOptimizers,
G = copy(θ)

local x, min_err, min_θ
min_err = typemax(eltype(prob.u0)) #dummy variables
min_err = typemax(eltype(real(prob.u0))) #dummy variables
min_opt = 1
min_θ = prob.u0

Expand Down
13 changes: 13 additions & 0 deletions lib/OptimizationOptimisers/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using OptimizationOptimisers, Optimization, ForwardDiff
using Test
using Zygote

@testset "OptimizationOptimisers.jl" begin
rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
Expand All @@ -17,4 +18,16 @@ using Test
prob = OptimizationProblem(optprob, x0, _p)
sol = solve(prob, Optimisers.ADAM(), maxiters = 1000, progress = false)
@test 10 * sol.minimum < l1

x0 = 2 * ones(ComplexF64, 2)
_p = ones(2)
sumfunc(x0, _p) = sum(abs2, (x0 - _p))
l1 = sumfunc(x0, _p)

optprob = OptimizationFunction(sumfunc, Optimization.AutoZygote())

prob = OptimizationProblem(optprob, x0, _p)

sol = solve(prob, Optimisers.ADAM(), maxiters = 1000)
@test 10 * sol.minimum < l1
end

0 comments on commit 9b0ab2c

Please sign in to comment.