Skip to content

Commit

Permalink
add pretty printing to (Verbose/Compact)SolutionResults
Browse files Browse the repository at this point in the history
  • Loading branch information
haakon-e committed Dec 19, 2024
1 parent 074ac0b commit 03af1d9
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 1 deletion.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ version = "0.4.2"

[deps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[compat]
ForwardDiff = "0.10"
Printf = "<0.0.1, 1"
julia = "1.6"
33 changes: 32 additions & 1 deletion src/RootSolvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ julia> sol = find_zero(x -> x^2 - 100^2,
CompactSolution());
julia> sol
RootSolvers.CompactSolutionResults{Float64}(99.99999999994358, true)
CompactSolutionResults{Float64}:
├── Status: converged
└── Root: 99.99999999994358
julia> sol.root
99.99999999994358
Expand All @@ -31,6 +33,7 @@ export AbstractTolerance, ResidualTolerance, SolutionTolerance, RelativeSolution
RelativeOrAbsoluteSolutionTolerance

import ForwardDiff
import Printf: @printf

base_type(::Type{FT}) where {FT} = FT
base_type(::Type{FT}) where {T, FT <: ForwardDiff.Dual{<:Any, T}} = base_type(T)
Expand Down Expand Up @@ -121,6 +124,27 @@ end
SolutionResults(soltype::VerboseSolution, args...) =
VerboseSolutionResults(args...)

function Base.show(io::IO, sol::VerboseSolutionResults{FT}) where {FT}
status = sol.converged ? "\e[32mconverged\e[0m" : "\e[31mfailed to converge\e[0m"
println(io, "VerboseSolutionResults{$FT}:")
println(io, "├── Status: ", status)
println(io, "├── Root: ", sol.root)
println(io, "├── Error: ", sol.err)
println(io, "├── Iterations: ", sol.iter_performed)
println(io, "└── History:")
n_iters = length(sol.root_history)
for i in 1:n_iters
if n_iters > 20 && 9 < i < n_iters - 9
i == 11 && println(io, " ⋮ ⋮ ⋮")
i == 12 && println(io, " ⋮ ⋮ ⋮")
continue
end
prefix = i == n_iters ? " └──" : " ├──"
@printf(io, "%s iter %2d: x = %8.5g, err = %.4g\n",
prefix, i, sol.root_history[i], sol.err_history[i])
end
end

"""
CompactSolution <: SolutionType
Expand Down Expand Up @@ -149,6 +173,13 @@ end
SolutionResults(soltype::CompactSolution, root, converged, args...) =
CompactSolutionResults(root, converged)

function Base.show(io::IO, sol::CompactSolutionResults{FT}) where {FT}
status = sol.converged ? "\e[32mconverged\e[0m" : "\e[31mfailed to converge\e[0m"
println(io, "CompactSolutionResults{$FT}:")
println(io, "├── Status: ", status)
println(io, "└── Root: ", sol.root)
end

init_history(::VerboseSolution, x::FT) where {FT <: Real} = FT[x]
init_history(::CompactSolution, x) = nothing
init_history(::VerboseSolution, ::Type{FT}) where {FT <: Real} = FT[]
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
RootSolvers = "7181ea78-2dcb-4de3-ab41-2b8ab5a31e74"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,5 @@ end
end

include("runtests_kernel.jl")

include("test_printing.jl")
41 changes: 41 additions & 0 deletions test/test_printing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using Printf

@testset "solution pretty printing" begin
@testset "CompactSolution" begin
sol = find_zero(x -> x^2 - 100^2,
SecantMethod{Float64}(0.0, 1000.0),
CompactSolution());
sol_str = sprint(show, sol)
@test startswith(sol_str, "CompactSolutionResults{Float64}")
@test contains(sol_str, "converged")
sol = find_zero(x -> x^2 - 100^2,
SecantMethod{Float64}(0.0, 1e3),
CompactSolution(),
RelativeSolutionTolerance(eps(10.0)),
2)
sol_str = sprint(show, sol)
@test startswith(sol_str, "CompactSolutionResults{Float64}")
@test contains(sol_str, "failed to converge")
end
@testset "VerboseSolution" begin
sol = find_zero(x -> x^2 - 100^2,
SecantMethod{Float64}(0.0, 1000.0),
VerboseSolution());
sol_str = sprint(show, sol)
@test startswith(sol_str, "VerboseSolutionResults{Float64}")
@test contains(sol_str, "converged")
@test contains(sol_str, "Root: $(sol.root)")
@test contains(sol_str, "Error: $(sol.err)")
@test contains(sol_str, "Iterations: $(length(sol.root_history)-1)")
@test contains(sol_str, "History")
sol = find_zero(x -> x^2 - 100^2,
SecantMethod{Float64}(0.0, 1e3),
VerboseSolution(),
RelativeSolutionTolerance(eps(10.0)),
2)
sol_str = sprint(show, sol)
@test startswith(sol_str, "VerboseSolutionResults{Float64}")
@test contains(sol_str, "failed to converge")

end
end

0 comments on commit 03af1d9

Please sign in to comment.