Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Commit

Permalink
import cudaconvert
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeInnes committed Sep 27, 2017
1 parent de26433 commit 54e550b
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/CuArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ __precompile__()
module CuArrays

using CUDAdrv, CUDAnative
import CUDAnative: cudaconvert

export CuArray, CuVector, CuMatrix, cu

Expand Down
2 changes: 1 addition & 1 deletion src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ function Base.convert(::Type{CuDeviceArray{T,N,AS.Global}}, a::CuArray{T,N}) whe
CuDeviceArray{T,N,AS.Global}(a.dims, DevicePtr{T,AS.Global}(ptr))
end

CUDAnative.cudaconvert(a::CuArray{T,N}) where {T,N} = convert(CuDeviceArray{T,N,AS.Global}, a)
cudaconvert(a::CuArray{T,N}) where {T,N} = convert(CuDeviceArray{T,N,AS.Global}, a)

# Utils

Expand Down
2 changes: 1 addition & 1 deletion src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ end
# Hacky interop

Base.Broadcast._containertype(::Type{<:RowVector{<:Any,<:CuArray}}) = CuArray
CUDAnative.cudaconvert(x::RowVector{<:Any,<:CuArray}) = RowVector(cudaconvert(x.vec))
cudaconvert(x::RowVector{<:Any,<:CuArray}) = RowVector(cudaconvert(x.vec))

# Hack to work with cuda's arithmetic functions

Expand Down

0 comments on commit 54e550b

Please sign in to comment.