Skip to content

Commit

Permalink
implement adapt
Browse files Browse the repository at this point in the history
  • Loading branch information
maxwindiff committed Mar 22, 2023
1 parent ee7cfc1 commit cc7b801
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/MetalKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ import .StaticArrays: MArray
KernelAbstractions.get_backend(::Metal.MtlArray) = MetalBackend()
KernelAbstractions.synchronize(::MetalBackend) = Metal.synchronize()

# TODO: why are these not needed in https://github.com/JuliaGPU/CUDA.jl/pull/1772 ?
Adapt.adapt_storage(::MetalBackend, a::Array) = Adapt.adapt(Metal.MtlArray, a)
Adapt.adapt_storage(::MetalBackend, a::Metal.MtlArray) = a
Adapt.adapt_storage(::KernelAbstractions.CPU, a::Metal.MtlArray) = convert(Array, a)

function KernelAbstractions.copyto!(::MetalBackend, A::Metal.MtlArray{T}, B::Metal.MtlArray{T}) where T
if Metal.device(dest) == Metal.device(src)
GC.@preserve A B unsafe_copyto!(Metal.device(A), pointer(A), pointer(B), length(A); async=true)
Expand Down

0 comments on commit cc7b801

Please sign in to comment.