Skip to content

Commit

Permalink
Fix MPSNDArrayDescriptor wrapper (#502)
Browse files Browse the repository at this point in the history
Don't reverse dimensions automatically
  • Loading branch information
christiangnrd authored Dec 19, 2024
1 parent b9610e3 commit 84447c4
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 14 deletions.
3 changes: 3 additions & 0 deletions lib/mps/MPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ import GPUArrays

const MtlFloat = Union{Float32, Float16}

const MPSShape = NSArray#{NSNumber}
Base.convert(::Type{MPSShape}, tuple::Union{Vector{N},NTuple{N, <:Integer}}) where N = NSArray(NSNumber.(collect(tuple)))

is_supported(dev::MTLDevice) = ccall(:MPSSupportsMTLDevice, Bool, (id{MTLDevice},), dev)

include("size.jl")
Expand Down
2 changes: 1 addition & 1 deletion lib/mps/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ end
"""
matmul!(a::MtlMatrix, b::MtlMatrix, c::MtlMatrix, alpha=1, beta=1,
transpose_left=false, transpose_right=false)
A `MPSMatrixMultiplication` kernel thay computes:
A `MPSMatrixMultiplication` kernel that computes:
`c = alpha * op(a) * beta * op(b) + beta * C`
This function should not typically be used. Rather, use the normal `LinearAlgebra` interface
Expand Down
37 changes: 28 additions & 9 deletions lib/mps/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@ function MPSNDArrayDescriptor(dataType::DataType, dimensionCount, dimensionSizes
end

function MPSNDArrayDescriptor(dataType::DataType, shape::DenseVector{T}) where {T<:Union{Int,UInt}}
revshape = collect(reverse(shape))
obj = GC.@preserve revshape begin
shapeptr = pointer(revshape)
MPSNDArrayDescriptor(dataType, length(revshape), shapeptr)
obj = GC.@preserve shape begin
shapeptr = pointer(shape)
MPSNDArrayDescriptor(dataType, length(shape), shapeptr)
end
return obj
end
Expand Down Expand Up @@ -75,6 +74,11 @@ else
end
end

function Base.size(ndarr::MPSNDArray)
ndims = Int(ndarr.numberOfDimensions)
Tuple([Int(lengthOfDimension(ndarr,i)) for i in 0:ndims-1])
end

@objcwrapper immutable=false MPSTemporaryNDArray <: MPSNDArray

@objcproperties MPSTemporaryNDArray begin
Expand Down Expand Up @@ -130,20 +134,23 @@ end

function MPSNDArray(arr::MtlArray{T,N}) where {T,N}
arrsize = size(arr)
@assert arrsize[end]*sizeof(T) % 16 == 0 "Final dimension of arr must have a byte size divisible by 16"
@assert arrsize[1]*sizeof(T) % 16 == 0 "First dimension of arr must have a byte size divisible by 16"
desc = MPSNDArrayDescriptor(T, arrsize)
return MPSNDArray(arr.data[], UInt(arr.offset), desc)
end

function Metal.MtlArray(ndarr::MPSNDArray; storage = Metal.DefaultStorageMode, async = false)
ndims = Int(ndarr.numberOfDimensions)
arrsize = [lengthOfDimension(ndarr,i) for i in 0:ndims-1]
arrsize = size(ndarr)
T = convert(DataType, ndarr.dataType)
arr = MtlArray{T,ndims,storage}(undef, reverse(arrsize)...)
arr = MtlArray{T,length(arrsize),storage}(undef, (arrsize)...)
return exportToMtlArray!(arr, ndarr; async)
end

function exportToMtlArray!(arr::MtlArray{T}, ndarr::MPSNDArray; async=false) where T
dev = device(arr)

cmdBuf = MTLCommandBuffer(global_queue(dev)) do cmdBuf
exportDataWithCommandBuffer(ndarr, cmdBuf, arr.data[], T, 0, collect(sizeof(T) .* reverse(strides(arr))))
exportDataWithCommandBuffer(ndarr, cmdBuf, arr.data[], T, arr.offset)
end

async || wait_completed(cmdBuf)
Expand All @@ -157,6 +164,12 @@ exportDataWithCommandBuffer(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, toBuffe
destinationDataType:destinationDataType::MPSDataType
offset:offset::NSUInteger
rowStrides:pointer(rowStrides)::Ptr{NSInteger}]::Nothing
exportDataWithCommandBuffer(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, toBuffer, destinationDataType, offset) =
@objc [ndarr::MPSNDArray exportDataWithCommandBuffer:cmdbuf::id{MTLCommandBuffer}
toBuffer:toBuffer::id{MTLBuffer}
destinationDataType:destinationDataType::MPSDataType
offset:offset::NSUInteger
rowStrides:nil::id{ObjectiveC.Object}]::Nothing

# rowStrides in Bytes
importDataWithCommandBuffer!(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, fromBuffer, sourceDataType, offset, rowStrides) =
Expand All @@ -165,6 +178,12 @@ importDataWithCommandBuffer!(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, fromBu
sourceDataType:sourceDataType::MPSDataType
offset:offset::NSUInteger
rowStrides:pointer(rowStrides)::Ptr{NSInteger}]::Nothing
importDataWithCommandBuffer!(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, fromBuffer, sourceDataType, offset) =
@objc [ndarr::MPSNDArray importDataWithCommandBuffer:cmdbuf::id{MTLCommandBuffer}
fromBuffer:fromBuffer::id{MTLBuffer}
sourceDataType:sourceDataType::MPSDataType
offset:offset::NSUInteger
rowStrides:nil::id{ObjectiveC.Object}]::Nothing

# TODO
# exportDataWithCommandBuffer(toImages, offset)
Expand Down
9 changes: 5 additions & 4 deletions test/mps/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using .MPS: MPSNDArrayDescriptor, MPSDataType, lengthOfDimension
T = Float32
DT = convert(MPSDataType, T)

desc1 = MPSNDArrayDescriptor(T, 5,4,3,2,1)
desc1 = MPSNDArrayDescriptor(T,1,2,3,4,5)
@test desc1 isa MPSNDArrayDescriptor
@test desc1.dataType == DT
@test desc1.preferPackedRows == false
Expand All @@ -19,7 +19,7 @@ using .MPS: MPSNDArrayDescriptor, MPSDataType, lengthOfDimension
@test lengthOfDimension(desc1,4) == 4
@test lengthOfDimension(desc1,3) == 5

desc2 = MPSNDArrayDescriptor(T, (4,3,2,1))
desc2 = MPSNDArrayDescriptor(T, (1,2,3,4))
@test desc2 isa MPSNDArrayDescriptor
@test desc2.dataType == DT
@test desc2.numberOfDimensions == 4
Expand Down Expand Up @@ -51,6 +51,7 @@ using .MPS: MPSNDArray
@test ndarr1.label == "Test1"
@test ndarr1.numberOfDimensions == 5
@test ndarr1.parent === nothing
@test size(ndarr1) == (5,4,3,2,1)

ndarr2 = MPSNDArray(dev, 4)
@test ndarr2 isa MPSNDArray
Expand All @@ -63,9 +64,9 @@ using .MPS: MPSNDArray
@test ndarr2.parent === nothing

arr3 = MtlArray(ones(Float16, 2,3,4))
@test_throws "Final dimension of arr must have a byte size divisible by 16" MPSNDArray(arr3)
@test_throws "First dimension of arr must have a byte size divisible by 16" MPSNDArray(arr3)

arr4 = MtlArray(ones(Float16, 2,3,8))
arr4 = MtlArray(ones(Float16, 8,3,2))

@static if Metal.macos_version() >= v"15"
@test ndarr1.descriptor isa MPSNDArrayDescriptor
Expand Down

0 comments on commit 84447c4

Please sign in to comment.