Skip to content

Commit

Permalink
[SYSTEMDS-3824] Decompressing Transpose
Browse files Browse the repository at this point in the history
Sebastian Baunsgaard <[email protected]> introduced a new CLALib for Reorg, specifically Transpose

e-strauss <[email protected]> applied minor changes:
- a manual rewrite in bultin kmeans script to use argmin (reduced runtime by 18%)
- added new decompressing transpose to DenseBlock from SparseBlock for ColGroupDDC
- fixed bug in sparsity evaluation in decompressed transposed (switch nrow w/ ncol)
- minor bug fix in regarding the cached decompression count
- fixed ctable with seq fuse rewrite fused ctable with given output dim (disaled: performance decrease, need to fix it first)
- fixed null handling in fused seq ctable
- fixed tests which passed for the wrong reason

Co-authored-by: e-strauss <[email protected]>
Co-authored-by: Sebastian Baunsgaard <[email protected]>
  • Loading branch information
Baunsgaard and e-strauss committed Feb 5, 2025
1 parent fd1ba7c commit dc3947a
Show file tree
Hide file tree
Showing 12 changed files with 268 additions and 87 deletions.
1 change: 1 addition & 0 deletions scripts/builtin/kmeans.dml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ m_kmeans = function(Matrix[Double] X, Integer k = 10, Integer runs = 10, Integer
P = D <= minD;
# If some records belong to multiple centroids, share them equally
P = P / rowSums (P);
# P = table(seq(1,num_records), rowIndexMin(D), num_records, num_centroids)
# Compute the column normalization factor for P
P_denom = colSums (P);
# Compute new centroids as weighted averages over the records
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/apache/sysds/hops/TernaryOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,8 @@ public boolean isSequenceRewriteApplicable( boolean left )

try
{
// TODO: to rewrite is not currently not triggered if outdim are given --> getInput().size()>=3
// currently disabled due performance decrease
if( getInput().size()==2 || (getInput().size()==3 && getInput().get(2).getDataType()==DataType.SCALAR) )
{
Hop input1 = getInput().get(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult;
import org.apache.sysds.runtime.compress.lib.CLALibMerge;
import org.apache.sysds.runtime.compress.lib.CLALibReplace;
import org.apache.sysds.runtime.compress.lib.CLALibReorg;
import org.apache.sysds.runtime.compress.lib.CLALibReshape;
import org.apache.sysds.runtime.compress.lib.CLALibRexpand;
import org.apache.sysds.runtime.compress.lib.CLALibScalar;
Expand Down Expand Up @@ -633,21 +634,7 @@ public MatrixBlock replaceOperations(MatrixValue result, double pattern, double

@Override
public MatrixBlock reorgOperations(ReorgOperator op, MatrixValue ret, int startRow, int startColumn, int length) {
if(op.fn instanceof SwapIndex && this.getNumColumns() == 1) {
MatrixBlock tmp = decompress(op.getNumThreads());
long nz = tmp.setNonZeros(tmp.getNonZeros());
tmp = new MatrixBlock(tmp.getNumColumns(), tmp.getNumRows(), tmp.getDenseBlockValues());
tmp.setNonZeros(nz);
return tmp;
}
else {
// Allow transpose to be compressed output. In general we need to have a transposed flag on
// the compressed matrix. https://issues.apache.org/jira/browse/SYSTEMDS-3025
String message = op.getClass().getSimpleName() + " -- " + op.fn.getClass().getSimpleName();
MatrixBlock tmp = getUncompressed(message, op.getNumThreads());
return tmp.reorgOperations(op, ret, startRow, startColumn, length);
}

return CLALibReorg.reorg(this, op, (MatrixBlock) ret, startRow, startColumn, length);
}

public boolean isOverlapping() {
Expand Down Expand Up @@ -1311,7 +1298,7 @@ public void allocateAndResetSparseBlock(boolean clearNNZ, SparseBlock.Type stype

@Override
public MatrixBlock transpose(int k) {
return getUncompressed().transpose(k);
return CLALibReorg.reorg(this, new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), null, 0, 0, 0);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,21 @@ protected void decompressToSparseBlockDenseDictionary(SparseBlock ret, int rl, i

@Override
protected void decompressToDenseBlockTransposedSparseDictionary(DenseBlock db, int rl, int ru, SparseBlock sb) {
throw new NotImplementedException();
for(int i = rl; i < ru; i++) {
final int vr = _data.getIndex(i);
if(sb.isEmpty(vr))
continue;
final int apos = sb.pos(vr);
final int alen = sb.size(vr) + apos;
final int[] aix = sb.indexes(vr);
final double[] aval = sb.values(vr);
for(int j = apos; j < alen; j++) {
final int rowOut = _colIndexes.get(aix[j]);
final double[] c = db.values(rowOut);
final int off = db.pos(rowOut);
c[off + i] += aval[j];
}
}
}

@Override
Expand Down
158 changes: 158 additions & 0 deletions src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReorg.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseBlockMCSR;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.util.CommonThreadPool;

public class CLALibReorg {

protected static final Log LOG = LogFactory.getLog(CLALibReorg.class.getName());

public static boolean warned = false;

public static MatrixBlock reorg(CompressedMatrixBlock cmb, ReorgOperator op, MatrixBlock ret, int startRow,
int startColumn, int length) {
// SwapIndex is transpose
if(op.fn instanceof SwapIndex && cmb.getNumColumns() == 1) {
MatrixBlock tmp = cmb.decompress(op.getNumThreads());
long nz = tmp.setNonZeros(tmp.getNonZeros());
if(tmp.isInSparseFormat())
return LibMatrixReorg.transpose(tmp); // edge case...
else
tmp = new MatrixBlock(tmp.getNumColumns(), tmp.getNumRows(), tmp.getDenseBlockValues());
tmp.setNonZeros(nz);
return tmp;
}
else if(op.fn instanceof SwapIndex) {
MatrixBlock tmp = cmb.getCachedDecompressed();
if(tmp != null)
return tmp.reorgOperations(op, ret, startRow, startColumn, length);
// Allow transpose to be compressed output. In general we need to have a transposed flag on
// the compressed matrix. https://issues.apache.org/jira/browse/SYSTEMDS-3025
return transpose(cmb, ret, op.getNumThreads());
}
else {
String message = !warned ? op.getClass().getSimpleName() + " -- " + op.fn.getClass().getSimpleName() : null;
MatrixBlock tmp = cmb.getUncompressed(message, op.getNumThreads());
warned = true;
return tmp.reorgOperations(op, ret, startRow, startColumn, length);
}
}

private static MatrixBlock transpose(CompressedMatrixBlock cmb, MatrixBlock ret, int k) {

final long nnz = cmb.getNonZeros();
final int nRow = cmb.getNumRows();
final int nCol = cmb.getNumColumns();
final boolean sparseOut = MatrixBlock.evalSparseFormatInMemory(nCol,nRow, nnz);
if(sparseOut)
return transposeSparse(cmb, ret, k, nRow, nCol, nnz);
else
return transposeDense(cmb, ret, k, nRow, nCol, nnz);
}

private static MatrixBlock transposeSparse(CompressedMatrixBlock cmb, MatrixBlock ret, int k, int nRow, int nCol,
long nnz) {
if(ret == null)
ret = new MatrixBlock(nCol, nRow, true, nnz);
else
ret.reset(nCol, nRow, true, nnz);

ret.allocateAndResetSparseBlock(true, SparseBlock.Type.MCSR);

final int nColOut = ret.getNumColumns();

if(k > 1 && cmb.getColGroups().size() > 1)
decompressToTransposedSparseParallel((SparseBlockMCSR) ret.getSparseBlock(), cmb.getColGroups(), nColOut, k);
else
decompressToTransposedSparseSingleThread((SparseBlockMCSR) ret.getSparseBlock(), cmb.getColGroups(), nColOut);

return ret;
}

private static MatrixBlock transposeDense(CompressedMatrixBlock cmb, MatrixBlock ret, int k, int nRow, int nCol,
long nnz) {
if(ret == null)
ret = new MatrixBlock(nCol, nRow, false, nnz);
else
ret.reset(nCol, nRow, false, nnz);

// TODO: parallelize
ret.allocateDenseBlock();

decompressToTransposedDense(ret.getDenseBlock(), cmb.getColGroups(), nRow, 0, nRow);
return ret;
}

private static void decompressToTransposedDense(DenseBlock ret, List<AColGroup> groups, int rlen, int rl, int ru) {
for(int i = 0; i < groups.size(); i++) {
AColGroup g = groups.get(i);
g.decompressToDenseBlockTransposed(ret, rl, ru);
}
}

private static void decompressToTransposedSparseSingleThread(SparseBlockMCSR ret, List<AColGroup> groups,
int nColOut) {
for(int i = 0; i < groups.size(); i++) {
AColGroup g = groups.get(i);
g.decompressToSparseBlockTransposed(ret, nColOut);
}
}

private static void decompressToTransposedSparseParallel(SparseBlockMCSR ret, List<AColGroup> groups, int nColOut,
int k) {
final ExecutorService pool = CommonThreadPool.get(k);
try {
final List<Future<?>> tasks = new ArrayList<>(groups.size());

for(int i = 0; i < groups.size(); i++) {
final AColGroup g = groups.get(i);
tasks.add(pool.submit(() -> g.decompressToSparseBlockTransposed(ret, nColOut)));
}

for(Future<?> f : tasks)
f.get();

}
catch(Exception e) {
throw new DMLCompressionException("Failed to parallel decompress transpose sparse", e);
}
finally {
pool.shutdown();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.Pair;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.UtilFunctions;

Expand Down Expand Up @@ -71,19 +72,23 @@ public static MatrixBlock rexpand(int seqHeight, MatrixBlock A, int nColOut, int

try {
final int[] map = new int[seqHeight];
int maxCol = constructInitialMapping(map, A, k);
Pair<Integer, Integer> meta = constructInitialMapping(map, A, k, nColOut);
int maxCol = meta.getKey();
int nZeros = meta.getValue();
boolean containsNull = maxCol < 0;
maxCol = Math.abs(maxCol);

boolean cutOff = false;
if(nColOut == -1)
nColOut = maxCol;
else if(nColOut < maxCol)
throw new DMLRuntimeException("invalid nColOut, requested: " + nColOut + " but have to be : " + maxCol);
cutOff = true;

final int nNulls = containsNull ? correctNulls(map, nColOut) : 0;
if(containsNull)
correctNulls(map, nColOut);
if(nColOut == 0) // edge case of empty zero dimension block.
return new MatrixBlock(seqHeight, 0, 0.0);
return createCompressedReturn(map, nColOut, seqHeight, nNulls, containsNull, k);
return createCompressedReturn(map, nColOut, seqHeight, nZeros, containsNull || cutOff, k);
}
catch(Exception e) {
throw new RuntimeException("Failed table seq operator", e);
Expand Down Expand Up @@ -139,7 +144,7 @@ private static int correctNulls(int[] map, int nColOut) {
return nNulls;
}

private static int constructInitialMapping(int[] map, MatrixBlock A, int k) {
private static Pair<Integer,Integer> constructInitialMapping(int[] map, MatrixBlock A, int k, int maxOutCol) {
if(A.isEmpty() || A.isInSparseFormat())
throw new DMLRuntimeException("not supported empty or sparse construction of seq table");
final MatrixBlock Ac;
Expand All @@ -155,20 +160,23 @@ private static int constructInitialMapping(int[] map, MatrixBlock A, int k) {
try {

int blkz = Math.max((map.length / k), 1000);
List<Future<Integer>> tasks = new ArrayList<>();
List<Future<Pair<Integer,Integer>>> tasks = new ArrayList<>();
for(int i = 0; i < map.length; i += blkz) {
final int start = i;
final int end = Math.min(i + blkz, map.length);
tasks.add(pool.submit(() -> partialMapping(map, Ac, start, end)));
tasks.add(pool.submit(() -> partialMapping(map, Ac, start, end, maxOutCol)));
}

int maxCol = 0;
for(Future<Integer> f : tasks) {
int tmp = f.get();
if(Math.abs(tmp) > Math.abs(maxCol))
maxCol = tmp;
int zeros = 0;
for(Future<Pair<Integer,Integer>> f : tasks) {
int tmpMaxCol = f.get().getKey();
int tmpZeros = f.get().getValue();
if(Math.abs(tmpMaxCol) > Math.abs(maxCol))
maxCol = tmpMaxCol;
zeros += tmpZeros;
}
return maxCol;
return new Pair<Integer,Integer>(maxCol, zeros);
}
catch(Exception e) {
throw new DMLRuntimeException(e);
Expand All @@ -179,33 +187,32 @@ private static int constructInitialMapping(int[] map, MatrixBlock A, int k) {

}

private static int partialMapping(int[] map, MatrixBlock A, int start, int end) {
private static Pair<Integer, Integer> partialMapping(int[] map, MatrixBlock A, int start, int end, int maxOutCol) {

int maxCol = 0;
boolean containsNull = false;

int zeros = 0;
final double[] aVals = A.getDenseBlockValues();

for(int i = start; i < end; i++) {
final double v2 = aVals[i];
if(Double.isNaN(v2)) {
map[i] = -1; // assign temporarily to -1
containsNull = true;
}
else {
// safe casts to long for consistent behavior with indexing
int col = UtilFunctions.toInt(v2);
if(col <= 0)
throw new DMLRuntimeException(
final int colUnsafe = UtilFunctions.toInt(v2);
if(!Double.isNaN(v2) && colUnsafe < 0)
throw new DMLRuntimeException(
"Erroneous input while computing the contingency table (value <= zero): " + v2);
// Boolean to int conversion to avoid branch
final int invalid = Double.isNaN(v2) || (maxOutCol != -1 && colUnsafe > maxOutCol) ? 1 : 0;
// if invalid -> maxOutCol else -> colUnsafe - 1
final int colSafe = maxOutCol*invalid + (colUnsafe - 1)*(1 - invalid);
zeros += invalid;
maxCol = Math.max(colUnsafe, maxCol);
map[i] = colSafe;
}

map[i] = col - 1;
// maintain max seen col
maxCol = Math.max(col, maxCol);
}
if (maxOutCol == -1 && zeros > 0){
maxCol *= -1;
}

return containsNull ? maxCol * -1 : maxCol;
return new Pair<Integer, Integer>(maxCol, zeros);
}

public static boolean compressedTableSeq() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ private void processSimpleCompressInstruction(ExecutionContext ec) {
else if(ec.isMatrixObject(input1.getName()))
processMatrixBlockCompression(ec, ec.getMatrixInput(input1.getName()), _numThreads, root);
else {
throw new NotImplementedException("Not supported other types of input for compression than frame and matrix");
LOG.warn("Compression on Scalar should not happen");
ScalarObject Scalar = ec.getScalarInput(input1);
ec.setScalarOutput(output.getName(),Scalar);
}
}

Expand Down
Loading

0 comments on commit dc3947a

Please sign in to comment.