diff --git a/docs/src/linear_systems/gmres.md b/docs/src/linear_systems/gmres.md index 5f5574d2..aa1c2e47 100644 --- a/docs/src/linear_systems/gmres.md +++ b/docs/src/linear_systems/gmres.md @@ -13,9 +13,9 @@ gmres! The implementation pre-allocates a matrix $V$ of size `n` by `restart` whose columns form an orthonormal basis for the Krylov subspace. This allows BLAS2 operations when updating the solution vector $x$. The Hessenberg matrix is also pre-allocated. -Modified Gram-Schmidt is used to orthogonalize the columns of $V$. +By default, modified Gram-Schmidt is used to orthogonalize the columns of $V$, since it is numerically more stable than classical Gram-Schmidt. Modified Gram-Schmidt is however inherently sequential, and if stability is not a concern, classical Gram-Schmidt can be used, which is implemented using BLAS2 operations. As a compromise the "DGKS criterion" can be used, which conditionally applies classical Gram-Schmidt repeatedly to stabilize it, and is typically one to two times slower than classical Gram-Schmidt. The computation of the residual norm is implemented in a non-standard way, namely keeping track of a vector $\gamma$ in the null-space of $H_k^*$, which is the adjoint of the $(k + 1) \times k$ Hessenberg matrix $H_k$ at the $k$th iteration. Only when $x$ needs to be updated is the Hessenberg matrix mutated with Givens rotations. !!! tip - GMRES can be used as an [iterator](@ref Iterators). This makes it possible to access the Hessenberg matrix and Krylov basis vectors during the iterations. \ No newline at end of file + GMRES can be used as an [iterator](@ref Iterators). This makes it possible to access the Hessenberg matrix and Krylov basis vectors during the iterations. diff --git a/src/gmres.jl b/src/gmres.jl index 56df221d..b737b6ed 100644 --- a/src/gmres.jl +++ b/src/gmres.jl @@ -28,7 +28,7 @@ Residual(order, T::Type) = Residual{T, real(T)}( one(real(T)) ) -mutable struct GMRESIterable{preclT, precrT, solT, rhsT, vecT, arnoldiT <: ArnoldiDecomp, residualT <: Residual, resT <: Real} +mutable struct GMRESIterable{preclT, precrT, solT, rhsT, vecT, arnoldiT <: ArnoldiDecomp, residualT <: Residual, resT <: Real, orthmethT} Pl::preclT Pr::precrT x::solT @@ -44,6 +44,8 @@ mutable struct GMRESIterable{preclT, precrT, solT, rhsT, vecT, arnoldiT <: Arnol maxiter::Int tol::resT β::resT + + orth_meth::orthmethT end converged(g::GMRESIterable) = g.residual.current ≤ g.tol @@ -66,7 +68,8 @@ function iterate(g::GMRESIterable, iteration::Int=start(g)) g.arnoldi.H[g.k + 1, g.k] = orthogonalize_and_normalize!( view(g.arnoldi.V, :, 1 : g.k), view(g.arnoldi.V, :, g.k + 1), - view(g.arnoldi.H, 1 : g.k, g.k) + view(g.arnoldi.H, 1 : g.k, g.k), + g.orth_meth ) # Implicitly computes the residual @@ -109,7 +112,8 @@ function gmres_iterable!(x, A, b; reltol::Real = sqrt(eps(real(eltype(b)))), restart::Int = min(20, size(A, 2)), maxiter::Int = size(A, 2), - initially_zero::Bool = false) + initially_zero::Bool = false, + orth_meth::OrthogonalizationMethod = ModifiedGramSchmidt()) T = eltype(x) # Approximate solution @@ -126,7 +130,8 @@ function gmres_iterable!(x, A, b; GMRESIterable(Pl, Pr, x, b, Ax, arnoldi, residual, - mv_products, restart, 1, maxiter, tolerance, residual.current + mv_products, restart, 1, maxiter, tolerance, residual.current, + orth_meth ) end @@ -163,6 +168,7 @@ Solves the problem ``Ax = b`` with restarted GMRES. - `Pr`: right preconditioner; - `log::Bool`: keep track of the residual norm in each iteration; - `verbose::Bool`: print convergence information during the iterations. +- `orth_meth::OrthogonalizationMethod = ModifiedGramSchmidt()`: orthogonalization method (ModifiedGramSchmidt(), ClassicalGramSchmidt(), DGKS()) # Return values @@ -184,7 +190,8 @@ function gmres!(x, A, b; maxiter::Int = size(A, 2), log::Bool = false, initially_zero::Bool = false, - verbose::Bool = false) + verbose::Bool = false, + orth_meth::OrthogonalizationMethod = ModifiedGramSchmidt()) history = ConvergenceHistory(partial = !log, restart = restart) history[:abstol] = abstol history[:reltol] = reltol @@ -192,7 +199,8 @@ function gmres!(x, A, b; iterable = gmres_iterable!(x, A, b; Pl = Pl, Pr = Pr, abstol = abstol, reltol = reltol, maxiter = maxiter, - restart = restart, initially_zero = initially_zero) + restart = restart, initially_zero = initially_zero, + orth_meth = orth_meth) verbose && @printf("=== gmres ===\n%4s\t%4s\t%7s\n","rest","iter","resnorm") diff --git a/src/orthogonalize.jl b/src/orthogonalize.jl index 28620d85..ade2d413 100644 --- a/src/orthogonalize.jl +++ b/src/orthogonalize.jl @@ -7,9 +7,10 @@ struct ClassicalGramSchmidt <: OrthogonalizationMethod end struct ModifiedGramSchmidt <: OrthogonalizationMethod end # Default to MGS, good enough for solving linear systems. -@inline orthogonalize_and_normalize!(V::StridedMatrix{T}, w::StridedVector{T}, h::StridedVector{T}) where {T} = orthogonalize_and_normalize!(V, w, h, ModifiedGramSchmidt) +@inline orthogonalize_and_normalize!(V::StridedMatrix{T}, w::StridedVector{T}, h::StridedVector{T}) where {T} = + orthogonalize_and_normalize!(V, w, h, ModifiedGramSchmidt()) -function orthogonalize_and_normalize!(V::StridedMatrix{T}, w::StridedVector{T}, h::StridedVector{T}, ::Type{DGKS}) where {T} +function orthogonalize_and_normalize!(V::StridedMatrix{T}, w::StridedVector{T}, h::StridedVector{T}, ::DGKS) where {T} # Orthogonalize using BLAS-2 ops mul!(h, adjoint(V), w) mul!(w, V, h, -one(T), one(T)) @@ -37,7 +38,7 @@ function orthogonalize_and_normalize!(V::StridedMatrix{T}, w::StridedVector{T}, nrm end -function orthogonalize_and_normalize!(V::StridedMatrix{T}, w::StridedVector{T}, h::StridedVector{T}, ::Type{ClassicalGramSchmidt}) where {T} +function orthogonalize_and_normalize!(V::StridedMatrix{T}, w::StridedVector{T}, h::StridedVector{T}, ::ClassicalGramSchmidt) where {T} # Orthogonalize using BLAS-2 ops mul!(h, adjoint(V), w) mul!(w, V, h, -one(T), one(T)) @@ -49,7 +50,7 @@ function orthogonalize_and_normalize!(V::StridedMatrix{T}, w::StridedVector{T}, nrm end -function orthogonalize_and_normalize!(V::StridedVector{Vector{T}}, w::StridedVector{T}, h::StridedVector{T}, ::Type{ModifiedGramSchmidt}) where {T} +function orthogonalize_and_normalize!(V::StridedVector{Vector{T}}, w::StridedVector{T}, h::StridedVector{T}, ::ModifiedGramSchmidt) where {T} # Orthogonalize using BLAS-1 ops for i = 1 : length(V) h[i] = dot(V[i], w) @@ -63,7 +64,7 @@ function orthogonalize_and_normalize!(V::StridedVector{Vector{T}}, w::StridedVec nrm end -function orthogonalize_and_normalize!(V::StridedMatrix{T}, w::StridedVector{T}, h::StridedVector{T}, ::Type{ModifiedGramSchmidt}) where {T} +function orthogonalize_and_normalize!(V::StridedMatrix{T}, w::StridedVector{T}, h::StridedVector{T}, ::ModifiedGramSchmidt) where {T} # Orthogonalize using BLAS-1 ops and column views. for i = 1 : size(V, 2) column = view(V, :, i) diff --git a/test/orthogonalize.jl b/test/orthogonalize.jl index 66cb88eb..9471201e 100644 --- a/test/orthogonalize.jl +++ b/test/orthogonalize.jl @@ -34,7 +34,7 @@ m = 3 end # Assuming V is a matrix - @testset "Using $method" for method = (DGKS, ClassicalGramSchmidt, ModifiedGramSchmidt) + @testset "Using $method" for method = (DGKS(), ClassicalGramSchmidt(), ModifiedGramSchmidt()) # Projection size h = zeros(T, m) @@ -55,7 +55,7 @@ m = 3 # Orthogonalize w in-place w = copy(w_original) - nrm = orthogonalize_and_normalize!(V_vec, w, h, ModifiedGramSchmidt) + nrm = orthogonalize_and_normalize!(V_vec, w, h, ModifiedGramSchmidt()) is_orthonormalized(w, h, nrm) end