Skip to content

Commit

Permalink
reExpand fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Baunsgaard committed Feb 5, 2025
1 parent 5602e60 commit a3ec92a
Show file tree
Hide file tree
Showing 17 changed files with 148 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ public int getNumValues() {
}

/**
* Returns the counts of values inside the dictionary. If already calculated it will return the previous counts.
* This produce an overhead in cases where the count is calculated, but the overhead will be limited to number of
* distinct tuples in the dictionary.
* Returns the counts of values inside the dictionary. If already calculated it will return the previous counts. This
* produce an overhead in cases where the count is calculated, but the overhead will be limited to number of distinct
* tuples in the dictionary.
*
* The returned counts always contains the number of zero tuples as well if there are some contained, even if they
* are not materialized.
Expand Down Expand Up @@ -197,12 +197,19 @@ public CM_COV_Object centralMoment(CMOperator op, int nRows) {
public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) {
try {
IDictionary d = _dict.rexpandCols(max, ignore, cast, _colIndexes.size());
if(d == null)
if(d == null) {
if(max <= 0)
return null;
return ColGroupEmpty.create(max);
else
return copyAndSet(ColIndexFactory.create(max), d);
}
else {
IColIndex outCols = ColIndexFactory.create(d.getNumberOfColumns(_dict.getNumberOfValues(1)));
return copyAndSet(outCols, d);
}
}
catch(DMLCompressionException e) {
if(max <= 0)
return null;
return ColGroupEmpty.create(max);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -527,8 +527,11 @@ public CM_COV_Object centralMoment(CMOperator op, int nRows) {
@Override
public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) {
IDictionary d = _dict.rexpandCols(max, ignore, cast, _colIndexes.size());
if(d == null)
if(d == null){
if(max <= 0)
return null;
return ColGroupEmpty.create(max);
}
else
return create(max, d);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import java.util.List;

import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
Expand Down Expand Up @@ -392,33 +391,15 @@ public AColGroup extractCommon(double[] constV) {
public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) {
final int def = (int) _reference[0];
IDictionary d = _dict.rexpandColsWithReference(max, ignore, cast, def);

if(d == null) {
if(def <= 0 || def > max)
return ColGroupEmpty.create(max);
else {
double[] retDef = new double[max];
retDef[def - 1] = 1;
return ColGroupConst.create(retDef);
}
if(max <= 0)
return null;
return ColGroupEmpty.create(max);
}
else {
IColIndex outCols = ColIndexFactory.create(max);
if(def <= 0) {
if(ignore)
return ColGroupDDC.create(outCols, d, _data, getCachedCounts());
else
throw new DMLRuntimeException("Invalid content of zero in rexpand");
}
else if(def > max)
return ColGroupDDC.create(outCols, d, _data, getCachedCounts());
else {
// double[] retDef = new double[max];
// retDef[def - 1] = 1;
return ColGroupDDC.create(outCols, d, _data, getCachedCounts());
}
IColIndex outCols = ColIndexFactory.create(d.getNumberOfColumns(_dict.getNumberOfValues(1)));
return ColGroupDDC.create(outCols, d, _data, getCachedCounts());
}

}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,15 +500,24 @@ public CM_COV_Object centralMoment(CMOperator op, int nRows) {
@Override
public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) {
IDictionary d = _dict.rexpandCols(max, ignore, cast, _colIndexes.size());
return rexpandCols(max, ignore, cast, nRows, d, _indexes, _data, getCachedCounts(), (int) _defaultTuple[0]);
return rexpandCols(max, ignore, cast, nRows, d, _indexes, _data, getCachedCounts(), (int) _defaultTuple[0],
_dict.getNumberOfValues(1));
}

protected static AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows, IDictionary d,
AOffset indexes, AMapToData data, int[] counts, int def) {
AOffset indexes, AMapToData data, int[] counts, int def, int nVal) {

if(d == null) {
if(def <= 0 || def > max)
if(def <= 0){
if(max > 0)
return ColGroupEmpty.create(max);
else
return null;
}
else if(def > max && max > 0)
return ColGroupEmpty.create(max);
else if(max <= 0)
return null;
else {
double[] retDef = new double[max];
retDef[def - 1] = 1;
Expand All @@ -517,7 +526,7 @@ protected static AColGroup rexpandCols(int max, boolean ignore, boolean cast, in
}
}
else {
final IColIndex outCols = ColIndexFactory.create(max);
final IColIndex outCols = ColIndexFactory.create(d.getNumberOfColumns(nVal));
if(def <= 0) {
if(ignore)
return ColGroupSDCZeros.create(outCols, nRows, d, indexes, data, counts);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ public AColGroup extractCommon(double[] constV) {
public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) {
IDictionary d = _dict.rexpandColsWithReference(max, ignore, cast, (int) _reference[0]);
return ColGroupSDC.rexpandCols(max, ignore, cast, nRows, d, _indexes, _data, getCachedCounts(),
(int) _reference[0]);
(int) _reference[0], _dict.getNumberOfValues(1));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public static AColGroup create(IColIndex colIndexes, int numRows, IDictionary di
if(offsets instanceof OffsetEmpty)
return ColGroupConst.create(colIndexes, defaultTuple);
final boolean allZero = ColGroupUtils.allZero(defaultTuple);
if(dict == null && allZero)
if(dict == null && allZero)
return new ColGroupEmpty(colIndexes);
else if(dict == null && offsets.getSize() * 2 > numRows + 2) {
AOffset rev = offsets.reverse(numRows);
Expand Down Expand Up @@ -469,27 +469,36 @@ public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) {
IDictionary d = _dict.rexpandCols(max, ignore, cast, _colIndexes.size());
final int def = (int) _defaultTuple[0];
if(d == null) {
if(def <= 0 || def > max)
if(def <= 0){
if(max > 0)
return ColGroupEmpty.create(max);
else
return null;
}
else if(def > max && max > 0)
return ColGroupEmpty.create(max);
else if(max <= 0)
return null;
else {
double[] retDef = new double[max];
retDef[((int) _defaultTuple[0]) - 1] = 1;
return ColGroupSDCSingle.create(ColIndexFactory.create(max), nRows, null, retDef, _indexes, null);
}
}
else {
final IColIndex outCols = ColIndexFactory.create(d.getNumberOfColumns(_dict.getNumberOfValues(1)));
if(def <= 0) {
if(ignore)
return ColGroupSDCSingleZeros.create(ColIndexFactory.create(max), nRows, d, _indexes, getCachedCounts());
return ColGroupSDCSingleZeros.create(outCols, nRows, d, _indexes, getCachedCounts());
else
throw new DMLRuntimeException("Invalid content of zero in rexpand");
}
else if(def > max)
return ColGroupSDCSingleZeros.create(ColIndexFactory.create(max), nRows, d, _indexes, getCachedCounts());
return ColGroupSDCSingleZeros.create(outCols, nRows, d, _indexes, getCachedCounts());
else {
double[] retDef = new double[max];
retDef[((int) _defaultTuple[0]) - 1] = 1;
return ColGroupSDCSingle.create(ColIndexFactory.create(max), nRows, d, retDef, _indexes, getCachedCounts());
return ColGroupSDCSingle.create(outCols, nRows, d, retDef, _indexes, getCachedCounts());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ public int getNumberOfValues(int ncol) {
return _values.length / ncol;
}

@Override
public int getNumberOfColumns(int nrow){
return _values.length / nrow;
}

@Override
public String getString(int colIndexes) {
throw new NotImplementedException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator;
import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
import org.apache.sysds.utils.MemoryEstimates;
Expand Down Expand Up @@ -388,6 +388,11 @@ public int getNumberOfValues(int nCol) {
return _values.length / nCol;
}

@Override
public int getNumberOfColumns(int nrow) {
return _values.length / nrow;
}

@Override
public double[] sumAllRowsToDouble(int nrColumns) {
if(nrColumns == 1)
Expand Down Expand Up @@ -1120,8 +1125,11 @@ public IDictionary rexpandColsWithReference(int max, boolean ignore, boolean cas
MatrixBlockDictionary m = getMBDict(1);
if(m == null)
return null;
IDictionary a = m.applyScalarOp(new LeftScalarOperator(Plus.getPlusFnObject(), reference));
return a == null ? null : a.rexpandCols(max, ignore, cast, 1);
IDictionary a = m.applyScalarOp(new RightScalarOperator(Plus.getPlusFnObject(), reference));
if(a == null)
return null; // second ending
a = a.rexpandCols(max, ignore, cast, 1);
return a;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,14 @@ public IDictionary binOpRightWithReference(BinaryOperator op, double[] v, IColIn
*/
public int getNumberOfValues(int ncol);

/**
* Get the number of columns in this dictionary, provided you know the number of values, or rows.
*
* @param nrow The number of rows/values known inside this dictionary
* @return The number of columns
*/
public int getNumberOfColumns(int nrow);

/**
* Method used as a pre-aggregate of each tuple in the dictionary, to single double values.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,13 @@ public int getNumberOfValues(int ncol) {
return nRowCol + (withEmpty ? 1 : 0);
}

@Override
public int getNumberOfColumns(int nrow) {
if(nrow != (nRowCol + (withEmpty ? 1 : 0)))
throw new DMLCompressionException("Invalid call to get Number of values assuming wrong number of columns");
return nRowCol;
}

@Override
public double[] sumAllRowsToDouble(int nrColumns) {
if(withEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.Arrays;

import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
Expand Down Expand Up @@ -237,6 +238,14 @@ public int getNumberOfValues(int ncol) {
return nRowCol + (withEmpty ? 1 : 0);
}

@Override
public int getNumberOfColumns(int nrow) {
if(nrow != (nRowCol + (withEmpty ? 1 : 0)))
throw new DMLCompressionException("Invalid call to get Number of values assuming wrong number of columns");
return u - l;
}


@Override
public void write(DataOutput out) throws IOException {
out.writeByte(DictionaryFactory.Type.IDENTITY_SLICE.ordinal());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,14 @@ public int getNumberOfValues(int ncol) {
return _data.getNumRows();
}

@Override
public int getNumberOfColumns(int nrow) {
if(nrow != _data.getNumRows())
throw new DMLCompressionException("Invalid call to get number of columns assuming wrong number of rows");
return _data.getNumColumns();
}


@Override
public double[] sumAllRowsToDouble(int nrColumns) {
double[] ret = new double[_data.getNumRows()];
Expand Down Expand Up @@ -2397,6 +2405,8 @@ public IDictionary rexpandCols(int max, boolean ignore, boolean cast, int nCol)
if(nCol > 1)
throw new DMLCompressionException("Invalid to rexpand the column groups if more than one column");
MatrixBlock ret = LibMatrixReorg.rexpand(_data, new MatrixBlock(), max, false, cast, ignore, 1);
if(ret.getNumColumns() == 0)
return null;
return MatrixBlockDictionary.create(ret);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ public int getNumberOfValues(int nCol) {
return nVal;
}

@Override
public int getNumberOfColumns(int nrow) {
throw new RuntimeException("invalid to get number of columns for PlaceHolderDict");
}

@Override
public MatrixBlockDictionary getMBDict() {
throw new RuntimeException(errMessage);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ public int getNumberOfValues(int nCol) {
return _values.length / nCol;
}

@Override
public int getNumberOfColumns(int nCol) {
return _values.length / nCol;
}

@Override
public double[] sumAllRowsToDouble(int nrColumns) {
if(nrColumns == 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,15 @@ else if(in.isOverlapping() || in.getColGroups().size() > 1)
cast, ignore, k);
else {
CompressedMatrixBlock retC = new CompressedMatrixBlock(nRows, max);
retC.allocateColGroup(in.getColGroups().get(0).rexpandCols(max, ignore, cast, nRows));
retC.recomputeNonZeros();

LOG.error(retC);
return retC;
AColGroup g = in.getColGroups().get(0).rexpandCols(max, ignore, cast, nRows);
if(g == null)
return new MatrixBlock(nRows,0,0);
else {
retC.setNumColumns(g.getNumCols());
retC.allocateColGroup(g);
retC.recomputeNonZeros();
return retC;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5152,7 +5152,6 @@ public MatrixBlock rexpandOperations( MatrixBlock ret, double max, boolean rows,
return LibMatrixReorg.rexpand(this, result, max, rows, cast, ignore, k);
}


@Override
public final MatrixBlock replaceOperations(MatrixValue result, double pattern, double replacement) {
return replaceOperations(result, pattern, replacement, 1);
Expand Down
Loading

0 comments on commit a3ec92a

Please sign in to comment.