From 88fe2b0eb4eb1fd342f37c2741629056155c56a2 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Thu, 30 Nov 2023 17:49:43 +0100 Subject: [PATCH] [SYSTEMDS-3653] Ultra Sparse Right MM Optimization Right side Ultra sparse optimizations goring from 8.525 to 4.575 on 100 repetitions of 100k by 1000 dense %*% 1000 by 1000 with 30 non zeros. Closes #1952 --- .../runtime/matrix/data/LibMatrixMult.java | 47 +++++++++++++++++-- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java index 41dc7f22642..e956f619060 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java @@ -49,6 +49,7 @@ import org.apache.sysds.runtime.data.SparseBlockCSR; import org.apache.sysds.runtime.data.SparseBlockFactory; import org.apache.sysds.runtime.data.SparseBlockMCSR; +import org.apache.sysds.runtime.data.SparseRow; import org.apache.sysds.runtime.data.SparseRowScalar; import org.apache.sysds.runtime.data.SparseRowVector; import org.apache.sysds.runtime.functionobjects.SwapIndex; @@ -194,7 +195,7 @@ public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock (!fixedRet && isUltraSparseMatrixMult(m1, m2, m1Perm)); boolean sparse = !fixedRet && !ultraSparse && !m1Perm && isSparseOutputMatrixMult(m1, m2); - + // allocate output if(ret == null) ret = new MatrixBlock(m1.rlen, m2.clen, ultraSparse | sparse); @@ -1718,7 +1719,6 @@ else if( leftUS || m1Perm ) matrixMultUltraSparseLeft(m1, m2, ret, rl, ru); else matrixMultUltraSparseRight(m1, m2, ret, rl, ru); - //no need to recompute nonzeros because maintained internally } private static void matrixMultUltraSparseSelf(MatrixBlock m1, MatrixBlock ret, int rl, int ru) { @@ -1926,10 +1926,14 @@ private static void matrixMultUltraSparseSparseSparseLeftRowGeneric(int i, int a private static void matrixMultUltraSparseRight(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru) { - if(!ret.isInSparseFormat() && ret.getDenseBlock().isContiguous()) + if(ret.isInSparseFormat()){ + if(m1.isInSparseFormat()) + matrixMultUltraSparseRightSparseMCSRLeftSparseOut(m1, m2, ret, rl, ru); + else + matrixMultUltraSparseRightDenseLeftSparseOut(m1, m2, ret, rl, ru); + } + else if(ret.getDenseBlock().isContiguous()) matrixMultUltraSparseRightDenseOut(m1, m2, ret, rl, ru); - else if(m1.isInSparseFormat() && ret.isInSparseFormat()) - matrixMultUltraSparseRightSparseMCSRLeftSparseOut(m1, m2, ret, rl, ru); else matrixMultUltraSparseRightGeneric(m1, m2, ret, rl, ru); } @@ -1990,6 +1994,39 @@ private static void matrixMultUltraSparseRightSparseMCSRLeftSparseOut(MatrixBloc } } + private static void matrixMultUltraSparseRightDenseLeftSparseOut(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru) { + final int cd = m1.clen; + final DenseBlock a = m1.denseBlock; + final SparseBlock b = m2.sparseBlock; + final SparseBlockMCSR c = (SparseBlockMCSR) ret.sparseBlock; + + for(int k = 0; k < cd; k++){ + if(b.isEmpty(k)) + continue; // skip emptry rows right side. + final int bpos = b.pos(k); + final int blen = b.size(k); + final int[] bixs = b.indexes(k); + final double[] bvals = b.values(k); + for(int i = rl; i < ru; i++) + mmDenseMatrixSparseRow(bpos, blen, bixs, bvals, k, i, a, c); + } + } + + private static void mmDenseMatrixSparseRow(int bpos, int blen, int[] bixs, double[] bvals, int k, int i, + DenseBlock a, SparseBlockMCSR c) { + final double[] aval = a.values(i); + final int apos = a.pos(i); + if(!c.isAllocated(i)) + c.allocate(i, Math.max(blen, 2)); + final SparseRowVector srv = (SparseRowVector) c.get(i); // guaranteed + for(int j = bpos; j < bpos + blen; j++) { // right side columns + final int bix = bixs[j]; + final double bval = bvals[j]; + srv.add(bix, bval * aval[apos + k]); + } + + } + private static void matrixMultUltraSparseRightGeneric(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru) { final int cd = m1.clen;