Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MINOR] cleanups and optimizations to CLA MM primitives #2210

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,30 @@ public void leftMMIdentityPreAggregateDense(MatrixBlock that, MatrixBlock ret, i

@Override
public void rightDecompressingMult(MatrixBlock right, MatrixBlock ret, int rl, int ru, int nRows, int crl, int cru) {
if(_dict instanceof IdentityDictionary)
identityRightDecompressingMult(right, ret, rl, ru, crl, cru);
else
defaultRightDecompressingMult(right, ret, rl, ru, crl, cru);
}

private void identityRightDecompressingMult(MatrixBlock right, MatrixBlock ret, int rl, int ru, int crl, int cru) {
final double[] b = right.getDenseBlockValues();
final double[] c = ret.getDenseBlockValues();
final int jd = right.getNumColumns();
final int vLen = 8;
final int lenJ = cru - crl;
final int end = cru - (lenJ % vLen);
for(int i = rl; i < ru; i++) {
int k = _data.getIndex(i);
final int offOut = i * jd + crl;
final double aa = 1;
final int k_right = _colIndexes.get(k);
vectMM(aa, b, c, end, jd, crl, cru, offOut, k_right, vLen);

}
}

private void defaultRightDecompressingMult(MatrixBlock right, MatrixBlock ret, int rl, int ru, int crl, int cru) {
final double[] a = _dict.getValues();
final double[] b = right.getDenseBlockValues();
final double[] c = ret.getDenseBlockValues();
Expand Down Expand Up @@ -930,8 +954,6 @@ protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret
}
}



private void leftMMIdentityPreAggregateDenseSingleRow(double[] values, int pos, double[] values2, int pos2, int cl,
int cu) {
IdentityDictionary a = (IdentityDictionary) _dict;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.utils.stats.Timing;

/**
* Support compressed MM chain operation to fuse the following cases :
Expand All @@ -53,6 +54,9 @@
public final class CLALibMMChain {
static final Log LOG = LogFactory.getLog(CLALibMMChain.class.getName());

/** Reusable cache intermediate double array for temporary decompression */
private static ThreadLocal<double[]> cacheIntermediate = null;

private CLALibMMChain() {
// private constructor
}
Expand Down Expand Up @@ -87,20 +91,31 @@ private CLALibMMChain() {
public static MatrixBlock mmChain(CompressedMatrixBlock x, MatrixBlock v, MatrixBlock w, MatrixBlock out,
ChainType ctype, int k) {

Timing t = new Timing();
if(x.isEmpty())
return returnEmpty(x, out);

// Morph the columns to efficient types for the operation.
x = filterColGroups(x);
double preFilterTime = t.stop();

// Allow overlapping intermediate if the intermediate is guaranteed not to be overlapping.
final boolean allowOverlap = x.getColGroups().size() == 1 && isOverlappingAllowed();

// Right hand side multiplication
MatrixBlock tmp = CLALibRightMultBy.rightMultByMatrix(x, v, null, k, allowOverlap);
MatrixBlock tmp = CLALibRightMultBy.rightMultByMatrix(x, v, null, k, true);

double rmmTime = t.stop();

if(ctype == ChainType.XtwXv) // Multiply intermediate with vector if needed
if(ctype == ChainType.XtwXv) { // Multiply intermediate with vector if needed
tmp = binaryMultW(tmp, w, k);
}

if(!allowOverlap && tmp instanceof CompressedMatrixBlock) {
tmp = decompressIntermediate((CompressedMatrixBlock) tmp, k);
}

double decompressTime = t.stop();

if(tmp instanceof CompressedMatrixBlock)
// Compressed Compressed Matrix Multiplication
Expand All @@ -109,12 +124,50 @@ public static MatrixBlock mmChain(CompressedMatrixBlock x, MatrixBlock v, Matrix
// LMM with Compressed - uncompressed multiplication.
CLALibLeftMultBy.leftMultByMatrixTransposed(x, tmp, out, k);

double lmmTime = t.stop();
if(out.getNumColumns() != 1) // transpose the output to make it a row output if needed
out = LibMatrixReorg.transposeInPlace(out, k);

if(LOG.isDebugEnabled()) {
StringBuilder sb = new StringBuilder("\n");
sb.append("\nPreFilter Time : " + preFilterTime);
sb.append("\nChain RMM : " + rmmTime);
sb.append("\nChain RMM Decompress: " + decompressTime);
sb.append("\nChain LMM : " + lmmTime);
sb.append("\nChain Transpose : " + t.stop());
LOG.debug(sb.toString());
}

return out;
}

private static MatrixBlock decompressIntermediate(CompressedMatrixBlock tmp, int k) {
// cacheIntermediate
final int rows = tmp.getNumRows();
final int cols = tmp.getNumColumns();
final int nCells = rows * cols;
final double[] tmpArr;
if(cacheIntermediate == null) {
tmpArr = new double[nCells];
cacheIntermediate = new ThreadLocal<>();
cacheIntermediate.set(tmpArr);
}
else {
double[] cachedArr = cacheIntermediate.get();
if(cachedArr == null || cachedArr.length < nCells) {
tmpArr = new double[nCells];
cacheIntermediate.set(tmpArr);
}
else {
tmpArr = cachedArr;
}
}

final MatrixBlock tmpV = new MatrixBlock(tmp.getNumRows(), tmp.getNumColumns(), tmpArr);
CLALibDecompress.decompressTo((CompressedMatrixBlock) tmp, tmpV, 0, 0, k, false, true);
return tmpV;
}

private static boolean isOverlappingAllowed() {
return ConfigurationManager.getDMLConfig().getBooleanValue(DMLConfig.COMPRESSED_OVERLAPPING);
}
Expand Down Expand Up @@ -146,6 +199,8 @@ private static CompressedMatrixBlock filterColGroups(CompressedMatrixBlock x) {
final List<AColGroup> groups = x.getColGroups();
final boolean shouldFilter = CLALibUtils.shouldPreFilter(groups);
if(shouldFilter) {
if(CLALibUtils.alreadyPreFiltered(groups, x.getNumColumns()))
return x;
final int nCol = x.getNumColumns();
final double[] constV = new double[nCol];
final List<AColGroup> filteredGroups = CLALibUtils.filterGroups(groups, constV);
Expand Down
Loading
Loading