Skip to content

Commit

Permalink
[MINOR] Compressed Dictionary Tests
Browse files Browse the repository at this point in the history
This commit adds some (apparently) much needed tests primarily focusing
on the Dictionary abstractions used for most of the dictionaries.

These changes resulted in going from 100 lines of tests, to 3.3k
changes to many files in the compression framework.

Closes #2183
  • Loading branch information
Baunsgaard committed Jan 28, 2025
1 parent 6b37c85 commit b96cf25
Show file tree
Hide file tree
Showing 41 changed files with 3,353 additions and 1,832 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
import java.io.ObjectOutput;
import java.lang.ref.SoftReference;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;

Expand All @@ -42,9 +44,11 @@
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType;
import org.apache.sysds.runtime.compress.colgroup.ADictBasedColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupIO;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.lib.CLALibAppend;
import org.apache.sysds.runtime.compress.lib.CLALibBinaryCellOp;
import org.apache.sysds.runtime.compress.lib.CLALibCMOps;
Expand Down Expand Up @@ -99,14 +103,13 @@ public class CompressedMatrixBlock extends MatrixBlock {
private static final Log LOG = LogFactory.getLog(CompressedMatrixBlock.class.getName());
private static final long serialVersionUID = 73193720143154058L;

/**
* Debugging flag for Compressed Matrices
*/
/** Debugging flag for Compressed Matrices */
public static boolean debug = false;

/**
* Column groups
*/
/** Disallow caching of uncompressed Block */
public static boolean allowCachingUncompressed = true;

/** Column groups */
protected transient List<AColGroup> _colGroups;

/**
Expand All @@ -119,6 +122,9 @@ public class CompressedMatrixBlock extends MatrixBlock {
*/
protected transient SoftReference<MatrixBlock> decompressedVersion;

/** Cached Memory size */
protected transient long cachedMemorySize = -1;

public CompressedMatrixBlock() {
super(true);
sparse = false;
Expand Down Expand Up @@ -169,7 +175,9 @@ protected CompressedMatrixBlock(MatrixBlock uncompressedMatrixBlock) {
clen = uncompressedMatrixBlock.getNumColumns();
sparse = false;
nonZeros = uncompressedMatrixBlock.getNonZeros();
decompressedVersion = new SoftReference<>(uncompressedMatrixBlock);
if(!(uncompressedMatrixBlock instanceof CompressedMatrixBlock)) {
decompressedVersion = new SoftReference<>(uncompressedMatrixBlock);
}
}

/**
Expand All @@ -189,6 +197,7 @@ public CompressedMatrixBlock(int rl, int cl, long nnz, boolean overlapping, List
this.nonZeros = nnz;
this.overlappingColGroups = overlapping;
this._colGroups = groups;
getInMemorySize(); // cache memory size
}

@Override
Expand All @@ -204,6 +213,7 @@ public void reset(int rl, int cl, boolean sp, long estnnz, double val) {
* @param cg The column group to use after.
*/
public void allocateColGroup(AColGroup cg) {
cachedMemorySize = -1;
_colGroups = new ArrayList<>(1);
_colGroups.add(cg);
}
Expand Down Expand Up @@ -270,6 +280,12 @@ public synchronized MatrixBlock decompress(int k) {

ret = CLALibDecompress.decompress(this, k);

if(ret.getNonZeros() <= 0) {
LOG.warn("Decompress incorrectly set nnz to 0 or -1");
ret.recomputeNonZeros(k);
}
ret.examSparsity(k);

// Set soft reference to the decompressed version
decompressedVersion = new SoftReference<>(ret);

Expand All @@ -290,7 +306,7 @@ public void putInto(MatrixBlock target, int rowOffset, int colOffset, boolean sp
* @return The cached decompressed matrix, if it does not exist return null
*/
public MatrixBlock getCachedDecompressed() {
if(decompressedVersion != null) {
if( allowCachingUncompressed && decompressedVersion != null) {
final MatrixBlock mb = decompressedVersion.get();
if(mb != null) {
DMLCompressionStatistics.addDecompressCacheCount();
Expand All @@ -302,6 +318,7 @@ public MatrixBlock getCachedDecompressed() {
}

public CompressedMatrixBlock squash(int k) {
cachedMemorySize = -1;
return CLALibSquash.squash(this, k);
}

Expand Down Expand Up @@ -377,12 +394,27 @@ public long estimateSizeInMemory() {
* @return an upper bound on the memory used to store this compressed block considering class overhead.
*/
public long estimateCompressedSizeInMemory() {
long total = baseSizeInMemory();

for(AColGroup grp : _colGroups)
total += grp.estimateInMemorySize();
if(cachedMemorySize <= -1L) {

long total = baseSizeInMemory();
// take into consideration duplicate dictionaries
Set<IDictionary> dicts = new HashSet<>();
for(AColGroup grp : _colGroups){
if(grp instanceof ADictBasedColGroup){
IDictionary dg = ((ADictBasedColGroup) grp).getDictionary();
if(dicts.contains(dg))
total -= dg.getInMemorySize();
dicts.add(dg);
}
total += grp.estimateInMemorySize();
}
cachedMemorySize = total;
return total;

return total;
}
else
return cachedMemorySize;
}

public static long baseSizeInMemory() {
Expand All @@ -392,6 +424,7 @@ public static long baseSizeInMemory() {
total += 8; // Col Group Ref
total += 8; // v reference
total += 8; // soft reference to decompressed version
total += 8; // long cached memory size
total += 1 + 7; // Booleans plus padding

total += 40; // Col Group Array List
Expand Down Expand Up @@ -431,6 +464,7 @@ public long estimateSizeOnDisk() {

@Override
public void readFields(DataInput in) throws IOException {
cachedMemorySize = -1;
// deserialize compressed block
rlen = in.readInt();
clen = in.readInt();
Expand Down Expand Up @@ -736,8 +770,22 @@ public MatrixBlock rexpandOperations(MatrixBlock ret, double max, boolean rows,

@Override
public boolean isEmptyBlock(boolean safe) {
final long nonZeros = getNonZeros();
return _colGroups == null || nonZeros == 0 || (nonZeros == -1 && recomputeNonZeros() == 0);
if(nonZeros > 1)
return false;
else if(_colGroups == null || nonZeros == 0)
return true;
else{
if(nonZeros == -1){
// try to use column groups
for(AColGroup g : _colGroups)
if(!g.isEmpty())
return false;
// Otherwise recompute non zeros.
recomputeNonZeros();
}

return getNonZeros() == 0;
}
}

@Override
Expand Down Expand Up @@ -1045,6 +1093,7 @@ public void copy(int rl, int ru, int cl, int cu, MatrixBlock src, boolean awareD
}

private void copyCompressedMatrix(CompressedMatrixBlock that) {
cachedMemorySize = -1;
this.rlen = that.getNumRows();
this.clen = that.getNumColumns();
this.sparseBlock = null;
Expand All @@ -1059,7 +1108,7 @@ private void copyCompressedMatrix(CompressedMatrixBlock that) {
}

public SoftReference<MatrixBlock> getSoftReferenceToDecompressed() {
return decompressedVersion;
return allowCachingUncompressed ? decompressedVersion : null;
}

public void clearSoftReferenceToDecompressed() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.List;
import java.util.Set;

import org.apache.sysds.runtime.compress.colgroup.dictionary.AIdentityDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
Expand Down Expand Up @@ -63,8 +64,8 @@ public IDictionary getDictionary() {

@Override
public final void decompressToDenseBlockTransposed(DenseBlock db, int rl, int ru) {
if(_dict instanceof IdentityDictionary) {
final MatrixBlockDictionary md = ((IdentityDictionary) _dict).getMBDict();
if(_dict instanceof AIdentityDictionary) {
final MatrixBlockDictionary md = ((AIdentityDictionary) _dict).getMBDict();
final MatrixBlock mb = md.getMatrixBlock();
// The dictionary is never empty.
if(mb.isInSparseFormat())
Expand All @@ -87,8 +88,8 @@ else if(_dict instanceof MatrixBlockDictionary) {

@Override
public void decompressToSparseBlockTransposed(SparseBlockMCSR sb, int nColOut) {
if(_dict instanceof IdentityDictionary) {
final MatrixBlockDictionary md = ((IdentityDictionary) _dict).getMBDict();
if(_dict instanceof AIdentityDictionary) {
final MatrixBlockDictionary md = ((AIdentityDictionary) _dict).getMBDict();
final MatrixBlock mb = md.getMatrixBlock();
// The dictionary is never empty.
if(mb.isInSparseFormat())
Expand Down Expand Up @@ -123,8 +124,8 @@ protected abstract void decompressToSparseBlockTransposedDenseDictionary(SparseB

@Override
public final void decompressToDenseBlock(DenseBlock db, int rl, int ru, int offR, int offC) {
if(_dict instanceof IdentityDictionary) {
final MatrixBlockDictionary md = ((IdentityDictionary) _dict).getMBDict();
if(_dict instanceof AIdentityDictionary) {
final MatrixBlockDictionary md = ((AIdentityDictionary) _dict).getMBDict();
final MatrixBlock mb = md.getMatrixBlock();
// The dictionary is never empty.
if(mb.isInSparseFormat())
Expand All @@ -147,9 +148,8 @@ else if(_dict instanceof MatrixBlockDictionary) {

@Override
public final void decompressToSparseBlock(SparseBlock sb, int rl, int ru, int offR, int offC) {
if(_dict instanceof IdentityDictionary) {

final MatrixBlockDictionary md = ((IdentityDictionary) _dict).getMBDict();
if(_dict instanceof AIdentityDictionary) {
final MatrixBlockDictionary md = ((AIdentityDictionary) _dict).getMBDict();
final MatrixBlock mb = md.getMatrixBlock();
// The dictionary is never empty.
if(mb.isInSparseFormat())
Expand Down Expand Up @@ -249,8 +249,8 @@ public final AColGroup rightMultByMatrix(MatrixBlock right, IColIndex allCols, i
return null;

// is candidate for identity mm.
if(_dict instanceof IdentityDictionary //
&& !((IdentityDictionary) _dict).withEmpty()
if(_dict instanceof AIdentityDictionary //
&& !((AIdentityDictionary) _dict).withEmpty()
&& right.getNumRows() == _colIndexes.size() //
&& allowShallowIdentityRightMult()){

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P;
import org.apache.sysds.runtime.compress.colgroup.dictionary.AIdentityDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.PlaceHolderDict;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
Expand Down Expand Up @@ -327,8 +327,8 @@ public AColGroup binaryRowOpRight(BinaryOperator op, double[] v, boolean isRowSa
* @param constV The output columns.
*/
public final void addToCommon(double[] constV) {
if(_dict instanceof IdentityDictionary) {
MatrixBlock mb = ((IdentityDictionary) _dict).getMBDict().getMatrixBlock();
if(_dict instanceof AIdentityDictionary) {
MatrixBlock mb = ((AIdentityDictionary) _dict).getMBDict().getMatrixBlock();
if(mb.isInSparseFormat())
addToCommonSparse(constV, mb.getSparseBlock());
else
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
* O
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.compress.colgroup.dictionary;

import java.lang.ref.SoftReference;

public abstract class ACachingMBDictionary extends ADictionary {

/** A Cache to contain a materialized version of the identity matrix. */
protected volatile SoftReference<MatrixBlockDictionary> cache = null;

@Override
public final MatrixBlockDictionary getMBDict(int nCol) {
if(cache != null) {
MatrixBlockDictionary r = cache.get();
if(r != null)
return r;
}
MatrixBlockDictionary ret = createMBDict(nCol);
cache = new SoftReference<>(ret);
return ret;
}

public abstract MatrixBlockDictionary createMBDict(int nCol);
}
Loading

0 comments on commit b96cf25

Please sign in to comment.