Skip to content

Commit

Permalink
RMM cleanup exception handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Baunsgaard committed Feb 3, 2025
1 parent 71b6c22 commit 659f819
Showing 1 changed file with 48 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,13 @@
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
Expand All @@ -57,47 +55,53 @@ public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock m1, MatrixBloc
public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k,
boolean allowOverlap) {

final int rr = m1.getNumRows();
final int rc = m2.getNumColumns();

if(m1.isEmpty() || m2.isEmpty()) {
LOG.trace("Empty right multiply");
if(ret == null)
ret = new MatrixBlock(rr, rc, 0);
else
ret.reset(rr, rc, 0);
return ret;
}
else {
if(m2 instanceof CompressedMatrixBlock)
m2 = ((CompressedMatrixBlock) m2).getUncompressed("Uncompressed right side of right MM", k);
try {
final int rr = m1.getNumRows();
final int rc = m2.getNumColumns();

if(betterIfDecompressed(m1)) {
// perform uncompressed multiplication.
return decompressingMatrixMult(m1, m2, k);
if(m1.isEmpty() || m2.isEmpty()) {
LOG.trace("Empty right multiply");
if(ret == null)
ret = new MatrixBlock(rr, rc, 0);
else
ret.reset(rr, rc, 0);
return ret;
}
else {
if(m2 instanceof CompressedMatrixBlock)
m2 = ((CompressedMatrixBlock) m2).getUncompressed("Uncompressed right side of right MM", k);

if(!allowOverlap) {
LOG.trace("Overlapping output not allowed in call to Right MM");
return RMM(m1, m2, k);
}
if(betterIfDecompressed(m1)) {
// perform uncompressed multiplication.
return decompressingMatrixMult(m1, m2, k);
}

final CompressedMatrixBlock retC = RMMOverlapping(m1, m2, k);
if(!allowOverlap) {
LOG.trace("Overlapping output not allowed in call to Right MM");
return RMM(m1, m2, k);
}

if(retC.isEmpty())
return retC;
else {
if(retC.isOverlapping())
retC.setNonZeros((long) rr * rc); // set non zeros to fully dense in case of overlapping.
else
retC.recomputeNonZeros(k); // recompute if non overlapping compressed out.
return retC;
final CompressedMatrixBlock retC = RMMOverlapping(m1, m2, k);

if(retC.isEmpty())
return retC;
else {
if(retC.isOverlapping())
retC.setNonZeros((long) rr * rc); // set non zeros to fully dense in case of overlapping.
else
retC.recomputeNonZeros(k); // recompute if non overlapping compressed out.
return retC;
}
}
}
catch(Exception e) {
throw new RuntimeException("Failed Right MM", e);
}
}

private static MatrixBlock decompressingMatrixMult(CompressedMatrixBlock m1, MatrixBlock m2, int k) {
ExecutorService pool = CommonThreadPool.get(k);
private static MatrixBlock decompressingMatrixMult(CompressedMatrixBlock m1, MatrixBlock m2, int k)
throws Exception {
final ExecutorService pool = CommonThreadPool.get(k);
try {
final int rl = m1.getNumRows();
final int cr = m2.getNumColumns();
Expand All @@ -113,13 +117,13 @@ private static MatrixBlock decompressingMatrixMult(CompressedMatrixBlock m1, Mat
for(int i = 0; i < rl; i += blkI) {
final int startI = i;
final int endI = Math.min(i + blkI, rl);
for(int j = 0; j < cr; j += blkJ){
for(int j = 0; j < cr; j += blkJ) {
final int startJ = j;
final int endJ = Math.min(j + blkJ, cr);
tasks.add(pool.submit(() -> {
for(AColGroup g : groups)
g.rightDecompressingMult(m2, ret, startI, endI, rl, startJ, endJ);
return ret.recomputeNonZeros(startI, endI - 1, startJ, endJ-1);
return ret.recomputeNonZeros(startI, endI - 1, startJ, endJ - 1);
}));
}
}
Expand All @@ -131,9 +135,6 @@ private static MatrixBlock decompressingMatrixMult(CompressedMatrixBlock m1, Mat
ret.examSparsity();
return ret;
}
catch(InterruptedException | ExecutionException e) {
throw new DMLRuntimeException(e);
}
finally {
pool.shutdown();
}
Expand All @@ -149,7 +150,8 @@ private static boolean betterIfDecompressed(CompressedMatrixBlock m) {
return false;
}

private static CompressedMatrixBlock RMMOverlapping(CompressedMatrixBlock m1, MatrixBlock that, int k) {
private static CompressedMatrixBlock RMMOverlapping(CompressedMatrixBlock m1, MatrixBlock that, int k)
throws Exception {

final int rl = m1.getNumRows();
final int cr = that.getNumColumns();
Expand Down Expand Up @@ -199,7 +201,7 @@ private static void addConstant(MatrixBlock constantRow, List<AColGroup> out) {
out.add(ColGroupConst.create(constantRow.getDenseBlockValues()));
}

private static MatrixBlock RMM(CompressedMatrixBlock m1, MatrixBlock that, int k) {
private static MatrixBlock RMM(CompressedMatrixBlock m1, MatrixBlock that, int k) throws Exception {

// Timing t = new Timing();
// this version returns a decompressed result.
Expand Down Expand Up @@ -232,36 +234,25 @@ private static MatrixBlock RMM(CompressedMatrixBlock m1, MatrixBlock that, int k
constV = null;
}


final List<AColGroup> retCg = new ArrayList<>(filteredGroups.size());
if(k == 1)
RMMSingle(filteredGroups, that, retCg);
else
RMMParallel(filteredGroups, that, retCg, k);


if(constV != null) {
MatrixBlock constVMB = new MatrixBlock(1, constV.length, constV);
MatrixBlock mmTemp = new MatrixBlock(1, cr, false);
LibMatrixMult.matrixMult(constVMB, that, mmTemp);
constV = mmTemp.isEmpty() ? null : mmTemp.getDenseBlockValues();
}

ret = asyncRet(f);
ret = f.get();
CLALibDecompress.decompressDense(ret, retCg, constV, 0, k, true);

return ret;
}

private static <T> T asyncRet(Future<T> in) {
try {
return in.get();
}
catch(Exception e) {
throw new DMLRuntimeException(e);
}
}

private static boolean RMMSingle(List<AColGroup> filteredGroups, MatrixBlock that, List<AColGroup> retCg) {
boolean containsNull = false;
final IColIndex allCols = ColIndexFactory.create(that.getNumColumns());
Expand All @@ -275,7 +266,8 @@ private static boolean RMMSingle(List<AColGroup> filteredGroups, MatrixBlock tha
return containsNull;
}

private static boolean RMMParallel(List<AColGroup> filteredGroups, MatrixBlock that, List<AColGroup> retCg, int k) {
private static boolean RMMParallel(List<AColGroup> filteredGroups, MatrixBlock that, List<AColGroup> retCg, int k)
throws Exception {
final ExecutorService pool = CommonThreadPool.get(k);
boolean containsNull = false;
try {
Expand All @@ -291,9 +283,6 @@ private static boolean RMMParallel(List<AColGroup> filteredGroups, MatrixBlock t
containsNull = true;
}
}
catch(InterruptedException | ExecutionException e) {
throw new DMLRuntimeException(e);
}
finally {
pool.shutdown();
}
Expand All @@ -314,13 +303,8 @@ protected RightMatrixMultTask(AColGroup colGroup, MatrixBlock b, IColIndex allCo
}

@Override
public AColGroup call() {
try {
return _colGroup.rightMultByMatrix(_b, _allCols, _k);
}
catch(Exception e) {
throw new DMLRuntimeException(e);
}
public AColGroup call() throws Exception {
return _colGroup.rightMultByMatrix(_b, _allCols, _k);
}
}
}

0 comments on commit 659f819

Please sign in to comment.