Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Adapt.jl to change storage and element type #2212

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Michael Schlottke-Lakemper <[email protected]>", "Gregor
version = "0.9.15-DEV"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
CodeTracking = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Expand Down Expand Up @@ -64,6 +65,7 @@ TrixiMakieExt = "Makie"
TrixiNLsolveExt = "NLsolve"

[compat]
Adapt = "3.7, 4.0"
Accessors = "0.1.12"
CodeTracking = "1.0.5"
ConstructionBase = "1.3"
Expand Down
1 change: 1 addition & 0 deletions src/Trixi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import SciMLBase: get_du, get_tmp_cache, u_modified!,

using DelimitedFiles: readdlm
using Downloads: Downloads
import Adapt
using CodeTracking: CodeTracking
using ConstructionBase: ConstructionBase
using DiffEqCallbacks: PeriodicCallback, PeriodicCallbackAffect
Expand Down
22 changes: 22 additions & 0 deletions src/auxiliary/vector_of_arrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# By default, Julia/LLVM does not use fused multiply-add operations (FMAs).
# Since these FMAs can increase the performance of many numerical algorithms,
# we need to opt-in explicitly.
# See https://ranocha.de/blog/Optimizing_EC_Trixi for further details.
@muladd begin
#! format: noindent

# Wraps a Vector of Arrays, forwards `getindex` to the underlying Vector.
# Implements `Adapt.adapt_structure` to allow offloading to the GPU which is
# not possible for a plain Vector of Arrays.
struct VecOfArrays{T <: AbstractArray}
arrays::Vector{T}
end
Base.getindex(v::VecOfArrays, i::Int) = Base.getindex(v.arrays, i)
Base.IndexStyle(v::VecOfArrays) = Base.IndexStyle(v.arrays)
Base.size(v::VecOfArrays) = Base.size(v.arrays)
Base.length(v::VecOfArrays) = Base.length(v.arrays)
Base.eltype(v::VecOfArrays{T}) where {T} = T
function Adapt.adapt_structure(to, v::VecOfArrays)
return [Adapt.adapt(to, arr) for arr in v.arrays] |> VecOfArrays
end
end # @muladd
28 changes: 18 additions & 10 deletions src/semidiscretization/semidiscretization_hyperbolic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,22 @@ mutable struct SemidiscretizationHyperbolic{Mesh, Equations, InitialCondition,

function SemidiscretizationHyperbolic{Mesh, Equations, InitialCondition,
BoundaryConditions, SourceTerms, Solver,
Cache}(mesh::Mesh, equations::Equations,
Cache}(mesh::Mesh,
equations::Equations,
initial_condition::InitialCondition,
boundary_conditions::BoundaryConditions,
source_terms::SourceTerms,
solver::Solver,
cache::Cache) where {Mesh, Equations,
InitialCondition,
BoundaryConditions,
SourceTerms,
Solver,
Cache}
performance_counter = PerformanceCounter()

cache::Cache,
performance_counter::PerformanceCounter) where {
Mesh,
Equations,
InitialCondition,
BoundaryConditions,
SourceTerms,
Solver,
Cache
}
new(mesh, equations, initial_condition, boundary_conditions, source_terms,
solver, cache, performance_counter)
end
Expand Down Expand Up @@ -74,14 +77,17 @@ function SemidiscretizationHyperbolic(mesh, equations, initial_condition, solver

check_periodicity_mesh_boundary_conditions(mesh, _boundary_conditions)

performance_counter = PerformanceCounter()

SemidiscretizationHyperbolic{typeof(mesh), typeof(equations),
typeof(initial_condition),
typeof(_boundary_conditions), typeof(source_terms),
typeof(solver), typeof(cache)}(mesh, equations,
initial_condition,
_boundary_conditions,
source_terms, solver,
cache)
cache,
performance_counter)
end

# Create a new semidiscretization but change some parameters compared to the input.
Expand All @@ -103,6 +109,8 @@ function remake(semi::SemidiscretizationHyperbolic; uEltype = real(semi.solver),
source_terms, boundary_conditions, uEltype)
end

Adapt.@adapt_structure(SemidiscretizationHyperbolic)

# general fallback
function digest_boundary_conditions(boundary_conditions, mesh, solver, cache)
boundary_conditions
Expand Down
4 changes: 4 additions & 0 deletions src/solvers/dgsem/basis_lobatto_legendre.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ In particular, not the nodes themselves are returned.

@inline get_nodes(basis::LobattoLegendreBasis) = basis.nodes

Adapt.@adapt_structure(LobattoLegendreBasis)

"""
integrate(f, u, basis::LobattoLegendreBasis)

Expand Down Expand Up @@ -209,6 +211,8 @@ end

@inline polydeg(mortar::LobattoLegendreMortarL2) = nnodes(mortar) - 1

Adapt.@adapt_structure(LobattoLegendreMortarL2)

# TODO: We can create EC mortars along the lines of the following implementation.
# abstract type AbstractMortarEC{RealT} <: AbstractMortar{RealT} end

Expand Down
Loading
Loading