Skip to content

Commit

Permalink
fix get_backend
Browse files Browse the repository at this point in the history
  • Loading branch information
maximilian-gelbrecht committed Dec 1, 2024
1 parent 34d0b0b commit 4e81d95
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 6 deletions.
2 changes: 2 additions & 0 deletions src/LowerTriangularMatrices/LowerTriangularMatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ using DocStringExtensions
# GPU
import Adapt
import GPUArrays
import KernelAbstractions
import KernelAbstractions: get_backend

# NUMERICS
import LinearAlgebra: tril!
Expand Down
6 changes: 3 additions & 3 deletions src/LowerTriangularMatrices/lower_triangular_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -630,10 +630,10 @@ function Base.similar(
return LowerTriangularArray{T, N, ArrayType{T,N}}(undef, size(L; as=Matrix))
end

function GPUArrays.get_backend(
::Type{LowerTriangularArray{T, N, ArrayType}}
function KernelAbstractions.get_backend(
a::LowerTriangularArray{T, N, ArrayType}
) where {T, N, ArrayType <: GPUArrays.AbstractGPUArray}
return GPUArrays.get_backend(ArrayType)
return KernelAbstractions.get_backend(a.data)
end

Adapt.adapt_structure(to, L::LowerTriangularArray) =
Expand Down
1 change: 1 addition & 0 deletions src/RingGrids/RingGrids.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import LinearAlgebra
# GPU
import Adapt
import GPUArrays
import KernelAbstractions

# ABSTRACT GRIDS (2D) AND GRIDARRAYS (3D+)
export AbstractGridArray,
Expand Down
6 changes: 3 additions & 3 deletions src/RingGrids/general.jl
Original file line number Diff line number Diff line change
Expand Up @@ -482,10 +482,10 @@ AbstractGPUGridArrayStyle{2, ArrayType, Grid}(::Val{1}) where {ArrayType, Grid}
AbstractGPUGridArrayStyle{3, ArrayType, Grid}(::Val{4}) where {ArrayType, Grid} = AbstractGPUGridArrayStyle{4, ArrayType, Grid}()
AbstractGPUGridArrayStyle{3, ArrayType, Grid}(::Val{2}) where {ArrayType, Grid} = AbstractGPUGridArrayStyle{3, ArrayType, Grid}()

function GPUArrays.get_backend(
::Type{Grid}
function KernelAbstractions.get_backend(
g::Grid
) where {Grid <: AbstractGridArray{T, N, ArrayType}} where {T, N, ArrayType <: GPUArrays.AbstractGPUArray}
return GPUArrays.backend(ArrayType)
return KernelAbstractions.get_backend(g.data)
end

function Base.similar(
Expand Down

0 comments on commit 4e81d95

Please sign in to comment.