Skip to content

Commit

Permalink
Make orthogonalization method chooseable for gmres (#293)
Browse files Browse the repository at this point in the history
  • Loading branch information
jkrch authored Mar 19, 2021
1 parent e8b795c commit ae01dfe
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 15 deletions.
4 changes: 2 additions & 2 deletions docs/src/linear_systems/gmres.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ gmres!

The implementation pre-allocates a matrix $V$ of size `n` by `restart` whose columns form an orthonormal basis for the Krylov subspace. This allows BLAS2 operations when updating the solution vector $x$. The Hessenberg matrix is also pre-allocated.

Modified Gram-Schmidt is used to orthogonalize the columns of $V$.
By default, modified Gram-Schmidt is used to orthogonalize the columns of $V$, since it is numerically more stable than classical Gram-Schmidt. Modified Gram-Schmidt is however inherently sequential, and if stability is not a concern, classical Gram-Schmidt can be used, which is implemented using BLAS2 operations. As a compromise the "DGKS criterion" can be used, which conditionally applies classical Gram-Schmidt repeatedly to stabilize it, and is typically one to two times slower than classical Gram-Schmidt.

The computation of the residual norm is implemented in a non-standard way, namely keeping track of a vector $\gamma$ in the null-space of $H_k^*$, which is the adjoint of the $(k + 1) \times k$ Hessenberg matrix $H_k$ at the $k$th iteration. Only when $x$ needs to be updated is the Hessenberg matrix mutated with Givens rotations.

!!! tip
GMRES can be used as an [iterator](@ref Iterators). This makes it possible to access the Hessenberg matrix and Krylov basis vectors during the iterations.
GMRES can be used as an [iterator](@ref Iterators). This makes it possible to access the Hessenberg matrix and Krylov basis vectors during the iterations.
20 changes: 14 additions & 6 deletions src/gmres.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Residual(order, T::Type) = Residual{T, real(T)}(
one(real(T))
)

mutable struct GMRESIterable{preclT, precrT, solT, rhsT, vecT, arnoldiT <: ArnoldiDecomp, residualT <: Residual, resT <: Real}
mutable struct GMRESIterable{preclT, precrT, solT, rhsT, vecT, arnoldiT <: ArnoldiDecomp, residualT <: Residual, resT <: Real, orthmethT}
Pl::preclT
Pr::precrT
x::solT
Expand All @@ -44,6 +44,8 @@ mutable struct GMRESIterable{preclT, precrT, solT, rhsT, vecT, arnoldiT <: Arnol
maxiter::Int
tol::resT
β::resT

orth_meth::orthmethT
end

converged(g::GMRESIterable) = g.residual.current g.tol
Expand All @@ -66,7 +68,8 @@ function iterate(g::GMRESIterable, iteration::Int=start(g))
g.arnoldi.H[g.k + 1, g.k] = orthogonalize_and_normalize!(
view(g.arnoldi.V, :, 1 : g.k),
view(g.arnoldi.V, :, g.k + 1),
view(g.arnoldi.H, 1 : g.k, g.k)
view(g.arnoldi.H, 1 : g.k, g.k),
g.orth_meth
)

# Implicitly computes the residual
Expand Down Expand Up @@ -109,7 +112,8 @@ function gmres_iterable!(x, A, b;
reltol::Real = sqrt(eps(real(eltype(b)))),
restart::Int = min(20, size(A, 2)),
maxiter::Int = size(A, 2),
initially_zero::Bool = false)
initially_zero::Bool = false,
orth_meth::OrthogonalizationMethod = ModifiedGramSchmidt())
T = eltype(x)

# Approximate solution
Expand All @@ -126,7 +130,8 @@ function gmres_iterable!(x, A, b;

GMRESIterable(Pl, Pr, x, b, Ax,
arnoldi, residual,
mv_products, restart, 1, maxiter, tolerance, residual.current
mv_products, restart, 1, maxiter, tolerance, residual.current,
orth_meth
)
end

Expand Down Expand Up @@ -163,6 +168,7 @@ Solves the problem ``Ax = b`` with restarted GMRES.
- `Pr`: right preconditioner;
- `log::Bool`: keep track of the residual norm in each iteration;
- `verbose::Bool`: print convergence information during the iterations.
- `orth_meth::OrthogonalizationMethod = ModifiedGramSchmidt()`: orthogonalization method (ModifiedGramSchmidt(), ClassicalGramSchmidt(), DGKS())
# Return values
Expand All @@ -184,15 +190,17 @@ function gmres!(x, A, b;
maxiter::Int = size(A, 2),
log::Bool = false,
initially_zero::Bool = false,
verbose::Bool = false)
verbose::Bool = false,
orth_meth::OrthogonalizationMethod = ModifiedGramSchmidt())
history = ConvergenceHistory(partial = !log, restart = restart)
history[:abstol] = abstol
history[:reltol] = reltol
log && reserve!(history, :resnorm, maxiter)

iterable = gmres_iterable!(x, A, b; Pl = Pl, Pr = Pr,
abstol = abstol, reltol = reltol, maxiter = maxiter,
restart = restart, initially_zero = initially_zero)
restart = restart, initially_zero = initially_zero,
orth_meth = orth_meth)

verbose && @printf("=== gmres ===\n%4s\t%4s\t%7s\n","rest","iter","resnorm")

Expand Down
11 changes: 6 additions & 5 deletions src/orthogonalize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ struct ClassicalGramSchmidt <: OrthogonalizationMethod end
struct ModifiedGramSchmidt <: OrthogonalizationMethod end

# Default to MGS, good enough for solving linear systems.
@inline orthogonalize_and_normalize!(V::StridedMatrix{T}, w::StridedVector{T}, h::StridedVector{T}) where {T} = orthogonalize_and_normalize!(V, w, h, ModifiedGramSchmidt)
@inline orthogonalize_and_normalize!(V::StridedMatrix{T}, w::StridedVector{T}, h::StridedVector{T}) where {T} =
orthogonalize_and_normalize!(V, w, h, ModifiedGramSchmidt())

function orthogonalize_and_normalize!(V::StridedMatrix{T}, w::StridedVector{T}, h::StridedVector{T}, ::Type{DGKS}) where {T}
function orthogonalize_and_normalize!(V::StridedMatrix{T}, w::StridedVector{T}, h::StridedVector{T}, ::DGKS) where {T}
# Orthogonalize using BLAS-2 ops
mul!(h, adjoint(V), w)
mul!(w, V, h, -one(T), one(T))
Expand Down Expand Up @@ -37,7 +38,7 @@ function orthogonalize_and_normalize!(V::StridedMatrix{T}, w::StridedVector{T},
nrm
end

function orthogonalize_and_normalize!(V::StridedMatrix{T}, w::StridedVector{T}, h::StridedVector{T}, ::Type{ClassicalGramSchmidt}) where {T}
function orthogonalize_and_normalize!(V::StridedMatrix{T}, w::StridedVector{T}, h::StridedVector{T}, ::ClassicalGramSchmidt) where {T}
# Orthogonalize using BLAS-2 ops
mul!(h, adjoint(V), w)
mul!(w, V, h, -one(T), one(T))
Expand All @@ -49,7 +50,7 @@ function orthogonalize_and_normalize!(V::StridedMatrix{T}, w::StridedVector{T},
nrm
end

function orthogonalize_and_normalize!(V::StridedVector{Vector{T}}, w::StridedVector{T}, h::StridedVector{T}, ::Type{ModifiedGramSchmidt}) where {T}
function orthogonalize_and_normalize!(V::StridedVector{Vector{T}}, w::StridedVector{T}, h::StridedVector{T}, ::ModifiedGramSchmidt) where {T}
# Orthogonalize using BLAS-1 ops
for i = 1 : length(V)
h[i] = dot(V[i], w)
Expand All @@ -63,7 +64,7 @@ function orthogonalize_and_normalize!(V::StridedVector{Vector{T}}, w::StridedVec
nrm
end

function orthogonalize_and_normalize!(V::StridedMatrix{T}, w::StridedVector{T}, h::StridedVector{T}, ::Type{ModifiedGramSchmidt}) where {T}
function orthogonalize_and_normalize!(V::StridedMatrix{T}, w::StridedVector{T}, h::StridedVector{T}, ::ModifiedGramSchmidt) where {T}
# Orthogonalize using BLAS-1 ops and column views.
for i = 1 : size(V, 2)
column = view(V, :, i)
Expand Down
4 changes: 2 additions & 2 deletions test/orthogonalize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ m = 3
end

# Assuming V is a matrix
@testset "Using $method" for method = (DGKS, ClassicalGramSchmidt, ModifiedGramSchmidt)
@testset "Using $method" for method = (DGKS(), ClassicalGramSchmidt(), ModifiedGramSchmidt())

# Projection size
h = zeros(T, m)
Expand All @@ -55,7 +55,7 @@ m = 3

# Orthogonalize w in-place
w = copy(w_original)
nrm = orthogonalize_and_normalize!(V_vec, w, h, ModifiedGramSchmidt)
nrm = orthogonalize_and_normalize!(V_vec, w, h, ModifiedGramSchmidt())

is_orthonormalized(w, h, nrm)
end
Expand Down

0 comments on commit ae01dfe

Please sign in to comment.