Skip to content

Commit

Permalink
[SYSTEMDS-3828] Parallel Compressed Replace
Browse files Browse the repository at this point in the history
This commit adds the parallel kernel for compressed
replace of values.

Closes apache#2209
  • Loading branch information
Baunsgaard authored and saminbassiri committed Feb 3, 2025
1 parent e979eab commit 5932fa4
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.apache.sysds.runtime.compress.lib.CLALibMMChain;
import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult;
import org.apache.sysds.runtime.compress.lib.CLALibMerge;
import org.apache.sysds.runtime.compress.lib.CLALibReplace;
import org.apache.sysds.runtime.compress.lib.CLALibReshape;
import org.apache.sysds.runtime.compress.lib.CLALibRexpand;
import org.apache.sysds.runtime.compress.lib.CLALibScalar;
Expand Down Expand Up @@ -307,7 +308,7 @@ public void putInto(MatrixBlock target, int rowOffset, int colOffset, boolean sp
* @return The cached decompressed matrix, if it does not exist return null
*/
public MatrixBlock getCachedDecompressed() {
if( allowCachingUncompressed && decompressedVersion != null) {
if(allowCachingUncompressed && decompressedVersion != null) {
final MatrixBlock mb = decompressedVersion.get();
if(mb != null) {
DMLCompressionStatistics.addDecompressCacheCount();
Expand Down Expand Up @@ -401,8 +402,8 @@ public long estimateCompressedSizeInMemory() {
long total = baseSizeInMemory();
// take into consideration duplicate dictionaries
Set<IDictionary> dicts = new HashSet<>();
for(AColGroup grp : _colGroups){
if(grp instanceof ADictBasedColGroup){
for(AColGroup grp : _colGroups) {
if(grp instanceof ADictBasedColGroup) {
IDictionary dg = ((ADictBasedColGroup) grp).getDictionary();
if(dicts.contains(dg))
total -= dg.getInMemorySize();
Expand Down Expand Up @@ -576,8 +577,7 @@ public void append(MatrixValue v2, ArrayList<IndexedMatrixValue> outlist, int bl
}

@Override
public MatrixBlock chainMatrixMultOperations(MatrixBlock v, MatrixBlock w, MatrixBlock out, ChainType ctype,
int k) {
public MatrixBlock chainMatrixMultOperations(MatrixBlock v, MatrixBlock w, MatrixBlock out, ChainType ctype, int k) {

checkMMChain(ctype, v, w);
// multi-threaded MMChain of single uncompressed ColGroup
Expand Down Expand Up @@ -629,27 +629,8 @@ public MatrixBlock transposeSelfMatrixMultOperations(MatrixBlock out, MMTSJType
}

@Override
public MatrixBlock replaceOperations(MatrixValue result, double pattern, double replacement) {
if(Double.isInfinite(pattern)) {
LOG.info("Ignoring replace infinite in compression since it does not contain this value");
return this;
}
else if(isOverlapping()) {
final String message = "replaceOperations " + pattern + " -> " + replacement;
return getUncompressed(message).replaceOperations(result, pattern, replacement);
}
else {

CompressedMatrixBlock ret = new CompressedMatrixBlock(getNumRows(), getNumColumns());
final List<AColGroup> prev = getColGroups();
final int colGroupsLength = prev.size();
final List<AColGroup> retList = new ArrayList<>(colGroupsLength);
for(int i = 0; i < colGroupsLength; i++)
retList.add(prev.get(i).replace(pattern, replacement));
ret.allocateColGroupList(retList);
ret.recomputeNonZeros();
return ret;
}
public MatrixBlock replaceOperations(MatrixValue result, double pattern, double replacement, int k) {
return CLALibReplace.replace(this, (MatrixBlock) result, pattern, replacement, k);
}

@Override
Expand Down Expand Up @@ -710,10 +691,10 @@ public boolean containsValue(double pattern) {
return false;
}
}

@Override
public boolean containsValue(double pattern, int k) {
//TODO parallel contains value
// TODO parallel contains value
return containsValue(pattern);
}

Expand Down Expand Up @@ -775,8 +756,8 @@ public boolean isEmptyBlock(boolean safe) {
return false;
else if(_colGroups == null || nonZeros == 0)
return true;
else{
if(nonZeros == -1){
else {
if(nonZeros == -1) {
// try to use column groups
for(AColGroup g : _colGroups)
if(!g.isEmpty())
Expand Down Expand Up @@ -1177,8 +1158,7 @@ public void appendRow(int r, SparseRow row, boolean deep) {
}

@Override
public void appendRowToSparse(SparseBlock dest, MatrixBlock src, int i, int rowoffset, int coloffset,
boolean deep) {
public void appendRowToSparse(SparseBlock dest, MatrixBlock src, int i, int rowoffset, int coloffset, boolean deep) {
throw new DMLCompressionException("Can't append row to compressed Matrix");
}

Expand Down Expand Up @@ -1238,7 +1218,7 @@ public void sparseToDense(int k) {
}

@Override
public void denseToSparse(boolean allowCSR, int k){
public void denseToSparse(boolean allowCSR, int k) {
// do nothing
}

Expand Down Expand Up @@ -1327,13 +1307,13 @@ public void allocateAndResetSparseBlock(boolean clearNNZ, SparseBlock.Type stype
throw new DMLCompressionException("Invalid to allocate block on a compressed MatrixBlock");
}

@Override
@Override
public MatrixBlock transpose(int k) {
return getUncompressed().transpose(k);
}

@Override
public MatrixBlock reshape(int rows,int cols, boolean byRow){
@Override
public MatrixBlock reshape(int rows, int cols, boolean byRow) {
return CLALibReshape.reshape(this, rows, cols, byRow);
}

Expand Down
108 changes: 108 additions & 0 deletions src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReplace.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;

public class CLALibReplace {
private static final Log LOG = LogFactory.getLog(CLALibReplace.class.getName());

private CLALibReplace(){
// private constructor
}

public static MatrixBlock replace(CompressedMatrixBlock in, MatrixBlock out, double pattern, double replacement,
int k) {
try {

if(Double.isInfinite(pattern)) {
LOG.info("Ignoring replace infinite in compression since it does not contain this value");
return in;
}
else if(in.isOverlapping()) {
final String message = "replaceOperations " + pattern + " -> " + replacement;
return in.getUncompressed(message).replaceOperations(out, pattern, replacement);
}
else
return replaceNormal(in, out, pattern, replacement, k);
}
catch(Exception e) {
throw new RuntimeException("Failed replace pattern: " + pattern + " replacement: " + replacement, e);
}
}

private static MatrixBlock replaceNormal(CompressedMatrixBlock in, MatrixBlock out, double pattern,
double replacement, int k) throws Exception {
CompressedMatrixBlock ret = new CompressedMatrixBlock(in.getNumRows(), in.getNumColumns());
final List<AColGroup> prev = in.getColGroups();
final int colGroupsLength = prev.size();
final List<AColGroup> retList = new ArrayList<>(colGroupsLength);

if(k <= 1)
replaceSingleThread(pattern, replacement, prev, colGroupsLength, retList);
else
replaceMultiThread(pattern, replacement, k, prev, colGroupsLength, retList);

ret.allocateColGroupList(retList);
if(replacement == 0) // have to recompute!
ret.recomputeNonZeros();
else if(pattern == 0) // always fully dense.
ret.setNonZeros(((long) in.getNumRows()) * in.getNumColumns());
else // same nonzeros as input
ret.setNonZeros(in.getNonZeros());
return ret;
}

private static void replaceMultiThread(double pattern, double replacement, int k, final List<AColGroup> prev,
final int colGroupsLength, final List<AColGroup> retList) throws InterruptedException, ExecutionException {
ExecutorService pool = CommonThreadPool.get(k);

try {
List<Future<AColGroup>> tasks = new ArrayList<>(colGroupsLength);
for(int i = 0; i < colGroupsLength; i++) {
final int j = i;
tasks.add(pool.submit(() -> prev.get(j).replace(pattern, replacement)));
}
for(int i = 0; i < colGroupsLength; i++) {
retList.add(tasks.get(i).get());
}
}
finally {
pool.shutdown();
}
}

private static void replaceSingleThread(double pattern, double replacement, final List<AColGroup> prev,
final int colGroupsLength, final List<AColGroup> retList) {
for(int i = 0; i < colGroupsLength; i++)
retList.add(prev.get(i).replace(pattern, replacement));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import org.apache.sysds.runtime.transform.tokenize.TokenizerFactory;
import org.apache.sysds.runtime.util.AutoDiff;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.utils.stats.InfrastructureAnalyzer;

public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction {
private static final Log LOG = LogFactory.getLog(ParameterizedBuiltinCPInstruction.class.getName());
Expand Down Expand Up @@ -276,7 +277,8 @@ else if(opcode.equalsIgnoreCase(Opcodes.REPLACE.toString())) {
MatrixBlock target = targetObj.acquireRead();
double pattern = Double.parseDouble(params.get("pattern"));
double replacement = Double.parseDouble(params.get("replacement"));
MatrixBlock ret = target.replaceOperations(new MatrixBlock(), pattern, replacement);
MatrixBlock ret = target.replaceOperations(new MatrixBlock(), pattern, replacement,
InfrastructureAnalyzer.getLocalParallelism());
if( ret == target ) //shallow copy (avoid bufferpool pollution)
ec.setVariable(output.getName(), targetObj);
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5157,9 +5157,13 @@ public MatrixBlock rexpandOperations( MatrixBlock ret, double max, boolean rows,


@Override
public MatrixBlock replaceOperations(MatrixValue result, double pattern, double replacement) {
public final MatrixBlock replaceOperations(MatrixValue result, double pattern, double replacement) {
return replaceOperations(result, pattern, replacement, 1);
}

public MatrixBlock replaceOperations(MatrixValue result, double pattern, double replacement, int k) {
MatrixBlock ret = checkType(result);
return LibMatrixReplace.replaceOperations(this, ret, pattern, replacement);
return LibMatrixReplace.replaceOperations(this, ret, pattern, replacement, k);
}

public MatrixBlock extractTriangular(MatrixBlock ret, boolean lower, boolean diag, boolean values) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.apache.sysds.test.component.compress;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

Expand All @@ -38,6 +39,7 @@
import org.apache.sysds.runtime.compress.cost.CostEstimatorFactory;
import org.apache.sysds.runtime.compress.cost.InstructionTypeCounter;
import org.apache.sysds.runtime.compress.lib.CLALibCBind;
import org.apache.sysds.runtime.compress.lib.CLALibReplace;
import org.apache.sysds.runtime.compress.workload.WTreeRoot;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.test.TestUtils;
Expand Down Expand Up @@ -397,9 +399,18 @@ public void manyRowsButNotQuite() {
TestUtils.compareMatricesBitAvgDistance(m1, m2, 0, 0, "no");
}

@Test(expected = Exception.class)
public void cbindWithError() {
CLALibCBind.cbind(null, new MatrixBlock[] {null}, 0);
}

@Test(expected = Exception.class)
public void cbindWithError(){
CLALibCBind.cbind(null, new MatrixBlock[]{null}, 0);
public void replaceWithError() {
CLALibReplace.replace(null, null, 0, 0, 10);
}

@Test
public void replaceInf() {
assertNull(CLALibReplace.replace(null, null, Double.POSITIVE_INFINITY, 0, 10));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -329,38 +329,6 @@ public void testContainsValue_not() {
}
}

@Test
public void testReplaceNotContainedValue() {
double v = min - 1;
if(v != 0)
testReplace(v);
}

@Test
public void testReplace() {
if(min != 0)
testReplace(min);
}

@Test
public void testReplaceZero() {
testReplace(0);
}

private void testReplace(double value) {
try {
if(!(cmb instanceof CompressedMatrixBlock) || rows * cols > 10000)
return;
ucRet = mb.replaceOperations(ucRet, value, 1425);
MatrixBlock ret2 = cmb.replaceOperations(new MatrixBlock(), value, 1425);
compareResultMatrices(ucRet, ret2, 1);
}
catch(Exception e) {
e.printStackTrace();
throw new DMLRuntimeException(e);
}
}

@Test
public void testCompressedMatrixConstruction() {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1173,7 +1173,7 @@ public void appendCBindAlignedSelfMultiple() {
}
catch(AssertionError e) {
e.printStackTrace();
fail("failed Cbind: " + cmb.toString() );
fail("failed Cbind: " + cmb.toString());
}
}

Expand Down Expand Up @@ -1299,4 +1299,42 @@ protected static CompressionSettingsBuilder csb() {
return new CompressionSettingsBuilder().setSeed(compressionSeed).setMinimumSampleSize(100);
}

@Test
public void testReplaceNotContainedValue() {
double v = min - 1;
if(v != 0)
testReplace(v, 132);
}

@Test
public void testReplace() {
if(min != 0)
testReplace(min, 323);
}

@Test
public void testReplaceWithZero() {
if(min != 0)
testReplace(min, 0);
}

@Test
public void testReplaceZero() {
testReplace(0, 3232);
}

private void testReplace(double value, double replacements) {
try {
if(!(cmb instanceof CompressedMatrixBlock) || rows * cols > 10000)
return;
ucRet = mb.replaceOperations(ucRet, value, replacements, _k);
MatrixBlock ret2 = cmb.replaceOperations(new MatrixBlock(), value, replacements, _k);
compareResultMatrices(ucRet, ret2, 1);
}
catch(Exception e) {
e.printStackTrace();
throw new DMLRuntimeException(e);
}
}

}

0 comments on commit 5932fa4

Please sign in to comment.