diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java index f3b37daa109..5e6bdb34b29 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java @@ -55,9 +55,9 @@ public int getNumValues() { } /** - * Returns the counts of values inside the dictionary. If already calculated it will return the previous counts. - * This produce an overhead in cases where the count is calculated, but the overhead will be limited to number of - * distinct tuples in the dictionary. + * Returns the counts of values inside the dictionary. If already calculated it will return the previous counts. This + * produce an overhead in cases where the count is calculated, but the overhead will be limited to number of distinct + * tuples in the dictionary. * * The returned counts always contains the number of zero tuples as well if there are some contained, even if they * are not materialized. @@ -197,12 +197,19 @@ public CM_COV_Object centralMoment(CMOperator op, int nRows) { public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) { try { IDictionary d = _dict.rexpandCols(max, ignore, cast, _colIndexes.size()); - if(d == null) + if(d == null) { + if(max <= 0) + return null; return ColGroupEmpty.create(max); - else - return copyAndSet(ColIndexFactory.create(max), d); + } + else { + IColIndex outCols = ColIndexFactory.create(d.getNumberOfColumns(_dict.getNumberOfValues(1))); + return copyAndSet(outCols, d); + } } catch(DMLCompressionException e) { + if(max <= 0) + return null; return ColGroupEmpty.create(max); } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java index a493b14f04b..21c6a0e1d80 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java @@ -527,8 +527,11 @@ public CM_COV_Object centralMoment(CMOperator op, int nRows) { @Override public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) { IDictionary d = _dict.rexpandCols(max, ignore, cast, _colIndexes.size()); - if(d == null) + if(d == null){ + if(max <= 0) + return null; return ColGroupEmpty.create(max); + } else return create(max, d); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java index a6b200dd99c..70191a27936 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java @@ -26,7 +26,6 @@ import java.util.List; import org.apache.commons.lang3.NotImplementedException; -import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; @@ -392,33 +391,15 @@ public AColGroup extractCommon(double[] constV) { public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) { final int def = (int) _reference[0]; IDictionary d = _dict.rexpandColsWithReference(max, ignore, cast, def); - if(d == null) { - if(def <= 0 || def > max) - return ColGroupEmpty.create(max); - else { - double[] retDef = new double[max]; - retDef[def - 1] = 1; - return ColGroupConst.create(retDef); - } + if(max <= 0) + return null; + return ColGroupEmpty.create(max); } else { - IColIndex outCols = ColIndexFactory.create(max); - if(def <= 0) { - if(ignore) - return ColGroupDDC.create(outCols, d, _data, getCachedCounts()); - else - throw new DMLRuntimeException("Invalid content of zero in rexpand"); - } - else if(def > max) - return ColGroupDDC.create(outCols, d, _data, getCachedCounts()); - else { - // double[] retDef = new double[max]; - // retDef[def - 1] = 1; - return ColGroupDDC.create(outCols, d, _data, getCachedCounts()); - } + IColIndex outCols = ColIndexFactory.create(d.getNumberOfColumns(_dict.getNumberOfValues(1))); + return ColGroupDDC.create(outCols, d, _data, getCachedCounts()); } - } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java index 541c2487d55..1270823bfdc 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java @@ -500,15 +500,24 @@ public CM_COV_Object centralMoment(CMOperator op, int nRows) { @Override public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) { IDictionary d = _dict.rexpandCols(max, ignore, cast, _colIndexes.size()); - return rexpandCols(max, ignore, cast, nRows, d, _indexes, _data, getCachedCounts(), (int) _defaultTuple[0]); + return rexpandCols(max, ignore, cast, nRows, d, _indexes, _data, getCachedCounts(), (int) _defaultTuple[0], + _dict.getNumberOfValues(1)); } protected static AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows, IDictionary d, - AOffset indexes, AMapToData data, int[] counts, int def) { + AOffset indexes, AMapToData data, int[] counts, int def, int nVal) { if(d == null) { - if(def <= 0 || def > max) + if(def <= 0){ + if(max > 0) + return ColGroupEmpty.create(max); + else + return null; + } + else if(def > max && max > 0) return ColGroupEmpty.create(max); + else if(max <= 0) + return null; else { double[] retDef = new double[max]; retDef[def - 1] = 1; @@ -517,7 +526,7 @@ protected static AColGroup rexpandCols(int max, boolean ignore, boolean cast, in } } else { - final IColIndex outCols = ColIndexFactory.create(max); + final IColIndex outCols = ColIndexFactory.create(d.getNumberOfColumns(nVal)); if(def <= 0) { if(ignore) return ColGroupSDCZeros.create(outCols, nRows, d, indexes, data, counts); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java index 4c4b2e20a50..41fb7ac5709 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java @@ -427,7 +427,7 @@ public AColGroup extractCommon(double[] constV) { public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) { IDictionary d = _dict.rexpandColsWithReference(max, ignore, cast, (int) _reference[0]); return ColGroupSDC.rexpandCols(max, ignore, cast, nRows, d, _indexes, _data, getCachedCounts(), - (int) _reference[0]); + (int) _reference[0], _dict.getNumberOfValues(1)); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java index f63df96fa73..fa5772c0c3e 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java @@ -85,7 +85,7 @@ public static AColGroup create(IColIndex colIndexes, int numRows, IDictionary di if(offsets instanceof OffsetEmpty) return ColGroupConst.create(colIndexes, defaultTuple); final boolean allZero = ColGroupUtils.allZero(defaultTuple); - if(dict == null && allZero) + if(dict == null && allZero) return new ColGroupEmpty(colIndexes); else if(dict == null && offsets.getSize() * 2 > numRows + 2) { AOffset rev = offsets.reverse(numRows); @@ -469,8 +469,16 @@ public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) { IDictionary d = _dict.rexpandCols(max, ignore, cast, _colIndexes.size()); final int def = (int) _defaultTuple[0]; if(d == null) { - if(def <= 0 || def > max) + if(def <= 0){ + if(max > 0) + return ColGroupEmpty.create(max); + else + return null; + } + else if(def > max && max > 0) return ColGroupEmpty.create(max); + else if(max <= 0) + return null; else { double[] retDef = new double[max]; retDef[((int) _defaultTuple[0]) - 1] = 1; @@ -478,18 +486,19 @@ public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) { } } else { + final IColIndex outCols = ColIndexFactory.create(d.getNumberOfColumns(_dict.getNumberOfValues(1))); if(def <= 0) { if(ignore) - return ColGroupSDCSingleZeros.create(ColIndexFactory.create(max), nRows, d, _indexes, getCachedCounts()); + return ColGroupSDCSingleZeros.create(outCols, nRows, d, _indexes, getCachedCounts()); else throw new DMLRuntimeException("Invalid content of zero in rexpand"); } else if(def > max) - return ColGroupSDCSingleZeros.create(ColIndexFactory.create(max), nRows, d, _indexes, getCachedCounts()); + return ColGroupSDCSingleZeros.create(outCols, nRows, d, _indexes, getCachedCounts()); else { double[] retDef = new double[max]; retDef[((int) _defaultTuple[0]) - 1] = 1; - return ColGroupSDCSingle.create(ColIndexFactory.create(max), nRows, d, retDef, _indexes, getCachedCounts()); + return ColGroupSDCSingle.create(outCols, nRows, d, retDef, _indexes, getCachedCounts()); } } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java index 5bbc1af5942..d67ab95f824 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java @@ -97,6 +97,11 @@ public int getNumberOfValues(int ncol) { return _values.length / ncol; } + @Override + public int getNumberOfColumns(int nrow){ + return _values.length / nrow; + } + @Override public String getString(int colIndexes) { throw new NotImplementedException(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java index 139254b5341..939b48bf424 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java @@ -41,7 +41,7 @@ import org.apache.sysds.runtime.matrix.data.LibMatrixMult; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; -import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator; +import org.apache.sysds.runtime.matrix.operators.RightScalarOperator; import org.apache.sysds.runtime.matrix.operators.ScalarOperator; import org.apache.sysds.runtime.matrix.operators.UnaryOperator; import org.apache.sysds.utils.MemoryEstimates; @@ -388,6 +388,11 @@ public int getNumberOfValues(int nCol) { return _values.length / nCol; } + @Override + public int getNumberOfColumns(int nrow) { + return _values.length / nrow; + } + @Override public double[] sumAllRowsToDouble(int nrColumns) { if(nrColumns == 1) @@ -1120,8 +1125,11 @@ public IDictionary rexpandColsWithReference(int max, boolean ignore, boolean cas MatrixBlockDictionary m = getMBDict(1); if(m == null) return null; - IDictionary a = m.applyScalarOp(new LeftScalarOperator(Plus.getPlusFnObject(), reference)); - return a == null ? null : a.rexpandCols(max, ignore, cast, 1); + IDictionary a = m.applyScalarOp(new RightScalarOperator(Plus.getPlusFnObject(), reference)); + if(a == null) + return null; // second ending + a = a.rexpandCols(max, ignore, cast, 1); + return a; } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java index 54b7cc809da..dddea0eec7a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java @@ -327,6 +327,14 @@ public IDictionary binOpRightWithReference(BinaryOperator op, double[] v, IColIn */ public int getNumberOfValues(int ncol); + /** + * Get the number of columns in this dictionary, provided you know the number of values, or rows. + * + * @param nrow The number of rows/values known inside this dictionary + * @return The number of columns + */ + public int getNumberOfColumns(int nrow); + /** * Method used as a pre-aggregate of each tuple in the dictionary, to single double values. * diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java index 41982a6842f..40e1b065653 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java @@ -194,6 +194,13 @@ public int getNumberOfValues(int ncol) { return nRowCol + (withEmpty ? 1 : 0); } + @Override + public int getNumberOfColumns(int nrow) { + if(nrow != (nRowCol + (withEmpty ? 1 : 0))) + throw new DMLCompressionException("Invalid call to get Number of values assuming wrong number of columns"); + return nRowCol; + } + @Override public double[] sumAllRowsToDouble(int nrColumns) { if(withEmpty) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java index 0f07e1eac74..87c9c91826a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java @@ -25,6 +25,7 @@ import java.util.Arrays; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -237,6 +238,14 @@ public int getNumberOfValues(int ncol) { return nRowCol + (withEmpty ? 1 : 0); } + @Override + public int getNumberOfColumns(int nrow) { + if(nrow != (nRowCol + (withEmpty ? 1 : 0))) + throw new DMLCompressionException("Invalid call to get Number of values assuming wrong number of columns"); + return u - l; + } + + @Override public void write(DataOutput out) throws IOException { out.writeByte(DictionaryFactory.Type.IDENTITY_SLICE.ordinal()); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java index 57f3a80e03a..ce52857138d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java @@ -936,6 +936,14 @@ public int getNumberOfValues(int ncol) { return _data.getNumRows(); } + @Override + public int getNumberOfColumns(int nrow) { + if(nrow != _data.getNumRows()) + throw new DMLCompressionException("Invalid call to get number of columns assuming wrong number of rows"); + return _data.getNumColumns(); + } + + @Override public double[] sumAllRowsToDouble(int nrColumns) { double[] ret = new double[_data.getNumRows()]; @@ -2397,6 +2405,8 @@ public IDictionary rexpandCols(int max, boolean ignore, boolean cast, int nCol) if(nCol > 1) throw new DMLCompressionException("Invalid to rexpand the column groups if more than one column"); MatrixBlock ret = LibMatrixReorg.rexpand(_data, new MatrixBlock(), max, false, cast, ignore, 1); + if(ret.getNumColumns() == 0) + return null; return MatrixBlockDictionary.create(ret); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java index f5c140e5227..f5746647a37 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java @@ -53,6 +53,11 @@ public int getNumberOfValues(int nCol) { return nVal; } + @Override + public int getNumberOfColumns(int nrow) { + throw new RuntimeException("invalid to get number of columns for PlaceHolderDict"); + } + @Override public MatrixBlockDictionary getMBDict() { throw new RuntimeException(errMessage); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java index 35a08b8d14b..6802d920b49 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java @@ -142,6 +142,11 @@ public int getNumberOfValues(int nCol) { return _values.length / nCol; } + @Override + public int getNumberOfColumns(int nCol) { + return _values.length / nCol; + } + @Override public double[] sumAllRowsToDouble(int nrColumns) { if(nrColumns == 1) diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRexpand.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRexpand.java index 43ab7cd2019..1bf43c49e5a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRexpand.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRexpand.java @@ -111,11 +111,15 @@ else if(in.isOverlapping() || in.getColGroups().size() > 1) cast, ignore, k); else { CompressedMatrixBlock retC = new CompressedMatrixBlock(nRows, max); - retC.allocateColGroup(in.getColGroups().get(0).rexpandCols(max, ignore, cast, nRows)); - retC.recomputeNonZeros(); - - LOG.error(retC); - return retC; + AColGroup g = in.getColGroups().get(0).rexpandCols(max, ignore, cast, nRows); + if(g == null) + return new MatrixBlock(nRows,0,0); + else { + retC.setNumColumns(g.getNumCols()); + retC.allocateColGroup(g); + retC.recomputeNonZeros(); + return retC; + } } } diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java index af068d25233..be5d7ba5dda 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java @@ -5152,7 +5152,6 @@ public MatrixBlock rexpandOperations( MatrixBlock ret, double max, boolean rows, return LibMatrixReorg.rexpand(this, result, max, rows, cast, ignore, k); } - @Override public final MatrixBlock replaceOperations(MatrixValue result, double pattern, double replacement) { return replaceOperations(result, pattern, replacement, 1); diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedVectorTest.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedVectorTest.java index f30f3401c79..5d91ddb2739 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/CompressedVectorTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedVectorTest.java @@ -158,14 +158,41 @@ public void testReExpandCol() { testReExpand(true); } + @Test + public void testReExpandColNoIgnore() { + testReExpand(true, 0, false, true); + } + + @Test + public void testReExpandColNoCast() { + testReExpand(true, 0, false, false); + } + public void testReExpand(boolean col) { + testReExpand(col, 50, true, true); + } + + public void testReExpand(boolean col, int max, boolean ignore, boolean cast) { try { if(cmb instanceof CompressedMatrixBlock) { - MatrixBlock ret1 = cmb.rexpandOperations(new MatrixBlock(), 50, !col, true, true, _k); - MatrixBlock ret2 = mb.rexpandOperations(new MatrixBlock(), 50, !col, true, true, _k); + MatrixBlock ret1 = null; + try{ + ret1 = cmb.rexpandOperations(new MatrixBlock(), max, !col, cast, ignore, _k); + } + catch(RuntimeException re){ + if(! re.getMessage().contains("Invalid input value <= 0 for ignore=false:")) + throw re; + else + return; // great! + } + MatrixBlock ret2 = mb.rexpandOperations(new MatrixBlock(), max, !col, cast, ignore, _k); compareResultMatrices(ret2, ret1, 0); } } + catch(AssertionError e){ + LOG.error(cmb); + throw e; + } catch(Exception e) { e.printStackTrace(); throw e;