Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYSTEMDS-3828] Parallel Compressed Replace #2209

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.apache.sysds.runtime.compress.lib.CLALibMMChain;
import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult;
import org.apache.sysds.runtime.compress.lib.CLALibMerge;
import org.apache.sysds.runtime.compress.lib.CLALibReplace;
import org.apache.sysds.runtime.compress.lib.CLALibReshape;
import org.apache.sysds.runtime.compress.lib.CLALibRexpand;
import org.apache.sysds.runtime.compress.lib.CLALibScalar;
Expand Down Expand Up @@ -307,7 +308,7 @@ public void putInto(MatrixBlock target, int rowOffset, int colOffset, boolean sp
* @return The cached decompressed matrix, if it does not exist return null
*/
public MatrixBlock getCachedDecompressed() {
if( allowCachingUncompressed && decompressedVersion != null) {
if(allowCachingUncompressed && decompressedVersion != null) {
final MatrixBlock mb = decompressedVersion.get();
if(mb != null) {
DMLCompressionStatistics.addDecompressCacheCount();
Expand Down Expand Up @@ -401,8 +402,8 @@ public long estimateCompressedSizeInMemory() {
long total = baseSizeInMemory();
// take into consideration duplicate dictionaries
Set<IDictionary> dicts = new HashSet<>();
for(AColGroup grp : _colGroups){
if(grp instanceof ADictBasedColGroup){
for(AColGroup grp : _colGroups) {
if(grp instanceof ADictBasedColGroup) {
IDictionary dg = ((ADictBasedColGroup) grp).getDictionary();
if(dicts.contains(dg))
total -= dg.getInMemorySize();
Expand Down Expand Up @@ -576,8 +577,7 @@ public void append(MatrixValue v2, ArrayList<IndexedMatrixValue> outlist, int bl
}

@Override
public MatrixBlock chainMatrixMultOperations(MatrixBlock v, MatrixBlock w, MatrixBlock out, ChainType ctype,
int k) {
public MatrixBlock chainMatrixMultOperations(MatrixBlock v, MatrixBlock w, MatrixBlock out, ChainType ctype, int k) {

checkMMChain(ctype, v, w);
// multi-threaded MMChain of single uncompressed ColGroup
Expand Down Expand Up @@ -629,27 +629,8 @@ public MatrixBlock transposeSelfMatrixMultOperations(MatrixBlock out, MMTSJType
}

@Override
public MatrixBlock replaceOperations(MatrixValue result, double pattern, double replacement) {
if(Double.isInfinite(pattern)) {
LOG.info("Ignoring replace infinite in compression since it does not contain this value");
return this;
}
else if(isOverlapping()) {
final String message = "replaceOperations " + pattern + " -> " + replacement;
return getUncompressed(message).replaceOperations(result, pattern, replacement);
}
else {

CompressedMatrixBlock ret = new CompressedMatrixBlock(getNumRows(), getNumColumns());
final List<AColGroup> prev = getColGroups();
final int colGroupsLength = prev.size();
final List<AColGroup> retList = new ArrayList<>(colGroupsLength);
for(int i = 0; i < colGroupsLength; i++)
retList.add(prev.get(i).replace(pattern, replacement));
ret.allocateColGroupList(retList);
ret.recomputeNonZeros();
return ret;
}
public MatrixBlock replaceOperations(MatrixValue result, double pattern, double replacement, int k) {
return CLALibReplace.replace(this, (MatrixBlock) result, pattern, replacement, k);
}

@Override
Expand Down Expand Up @@ -710,10 +691,10 @@ public boolean containsValue(double pattern) {
return false;
}
}

@Override
public boolean containsValue(double pattern, int k) {
//TODO parallel contains value
// TODO parallel contains value
return containsValue(pattern);
}

Expand Down Expand Up @@ -775,8 +756,8 @@ public boolean isEmptyBlock(boolean safe) {
return false;
else if(_colGroups == null || nonZeros == 0)
return true;
else{
if(nonZeros == -1){
else {
if(nonZeros == -1) {
// try to use column groups
for(AColGroup g : _colGroups)
if(!g.isEmpty())
Expand Down Expand Up @@ -1177,8 +1158,7 @@ public void appendRow(int r, SparseRow row, boolean deep) {
}

@Override
public void appendRowToSparse(SparseBlock dest, MatrixBlock src, int i, int rowoffset, int coloffset,
boolean deep) {
public void appendRowToSparse(SparseBlock dest, MatrixBlock src, int i, int rowoffset, int coloffset, boolean deep) {
throw new DMLCompressionException("Can't append row to compressed Matrix");
}

Expand Down Expand Up @@ -1238,7 +1218,7 @@ public void sparseToDense(int k) {
}

@Override
public void denseToSparse(boolean allowCSR, int k){
public void denseToSparse(boolean allowCSR, int k) {
// do nothing
}

Expand Down Expand Up @@ -1327,13 +1307,13 @@ public void allocateAndResetSparseBlock(boolean clearNNZ, SparseBlock.Type stype
throw new DMLCompressionException("Invalid to allocate block on a compressed MatrixBlock");
}

@Override
@Override
public MatrixBlock transpose(int k) {
return getUncompressed().transpose(k);
}

@Override
public MatrixBlock reshape(int rows,int cols, boolean byRow){
@Override
public MatrixBlock reshape(int rows, int cols, boolean byRow) {
return CLALibReshape.reshape(this, rows, cols, byRow);
}

Expand Down
108 changes: 108 additions & 0 deletions src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReplace.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;

public class CLALibReplace {
private static final Log LOG = LogFactory.getLog(CLALibReplace.class.getName());

private CLALibReplace(){
// private constructor
}

public static MatrixBlock replace(CompressedMatrixBlock in, MatrixBlock out, double pattern, double replacement,
int k) {
try {

if(Double.isInfinite(pattern)) {
LOG.info("Ignoring replace infinite in compression since it does not contain this value");
return in;
}
else if(in.isOverlapping()) {
final String message = "replaceOperations " + pattern + " -> " + replacement;
return in.getUncompressed(message).replaceOperations(out, pattern, replacement);
}
else
return replaceNormal(in, out, pattern, replacement, k);
}
catch(Exception e) {
throw new RuntimeException("Failed replace pattern: " + pattern + " replacement: " + replacement, e);
}
}

private static MatrixBlock replaceNormal(CompressedMatrixBlock in, MatrixBlock out, double pattern,
double replacement, int k) throws Exception {
CompressedMatrixBlock ret = new CompressedMatrixBlock(in.getNumRows(), in.getNumColumns());
final List<AColGroup> prev = in.getColGroups();
final int colGroupsLength = prev.size();
final List<AColGroup> retList = new ArrayList<>(colGroupsLength);

if(k <= 1)
replaceSingleThread(pattern, replacement, prev, colGroupsLength, retList);
else
replaceMultiThread(pattern, replacement, k, prev, colGroupsLength, retList);

ret.allocateColGroupList(retList);
if(replacement == 0) // have to recompute!
ret.recomputeNonZeros();
else if(pattern == 0) // always fully dense.
ret.setNonZeros(((long) in.getNumRows()) * in.getNumColumns());
else // same nonzeros as input
ret.setNonZeros(in.getNonZeros());
return ret;
}

private static void replaceMultiThread(double pattern, double replacement, int k, final List<AColGroup> prev,
final int colGroupsLength, final List<AColGroup> retList) throws InterruptedException, ExecutionException {
ExecutorService pool = CommonThreadPool.get(k);

try {
List<Future<AColGroup>> tasks = new ArrayList<>(colGroupsLength);
for(int i = 0; i < colGroupsLength; i++) {
final int j = i;
tasks.add(pool.submit(() -> prev.get(j).replace(pattern, replacement)));
}
for(int i = 0; i < colGroupsLength; i++) {
retList.add(tasks.get(i).get());
}
}
finally {
pool.shutdown();
}
}

private static void replaceSingleThread(double pattern, double replacement, final List<AColGroup> prev,
final int colGroupsLength, final List<AColGroup> retList) {
for(int i = 0; i < colGroupsLength; i++)
retList.add(prev.get(i).replace(pattern, replacement));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import org.apache.sysds.runtime.transform.tokenize.TokenizerFactory;
import org.apache.sysds.runtime.util.AutoDiff;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.utils.stats.InfrastructureAnalyzer;

public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction {
private static final Log LOG = LogFactory.getLog(ParameterizedBuiltinCPInstruction.class.getName());
Expand Down Expand Up @@ -276,7 +277,8 @@ else if(opcode.equalsIgnoreCase(Opcodes.REPLACE.toString())) {
MatrixBlock target = targetObj.acquireRead();
double pattern = Double.parseDouble(params.get("pattern"));
double replacement = Double.parseDouble(params.get("replacement"));
MatrixBlock ret = target.replaceOperations(new MatrixBlock(), pattern, replacement);
MatrixBlock ret = target.replaceOperations(new MatrixBlock(), pattern, replacement,
InfrastructureAnalyzer.getLocalParallelism());
if( ret == target ) //shallow copy (avoid bufferpool pollution)
ec.setVariable(output.getName(), targetObj);
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5157,9 +5157,13 @@ public MatrixBlock rexpandOperations( MatrixBlock ret, double max, boolean rows,


@Override
public MatrixBlock replaceOperations(MatrixValue result, double pattern, double replacement) {
public final MatrixBlock replaceOperations(MatrixValue result, double pattern, double replacement) {
return replaceOperations(result, pattern, replacement, 1);
}

public MatrixBlock replaceOperations(MatrixValue result, double pattern, double replacement, int k) {
MatrixBlock ret = checkType(result);
return LibMatrixReplace.replaceOperations(this, ret, pattern, replacement);
return LibMatrixReplace.replaceOperations(this, ret, pattern, replacement, k);
}

public MatrixBlock extractTriangular(MatrixBlock ret, boolean lower, boolean diag, boolean values) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.apache.sysds.test.component.compress;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

Expand All @@ -38,6 +39,7 @@
import org.apache.sysds.runtime.compress.cost.CostEstimatorFactory;
import org.apache.sysds.runtime.compress.cost.InstructionTypeCounter;
import org.apache.sysds.runtime.compress.lib.CLALibCBind;
import org.apache.sysds.runtime.compress.lib.CLALibReplace;
import org.apache.sysds.runtime.compress.workload.WTreeRoot;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.test.TestUtils;
Expand Down Expand Up @@ -397,9 +399,18 @@ public void manyRowsButNotQuite() {
TestUtils.compareMatricesBitAvgDistance(m1, m2, 0, 0, "no");
}

@Test(expected = Exception.class)
public void cbindWithError() {
CLALibCBind.cbind(null, new MatrixBlock[] {null}, 0);
}

@Test(expected = Exception.class)
public void cbindWithError(){
CLALibCBind.cbind(null, new MatrixBlock[]{null}, 0);
public void replaceWithError() {
CLALibReplace.replace(null, null, 0, 0, 10);
}

@Test
public void replaceInf() {
assertNull(CLALibReplace.replace(null, null, Double.POSITIVE_INFINITY, 0, 10));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -329,38 +329,6 @@ public void testContainsValue_not() {
}
}

@Test
public void testReplaceNotContainedValue() {
double v = min - 1;
if(v != 0)
testReplace(v);
}

@Test
public void testReplace() {
if(min != 0)
testReplace(min);
}

@Test
public void testReplaceZero() {
testReplace(0);
}

private void testReplace(double value) {
try {
if(!(cmb instanceof CompressedMatrixBlock) || rows * cols > 10000)
return;
ucRet = mb.replaceOperations(ucRet, value, 1425);
MatrixBlock ret2 = cmb.replaceOperations(new MatrixBlock(), value, 1425);
compareResultMatrices(ucRet, ret2, 1);
}
catch(Exception e) {
e.printStackTrace();
throw new DMLRuntimeException(e);
}
}

@Test
public void testCompressedMatrixConstruction() {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1173,7 +1173,7 @@ public void appendCBindAlignedSelfMultiple() {
}
catch(AssertionError e) {
e.printStackTrace();
fail("failed Cbind: " + cmb.toString() );
fail("failed Cbind: " + cmb.toString());
}
}

Expand Down Expand Up @@ -1299,4 +1299,42 @@ protected static CompressionSettingsBuilder csb() {
return new CompressionSettingsBuilder().setSeed(compressionSeed).setMinimumSampleSize(100);
}

@Test
public void testReplaceNotContainedValue() {
double v = min - 1;
if(v != 0)
testReplace(v, 132);
}

@Test
public void testReplace() {
if(min != 0)
testReplace(min, 323);
}

@Test
public void testReplaceWithZero() {
if(min != 0)
testReplace(min, 0);
}

@Test
public void testReplaceZero() {
testReplace(0, 3232);
}

private void testReplace(double value, double replacements) {
try {
if(!(cmb instanceof CompressedMatrixBlock) || rows * cols > 10000)
return;
ucRet = mb.replaceOperations(ucRet, value, replacements, _k);
MatrixBlock ret2 = cmb.replaceOperations(new MatrixBlock(), value, replacements, _k);
compareResultMatrices(ucRet, ret2, 1);
}
catch(Exception e) {
e.printStackTrace();
throw new DMLRuntimeException(e);
}
}

}
Loading