From 83a4fe503be32aa9891d9206cbd077a1aed9c9b3 Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Tue, 5 Sep 2023 16:02:51 +0200 Subject: [PATCH 1/3] ldiv --- lib/mps/linalg.jl | 142 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 142 insertions(+) diff --git a/lib/mps/linalg.jl b/lib/mps/linalg.jl index cf2c8f360..43e4da978 100644 --- a/lib/mps/linalg.jl +++ b/lib/mps/linalg.jl @@ -261,5 +261,147 @@ function LinearAlgebra.transpose!(B::MtlMatrix{T}, A::MtlMatrix{T}) where {T} commit!(cmdbuf) + wait_completed(cmdbuf) + + return B +end + + +function LinearAlgebra.:(\)(A::LU{T,<:MtlMatrix{T},<:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat} + C = deepcopy(B) + LinearAlgebra.ldiv!(A, C) + return C +end + + +function LinearAlgebra.ldiv!(A::LU{T,<:MtlMatrix{T},<:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat} + M, N = size(B, 1), size(B, 2) + dev = current_device() + queue = global_queue(dev) + + At = similar(A.factors) + Bt = similar(B, (N, M)) + P = reshape((A.ipiv .- UInt32(1)), (1, M)) + X = similar(B, (N, M)) + + transpose!(At, A.factors) + transpose!(Bt, B) + + mps_a = MPSMatrix(At) + mps_b = MPSMatrix(Bt) + mps_p = MPSMatrix(P) + mps_x = MPSMatrix(X) + + MTLCommandBuffer(queue) do cmdbuf + kernel = MPSMatrixSolveLU(dev, false, M, N) + encode!(cmdbuf, kernel, mps_a, mps_b, mps_p, mps_x) + end + + transpose!(B, X) + return B +end + + +function LinearAlgebra.ldiv!(A::UpperTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat} + M, N = size(B, 1), size(B, 2) + dev = current_device() + queue = global_queue(dev) + + Ad = MtlMatrix(A') + Br = similar(B, (M, M)) + X = similar(Br) + + transpose!(Br, B) + + mps_a = MPSMatrix(Ad) + mps_b = MPSMatrix(Br) + mps_x = MPSMatrix(X) + + buf = MTLCommandBuffer(queue) do cmdbuf + kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, N, M, 1.0) + encode!(cmdbuf, kernel, mps_a, mps_b, mps_x) + end + + wait_completed(buf) + + copy!(B, X) + return B +end + + +function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat} + M, N = size(B, 1), size(B, 2) + dev = current_device() + queue = global_queue(dev) + + Ad = MtlMatrix(A) + Br = reshape(B, (M, N)) + X = similar(Br) + + mps_a = MPSMatrix(Ad) + mps_b = MPSMatrix(Br) + mps_x = MPSMatrix(X) + + + buf = MTLCommandBuffer(queue) do cmdbuf + kernel = MPSMatrixSolveTriangular(dev, true, false, false, true, M, N, 1.0) + encode!(cmdbuf, kernel, mps_a, mps_b, mps_x) + end + + wait_completed(buf) + + copy!(Br, X) + return B +end + + +function LinearAlgebra.ldiv!(A::LowerTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat} + M, N = size(B, 1), size(B, 2) + dev = current_device() + queue = global_queue(dev) + + Ad = MtlMatrix(A) + Br = reshape(B, (M, N)) + X = similar(Br) + + mps_a = MPSMatrix(Ad) + mps_b = MPSMatrix(Br) + mps_x = MPSMatrix(X) + + + buf = MTLCommandBuffer(queue) do cmdbuf + kernel = MPSMatrixSolveTriangular(dev, true, true, false, false, M, N, 1.0) + encode!(cmdbuf, kernel, mps_a, mps_b, mps_x) + end + + wait_completed(buf) + + copy!(Br, X) return B end + + +function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat} + M, N = size(B, 1), size(B, 2) + dev = current_device() + queue = global_queue(dev) + + Ad = MtlMatrix(A) + Br = reshape(B, (M, N)) + X = similar(Br) + + mps_a = MPSMatrix(Ad) + mps_b = MPSMatrix(Br) + mps_x = MPSMatrix(X) + + + buf = MTLCommandBuffer(queue) do cmdbuf + kernel = MPSMatrixSolveTriangular(dev, true, true, false, true, M, N, 1.0) + encode!(cmdbuf, kernel, mps_a, mps_b, mps_x) + end + + wait_completed(buf) + + copy!(Br, X) + return B +end \ No newline at end of file From af07cd65fbbbab0a7211a29203c887ec1e21ebfa Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Tue, 5 Sep 2023 16:03:46 +0200 Subject: [PATCH 2/3] add test --- test/mps.jl | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/test/mps.jl b/test/mps.jl index 1e28807f5..fd3482eeb 100644 --- a/test/mps.jl +++ b/test/mps.jl @@ -175,4 +175,39 @@ end @test_throws SingularException lu(A) end +@testset "solves" begin + b = MtlVector(rand(Float32, 1024)) + B = MtlMatrix(rand(Float32, 1024, 1024)) + + A = MtlMatrix(rand(Float32, 1024, 512)) + x = lu(A) \ b + @test A * x ≈ b + X = lu(A) \ B + @test A * X ≈ B + + A = UpperTriangular(MtlMatrix(rand(Float32, 1024, 1024))) + x = A \ b + @test A * x ≈ b + X = A \ B + @test A * X ≈ B + + A = UnitUpperTriangular(MtlMatrix(rand(Float32, 1024, 1024))) + x = A \ b + @test A * x ≈ b + X = A \ B + @test A * X ≈ B + + A = LowerTriangular(MtlMatrix(rand(Float32, 1024, 1024))) + x = A \ b + @test A * x ≈ b + X = A \ B + @test A * X ≈ B + + A = UnitLowerTriangular(MtlMatrix(rand(Float32, 1024, 1024))) + x = A \ b + @test A * x ≈ b + X = A \ B + @test A * X ≈ B +end + end \ No newline at end of file From a081497fea5c13fe086034a49e1d1b87cb14fda7 Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Mon, 15 Apr 2024 13:51:54 +0200 Subject: [PATCH 3/3] add solvers --- lib/mps/MPS.jl | 2 ++ lib/mps/solve.jl | 77 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 lib/mps/solve.jl diff --git a/lib/mps/MPS.jl b/lib/mps/MPS.jl index c38f4bdab..9bee3dd92 100644 --- a/lib/mps/MPS.jl +++ b/lib/mps/MPS.jl @@ -27,6 +27,8 @@ include("linalg.jl") # decompositions include("decomposition.jl") +include("solve.jl") + # matrix copy include("copy.jl") diff --git a/lib/mps/solve.jl b/lib/mps/solve.jl new file mode 100644 index 000000000..62a282164 --- /dev/null +++ b/lib/mps/solve.jl @@ -0,0 +1,77 @@ + +export MPSMatrixSolveTriangular + +@objcwrapper immutable=false MPSMatrixSolveTriangular <: MPSMatrixUnaryKernel + +function MPSMatrixSolveTriangular(device, right, upper, unit, order, numberOfRightHandSides, alpha) + kernel = @objc [MPSMatrixSolveTriangular alloc]::id{MPSMatrixSolveTriangular} + obj = MPSMatrixSolveTriangular(kernel) + finalizer(release, obj) + @objc [obj::id{MPSMatrixSolveTriangular} initWithDevice:device::id{MTLDevice} + right:right::Bool + upper:upper::Bool + transpose:transpose::Bool + unit:unit::Bool + order:order::NSUInteger + numberOfRightHandSides:numberOfRightHandSides::NSUInteger + alpha:alpha::Float64]::id{MPSMatrixSolveTriangular} + return obj +end + +function encode!(cmdbuf::MTLCommandBuffer, kernel::MPSMatrixSolveTriangular, sourceMatrix, resultMatrix, pivotIndices, status) + @objc [kernel::id{MPSMatrixSolveTriangular} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer} + sourceMatrix:sourceMatrix::id{MPSMatrix} + resultMatrix:resultMatrix::id{MPSMatrix} + pivotIndices:pivotIndices::id{MPSMatrix} + status:status::id{MPSMatrix}]::Nothing +end + + +export MPSMatrixSolveLU + +@objcwrapper immutable=false MPSMatrixSolveLU <: MPSMatrixUnaryKernel + +function MPSMatrixSolveLU(device, transpose, order, numberOfRightHandSides) + kernel = @objc [MPSMatrixSolveLU alloc]::id{MPSMatrixSolveLU} + obj = MPSMatrixSolveLU(kernel) + finalizer(release, obj) + @objc [obj::id{MPSMatrixSolveLU} initWithDevice:device::id{MTLDevice} + transpose:transpose::Bool + order:order::NSUInteger + numberOfRightHandSides:numberOfRightHandSides::NSUInteger]::id{MPSMatrixSolveLU} + return obj +end + +function encode!(cmdbuf::MTLCommandBuffer, kernel::MPSMatrixSolveLU, sourceMatrix, rightHandSideMatrix, pivotIndices, solutionMatrix) + @objc [kernel::id{MPSMatrixSolveLU} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer} + sourceMatrix:sourceMatrix::id{MPSMatrix} + rightHandSideMatrix:rightHandSideMatrix::id{MPSMatrix} + pivotIndices:pivotIndices::id{MPSMatrix} + solutionMatrix:solutionMatrix::id{MPSMatrix}]::Nothing +end + + + + +export MPSMatrixSolveCholesky + +@objcwrapper immutable=false MPSMatrixSolveCholesky <: MPSMatrixUnaryKernel + +function MPSMatrixSolveCholesky(device, upper, order, numberOfRightHandSides) + kernel = @objc [MPSMatrixSolveCholesky alloc]::id{MPSMatrixSolveCholesky} + obj = MPSMatrixSolveCholesky(kernel) + finalizer(release, obj) + @objc [obj::id{MPSMatrixSolveCholesky} initWithDevice:device::id{MTLDevice} + upper:upper::Bool + order:order::NSUInteger + numberOfRightHandSides:numberOfRightHandSides::NSUInteger]::id{MPSMatrixSolveCholesky} + return obj +end + +function encode!(cmdbuf::MTLCommandBuffer, kernel::MPSMatrixSolveCholesky, sourceMatrix, rightHandSideMatrix, solutionMatrix) + @objc [kernel::id{MPSMatrixSolveCholesky} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer} + sourceMatrix:sourceMatrix::id{MPSMatrix} + rightHandSideMatrix:rightHandSideMatrix::id{MPSMatrix} + pivotIndices:pivotIndices::id{MPSMatrix} + solutionMatrix:solutionMatrix::id{MPSMatrix}]::Nothing +end