diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java index 001e11dcd4b..d16e738a313 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java @@ -26,8 +26,10 @@ import java.io.ObjectOutput; import java.lang.ref.SoftReference; import java.util.ArrayList; +import java.util.HashSet; import java.util.Iterator; import java.util.List; +import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; @@ -42,9 +44,11 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType; +import org.apache.sysds.runtime.compress.colgroup.ADictBasedColGroup; import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; import org.apache.sysds.runtime.compress.colgroup.ColGroupIO; import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.lib.CLALibAppend; import org.apache.sysds.runtime.compress.lib.CLALibBinaryCellOp; import org.apache.sysds.runtime.compress.lib.CLALibCMOps; @@ -99,14 +103,13 @@ public class CompressedMatrixBlock extends MatrixBlock { private static final Log LOG = LogFactory.getLog(CompressedMatrixBlock.class.getName()); private static final long serialVersionUID = 73193720143154058L; - /** - * Debugging flag for Compressed Matrices - */ + /** Debugging flag for Compressed Matrices */ public static boolean debug = false; - /** - * Column groups - */ + /** Disallow caching of uncompressed Block */ + public static boolean allowCachingUncompressed = true; + + /** Column groups */ protected transient List _colGroups; /** @@ -119,6 +122,9 @@ public class CompressedMatrixBlock extends MatrixBlock { */ protected transient SoftReference decompressedVersion; + /** Cached Memory size */ + protected transient long cachedMemorySize = -1; + public CompressedMatrixBlock() { super(true); sparse = false; @@ -169,7 +175,9 @@ protected CompressedMatrixBlock(MatrixBlock uncompressedMatrixBlock) { clen = uncompressedMatrixBlock.getNumColumns(); sparse = false; nonZeros = uncompressedMatrixBlock.getNonZeros(); - decompressedVersion = new SoftReference<>(uncompressedMatrixBlock); + if(!(uncompressedMatrixBlock instanceof CompressedMatrixBlock)) { + decompressedVersion = new SoftReference<>(uncompressedMatrixBlock); + } } /** @@ -189,6 +197,7 @@ public CompressedMatrixBlock(int rl, int cl, long nnz, boolean overlapping, List this.nonZeros = nnz; this.overlappingColGroups = overlapping; this._colGroups = groups; + getInMemorySize(); // cache memory size } @Override @@ -204,6 +213,7 @@ public void reset(int rl, int cl, boolean sp, long estnnz, double val) { * @param cg The column group to use after. */ public void allocateColGroup(AColGroup cg) { + cachedMemorySize = -1; _colGroups = new ArrayList<>(1); _colGroups.add(cg); } @@ -270,6 +280,12 @@ public synchronized MatrixBlock decompress(int k) { ret = CLALibDecompress.decompress(this, k); + if(ret.getNonZeros() <= 0) { + LOG.warn("Decompress incorrectly set nnz to 0 or -1"); + ret.recomputeNonZeros(k); + } + ret.examSparsity(k); + // Set soft reference to the decompressed version decompressedVersion = new SoftReference<>(ret); @@ -290,7 +306,7 @@ public void putInto(MatrixBlock target, int rowOffset, int colOffset, boolean sp * @return The cached decompressed matrix, if it does not exist return null */ public MatrixBlock getCachedDecompressed() { - if(decompressedVersion != null) { + if( allowCachingUncompressed && decompressedVersion != null) { final MatrixBlock mb = decompressedVersion.get(); if(mb != null) { DMLCompressionStatistics.addDecompressCacheCount(); @@ -302,6 +318,7 @@ public MatrixBlock getCachedDecompressed() { } public CompressedMatrixBlock squash(int k) { + cachedMemorySize = -1; return CLALibSquash.squash(this, k); } @@ -377,12 +394,27 @@ public long estimateSizeInMemory() { * @return an upper bound on the memory used to store this compressed block considering class overhead. */ public long estimateCompressedSizeInMemory() { - long total = baseSizeInMemory(); - for(AColGroup grp : _colGroups) - total += grp.estimateInMemorySize(); + if(cachedMemorySize <= -1L) { + + long total = baseSizeInMemory(); + // take into consideration duplicate dictionaries + Set dicts = new HashSet<>(); + for(AColGroup grp : _colGroups){ + if(grp instanceof ADictBasedColGroup){ + IDictionary dg = ((ADictBasedColGroup) grp).getDictionary(); + if(dicts.contains(dg)) + total -= dg.getInMemorySize(); + dicts.add(dg); + } + total += grp.estimateInMemorySize(); + } + cachedMemorySize = total; + return total; - return total; + } + else + return cachedMemorySize; } public static long baseSizeInMemory() { @@ -392,6 +424,7 @@ public static long baseSizeInMemory() { total += 8; // Col Group Ref total += 8; // v reference total += 8; // soft reference to decompressed version + total += 8; // long cached memory size total += 1 + 7; // Booleans plus padding total += 40; // Col Group Array List @@ -431,6 +464,7 @@ public long estimateSizeOnDisk() { @Override public void readFields(DataInput in) throws IOException { + cachedMemorySize = -1; // deserialize compressed block rlen = in.readInt(); clen = in.readInt(); @@ -736,8 +770,22 @@ public MatrixBlock rexpandOperations(MatrixBlock ret, double max, boolean rows, @Override public boolean isEmptyBlock(boolean safe) { - final long nonZeros = getNonZeros(); - return _colGroups == null || nonZeros == 0 || (nonZeros == -1 && recomputeNonZeros() == 0); + if(nonZeros > 1) + return false; + else if(_colGroups == null || nonZeros == 0) + return true; + else{ + if(nonZeros == -1){ + // try to use column groups + for(AColGroup g : _colGroups) + if(!g.isEmpty()) + return false; + // Otherwise recompute non zeros. + recomputeNonZeros(); + } + + return getNonZeros() == 0; + } } @Override @@ -1045,6 +1093,7 @@ public void copy(int rl, int ru, int cl, int cu, MatrixBlock src, boolean awareD } private void copyCompressedMatrix(CompressedMatrixBlock that) { + cachedMemorySize = -1; this.rlen = that.getNumRows(); this.clen = that.getNumColumns(); this.sparseBlock = null; @@ -1059,7 +1108,7 @@ private void copyCompressedMatrix(CompressedMatrixBlock that) { } public SoftReference getSoftReferenceToDecompressed() { - return decompressedVersion; + return allowCachingUncompressed ? decompressedVersion : null; } public void clearSoftReferenceToDecompressed() { diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java index 24a52ccecbf..21bddd102ef 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java @@ -26,6 +26,7 @@ import java.util.List; import java.util.Set; +import org.apache.sysds.runtime.compress.colgroup.dictionary.AIdentityDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; @@ -63,8 +64,8 @@ public IDictionary getDictionary() { @Override public final void decompressToDenseBlockTransposed(DenseBlock db, int rl, int ru) { - if(_dict instanceof IdentityDictionary) { - final MatrixBlockDictionary md = ((IdentityDictionary) _dict).getMBDict(); + if(_dict instanceof AIdentityDictionary) { + final MatrixBlockDictionary md = ((AIdentityDictionary) _dict).getMBDict(); final MatrixBlock mb = md.getMatrixBlock(); // The dictionary is never empty. if(mb.isInSparseFormat()) @@ -87,8 +88,8 @@ else if(_dict instanceof MatrixBlockDictionary) { @Override public void decompressToSparseBlockTransposed(SparseBlockMCSR sb, int nColOut) { - if(_dict instanceof IdentityDictionary) { - final MatrixBlockDictionary md = ((IdentityDictionary) _dict).getMBDict(); + if(_dict instanceof AIdentityDictionary) { + final MatrixBlockDictionary md = ((AIdentityDictionary) _dict).getMBDict(); final MatrixBlock mb = md.getMatrixBlock(); // The dictionary is never empty. if(mb.isInSparseFormat()) @@ -123,8 +124,8 @@ protected abstract void decompressToSparseBlockTransposedDenseDictionary(SparseB @Override public final void decompressToDenseBlock(DenseBlock db, int rl, int ru, int offR, int offC) { - if(_dict instanceof IdentityDictionary) { - final MatrixBlockDictionary md = ((IdentityDictionary) _dict).getMBDict(); + if(_dict instanceof AIdentityDictionary) { + final MatrixBlockDictionary md = ((AIdentityDictionary) _dict).getMBDict(); final MatrixBlock mb = md.getMatrixBlock(); // The dictionary is never empty. if(mb.isInSparseFormat()) @@ -147,9 +148,8 @@ else if(_dict instanceof MatrixBlockDictionary) { @Override public final void decompressToSparseBlock(SparseBlock sb, int rl, int ru, int offR, int offC) { - if(_dict instanceof IdentityDictionary) { - - final MatrixBlockDictionary md = ((IdentityDictionary) _dict).getMBDict(); + if(_dict instanceof AIdentityDictionary) { + final MatrixBlockDictionary md = ((AIdentityDictionary) _dict).getMBDict(); final MatrixBlock mb = md.getMatrixBlock(); // The dictionary is never empty. if(mb.isInSparseFormat()) @@ -249,8 +249,8 @@ public final AColGroup rightMultByMatrix(MatrixBlock right, IColIndex allCols, i return null; // is candidate for identity mm. - if(_dict instanceof IdentityDictionary // - && !((IdentityDictionary) _dict).withEmpty() + if(_dict instanceof AIdentityDictionary // + && !((AIdentityDictionary) _dict).withEmpty() && right.getNumRows() == _colIndexes.size() // && allowShallowIdentityRightMult()){ diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java index 9faf43b6095..d3ca10445c2 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java @@ -26,10 +26,10 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; +import org.apache.sysds.runtime.compress.colgroup.dictionary.AIdentityDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; -import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.PlaceHolderDict; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; @@ -327,8 +327,8 @@ public AColGroup binaryRowOpRight(BinaryOperator op, double[] v, boolean isRowSa * @param constV The output columns. */ public final void addToCommon(double[] constV) { - if(_dict instanceof IdentityDictionary) { - MatrixBlock mb = ((IdentityDictionary) _dict).getMBDict().getMatrixBlock(); + if(_dict instanceof AIdentityDictionary) { + MatrixBlock mb = ((AIdentityDictionary) _dict).getMBDict().getMatrixBlock(); if(mb.isInSparseFormat()) addToCommonSparse(constV, mb.getSparseBlock()); else diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ACachingMBDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ACachingMBDictionary.java new file mode 100644 index 00000000000..8117bd345cc --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ACachingMBDictionary.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * O + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.colgroup.dictionary; + +import java.lang.ref.SoftReference; + +public abstract class ACachingMBDictionary extends ADictionary { + + /** A Cache to contain a materialized version of the identity matrix. */ + protected volatile SoftReference cache = null; + + @Override + public final MatrixBlockDictionary getMBDict(int nCol) { + if(cache != null) { + MatrixBlockDictionary r = cache.get(); + if(r != null) + return r; + } + MatrixBlockDictionary ret = createMBDict(nCol); + cache = new SoftReference<>(ret); + return ret; + } + + public abstract MatrixBlockDictionary createMBDict(int nCol); +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java index 8d4a9d6cc89..31f2c9fb3c4 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java @@ -24,8 +24,13 @@ import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.ValueFunction; import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.ScalarOperator; +import org.apache.sysds.runtime.matrix.operators.UnaryOperator; /** * This dictionary class aims to encapsulate the storage and operations over unique tuple values of a column group. @@ -49,7 +54,7 @@ public final CM_COV_Object centralMomentWithReference(ValueFunction fn, int[] co @Override public final boolean equals(Object o) { - if(o instanceof IDictionary) + if(o != null && o instanceof IDictionary) return equals((IDictionary) o); return false; } @@ -65,7 +70,7 @@ public final boolean equals(double[] v) { * @param v The value * @return The string */ - public static String doubleToString(double v) { + protected static String doubleToString(double v) { if(v == (long) v) return Long.toString(((long) v)); else @@ -96,104 +101,468 @@ public IDictionary rightMMPreAggSparse(int numVals, SparseBlock b, IColIndex thi return rightMMPreAggSparseAllColsRight(numVals, b, thisCols, nColRight); } - protected IDictionary rightMMPreAggSparseSelectedCols(int numVals, SparseBlock b, IColIndex thisCols, - IColIndex aggregateColumns) { + @Override + public void putSparse(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) { + for(int i = 0; i < nCol; i++) + sb.append(rowOut, columns.get(i), getValue(idx, i, nCol)); + } + + @Override + public void putDense(DenseBlock dr, int idx, int rowOut, int nCol, IColIndex columns) { + double[] dv = dr.values(rowOut); + int off = dr.pos(rowOut); + for(int i = 0; i < nCol; i++) + dv[off + columns.get(i)] += getValue(idx, i, nCol); + } + + @Override + public double[] getRow(int i, int nCol) { + double[] ret = new double[nCol]; + for(int c = 0; c < nCol; c++) { + ret[c] = getValue(i, c, nCol); + } + return ret; + } - final int thisColsSize = thisCols.size(); - final int aggColSize = aggregateColumns.size(); - final double[] ret = new double[numVals * aggColSize]; + public MatrixBlockDictionary getMBDict() { + throw new RuntimeException("Invalid call to getMBDict for " + getClass().getSimpleName()); + } - for(int h = 0; h < thisColsSize; h++) { - // chose row in right side matrix via column index of the dictionary - final int colIdx = thisCols.get(h); - if(b.isEmpty(colIdx)) - continue; + @Override + public void product(double[] ret, int[] counts, int nCol) { + getMBDict().product(ret, counts, nCol); + } - // extract the row values on the right side. - final double[] sValues = b.values(colIdx); - final int[] sIndexes = b.indexes(colIdx); - final int sPos = b.pos(colIdx); - final int sEnd = b.size(colIdx) + sPos; + @Override + public void productWithDefault(double[] ret, int[] counts, double[] def, int defCount) { + getMBDict().productWithDefault(ret, counts, def, defCount); + } - for(int j = 0; j < numVals; j++) { // rows left - final int offOut = j * aggColSize; - final double v = getValue(j, h, thisColsSize); - sparseAddSelected(sPos, sEnd, aggColSize, aggregateColumns, sIndexes, sValues, ret, offOut, v); - } + @Override + public void productWithReference(double[] ret, int[] counts, double[] reference, int refCount) { + getMBDict().productWithReference(ret, counts, reference, refCount); + } - } - return Dictionary.create(ret); + @Override + public CM_COV_Object centralMoment(CM_COV_Object ret, ValueFunction fn, int[] counts, int nRows) { + return getMBDict().centralMoment(ret, fn, counts, nRows); + } + + @Override + public double getSparsity() { + return getMBDict().getSparsity(); + } + + @Override + public CM_COV_Object centralMomentWithDefault(CM_COV_Object ret, ValueFunction fn, int[] counts, double def, + int nRows) { + return getMBDict().centralMomentWithDefault(ret, fn, counts, def, nRows); + } + + @Override + public CM_COV_Object centralMomentWithReference(CM_COV_Object ret, ValueFunction fn, int[] counts, double reference, + int nRows) { + return getMBDict().centralMomentWithReference(ret, fn, counts, reference, nRows); + } + + @Override + public IDictionary rexpandCols(int max, boolean ignore, boolean cast, int nCol) { + return getMBDict().rexpandCols(max, ignore, cast, nCol); + } + + @Override + public IDictionary rexpandColsWithReference(int max, boolean ignore, boolean cast, int reference) { + return getMBDict().rexpandColsWithReference(max, ignore, cast, reference); + } + + @Override + public void TSMMWithScaling(int[] counts, IColIndex rows, IColIndex cols, MatrixBlock ret) { + getMBDict().TSMMWithScaling(counts, rows, cols, ret); + } + + @Override + public void MMDict(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { + getMBDict().MMDict(right, rowsLeft, colsRight, result); + } + + @Override + public void MMDictScaling(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, + int[] scaling) { + getMBDict().MMDictScaling(right, rowsLeft, colsRight, result, scaling); + } + + @Override + public void MMDictSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { + getMBDict().MMDictSparse(left, rowsLeft, colsRight, result); + } + + @Override + public void MMDictScalingSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, + int[] scaling) { + getMBDict().MMDictScalingSparse(left, rowsLeft, colsRight, result, scaling); + } + + @Override + public void TSMMToUpperTriangle(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { + getMBDict().TSMMToUpperTriangle(right, rowsLeft, colsRight, result); + } + + @Override + public void TSMMToUpperTriangleDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { + getMBDict().TSMMToUpperTriangleDense(left, rowsLeft, colsRight, result); + } + + @Override + public void TSMMToUpperTriangleSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, + MatrixBlock result) { + getMBDict().TSMMToUpperTriangleSparse(left, rowsLeft, colsRight, result); + } + + @Override + public void TSMMToUpperTriangleScaling(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, int[] scale, + MatrixBlock result) { + getMBDict().TSMMToUpperTriangleScaling(right, rowsLeft, colsRight, scale, result); + } + + @Override + public void TSMMToUpperTriangleDenseScaling(double[] left, IColIndex rowsLeft, IColIndex colsRight, int[] scale, + MatrixBlock result) { + getMBDict().TSMMToUpperTriangleDenseScaling(left, rowsLeft, colsRight, scale, result); + } + + @Override + public void TSMMToUpperTriangleSparseScaling(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, int[] scale, + MatrixBlock result) { + getMBDict().TSMMToUpperTriangleSparseScaling(left, rowsLeft, colsRight, scale, result); + } + + @Override + public IDictionary reorder(int[] reorder) { + return getMBDict().reorder(reorder); + } + + @Override + public IDictionary cbind(IDictionary that, int nCol) { + return getMBDict().cbind(that, nCol); + } + + @Override + public IDictionary append(double[] row) { + return getMBDict().append(row); + } + + @Override + public IDictionary replace(double pattern, double replace, int nCol) { + if(containsValue(pattern)) + return getMBDict().replace(pattern, replace, nCol); + else + return this; } - private final void sparseAddSelected(int sPos, int sEnd, int aggColSize, IColIndex aggregateColumns, int[] sIndexes, - double[] sValues, double[] ret, int offOut, double v) { + @Override + public IDictionary replaceWithReference(double pattern, double replace, double[] reference) { + if(containsValueWithReference(pattern, reference)) + return getMBDict().replaceWithReference(pattern, replace, reference); + else + return this; + } + + @Override + public IDictionary subtractTuple(double[] tuple) { + return getMBDict().subtractTuple(tuple); + } + + @Override + public long getNumberNonZerosWithReference(int[] counts, double[] reference, int nRows) { + return getMBDict().getNumberNonZerosWithReference(counts, reference, nRows); + } - int retIdx = 0; - for(int i = sPos; i < sEnd; i++) { - // skip through the retIdx. - while(retIdx < aggColSize && aggregateColumns.get(retIdx) < sIndexes[i]) - retIdx++; - if(retIdx == aggColSize) - break; - ret[offOut + retIdx] += v * sValues[i]; + @Override + public boolean containsValueWithReference(double pattern, double[] reference) { + if(Double.isNaN(pattern)) { + for(int i = 0; i < reference.length; i++) + if(Double.isNaN(reference[i])) + return true; + return containsValue(pattern); } - retIdx = 0; + return getMBDict().containsValueWithReference(pattern, reference); + } + + @Override + public double sumSqWithReference(int[] counts, double[] reference) { + return getMBDict().sumSqWithReference(counts, reference); + } + + @Override + public void colProductWithReference(double[] res, int[] counts, IColIndex colIndexes, double[] reference) { + getMBDict().colProductWithReference(res, counts, colIndexes, reference); + + } + + @Override + public void colSumSqWithReference(double[] c, int[] counts, IColIndex colIndexes, double[] reference) { + getMBDict().colSumSqWithReference(c, counts, colIndexes, reference); + } + + @Override + public void multiplyScalar(double v, double[] ret, int off, int dictIdx, IColIndex cols) { + getMBDict().multiplyScalar(v, ret, off, dictIdx, cols); + } + + @Override + public void MMDictDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { + getMBDict().MMDictDense(left, rowsLeft, colsRight, result); } protected IDictionary rightMMPreAggSparseAllColsRight(int numVals, SparseBlock b, IColIndex thisCols, int nColRight) { - final int thisColsSize = thisCols.size(); - final double[] ret = new double[numVals * nColRight]; - - for(int h = 0; h < thisColsSize; h++) { // common dim - // chose row in right side matrix via column index of the dictionary - final int colIdx = thisCols.get(h); - if(b.isEmpty(colIdx)) - continue; - - // extract the row values on the right side. - final double[] sValues = b.values(colIdx); - final int[] sIndexes = b.indexes(colIdx); - final int sPos = b.pos(colIdx); - final int sEnd = b.size(colIdx) + sPos; - - for(int i = 0; i < numVals; i++) { // rows left - final int offOut = i * nColRight; - final double v = getValue(i, h, thisColsSize); - SparseAdd(sPos, sEnd, ret, offOut, sIndexes, sValues, v); - } - } - return Dictionary.create(ret); + return getMBDict().rightMMPreAggSparseAllColsRight(numVals, b, thisCols, nColRight); } - private final void SparseAdd(int sPos, int sEnd, double[] ret, int offOut, int[] sIdx, double[] sVals, double v) { - if(v != 0) { - for(int k = sPos; k < sEnd; k++) { // cols right with value - ret[offOut + sIdx[k]] += v * sVals[k]; - } - } + protected IDictionary rightMMPreAggSparseSelectedCols(int numVals, SparseBlock b, IColIndex thisCols, + IColIndex aggregateColumns) { + return getMBDict().rightMMPreAggSparseSelectedCols(numVals, b, thisCols, aggregateColumns); } @Override - public void putSparse(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) { - for(int i = 0; i < nCol; i++) - sb.append(rowOut, columns.get(i), getValue(idx, i, nCol)); + public double[] productAllRowsToDoubleWithReference(double[] reference) { + return getMBDict().productAllRowsToDoubleWithReference(reference); } @Override - public void putDense(DenseBlock dr, int idx, int rowOut, int nCol, IColIndex columns) { - double[] dv = dr.values(rowOut); - int off = dr.pos(rowOut); - for(int i = 0; i < nCol; i++) - dv[off + columns.get(i)] += getValue(idx, i, nCol); + public double[] sumAllRowsToDoubleSqWithDefault(double[] defaultTuple) { + return getMBDict().sumAllRowsToDoubleSqWithDefault(defaultTuple); } @Override - public double[] getRow(int i, int nCol) { - double[] ret = new double[nCol]; - for(int c = 0; c < nCol; c++) { - ret[c] = getValue(i, c, nCol); - } - return ret; + public double[] sumAllRowsToDoubleSqWithReference(double[] reference) { + return getMBDict().sumAllRowsToDoubleSqWithReference(reference); + } + + @Override + public IDictionary binOpRightWithReference(BinaryOperator op, double[] v, IColIndex colIndexes, double[] reference, + double[] newReference) { + return getMBDict().binOpRightWithReference(op, v, colIndexes, reference, newReference); + } + + @Override + public IDictionary binOpRightAndAppend(BinaryOperator op, double[] v, IColIndex colIndexes) { + return getMBDict().binOpRightAndAppend(op, v, colIndexes); + } + + @Override + public IDictionary binOpRight(BinaryOperator op, double[] v) { + return getMBDict().binOpRight(op, v); + } + + @Override + public IDictionary applyScalarOp(ScalarOperator op) { + return getMBDict().applyScalarOp(op); + } + + @Override + public IDictionary applyScalarOpAndAppend(ScalarOperator op, double v0, int nCol) { + return getMBDict().applyScalarOpAndAppend(op, v0, nCol); + } + + @Override + public IDictionary applyUnaryOp(UnaryOperator op) { + return getMBDict().applyUnaryOp(op); + } + + @Override + public IDictionary applyUnaryOpAndAppend(UnaryOperator op, double v0, int nCol) { + return getMBDict().applyUnaryOpAndAppend(op, v0, nCol); + } + + @Override + public IDictionary applyScalarOpWithReference(ScalarOperator op, double[] reference, double[] newReference) { + return getMBDict().applyScalarOpWithReference(op, reference, newReference); + } + + @Override + public IDictionary applyUnaryOpWithReference(UnaryOperator op, double[] reference, double[] newReference) { + return getMBDict().applyUnaryOpWithReference(op, reference, newReference); + } + + @Override + public IDictionary binOpLeft(BinaryOperator op, double[] v, IColIndex colIndexes) { + return getMBDict().binOpLeft(op, v, colIndexes); + } + + @Override + public IDictionary binOpLeftAndAppend(BinaryOperator op, double[] v, IColIndex colIndexes) { + return getMBDict().binOpLeftAndAppend(op, v, colIndexes); + } + + @Override + public IDictionary binOpLeftWithReference(BinaryOperator op, double[] v, IColIndex colIndexes, double[] reference, + double[] newReference) { + return getMBDict().binOpLeftWithReference(op, v, colIndexes, reference, newReference); + } + + @Override + public void aggregateColsWithReference(double[] c, Builtin fn, IColIndex colIndexes, double[] reference, + boolean def) { + getMBDict().aggregateColsWithReference(c, fn, colIndexes, reference, def); + } + + @Override + public double[] aggregateRowsWithDefault(Builtin fn, double[] defaultTuple) { + return getMBDict().aggregateRowsWithDefault(fn, defaultTuple); + } + + @Override + public double[] aggregateRowsWithReference(Builtin fn, double[] reference) { + return getMBDict().aggregateRowsWithReference(fn, reference); + } + + @Override + public double aggregateWithReference(double init, Builtin fn, double[] reference, boolean def) { + return getMBDict().aggregateWithReference(init, fn, reference, def); + } + + @Override + public double aggregate(double init, Builtin fn) { + return getMBDict().aggregate(init, fn); + } + + @Override + public void colSumSq(double[] c, int[] counts, IColIndex colIndexes) { + getMBDict().colSumSq(c, counts, colIndexes); + } + + @Override + public void addToEntry(double[] v, int fr, int to, int nCol) { + getMBDict().addToEntry(v, fr, to, nCol); + } + + @Override + public void colProduct(double[] res, int[] counts, IColIndex colIndexes) { + getMBDict().colProduct(res, counts, colIndexes); + } + + @Override + public void MMDictScalingDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, + int[] scaling) { + getMBDict().MMDictScalingDense(left, rowsLeft, colsRight, result, scaling); + } + + @Override + public int[] countNNZZeroColumns(int[] counts) { + return getMBDict().countNNZZeroColumns(counts); } + + @Override + public IDictionary sliceOutColumnRange(int idxStart, int idxEnd, int previousNumberOfColumns) { + return getMBDict().sliceOutColumnRange(idxStart, idxEnd, previousNumberOfColumns); + } + + @Override + public IDictionary scaleTuples(int[] scaling, int nCol) { + return getMBDict().scaleTuples(scaling, nCol); + } + + @Override + public IDictionary binOpRight(BinaryOperator op, double[] v, IColIndex colIndexes) { + return getMBDict().binOpRight(op, v, colIndexes); + } + + @Override + public IDictionary preaggValuesFromDense(final int numVals, final IColIndex colIndexes, + final IColIndex aggregateColumns, final double[] b, final int cut) { + return getMBDict().preaggValuesFromDense(numVals, colIndexes, aggregateColumns, b, cut); + } + + @Override + public void addToEntryVectorized(double[] v, int f1, int f2, int f3, int f4, int f5, int f6, int f7, int f8, int t1, + int t2, int t3, int t4, int t5, int t6, int t7, int t8, int nCol) { + getMBDict().addToEntryVectorized(v, f1, f2, f3, f4, f5, f6, f7, f8, t1, t2, t3, t4, t5, t6, t7, t8, nCol); + } + + @Override + public double[] getValues() { + return getMBDict().getValues(); + } + + @Override + public double getValue(int i) { + return getMBDict().getValue(i); + } + + @Override + public double getValue(int r, int col, int nCol) { + return getMBDict().getValue(r, col, nCol); + } + + @Override + public double[] aggregateRows(Builtin fn, int nCol) { + return getMBDict().aggregateRows(fn, nCol); + } + + @Override + public void aggregateCols(double[] c, Builtin fn, IColIndex colIndexes) { + getMBDict().aggregateCols(c, fn, colIndexes); + } + + @Override + public double[] sumAllRowsToDouble(int nrColumns) { + return getMBDict().sumAllRowsToDouble(nrColumns); + } + + @Override + public double[] sumAllRowsToDoubleWithDefault(double[] defaultTuple) { + return getMBDict().sumAllRowsToDoubleWithDefault(defaultTuple); + } + + @Override + public double[] sumAllRowsToDoubleWithReference(double[] reference) { + return getMBDict().sumAllRowsToDoubleWithReference(reference); + } + + @Override + public double[] sumAllRowsToDoubleSq(int nrColumns) { + return getMBDict().sumAllRowsToDoubleSq(nrColumns); + } + + @Override + public double[] productAllRowsToDouble(int nrColumns) { + return getMBDict().productAllRowsToDouble(nrColumns); + } + + @Override + public double[] productAllRowsToDoubleWithDefault(double[] defaultTuple) { + return getMBDict().productAllRowsToDoubleWithDefault(defaultTuple); + } + + @Override + public void colSum(double[] c, int[] counts, IColIndex colIndexes) { + getMBDict().colSum(c, counts, colIndexes); + } + + @Override + public double sum(int[] counts, int nCol) { + return getMBDict().sum(counts, nCol); + } + + @Override + public double sumSq(int[] counts, int nCol) { + return getMBDict().sumSq(counts, nCol); + } + + @Override + public boolean containsValue(double pattern) { + return getMBDict().containsValue(pattern); + } + + @Override + public void addToEntry(double[] v, int fr, int to, int nCol, int rep) { + getMBDict().addToEntry(v, fr, to, nCol, rep); + } + + @Override + public MatrixBlockDictionary getMBDict(int nCol) { + return getMBDict(); + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/AIdentityDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/AIdentityDictionary.java new file mode 100644 index 00000000000..2bc10b1b062 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/AIdentityDictionary.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.colgroup.dictionary; + +import org.apache.sysds.runtime.compress.DMLCompressionException; + +public abstract class AIdentityDictionary extends ACachingMBDictionary { + /** The number of rows or columns, rows can be +1 if withEmpty is set. */ + protected final int nRowCol; + /** Specify if the Identity matrix should contain an empty row in the end. */ + protected final boolean withEmpty; + + /** + * Create an identity matrix dictionary. It behaves as if allocated a Sparse Matrix block but exploits that the + * structure is known to have certain properties. + * + * @param nRowCol The number of rows and columns in this identity matrix. + */ + public AIdentityDictionary(int nRowCol) { + if(nRowCol <= 0) + throw new DMLCompressionException("Invalid Identity Dictionary"); + this.nRowCol = nRowCol; + this.withEmpty = false; + } + + public AIdentityDictionary(int nRowCol, boolean withEmpty) { + if(nRowCol <= 0) + throw new DMLCompressionException("Invalid Identity Dictionary"); + this.nRowCol = nRowCol; + this.withEmpty = withEmpty; + } + + public boolean withEmpty() { + return withEmpty; + } + + public static long getInMemorySize(int numberColumns) { + return 4 + 4 + 8; // int + padding + softReference + } + + @Override + public final boolean containsValue(double pattern) { + return pattern == 0.0 || pattern == 1.0; + } + + @Override + public double[] productAllRowsToDouble(int nCol) { + return new double[nRowCol + (withEmpty ? 1 : 0)]; + } + + @Override + public double[] productAllRowsToDoubleWithDefault(double[] defaultTuple) { + double[] ret = new double[nRowCol + (withEmpty ? 1 : 0) + 1]; + ret[ret.length - 1] = 1; + for(int i = 0; i < defaultTuple.length; i++) + ret[ret.length - 1] *= defaultTuple[i]; + return ret; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java index a990b689b99..5bbc1af5942 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java @@ -19,6 +19,9 @@ package org.apache.sysds.runtime.compress.colgroup.dictionary; +import java.io.DataOutput; +import java.io.IOException; + import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.functionobjects.Divide; import org.apache.sysds.runtime.functionobjects.Minus; @@ -30,16 +33,23 @@ * This dictionary class is a specialization for the DeltaDDCColgroup. Here the adjustments for operations for the delta * encoded values are implemented. */ -public class DeltaDictionary extends Dictionary { +public class DeltaDictionary extends ADictionary { private static final long serialVersionUID = -5700139221491143705L; - + private final int _numCols; + protected final double[] _values; + public DeltaDictionary(double[] values, int numCols) { - super(values); + _values = values; _numCols = numCols; } + @Override + public double[] getValues(){ + return _values; + } + @Override public DeltaDictionary applyScalarOp(ScalarOperator op) { final double[] retV = new double[_values.length]; @@ -61,4 +71,49 @@ else if(op.fn instanceof Plus || op.fn instanceof Minus) { return new DeltaDictionary(retV, _numCols); } + + @Override + public long getInMemorySize() { + return Dictionary.getInMemorySize(_values.length); + } + + @Override + public void write(DataOutput out) throws IOException { + throw new NotImplementedException(); + } + + @Override + public long getExactSizeOnDisk() { + throw new NotImplementedException(); + } + + @Override + public DictType getDictType() { + throw new NotImplementedException(); + } + + @Override + public int getNumberOfValues(int ncol) { + return _values.length / ncol; + } + + @Override + public String getString(int colIndexes) { + throw new NotImplementedException(); + } + + @Override + public long getNumberNonZeros(int[] counts, int nCol) { + throw new NotImplementedException(); + } + + @Override + public boolean equals(IDictionary o) { + throw new NotImplementedException(); + } + + @Override + public IDictionary clone() { + throw new NotImplementedException(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java index 4bbabe2f926..139254b5341 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java @@ -22,7 +22,6 @@ import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; -import java.lang.ref.SoftReference; import java.math.BigDecimal; import java.math.MathContext; import java.util.Arrays; @@ -52,13 +51,11 @@ * group. The primary reason for its introduction was to provide an entry point for specialization such as shared * dictionaries, which require additional information. */ -public class Dictionary extends ADictionary { +public class Dictionary extends ACachingMBDictionary { private static final long serialVersionUID = -6517136537249507753L; protected final double[] _values; - /** A Cache to contain a MatrixBlock version of the dictionary. */ - protected volatile SoftReference cache = null; protected Dictionary(double[] values) { _values = values; @@ -292,7 +289,7 @@ public IDictionary binOpRightAndAppend(BinaryOperator op, double[] v, IColIndex final int lenV = colIndexes.size(); for(int i = 0; i < _values.length; i++) retVals[i] = fn.execute(_values[i], v[colIndexes.get(i % lenV)]); - for(int i = _values.length; i < _values.length; i++) + for(int i = _values.length; i < retVals.length; i++) retVals[i] = fn.execute(0, v[colIndexes.get(i % lenV)]); return create(retVals); @@ -332,7 +329,7 @@ public IDictionary binOpLeftAndAppend(BinaryOperator op, double[] v, IColIndex c final int lenV = colIndexes.size(); for(int i = 0; i < _values.length; i++) retVals[i] = fn.execute(v[colIndexes.get(i % lenV)], _values[i]); - for(int i = _values.length; i < _values.length; i++) + for(int i = _values.length; i < retVals.length; i++) retVals[i] = fn.execute(v[colIndexes.get(i % lenV)], 0); return create(retVals); @@ -468,9 +465,9 @@ public double[] productAllRowsToDoubleWithDefault(double[] defaultTuple) { final double[] ret = new double[numVals + 1]; for(int k = 0; k < numVals; k++) ret[k] = prodRow(k, nCol); - ret[ret.length - 1] = defaultTuple[0]; + ret[numVals] = defaultTuple[0]; for(int i = 1; i < nCol; i++) - ret[ret.length - 1] *= defaultTuple[i]; + ret[numVals] *= defaultTuple[i]; return ret; } @@ -525,9 +522,10 @@ private double sumRowSq(int k, int nrColumns) { private double prodRow(int k, int nrColumns) { final int valOff = k * nrColumns; + final int end = valOff + nrColumns; double res = _values[valOff]; - for(int i = 1; i < nrColumns; i++) - res *= _values[valOff + i]; + for(int i = valOff + 1; i < end && res != 0; i++) // early abort on zero + res *= _values[i]; return res; } @@ -729,6 +727,8 @@ public boolean containsValue(double pattern) { @Override public boolean containsValueWithReference(double pattern, double[] reference) { + if(Double.isNaN(pattern)) + return super.containsValueWithReference(pattern, reference); final int nCol = reference.length; for(int i = 0; i < _values.length; i++) if(_values[i] + reference[i % nCol] == pattern) @@ -840,14 +840,8 @@ public IDictionary subtractTuple(double[] tuple) { } @Override - public MatrixBlockDictionary getMBDict(int nCol) { - if(cache != null) { - MatrixBlockDictionary r = cache.get(); - if(r != null) - return r; - } + public MatrixBlockDictionary createMBDict(int nCol) { MatrixBlockDictionary ret = MatrixBlockDictionary.createDictionary(_values, nCol, true); - cache = new SoftReference<>(ret); return ret; } @@ -922,46 +916,7 @@ public IDictionary replaceWithReference(double pattern, double replace, double[] final int nCol = reference.length; final int nRow = _values.length / nCol; if(Util.eq(pattern, Double.NaN)) { - Set colsWithNan = null; - for(int i = 0; i < reference.length; i++) { - if(Util.eq(reference[i], Double.NaN)) { - if(colsWithNan == null) - colsWithNan = new HashSet<>(); - colsWithNan.add(i); - reference[i] = replace; - } - } - - if(colsWithNan != null) { - final double[] retV = new double[_values.length]; - for(int i = 0; i < nRow; i++) { - final int off = i * reference.length; - for(int j = 0; j < nCol; j++) { - final int cell = off + j; - if(colsWithNan.contains(j)) - retV[cell] = 0; - else if(Util.eq(_values[cell], Double.NaN)) - retV[cell] = replace; - else - retV[cell] = _values[cell]; - } - } - return create(retV); - } - else { - final double[] retV = new double[_values.length]; - for(int i = 0; i < nRow; i++) { - final int off = i * reference.length; - for(int j = 0; j < nCol; j++) { - final int cell = off + j; - if(Util.eq(_values[cell], Double.NaN)) - retV[cell] = replace; - else - retV[cell] = _values[cell]; - } - } - return create(retV); - } + return replaceWithReferenceNaN(replace, reference, nCol, nRow); } else { final double[] retV = new double[_values.length]; @@ -978,6 +933,62 @@ else if(Util.eq(_values[cell], Double.NaN)) } } + private IDictionary replaceWithReferenceNaN(double replace, double[] reference, final int nCol, final int nRow) { + final Set colsWithNan = getColsWithNan(replace, reference); + final double[] retV; + if(colsWithNan != null) { + if(colsWithNan.size() == nCol && replace == 0) + return null; + retV = new double[_values.length]; + replaceWithReferenceNanDenseWithNanCols(replace, reference, nRow, nCol, colsWithNan, _values, retV); + } + else { + retV = new double[_values.length]; + replaceWithReferenceNanDenseWithoutNanCols(replace, reference, nRow, nCol, retV, _values); + } + return create(retV); + } + + protected static Set getColsWithNan(double replace, double[] reference) { + Set colsWithNan = null; + for(int i = 0; i < reference.length; i++) { + if(Util.eq(reference[i], Double.NaN)) { + if(colsWithNan == null) + colsWithNan = new HashSet<>(); + colsWithNan.add(i); + reference[i] = replace; + } + } + return colsWithNan; + } + + protected static void replaceWithReferenceNanDenseWithoutNanCols(final double replace, final double[] reference, + final int nRow, final int nCol, final double[] retV, final double[] values) { + int off = 0; + for(int i = 0; i < nRow; i++) { + for(int j = 0; j < nCol; j++) { + final double v = values[off]; + retV[off++] = Util.eq(Double.NaN, v) ? replace - reference[j] : v; + } + } + } + + protected static void replaceWithReferenceNanDenseWithNanCols(final double replace, final double[] reference, + final int nRow, final int nCol, Set colsWithNan, final double[] values, final double[] retV) { + int off = 0; + for(int i = 0; i < nRow; i++) { + for(int j = 0; j < nCol; j++) { + final double v = values[off]; + if(colsWithNan.contains(j)) + retV[off++] = 0; + else if(Util.eq(v, Double.NaN)) + retV[off++] = replace - reference[j]; + else + retV[off++] = v; + } + } + } + @Override public void product(double[] ret, int[] counts, int nCol) { if(ret[0] == 0) @@ -1033,17 +1044,22 @@ public void productWithReference(double[] ret, int[] counts, double[] reference, if(ret[0] == 0) return; final MathContext cont = MathContext.DECIMAL128; - final int len = counts.length; + final int nRow = counts.length; final int nCol = reference.length; + BigDecimal tmp = BigDecimal.ONE; int off = 0; - for(int i = 0; i < len; i++) { + for(int i = 0; i < nRow; i++) { for(int j = 0; j < nCol; j++) { final double v = _values[off++] + reference[j]; if(v == 0) { ret[0] = 0; return; } + else if(!Double.isFinite(v)) { + ret[0] = v; + return; + } tmp = tmp.multiply(new BigDecimal(v).pow(counts[i], cont), cont); } } @@ -1053,6 +1069,7 @@ public void productWithReference(double[] ret, int[] counts, double[] reference, ret[0] = 0; else if(!Double.isInfinite(ret[0])) ret[0] = new BigDecimal(ret[0]).multiply(tmp, MathContext.DECIMAL128).doubleValue(); + } @Override @@ -1201,25 +1218,8 @@ public void TSMMToUpperTriangleSparseScaling(SparseBlock left, IColIndex rowsLef public boolean equals(IDictionary o) { if(o instanceof Dictionary) return Arrays.equals(_values, ((Dictionary) o)._values); - else if(o instanceof IdentityDictionary) - return ((IdentityDictionary) o).equals(this); - else if(o instanceof MatrixBlockDictionary) { - final MatrixBlock mb = ((MatrixBlockDictionary) o).getMatrixBlock(); - if(mb.isEmpty()) { - for(int i = 0; i < _values.length; i++) { - if(_values[i] != 0) - return false; - } - return true; - } - else if(mb.isInSparseFormat()) - return mb.getSparseBlock().equals(_values, mb.getNumColumns()); - final double[] dv = mb.getDenseBlockValues(); - return Arrays.equals(_values, dv); - } - else if(o instanceof IdentityDictionary) { + else if(o != null) return o.equals(this); - } return false; } @@ -1245,6 +1245,86 @@ public IDictionary reorder(int[] reorder) { return ret; } + @Override + protected IDictionary rightMMPreAggSparseSelectedCols(int numVals, SparseBlock b, IColIndex thisCols, + IColIndex aggregateColumns) { + + final int thisColsSize = thisCols.size(); + final int aggColSize = aggregateColumns.size(); + final double[] ret = new double[numVals * aggColSize]; + + for(int h = 0; h < thisColsSize; h++) { + // chose row in right side matrix via column index of the dictionary + final int colIdx = thisCols.get(h); + if(b.isEmpty(colIdx)) + continue; + + // extract the row values on the right side. + final double[] sValues = b.values(colIdx); + final int[] sIndexes = b.indexes(colIdx); + final int sPos = b.pos(colIdx); + final int sEnd = b.size(colIdx) + sPos; + + for(int j = 0; j < numVals; j++) { // rows left + final int offOut = j * aggColSize; + final double v = getValue(j, h, thisColsSize); + sparseAddSelected(sPos, sEnd, aggColSize, aggregateColumns, sIndexes, sValues, ret, offOut, v); + } + + } + return Dictionary.create(ret); + } + + private void sparseAddSelected(int sPos, int sEnd, int aggColSize, IColIndex aggregateColumns, int[] sIndexes, + double[] sValues, double[] ret, int offOut, double v) { + + int retIdx = 0; + for(int i = sPos; i < sEnd; i++) { + // skip through the retIdx. + while(retIdx < aggColSize && aggregateColumns.get(retIdx) < sIndexes[i]) + retIdx++; + if(retIdx == aggColSize) + break; + ret[offOut + retIdx] += v * sValues[i]; + } + retIdx = 0; + } + + @Override + protected IDictionary rightMMPreAggSparseAllColsRight(int numVals, SparseBlock b, IColIndex thisCols, + int nColRight) { + final int thisColsSize = thisCols.size(); + final double[] ret = new double[numVals * nColRight]; + + for(int h = 0; h < thisColsSize; h++) { // common dim + // chose row in right side matrix via column index of the dictionary + final int colIdx = thisCols.get(h); + if(b.isEmpty(colIdx)) + continue; + + // extract the row values on the right side. + final double[] sValues = b.values(colIdx); + final int[] sIndexes = b.indexes(colIdx); + final int sPos = b.pos(colIdx); + final int sEnd = b.size(colIdx) + sPos; + + for(int i = 0; i < numVals; i++) { // rows left + final int offOut = i * nColRight; + final double v = getValue(i, h, thisColsSize); + SparseAdd(sPos, sEnd, ret, offOut, sIndexes, sValues, v); + } + } + return Dictionary.create(ret); + } + + private void SparseAdd(int sPos, int sEnd, double[] ret, int offOut, int[] sIdx, double[] sVals, double v) { + if(v != 0) { + for(int k = sPos; k < sEnd; k++) { // cols right with value + ret[offOut + sIdx[k]] += v * sVals[k]; + } + } + } + @Override public IDictionary append(double[] row) { double[] retV = new double[_values.length + row.length]; @@ -1252,4 +1332,5 @@ public IDictionary append(double[] row) { System.arraycopy(row, 0, retV, _values.length, row.length); return new Dictionary(retV); } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java index 72becbae0c0..f88ac99b87b 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java @@ -72,7 +72,6 @@ public static IDictionary read(DataInput in) throws IOException { default: return MatrixBlockDictionary.read(in); } - } public static long getInMemorySize(int nrValues, int nrColumns, double tupleSparsity, boolean lossy) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java index 7a76f308523..41982a6842f 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java @@ -22,7 +22,6 @@ import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; -import java.lang.ref.SoftReference; import java.util.Arrays; import org.apache.commons.lang3.NotImplementedException; @@ -36,38 +35,35 @@ import org.apache.sysds.runtime.functionobjects.Divide; import org.apache.sysds.runtime.functionobjects.Minus; import org.apache.sysds.runtime.functionobjects.Plus; -import org.apache.sysds.runtime.functionobjects.ValueFunction; -import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; -import org.apache.sysds.runtime.matrix.operators.ScalarOperator; -import org.apache.sysds.runtime.matrix.operators.UnaryOperator; /** * A specialized dictionary that exploits the fact that the contained dictionary is an Identity Matrix. */ -public class IdentityDictionary extends ADictionary { +public class IdentityDictionary extends AIdentityDictionary { - private static final long serialVersionUID = 2535887782150955098L; + private static final long serialVersionUID = 2535887782153955098L; - /** The number of rows or columns, rows can be +1 if withEmpty is set. */ - protected final int nRowCol; - /** Specify if the Identity matrix should contain an empty row in the end. */ - protected final boolean withEmpty; - /** A Cache to contain a materialized version of the identity matrix. */ - protected volatile SoftReference cache = null; + /** + * Create an identity matrix dictionary. It behaves as if allocated a Sparse Matrix block but exploits that the + * structure is known to have certain properties. + * + * @param nRowCol The number of rows and columns in this identity matrix. + */ + private IdentityDictionary(int nRowCol) { + super(nRowCol); + } /** * Create an identity matrix dictionary. It behaves as if allocated a Sparse Matrix block but exploits that the * structure is known to have certain properties. * * @param nRowCol The number of rows and columns in this identity matrix. + * @return a Dictionary instance. */ - public IdentityDictionary(int nRowCol) { - if(nRowCol <= 0) - throw new DMLCompressionException("Invalid Identity Dictionary"); - this.nRowCol = nRowCol; - this.withEmpty = false; + public static IDictionary create(int nRowCol) { + return create(nRowCol, false); } /** @@ -77,11 +73,26 @@ public IdentityDictionary(int nRowCol) { * @param nRowCol The number of rows and columns in this identity matrix. * @param withEmpty If the matrix should contain an empty row in the end. */ - public IdentityDictionary(int nRowCol, boolean withEmpty) { - if(nRowCol <= 0) - throw new DMLCompressionException("Invalid Identity Dictionary"); - this.nRowCol = nRowCol; - this.withEmpty = withEmpty; + private IdentityDictionary(int nRowCol, boolean withEmpty) { + super(nRowCol, withEmpty); + } + + /** + * Create an identity matrix dictionary, It behaves as if allocated a Sparse Matrix block but exploits that the + * structure is known to have certain properties. + * + * @param nRowCol The number of rows and columns in this identity matrix. + * @param withEmpty If the matrix should contain an empty row in the end. + * @return a Dictionary instance. + */ + public static IDictionary create(int nRowCol, boolean withEmpty) { + if(nRowCol == 1) { + if(withEmpty) + return new Dictionary(new double[] {1, 0}); + else + return new Dictionary(new double[] {1}); + } + return new IdentityDictionary(nRowCol, withEmpty); } @Override @@ -107,10 +118,6 @@ public double getValue(int i) { return row == col ? 1 : 0; } - public boolean withEmpty() { - return withEmpty; - } - @Override public double getValue(int r, int c, int nCol) { return r == c ? 1 : 0; @@ -118,11 +125,11 @@ public double getValue(int r, int c, int nCol) { @Override public long getInMemorySize() { - return 4 + 4 + 8; // int + padding + softReference + return getInMemorySize(-1); // int + padding + softReference } public static long getInMemorySize(int numberColumns) { - return 4 + 4 + 8; + return AIdentityDictionary.getInMemorySize(numberColumns); } @Override @@ -135,11 +142,6 @@ else if(fn.getBuiltinCode() == BuiltinCode.MIN) throw new NotImplementedException(); } - @Override - public double aggregateWithReference(double init, Builtin fn, double[] reference, boolean def) { - return getMBDict().aggregateWithReference(init, fn, reference, def); - } - @Override public double[] aggregateRows(Builtin fn, int nCol) { double[] ret = new double[nRowCol]; @@ -147,16 +149,6 @@ public double[] aggregateRows(Builtin fn, int nCol) { return ret; } - @Override - public double[] aggregateRowsWithDefault(Builtin fn, double[] defaultTuple) { - return getMBDict().aggregateRowsWithDefault(fn, defaultTuple); - } - - @Override - public double[] aggregateRowsWithReference(Builtin fn, double[] reference) { - return getMBDict().aggregateRowsWithReference(fn, reference); - } - @Override public void aggregateCols(double[] c, Builtin fn, IColIndex colIndexes) { for(int i = 0; i < nRowCol; i++) { @@ -166,60 +158,6 @@ public void aggregateCols(double[] c, Builtin fn, IColIndex colIndexes) { } } - @Override - public void aggregateColsWithReference(double[] c, Builtin fn, IColIndex colIndexes, double[] reference, - boolean def) { - getMBDict().aggregateColsWithReference(c, fn, colIndexes, reference, def); - } - - @Override - public IDictionary applyScalarOp(ScalarOperator op) { - return getMBDict().applyScalarOp(op); - } - - @Override - public IDictionary applyScalarOpAndAppend(ScalarOperator op, double v0, int nCol) { - - return getMBDict().applyScalarOpAndAppend(op, v0, nCol); - } - - @Override - public IDictionary applyUnaryOp(UnaryOperator op) { - return getMBDict().applyUnaryOp(op); - } - - @Override - public IDictionary applyUnaryOpAndAppend(UnaryOperator op, double v0, int nCol) { - return getMBDict().applyUnaryOpAndAppend(op, v0, nCol); - } - - @Override - public IDictionary applyScalarOpWithReference(ScalarOperator op, double[] reference, double[] newReference) { - return getMBDict().applyScalarOpWithReference(op, reference, newReference); - } - - @Override - public IDictionary applyUnaryOpWithReference(UnaryOperator op, double[] reference, double[] newReference) { - return getMBDict().applyUnaryOpWithReference(op, reference, newReference); - } - - @Override - public IDictionary binOpLeft(BinaryOperator op, double[] v, IColIndex colIndexes) { - return getMBDict().binOpLeft(op, v, colIndexes); - } - - @Override - public IDictionary binOpLeftAndAppend(BinaryOperator op, double[] v, IColIndex colIndexes) { - return getMBDict().binOpLeftAndAppend(op, v, colIndexes); - } - - @Override - public IDictionary binOpLeftWithReference(BinaryOperator op, double[] v, IColIndex colIndexes, double[] reference, - double[] newReference) { - return getMBDict().binOpLeftWithReference(op, v, colIndexes, reference, newReference); - - } - @Override public IDictionary binOpRight(BinaryOperator op, double[] v, IColIndex colIndexes) { boolean same = false; @@ -239,22 +177,6 @@ public IDictionary binOpRight(BinaryOperator op, double[] v, IColIndex colIndexe return mb.binOpRight(op, v, colIndexes); } - @Override - public IDictionary binOpRightAndAppend(BinaryOperator op, double[] v, IColIndex colIndexes) { - return getMBDict().binOpRightAndAppend(op, v, colIndexes); - } - - @Override - public IDictionary binOpRight(BinaryOperator op, double[] v) { - return getMBDict().binOpRight(op, v); - } - - @Override - public IDictionary binOpRightWithReference(BinaryOperator op, double[] v, IColIndex colIndexes, double[] reference, - double[] newReference) { - return getMBDict().binOpRightWithReference(op, v, colIndexes, reference, newReference); - } - @Override public IDictionary clone() { return new IdentityDictionary(nRowCol, withEmpty); @@ -321,31 +243,6 @@ public double[] sumAllRowsToDoubleSq(int nrColumns) { return ret; } - @Override - public double[] sumAllRowsToDoubleSqWithDefault(double[] defaultTuple) { - return getMBDict().sumAllRowsToDoubleSqWithDefault(defaultTuple); - } - - @Override - public double[] sumAllRowsToDoubleSqWithReference(double[] reference) { - return getMBDict().sumAllRowsToDoubleSqWithReference(reference); - } - - @Override - public double[] productAllRowsToDouble(int nCol) { - return new double[nRowCol]; - } - - @Override - public double[] productAllRowsToDoubleWithDefault(double[] defaultTuple) { - return new double[nRowCol]; - } - - @Override - public double[] productAllRowsToDoubleWithReference(double[] reference) { - return getMBDict().productAllRowsToDoubleWithReference(reference); - } - @Override public void colSum(double[] c, int[] counts, IColIndex colIndexes) { for(int i = 0; i < colIndexes.size(); i++) @@ -364,17 +261,6 @@ public void colProduct(double[] res, int[] counts, IColIndex colIndexes) { } } - @Override - public void colProductWithReference(double[] res, int[] counts, IColIndex colIndexes, double[] reference) { - getMBDict().colProductWithReference(res, counts, colIndexes, reference); - - } - - @Override - public void colSumSqWithReference(double[] c, int[] counts, IColIndex colIndexes, double[] reference) { - getMBDict().colSumSqWithReference(c, counts, colIndexes, reference); - } - @Override public double sum(int[] counts, int ncol) { // number of rows, change this. @@ -391,27 +277,12 @@ public double sumSq(int[] counts, int ncol) { return sum(counts, ncol); } - @Override - public double sumSqWithReference(int[] counts, double[] reference) { - return getMBDict().sumSqWithReference(counts, reference); - } - @Override public IDictionary sliceOutColumnRange(int idxStart, int idxEnd, int previousNumberOfColumns) { if(idxStart == 0 && idxEnd == nRowCol) return new IdentityDictionary(nRowCol, withEmpty); else - return new IdentityDictionarySlice(nRowCol, withEmpty, idxStart, idxEnd); - } - - @Override - public boolean containsValue(double pattern) { - return pattern == 0.0 || pattern == 1.0; - } - - @Override - public boolean containsValueWithReference(double pattern, double[] reference) { - return getMBDict().containsValueWithReference(pattern, reference); + return IdentityDictionarySlice.create(nRowCol, withEmpty, idxStart, idxEnd); } @Override @@ -421,14 +292,11 @@ public long getNumberNonZeros(int[] counts, int nCol) { @Override public int[] countNNZZeroColumns(int[] counts) { + if(withEmpty) + return Arrays.copyOf(counts, nRowCol); // one less. return counts; // interesting ... but true. } - @Override - public long getNumberNonZerosWithReference(int[] counts, double[] reference, int nRows) { - return getMBDict().getNumberNonZerosWithReference(counts, reference, nRows); - } - @Override public final void addToEntry(final double[] v, final int fr, final int to, final int nCol) { addToEntry(v, fr, to, nCol, 1); @@ -484,46 +352,24 @@ private void addToEntryVectorizedNorm(double[] v, int f1, int f2, int f3, int f4 } @Override - public IDictionary subtractTuple(double[] tuple) { - return getMBDict().subtractTuple(tuple); - } - public MatrixBlockDictionary getMBDict() { return getMBDict(nRowCol); } @Override - public MatrixBlockDictionary getMBDict(int nCol) { - if(cache != null) { - MatrixBlockDictionary r = cache.get(); - if(r != null) - return r; - } - MatrixBlockDictionary ret = createMBDict(); - cache = new SoftReference<>(ret); - return ret; - } - - private MatrixBlockDictionary createMBDict() { - + public MatrixBlockDictionary createMBDict(int nCol) { if(withEmpty) { final SparseBlock sb = SparseBlockFactory.createIdentityMatrixWithEmptyRow(nRowCol); final MatrixBlock identity = new MatrixBlock(nRowCol + 1, nRowCol, nRowCol, sb); return new MatrixBlockDictionary(identity); } else { - final SparseBlock sb = SparseBlockFactory.createIdentityMatrix(nRowCol); final MatrixBlock identity = new MatrixBlock(nRowCol, nRowCol, nRowCol, sb); return new MatrixBlockDictionary(identity); } } - @Override - public IDictionary scaleTuples(int[] scaling, int nCol) { - return getMBDict().scaleTuples(scaling, nCol); - } - @Override public void write(DataOutput out) throws IOException { out.writeByte(DictionaryFactory.Type.IDENTITY.ordinal()); @@ -566,64 +412,6 @@ public IDictionary preaggValuesFromDense(final int numVals, final IColIndex colI return new MatrixBlockDictionary(db); } - @Override - public IDictionary replace(double pattern, double replace, int nCol) { - if(containsValue(pattern)) - return getMBDict().replace(pattern, replace, nCol); - else - return this; - } - - @Override - public IDictionary replaceWithReference(double pattern, double replace, double[] reference) { - if(containsValueWithReference(pattern, reference)) - return getMBDict().replaceWithReference(pattern, replace, reference); - else - return this; - } - - @Override - public void product(double[] ret, int[] counts, int nCol) { - getMBDict().product(ret, counts, nCol); - } - - @Override - public void productWithDefault(double[] ret, int[] counts, double[] def, int defCount) { - getMBDict().productWithDefault(ret, counts, def, defCount); - } - - @Override - public void productWithReference(double[] ret, int[] counts, double[] reference, int refCount) { - getMBDict().productWithReference(ret, counts, reference, refCount); - } - - @Override - public CM_COV_Object centralMoment(CM_COV_Object ret, ValueFunction fn, int[] counts, int nRows) { - return getMBDict().centralMoment(ret, fn, counts, nRows); - } - - @Override - public CM_COV_Object centralMomentWithDefault(CM_COV_Object ret, ValueFunction fn, int[] counts, double def, - int nRows) { - return getMBDict().centralMomentWithDefault(ret, fn, counts, def, nRows); - } - - @Override - public CM_COV_Object centralMomentWithReference(CM_COV_Object ret, ValueFunction fn, int[] counts, double reference, - int nRows) { - return getMBDict().centralMomentWithReference(ret, fn, counts, reference, nRows); - } - - @Override - public IDictionary rexpandCols(int max, boolean ignore, boolean cast, int nCol) { - return getMBDict().rexpandCols(max, ignore, cast, nCol); - } - - @Override - public IDictionary rexpandColsWithReference(int max, boolean ignore, boolean cast, int reference) { - return getMBDict().rexpandColsWithReference(max, ignore, cast, reference); - } - @Override public double getSparsity() { if(withEmpty) @@ -638,25 +426,10 @@ public void multiplyScalar(double v, double[] ret, int off, int dictIdx, IColInd ret[off + cols.get(dictIdx)] += v; } - @Override - public void TSMMWithScaling(int[] counts, IColIndex rows, IColIndex cols, MatrixBlock ret) { - getMBDict().TSMMWithScaling(counts, rows, cols, ret); - } - - @Override - public void MMDict(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { - getMBDict().MMDict(right, rowsLeft, colsRight, result); - } - - public void MMDictScaling(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, - int[] scaling) { - getMBDict().MMDictScaling(right, rowsLeft, colsRight, result, scaling); - } - @Override public void MMDictDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { // similar to fused transpose left into right locations. - + final int leftSide = rowsLeft.size(); final int colsOut = result.getNumColumns(); final int commonDim = Math.min(left.length / leftSide, nRowCol); @@ -673,7 +446,6 @@ public void MMDictDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, @Override public void MMDictScalingDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, int[] scaling) { - // getMBDict().MMDictScalingDense(left, rowsLeft, colsRight, result, scaling); final int leftSide = rowsLeft.size(); final int resCols = result.getNumColumns(); final double[] resV = result.getDenseBlockValues(); @@ -685,52 +457,6 @@ public void MMDictScalingDense(double[] left, IColIndex rowsLeft, IColIndex cols } } - @Override - public void MMDictSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { - getMBDict().MMDictSparse(left, rowsLeft, colsRight, result); - } - - @Override - public void MMDictScalingSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, - int[] scaling) { - getMBDict().MMDictScalingSparse(left, rowsLeft, colsRight, result, scaling); - } - - @Override - public void TSMMToUpperTriangle(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { - getMBDict().TSMMToUpperTriangle(right, rowsLeft, colsRight, result); - } - - @Override - public void TSMMToUpperTriangleDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { - getMBDict().TSMMToUpperTriangleDense(left, rowsLeft, colsRight, result); - } - - @Override - public void TSMMToUpperTriangleSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, - MatrixBlock result) { - getMBDict().TSMMToUpperTriangleSparse(left, rowsLeft, colsRight, result); - } - - @Override - public void TSMMToUpperTriangleScaling(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, int[] scale, - MatrixBlock result) { - getMBDict().TSMMToUpperTriangleScaling(right, rowsLeft, colsRight, scale, result); - } - - @Override - public void TSMMToUpperTriangleDenseScaling(double[] left, IColIndex rowsLeft, IColIndex colsRight, int[] scale, - MatrixBlock result) { - getMBDict().TSMMToUpperTriangleDenseScaling(left, rowsLeft, colsRight, scale, result); - } - - @Override - public void TSMMToUpperTriangleSparseScaling(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, int[] scale, - MatrixBlock result) { - - getMBDict().TSMMToUpperTriangleSparseScaling(left, rowsLeft, colsRight, scale, result); - } - @Override public boolean equals(IDictionary o) { if(o instanceof IdentityDictionary && // @@ -740,16 +466,6 @@ public boolean equals(IDictionary o) { return getMBDict().equals(o); } - @Override - public IDictionary cbind(IDictionary that, int nCol) { - throw new NotImplementedException(); - } - - @Override - public IDictionary reorder(int[] reorder) { - return getMBDict().reorder(reorder); - } - @Override protected IDictionary rightMMPreAggSparseAllColsRight(int numVals, SparseBlock b, IColIndex thisCols, int nColRight) { @@ -812,11 +528,6 @@ protected IDictionary rightMMPreAggSparseSelectedCols(int numVals, SparseBlock b return MatrixBlockDictionary.create(retB, false); } - @Override - public IDictionary append(double[] row) { - return getMBDict().append(row); - } - @Override public String getString(int colIndexes) { return "IdentityMatrix of size: " + nRowCol + " with empty: " + withEmpty; diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java index b397f655f4d..0f07e1eac74 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java @@ -22,20 +22,16 @@ import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; -import java.lang.ref.SoftReference; import java.util.Arrays; -import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; -import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.matrix.operators.BinaryOperator; -public class IdentityDictionarySlice extends IdentityDictionary { +public class IdentityDictionarySlice extends AIdentityDictionary { - private static final long serialVersionUID = 2535887782150955098L; + private static final long serialVersionUID = 2535887782153555098L; /** Lower index for the slice */ private final int l; @@ -53,12 +49,35 @@ public class IdentityDictionarySlice extends IdentityDictionary { */ public IdentityDictionarySlice(int nRowCol, boolean withEmpty, int l, int u) { super(nRowCol, withEmpty); - if(u > nRowCol || l < 0 || l >= u) - throw new DMLRuntimeException("Invalid slice Identity: " + nRowCol + " range: " + l + "--" + u); this.l = l; this.u = u; } + /** + * Create a Identity matrix dictionary slice (if other groups are not more applicable). It behaves as if allocated a + * Sparse Matrix block but exploits that the structure is known to have certain properties. + * + * @param nRowCol the number of rows and columns in this identity matrix. + * @param withEmpty If the matrix should contain an empty row in the end. + * @param l the index lower to start at + * @param u the index upper to end at (not inclusive) + * @return a Dictionary instance. + */ + public static IDictionary create(int nRowCol, boolean withEmpty, int l, int u) { + if(u > nRowCol || l < 0 || l >= u) + throw new DMLRuntimeException("Invalid slice Identity: " + nRowCol + " range: " + l + "--" + u); + if(nRowCol == 1) { + if(withEmpty) + return new Dictionary(new double[] {1, 0}); + else + return new Dictionary(new double[] {1}); + } + else if(l == 0 && u == nRowCol) + return IdentityDictionary.create(nRowCol, withEmpty); + else + return new IdentityDictionarySlice(nRowCol, withEmpty, l, u); + } + @Override public double[] getValues() { LOG.warn("Should not call getValues on Identity Dictionary"); @@ -72,14 +91,20 @@ public double[] getValues() { @Override public double getValue(int i) { - throw new NotImplementedException(); + final int nCol = u - l; + final int vRow = i / nCol; + if(vRow < l || vRow >= u) + return 0; + final int oRow = vRow - l; + final int col = i % nCol; + return oRow == col ? 1 : 0; } @Override public final double getValue(int r, int c, int nCol) { if(r < l || r > u) return 0; - return super.getValue(r - l, c, nCol); + return (r - l) == c ? 1 : 0; } @Override @@ -88,15 +113,21 @@ public long getInMemorySize() { } public static long getInMemorySize(int numberColumns) { - // int * 3 + padding + softReference - return 12 + 4 + 8; + // 2 more ints, no padding. + return AIdentityDictionary.getInMemorySize(numberColumns) + 8; } @Override public double[] aggregateRows(Builtin fn, int nCol) { - double[] ret = new double[nRowCol]; - Arrays.fill(ret, l, u, fn.execute(1, 0)); - return ret; + double[] ret = new double[nRowCol + (withEmpty ? 1 : 0)]; + if(l + 1 == u) { + ret[l] = 1; + return ret; + } + else { + Arrays.fill(ret, l, u, fn.execute(1, 0)); + return ret; + } } @Override @@ -120,60 +151,73 @@ public DictType getDictType() { @Override public double[] sumAllRowsToDouble(int nrColumns) { - double[] ret = new double[nRowCol]; + double[] ret = new double[nRowCol + (withEmpty ? 1 : 0)]; Arrays.fill(ret, l, u, 1.0); return ret; } @Override public double[] sumAllRowsToDoubleWithDefault(double[] defaultTuple) { - double[] ret = new double[nRowCol]; - Arrays.fill(ret, l, u, 1.0); + double[] ret = new double[getNumberOfValues(defaultTuple.length) + 1]; + for(int i = l; i < u; i++) + ret[i] = 1; for(int i = 0; i < defaultTuple.length; i++) - ret[i] += defaultTuple[i]; + ret[ret.length - 1] += defaultTuple[i]; return ret; } @Override public double[] sumAllRowsToDoubleWithReference(double[] reference) { - double[] ret = new double[nRowCol]; - Arrays.fill(ret, l, u, 1.0); + final double[] ret = new double[getNumberOfValues(reference.length)]; + double refSum = 0; for(int i = 0; i < reference.length; i++) - ret[i] += reference[i] * nRowCol; + refSum += reference[i]; + for(int i = 0; i < l; i++) + ret[i] = refSum; + for(int i = l; i < u; i++) + ret[i] = 1 + refSum; + for(int i = u; i < ret.length; i++) + ret[i] = refSum; return ret; } @Override public double[] sumAllRowsToDoubleSq(int nrColumns) { - double[] ret = new double[nRowCol]; + double[] ret = new double[nRowCol + (withEmpty ? 1 : 0)]; Arrays.fill(ret, l, u, 1); return ret; } @Override public double[] productAllRowsToDouble(int nCol) { - return new double[nRowCol]; + double[] ret = new double[nRowCol + (withEmpty ? 1 : 0)]; + if(u - l - 1 == 0) + ret[l] = 1; + return ret; } @Override public double[] productAllRowsToDoubleWithDefault(double[] defaultTuple) { - return new double[nRowCol]; + int nVal = nRowCol + (withEmpty ? 1 : 0); + double[] ret = new double[nVal + 1]; + if(u - l - 1 == 0) + ret[l] = 1; + ret[nVal] = defaultTuple[0]; + for(int i = 1; i < defaultTuple.length; i++) + ret[nVal] *= defaultTuple[i]; + return ret; } @Override public void colSum(double[] c, int[] counts, IColIndex colIndexes) { - for(int i = 0; i < colIndexes.size(); i++) { - // very nice... - final int idx = colIndexes.get(i); - c[idx] = counts[i]; - } + for(int i = l; i < u; i++) + c[colIndexes.get(i - l)] = counts[i]; } @Override public double sum(int[] counts, int ncol) { - int end = withEmpty && u == ncol ? u - 1 : u; double s = 0.0; - for(int i = l; i < end; i++) + for(int i = l; i < u; i++) s += counts[i]; return s; } @@ -183,16 +227,6 @@ public double sumSq(int[] counts, int ncol) { return sum(counts, ncol); } - @Override - public IDictionary sliceOutColumnRange(int idxStart, int idxEnd, int previousNumberOfColumns) { - return getMBDict().sliceOutColumnRange(idxStart, idxEnd, previousNumberOfColumns); - } - - @Override - public boolean containsValue(double pattern) { - return pattern == 0.0 || pattern == 1.0; - } - @Override public long getNumberNonZeros(int[] counts, int nCol) { return (long) sum(counts, nCol); @@ -203,40 +237,6 @@ public int getNumberOfValues(int ncol) { return nRowCol + (withEmpty ? 1 : 0); } - @Override - public MatrixBlockDictionary getMBDict(int nCol) { - if(cache != null) { - MatrixBlockDictionary r = cache.get(); - if(r != null) - return r; - } - MatrixBlockDictionary ret = createMBDict(); - cache = new SoftReference<>(ret); - return ret; - } - - private MatrixBlockDictionary createMBDict() { - MatrixBlock identity = new MatrixBlock(nRowCol, u - l, true); - for(int i = l; i < u; i++) - identity.set(i, i - l, 1.0); - return new MatrixBlockDictionary(identity); - } - - @Override - public String getString(int colIndexes) { - return "IdentityMatrix of size: " + nRowCol; - } - - @Override - public String toString() { - return "IdentityMatrix of size: " + nRowCol; - } - - @Override - public IDictionary scaleTuples(int[] scaling, int nCol) { - return getMBDict().scaleTuples(scaling, nCol); - } - @Override public void write(DataOutput out) throws IOException { out.writeByte(DictionaryFactory.Type.IDENTITY_SLICE.ordinal()); @@ -261,47 +261,14 @@ public long getExactSizeOnDisk() { return 1 + 4 * 3 + 1; } - @Override - public IDictionary replace(double pattern, double replace, int nCol) { - if(containsValue(pattern)) - return getMBDict().replace(pattern, replace, nCol); - else - return this; - } - - @Override - public IDictionary replaceWithReference(double pattern, double replace, double[] reference) { - if(containsValueWithReference(pattern, reference)) - return getMBDict().replaceWithReference(pattern, replace, reference); - else - return this; - } - @Override public double getSparsity() { - return (double) (u - l) / ((u -l) * (nRowCol + (withEmpty ? 1 : 0))); - } - - @Override - public IDictionary binOpRight(BinaryOperator op, double[] v, IColIndex colIndexes) { - return getMBDict().binOpRight(op, v); - } - - @Override - public IDictionary preaggValuesFromDense(final int numVals, final IColIndex colIndexes, - final IColIndex aggregateColumns, final double[] b, final int cut) { - return getMBDict().preaggValuesFromDense(numVals, colIndexes, aggregateColumns, b, cut); - } - - @Override - public void addToEntryVectorized(double[] v, int f1, int f2, int f3, int f4, int f5, int f6, int f7, int f8, int t1, - int t2, int t3, int t4, int t5, int t6, int t7, int t8, int nCol) { - getMBDict().addToEntryVectorized(v, f1, f2, f3, f4, f5, f6, f7, f8, t1, t2, t3, t4, t5, t6, t7, t8, nCol); + return (double) (u - l) / ((u - l) * (nRowCol + (withEmpty ? 1 : 0))); } @Override public void addToEntry(final double[] v, final int fr, final int to, final int nCol, int rep) { - if(fr >= l && fr < u) + if(fr >= l && fr < u) v[to * nCol + fr - l] += rep; } @@ -309,7 +276,7 @@ public void addToEntry(final double[] v, final int fr, final int to, final int n public boolean equals(IDictionary o) { if(o instanceof IdentityDictionarySlice) { IdentityDictionarySlice os = ((IdentityDictionarySlice) o); - return os.nRowCol == nRowCol && os.l == l && os.u == u; + return os.nRowCol == nRowCol && os.l == l && os.u == u && withEmpty == os.withEmpty; } else if(o instanceof IdentityDictionary) return false; @@ -319,22 +286,25 @@ else if(o instanceof IdentityDictionary) @Override public MatrixBlockDictionary getMBDict() { - final int nCol = u - l; - MatrixBlock mb = new MatrixBlock(nRowCol + (withEmpty ? 1 : 0), nCol, true); - mb.allocateSparseRowsBlock(); + return getMBDict(nRowCol); + } - SparseBlock sb = mb.getSparseBlock(); - for(int i = l; i < u; i++) { - sb.append(i, i - l, 1); - } + @Override + public MatrixBlockDictionary createMBDict(int nCol) { + MatrixBlock identity = new MatrixBlock(nRowCol + (withEmpty ? 1 : 0), u - l, true); + for(int i = l; i < u; i++) + identity.set(i, i - l, 1.0); + return new MatrixBlockDictionary(identity); + } - mb.setNonZeros(nCol); - return new MatrixBlockDictionary(mb); + @Override + public String getString(int colIndexes) { + return toString(); } @Override - public void multiplyScalar(double v, double[] ret, int off, int dictIdx, IColIndex cols) { - getMBDict().multiplyScalar(v, ret, off, dictIdx, cols); + public String toString() { + return "IdentityMatrixSlice of size: " + nRowCol + " l " + l + " u " + u; } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java index 12a063ad2a8..57f3a80e03a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java @@ -25,7 +25,6 @@ import java.math.BigDecimal; import java.math.MathContext; import java.util.Arrays; -import java.util.HashSet; import java.util.Set; import org.apache.commons.lang3.NotImplementedException; @@ -36,11 +35,14 @@ import org.apache.sysds.runtime.compress.colgroup.indexes.SingleIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.TwoIndex; import org.apache.sysds.runtime.compress.utils.Util; +import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.DenseBlockFP64; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockCSR; import org.apache.sysds.runtime.data.SparseBlockFactory; import org.apache.sysds.runtime.data.SparseBlockMCSR; +import org.apache.sysds.runtime.data.SparseRow; +import org.apache.sysds.runtime.data.SparseRowScalar; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode; import org.apache.sysds.runtime.functionobjects.Divide; @@ -128,7 +130,7 @@ public double[] getValues() { if(_data.isInSparseFormat()) { LOG.warn("Inefficient call to getValues for a MatrixBlockDictionary because it was sparse"); throw new DMLCompressionException("Should not call this function"); - // _data.sparseToDense(); + } return _data.getDenseBlockValues(); } @@ -436,28 +438,45 @@ public IDictionary applyScalarOp(ScalarOperator op) { public IDictionary applyScalarOpAndAppend(ScalarOperator op, double v0, int nCol) { // guaranteed to be densifying since this only is called if op(0) is != 0 // set the entire output to v0. - final MatrixBlock ret = new MatrixBlock(_data.getNumRows() + 1, _data.getNumColumns(), v0); + final MatrixBlock ret = new MatrixBlock(_data.getNumRows() + 1, _data.getNumColumns(), false); + ret.allocateDenseBlock(); final double[] retV = ret.getDenseBlockValues(); if(_data.isInSparseFormat()) { final int nRow = _data.getNumRows(); final SparseBlock sb = _data.getSparseBlock(); + final double v0r = op.executeScalar(0.0d); for(int i = 0; i < nRow; i++) { - if(sb.isEmpty(i)) - continue; - - final int apos = sb.pos(i); - final int alen = sb.size(i) + apos; - final int[] aix = sb.indexes(i); - final double[] avals = sb.values(i); - for(int k = apos; k < alen; k++) - retV[i * nCol + aix[k]] = op.executeScalar(avals[i]); + final int off = i * nCol; + if(sb.isEmpty(i)) { + for(int j = 0; j < nCol; j++) + retV[off + j] = v0r; + } + else { + final int apos = sb.pos(i); + final int alen = sb.size(i) + apos; + final int[] aix = sb.indexes(i); + final double[] avals = sb.values(i); + int k = apos; + int j = 0; + for(; j < nCol && k < alen; j++) { + double v = aix[k] == j ? avals[k++] : 0; + retV[off + j] = op.executeScalar(v); + } + for(; j < nCol; j++) { + retV[off + j] = v0r; + } + } } + for(int i = nCol * nRow; i < retV.length; i++) + retV[i] = v0; } else { final double[] v = _data.getDenseBlockValues(); for(int i = 0; i < v.length; i++) retV[i] = op.executeScalar(v[i]); + for(int i = v.length; i < retV.length; i++) + retV[i] = v0; } ret.recomputeNonZeros(); @@ -472,14 +491,12 @@ public IDictionary applyUnaryOp(UnaryOperator op) { @Override public IDictionary applyUnaryOpAndAppend(UnaryOperator op, double v0, int nCol) { - // guaranteed to be densifying since this only is called if op(0) is != 0 - // set the entire output to v0. - final MatrixBlock ret = new MatrixBlock(_data.getNumRows() + 1, _data.getNumColumns(), v0); - final double[] retV = ret.getDenseBlockValues(); - - if(_data.isInSparseFormat()) { - final int nRow = _data.getNumRows(); + final int nRow = _data.getNumRows(); + final MatrixBlock ret = new MatrixBlock(nRow + 1, nCol, op.sparseSafe && _data.isInSparseFormat()); + if(op.sparseSafe && _data.isInSparseFormat()) { + ret.allocateSparseRowsBlock(); final SparseBlock sb = _data.getSparseBlock(); + final SparseBlock sbr = ret.getSparseBlock(); for(int i = 0; i < nRow; i++) { if(sb.isEmpty(i)) continue; @@ -489,16 +506,56 @@ public IDictionary applyUnaryOpAndAppend(UnaryOperator op, double v0, int nCol) final int[] aix = sb.indexes(i); final double[] avals = sb.values(i); for(int k = apos; k < alen; k++) - retV[i * nCol + aix[k]] = op.fn.execute(avals[i]); + sbr.append(i, aix[k], op.fn.execute(avals[k])); } + + for(int i = 0; i < nCol; i++) + sbr.append(nRow, i, v0); + } + else if(_data.isInSparseFormat()) { + ret.allocateDenseBlock(); + final double[] retV = ret.getDenseBlockValues(); + final SparseBlock sb = _data.getSparseBlock(); + double v0r = op.fn.execute(0); + for(int i = 0; i < nRow; i++) { + final int off = i * nCol; + if(sb.isEmpty(i)) { + for(int j = 0; j < nCol; j++) + retV[off + j] = v0r; + } + else { + + final int apos = sb.pos(i); + final int alen = sb.size(i) + apos; + final int[] aix = sb.indexes(i); + final double[] avals = sb.values(i); + int k = apos; + int j = 0; + for(; j < nCol && k < alen; j++) { + double v = aix[k] == j ? avals[k++] : 0; + retV[off + j] = op.fn.execute(v); + } + for(; j < nCol; j++) { + retV[off + j] = v0r; + } + } + + } + for(int i = nRow * nCol; i < retV.length; i++) + retV[i] = v0; } else { + ret.allocateDenseBlock(); + final double[] retV = ret.getDenseBlockValues(); final double[] v = _data.getDenseBlockValues(); for(int i = 0; i < v.length; i++) retV[i] = op.fn.execute(v[i]); + for(int i = nRow * nCol; i < retV.length; i++) + retV[i] = v0; } ret.recomputeNonZeros(); + ret.examSparsity(); return MatrixBlockDictionary.create(ret); } @@ -655,7 +712,7 @@ public IDictionary binOpLeftAndAppend(BinaryOperator op, double[] v, IColIndex c for(int i = 0; i < nRow; i++) { if(sb.isEmpty(i)) for(int j = 0; j < nCol; j++) - retV[off++] = op.fn.execute(v[j], 0); + retV[off++] = op.fn.execute(v[colIndexes.get(j)], 0); else { final int apos = sb.pos(i); final int alen = sb.size(i) + apos; @@ -664,10 +721,10 @@ public IDictionary binOpLeftAndAppend(BinaryOperator op, double[] v, IColIndex c int j = 0; for(int k = apos; j < nCol && k < alen; j++) { final double vx = aix[k] == j ? avals[k++] : 0; - retV[off++] = op.fn.execute(v[j], vx); + retV[off++] = op.fn.execute(v[colIndexes.get(j)], vx); } for(; j < nCol; j++) - retV[off++] = op.fn.execute(v[j], 0); + retV[off++] = op.fn.execute(v[colIndexes.get(j)], 0); } } } @@ -675,13 +732,13 @@ public IDictionary binOpLeftAndAppend(BinaryOperator op, double[] v, IColIndex c final double[] values = _data.getDenseBlockValues(); for(int i = 0; i < nRow; i++) { for(int j = 0; j < nCol; j++) { - retV[off] = op.fn.execute(v[j], values[off]); + retV[off] = op.fn.execute(v[colIndexes.get(j)], values[off]); off++; } } } for(int j = 0; j < nCol; j++) { - retV[off] = op.fn.execute(v[j], 0); + retV[off] = op.fn.execute(v[colIndexes.get(j)], 0); off++; } @@ -769,7 +826,7 @@ public IDictionary binOpRightAndAppend(BinaryOperator op, double[] v, IColIndex for(int i = 0; i < nRow; i++) { if(sb.isEmpty(i)) for(int j = 0; j < nCol; j++) - retV[off++] = op.fn.execute(0, v[j]); + retV[off++] = op.fn.execute(0, v[colIndexes.get(j)]); else { final int apos = sb.pos(i); final int alen = sb.size(i) + apos; @@ -778,10 +835,10 @@ public IDictionary binOpRightAndAppend(BinaryOperator op, double[] v, IColIndex int j = 0; for(int k = apos; j < nCol && k < alen; j++) { final double vx = aix[k] == j ? avals[k++] : 0; - retV[off++] = op.fn.execute(vx, v[j]); + retV[off++] = op.fn.execute(vx, v[colIndexes.get(j)]); } for(; j < nCol; j++) - retV[off++] = op.fn.execute(0, v[j]); + retV[off++] = op.fn.execute(0, v[colIndexes.get(j)]); } } } @@ -789,13 +846,13 @@ public IDictionary binOpRightAndAppend(BinaryOperator op, double[] v, IColIndex final double[] values = _data.getDenseBlockValues(); for(int i = 0; i < nRow; i++) { for(int j = 0; j < nCol; j++) { - retV[off] = op.fn.execute(values[off], v[j]); + retV[off] = op.fn.execute(values[off], v[colIndexes.get(j)]); off++; } } } for(int j = 0; j < nCol; j++) { - retV[off] = op.fn.execute(0, v[j]); + retV[off] = op.fn.execute(0, v[colIndexes.get(j)]); off++; } @@ -1088,17 +1145,31 @@ public double[] productAllRowsToDouble(int nCol) { } private final void productAllRowsToDouble(double[] ret, int nCol) { + final int nRow = _data.getNumRows(); + if(_data.isInSparseFormat()) { SparseBlock sb = _data.getSparseBlock(); - for(int i = 0; i < _data.getNumRows(); i++) { - if(!sb.isEmpty(i) && sb.size(i) == nCol) { + for(int i = 0; i < nRow; i++) { + if(!sb.isEmpty(i)) { // if not equal to nCol ... skip final int apos = sb.pos(i); final int alen = sb.size(i) + apos; + final int[] aix = sb.indexes(i); final double[] avals = sb.values(i); ret[i] = 1; - for(int j = apos; j < alen; j++) { + int pj = 0; + // many extra cases to handle NaN... + for(int j = apos; j < alen && !Double.isNaN(ret[i]); j++) { + if(aix[j] - pj >= 1) { + ret[i] = 0; + break; + } ret[i] *= avals[j]; + pj = aix[j]; + } + + if(!Double.isNaN(ret[i]) && sb.size(i) != nCol) { + ret[i] = 0; } } else @@ -1107,10 +1178,10 @@ private final void productAllRowsToDouble(double[] ret, int nCol) { } else { double[] values = _data.getDenseBlockValues(); - int off = 0; - for(int k = 0; k < _data.getNumRows(); k++) { + for(int k = 0; k < nRow; k++) { + int off = k * nCol; ret[k] = 1; - for(int j = 0; j < _data.getNumColumns(); j++) { + for(int j = 0; j < nCol && ret[k] != 0; j++) { // early abort on zero final double v = values[off++]; ret[k] *= v; } @@ -1120,11 +1191,12 @@ private final void productAllRowsToDouble(double[] ret, int nCol) { @Override public double[] productAllRowsToDoubleWithDefault(double[] defaultTuple) { - double[] ret = new double[_data.getNumRows() + 1]; + final int nRow = _data.getNumRows(); + double[] ret = new double[nRow + 1]; productAllRowsToDouble(ret, defaultTuple.length); - ret[_data.getNumRows()] = defaultTuple[0]; + ret[nRow] = defaultTuple[0]; for(int j = 1; j < defaultTuple.length; j++) - ret[_data.getNumRows()] *= defaultTuple[j]; + ret[nRow] *= defaultTuple[j]; return ret; } @@ -1234,7 +1306,7 @@ public void colSumSq(double[] c, int[] counts, IColIndex colIndexes) { final int[] aix = sb.indexes(i); final double[] avals = sb.values(i); for(int j = apos; j < alen; j++) { - c[colIndexes.get(aix[j])] += count * avals[j] * avals[j]; + c[colIndexes.get(aix[j])] += avals[j] * avals[j] * count; } } } @@ -1495,6 +1567,8 @@ public boolean containsValue(double pattern) { @Override public boolean containsValueWithReference(double pattern, double[] reference) { + if(Double.isNaN(pattern)) + return super.containsValueWithReference(pattern, reference); if(_data.isInSparseFormat()) { final SparseBlock sb = _data.getSparseBlock(); for(int i = 0; i < _data.getNumRows(); i++) { @@ -1580,7 +1654,7 @@ public int[] countNNZZeroColumns(int[] counts) { final int aix[] = sb.indexes(i); for(int j = apos; j < alen; j++) { - ret[aix[i]] += counts[i]; + ret[aix[j]] += counts[i]; } } } @@ -1790,12 +1864,12 @@ public IDictionary scaleTuples(int[] scaling, int nCol) { if(!sbThis.isEmpty(i)) { sbRet.set(i, sbThis.get(i), true); - final int count = scaling[i]; + final int sc = scaling[i]; final int apos = sbRet.pos(i); final int alen = sbRet.size(i) + apos; final double[] avals = sbRet.values(i); for(int j = apos; j < alen; j++) - avals[j] = count * avals[j]; + avals[j] = sc * avals[j]; } } retBlock.setNonZeros(_data.getNonZeros()); @@ -2059,9 +2133,8 @@ public IDictionary replace(double pattern, double replace, int nCol) { @Override public IDictionary replaceWithReference(double pattern, double replace, double[] reference) { - if(Util.eq(pattern, Double.NaN)) { + if(Util.eq(pattern, Double.NaN)) return replaceWithReferenceNan(replace, reference); - } final int nRow = _data.getNumRows(); final int nCol = _data.getNumColumns(); @@ -2108,27 +2181,19 @@ public IDictionary replaceWithReference(double pattern, double replace, double[] } private IDictionary replaceWithReferenceNan(double replace, double[] reference) { - + final Set colsWithNan = Dictionary.getColsWithNan(replace, reference); final int nRow = _data.getNumRows(); final int nCol = _data.getNumColumns(); + if(colsWithNan != null && colsWithNan.size() == nCol && replace == 0) + return null; + final MatrixBlock ret = new MatrixBlock(nRow, nCol, false); ret.allocateDenseBlock(); - - Set colsWithNan = null; - for(int i = 0; i < reference.length; i++) { - if(Util.eq(reference[i], Double.NaN)) { - if(colsWithNan == null) - colsWithNan = new HashSet<>(); - colsWithNan.add(i); - reference[i] = replace; - } - } + final double[] retV = ret.getDenseBlockValues(); if(colsWithNan == null) { - - final double[] retV = ret.getDenseBlockValues(); - int off = 0; if(_data.isInSparseFormat()) { + final DenseBlock db = ret.getDenseBlock(); final SparseBlock sb = _data.getSparseBlock(); for(int i = 0; i < nRow; i++) { if(sb.isEmpty(i)) @@ -2137,30 +2202,22 @@ private IDictionary replaceWithReferenceNan(double replace, double[] reference) final int apos = sb.pos(i); final int alen = sb.size(i) + apos; final double[] avals = sb.values(i); + final int[] aix = sb.indexes(i); int j = 0; + int off = db.pos(i); for(int k = apos; k < alen; k++) { final double v = avals[k]; - retV[off++] = Util.eq(Double.NaN, v) ? -reference[j] : v; + retV[off + aix[k]] = Util.eq(Double.NaN, v) ? replace - reference[j] : v; } } } else { final double[] values = _data.getDenseBlockValues(); - for(int i = 0; i < nRow; i++) { - for(int j = 0; j < nCol; j++) { - final double v = values[off]; - retV[off++] = Util.eq(Double.NaN, v) ? -reference[j] : v; - } - } + Dictionary.replaceWithReferenceNanDenseWithoutNanCols(replace, reference, nRow, nCol, retV, values); } - ret.recomputeNonZeros(); - ret.examSparsity(); - return MatrixBlockDictionary.create(ret); } else { - - final double[] retV = ret.getDenseBlockValues(); if(_data.isInSparseFormat()) { final SparseBlock sb = _data.getSparseBlock(); for(int i = 0; i < nRow; i++) { @@ -2170,10 +2227,10 @@ private IDictionary replaceWithReferenceNan(double replace, double[] reference) final int apos = sb.pos(i); final int alen = sb.size(i) + apos; final double[] avals = sb.values(i); - final int[] aidx = sb.indexes(i); + final int[] aix = sb.indexes(i); for(int k = apos; k < alen; k++) { - final int c = aidx[k]; - final int outIdx = off + aidx[k]; + final int c = aix[k]; + final int outIdx = off + aix[k]; final double v = avals[k]; if(colsWithNan.contains(c)) retV[outIdx] = 0; @@ -2185,27 +2242,16 @@ else if(Util.eq(v, Double.NaN)) } } else { - int off = 0; final double[] values = _data.getDenseBlockValues(); - for(int i = 0; i < nRow; i++) { - for(int j = 0; j < nCol; j++) { - final double v = values[off]; - if(colsWithNan.contains(j)) - retV[off++] = 0; - else if(Util.eq(v, Double.NaN)) - retV[off++] = replace - reference[j]; - else - retV[off++] = v; - } - } + Dictionary.replaceWithReferenceNanDenseWithNanCols(replace, reference, nRow, nCol, colsWithNan, values, + retV); } - - ret.recomputeNonZeros(); - ret.examSparsity(); - return MatrixBlockDictionary.create(ret); } + ret.recomputeNonZeros(); + ret.examSparsity(); + return MatrixBlockDictionary.create(ret); } @Override @@ -2266,9 +2312,18 @@ public void productWithReference(double[] ret, int[] counts, double[] reference, final MathContext cont = MathContext.DECIMAL128; final int nCol = _data.getNumColumns(); final int nRow = _data.getNumRows(); + + final double[] values; // force dense ... if this ever is a bottleneck i will be surprised - _data.sparseToDense(); - final double[] values = _data.getDenseBlockValues(); + if(_data.isInSparseFormat()) { + MatrixBlock tmp = new MatrixBlock(); + tmp.copy(_data); + tmp.sparseToDense(); + values = tmp.getDenseBlockValues(); + } + else + values = _data.getDenseBlockValues(); + BigDecimal tmp = BigDecimal.ONE; int off = 0; for(int i = 0; i < nRow; i++) { @@ -2278,6 +2333,10 @@ public void productWithReference(double[] ret, int[] counts, double[] reference, ret[0] = 0; return; } + else if(!Double.isFinite(v)) { + ret[0] = v; + return; + } tmp = tmp.multiply(new BigDecimal(v).pow(counts[i], cont), cont); } } @@ -2286,7 +2345,8 @@ public void productWithReference(double[] ret, int[] counts, double[] reference, if(Math.abs(tmp.doubleValue()) == 0) ret[0] = 0; else if(!Double.isInfinite(ret[0])) - ret[0] = new BigDecimal(ret[0]).multiply(tmp, MathContext.DECIMAL128).doubleValue(); + ret[0] = new BigDecimal(ret[0]).multiply(tmp, cont).doubleValue(); + } @Override @@ -2502,14 +2562,15 @@ public void TSMMToUpperTriangleSparseScaling(SparseBlock left, IColIndex rowsLef @Override public boolean equals(IDictionary o) { - if(o instanceof MatrixBlockDictionary) + if(o == null) + return false; + else if(o instanceof MatrixBlockDictionary) return _data.equals(((MatrixBlockDictionary) o)._data); - - else if(o instanceof IdentityDictionary) - return ((IdentityDictionary) o).equals(this); else if(o instanceof Dictionary) { double[] dVals = ((Dictionary) o)._values; if(_data.isEmpty()) { + if(_data.getNumRows() * _data.getNumColumns() != dVals.length) + return false; for(int i = 0; i < dVals.length; i++) { if(dVals[i] != 0) return false; @@ -2521,11 +2582,10 @@ else if(_data.isInSparseFormat()) final double[] dv = _data.getDenseBlockValues(); return Arrays.equals(dv, dVals); } - else if(o instanceof IdentityDictionary) { - return o.equals(this); - } - return false; + // fallback + return o.equals(this); + } @Override @@ -2551,30 +2611,25 @@ public IDictionary reorder(int[] reorder) { @Override public IDictionary append(double[] row) { - if(_data.isEmpty()) { - throw new NotImplementedException(); - } - else if(_data.isInSparseFormat()) { + if(_data.isInSparseFormat()) { final int nRow = _data.getNumRows(); - if(_data.getSparseBlock() instanceof SparseBlockMCSR) { - MatrixBlock mb = new MatrixBlock(_data.getNumRows() + 1, _data.getNumColumns(), true); - mb.allocateBlock(); - SparseBlock sb = mb.getSparseBlock(); - SparseBlockMCSR s = (SparseBlockMCSR) _data.getSparseBlock(); - - for(int i = 0; i < _data.getNumRows(); i++) - sb.set(i, s.get(i), false); - - for(int i = 0; i < row.length; i++) - sb.set(nRow, i, row[i]); - - mb.examSparsity(); - return new MatrixBlockDictionary(mb); - - } - else { - throw new NotImplementedException("Not implemented append for CSR"); - } + final int nCol = _data.getNumColumns(); + SparseRow sr = null; + for(int i = 0; i < row.length; i++) { + if(row[i] != 0) { + if(sr == null) + sr = new SparseRowScalar(i, row[i]); + else + sr = sr.append(i, row[i]); + } + } + MatrixBlock mb = new MatrixBlock(_data.getNumRows() + 1, _data.getNumColumns(), true); + mb.allocateBlock(); + SparseBlock sb = mb.getSparseBlock(); + mb.copy(0, nRow, 0, nCol, _data, false); + sb.set(nRow, sr, false); + mb.examSparsity(); + return new MatrixBlockDictionary(mb); } else { @@ -2588,4 +2643,85 @@ else if(_data.isInSparseFormat()) { return new MatrixBlockDictionary(mb); } } + + @Override + protected IDictionary rightMMPreAggSparseSelectedCols(int numVals, SparseBlock b, IColIndex thisCols, + IColIndex aggregateColumns) { + + final int thisColsSize = thisCols.size(); + final int aggColSize = aggregateColumns.size(); + final double[] ret = new double[numVals * aggColSize]; + + for(int h = 0; h < thisColsSize; h++) { + // chose row in right side matrix via column index of the dictionary + final int colIdx = thisCols.get(h); + if(b.isEmpty(colIdx)) + continue; + + // extract the row values on the right side. + final double[] sValues = b.values(colIdx); + final int[] sIndexes = b.indexes(colIdx); + final int sPos = b.pos(colIdx); + final int sEnd = b.size(colIdx) + sPos; + + for(int j = 0; j < numVals; j++) { // rows left + final int offOut = j * aggColSize; + final double v = getValue(j, h, thisColsSize); + sparseAddSelected(sPos, sEnd, aggColSize, aggregateColumns, sIndexes, sValues, ret, offOut, v); + } + + } + return Dictionary.create(ret); + } + + private void sparseAddSelected(int sPos, int sEnd, int aggColSize, IColIndex aggregateColumns, int[] sIndexes, + double[] sValues, double[] ret, int offOut, double v) { + + int retIdx = 0; + for(int i = sPos; i < sEnd; i++) { + // skip through the retIdx. + while(retIdx < aggColSize && aggregateColumns.get(retIdx) < sIndexes[i]) + retIdx++; + if(retIdx == aggColSize) + break; + ret[offOut + retIdx] += v * sValues[i]; + } + retIdx = 0; + } + + @Override + protected IDictionary rightMMPreAggSparseAllColsRight(int numVals, SparseBlock b, IColIndex thisCols, + int nColRight) { + final int thisColsSize = thisCols.size(); + final double[] ret = new double[numVals * nColRight]; + + for(int h = 0; h < thisColsSize; h++) { // common dim + // chose row in right side matrix via column index of the dictionary + final int colIdx = thisCols.get(h); + if(b.isEmpty(colIdx)) + continue; + + // extract the row values on the right side. + final double[] sValues = b.values(colIdx); + final int[] sIndexes = b.indexes(colIdx); + final int sPos = b.pos(colIdx); + final int sEnd = b.size(colIdx) + sPos; + + for(int i = 0; i < numVals; i++) { // rows left + final int offOut = i * nColRight; + final double v = getValue(i, h, thisColsSize); + SparseAdd(sPos, sEnd, ret, offOut, sIndexes, sValues, v); + } + } + return Dictionary.create(ret); + } + + private void SparseAdd(int sPos, int sEnd, double[] ret, int offOut, int[] sIdx, double[] sVals, double v) { + if(v != 0) { + for(int k = sPos; k < sEnd; k++) { // cols right with value + ret[offOut + sIdx[k]] += v * sVals[k]; + } + } + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java index a8924060c3d..f5c140e5227 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java @@ -22,21 +22,10 @@ import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; -import java.io.Serializable; -import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; -import org.apache.sysds.runtime.data.DenseBlock; -import org.apache.sysds.runtime.data.SparseBlock; -import org.apache.sysds.runtime.functionobjects.Builtin; -import org.apache.sysds.runtime.functionobjects.ValueFunction; -import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.io.IOUtilFunctions; -import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.matrix.operators.BinaryOperator; -import org.apache.sysds.runtime.matrix.operators.ScalarOperator; -import org.apache.sysds.runtime.matrix.operators.UnaryOperator; -public class PlaceHolderDict implements IDictionary, Serializable { +public class PlaceHolderDict extends ADictionary { private static final long serialVersionUID = 9176356558592L; @@ -50,18 +39,8 @@ public PlaceHolderDict(int nVal) { } @Override - public double[] getValues() { - throw new RuntimeException(errMessage); - } - - @Override - public double getValue(int i) { - throw new RuntimeException(errMessage); - } - - @Override - public double getValue(int r, int col, int nCol) { - throw new RuntimeException(errMessage); + public long getExactSizeOnDisk() { + return 1 + 4; } @Override @@ -70,105 +49,12 @@ public long getInMemorySize() { } @Override - public double aggregate(double init, Builtin fn) { - throw new RuntimeException(errMessage); - } - - @Override - public double aggregateWithReference(double init, Builtin fn, double[] reference, boolean def) { - throw new RuntimeException(errMessage); - } - - @Override - public double[] aggregateRows(Builtin fn, int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public double[] aggregateRowsWithDefault(Builtin fn, double[] defaultTuple) { - throw new RuntimeException(errMessage); - } - - @Override - public double[] aggregateRowsWithReference(Builtin fn, double[] reference) { - throw new RuntimeException(errMessage); - } - - @Override - public void aggregateCols(double[] c, Builtin fn, IColIndex colIndexes) { - throw new RuntimeException(errMessage); - } - - @Override - public void aggregateColsWithReference(double[] c, Builtin fn, IColIndex colIndexes, double[] reference, - boolean def) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary applyScalarOp(ScalarOperator op) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary applyScalarOpAndAppend(ScalarOperator op, double v0, int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary applyUnaryOp(UnaryOperator op) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary applyUnaryOpAndAppend(UnaryOperator op, double v0, int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary applyScalarOpWithReference(ScalarOperator op, double[] reference, double[] newReference) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary applyUnaryOpWithReference(UnaryOperator op, double[] reference, double[] newReference) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary binOpLeft(BinaryOperator op, double[] v, IColIndex colIndexes) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary binOpLeftAndAppend(BinaryOperator op, double[] v, IColIndex colIndexes) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary binOpLeftWithReference(BinaryOperator op, double[] v, IColIndex colIndexes, double[] reference, - double[] newReference) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary binOpRight(BinaryOperator op, double[] v, IColIndex colIndexes) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary binOpRightAndAppend(BinaryOperator op, double[] v, IColIndex colIndexes) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary binOpRight(BinaryOperator op, double[] v) { - throw new RuntimeException(errMessage); + public int getNumberOfValues(int nCol) { + return nVal; } @Override - public IDictionary binOpRightWithReference(BinaryOperator op, double[] v, IColIndex colIndexes, double[] reference, - double[] newReference) { + public MatrixBlockDictionary getMBDict() { throw new RuntimeException(errMessage); } @@ -185,375 +71,29 @@ public static PlaceHolderDict read(DataInput in) throws IOException { return new PlaceHolderDict(nVals); } - @Override - public long getExactSizeOnDisk() { - return 1 + 4; - } - - @Override - public DictType getDictType() { - throw new RuntimeException(errMessage); - } - - @Override - public int getNumberOfValues(int nCol) { - return nVal; - } - - @Override - public double[] sumAllRowsToDouble(int nrColumns) { - throw new RuntimeException(errMessage); - } - - @Override - public double[] sumAllRowsToDoubleWithDefault(double[] defaultTuple) { - throw new RuntimeException(errMessage); - } - - @Override - public double[] sumAllRowsToDoubleWithReference(double[] reference) { - throw new RuntimeException(errMessage); - } - - @Override - public double[] sumAllRowsToDoubleSq(int nrColumns) { - throw new RuntimeException(errMessage); - } - - @Override - public double[] sumAllRowsToDoubleSqWithDefault(double[] defaultTuple) { - throw new RuntimeException(errMessage); - } - - @Override - public double[] sumAllRowsToDoubleSqWithReference(double[] reference) { - throw new RuntimeException(errMessage); - } - - @Override - public double[] productAllRowsToDouble(int nrColumns) { - throw new RuntimeException(errMessage); - } - - @Override - public double[] productAllRowsToDoubleWithDefault(double[] defaultTuple) { - throw new RuntimeException(errMessage); - } - - @Override - public double[] productAllRowsToDoubleWithReference(double[] reference) { - throw new RuntimeException(errMessage); - } - - @Override - public void colSum(double[] c, int[] counts, IColIndex colIndexes) { - throw new RuntimeException(errMessage); - } - - @Override - public void colSumSq(double[] c, int[] counts, IColIndex colIndexes) { - throw new RuntimeException(errMessage); - } - - @Override - public void colSumSqWithReference(double[] c, int[] counts, IColIndex colIndexes, double[] reference) { - throw new RuntimeException(errMessage); - } - - @Override - public double sum(int[] counts, int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public double sumSq(int[] counts, int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public double sumSqWithReference(int[] counts, double[] reference) { - throw new RuntimeException(errMessage); - } - @Override public String getString(int colIndexes) { return ""; // get string empty } - @Override - public IDictionary sliceOutColumnRange(int idxStart, int idxEnd, int previousNumberOfColumns) { - throw new RuntimeException(errMessage); - } - - @Override - public boolean containsValue(double pattern) { - throw new RuntimeException(errMessage); - } - - @Override - public boolean containsValueWithReference(double pattern, double[] reference) { - throw new RuntimeException(errMessage); - } - @Override public long getNumberNonZeros(int[] counts, int nCol) { return -1; } - @Override - public int[] countNNZZeroColumns(int[] counts) { - throw new RuntimeException(errMessage); - } - - @Override - public long getNumberNonZerosWithReference(int[] counts, double[] reference, int nRows) { - throw new RuntimeException(errMessage); - } - - @Override - public void addToEntry(double[] v, int fr, int to, int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public void addToEntry(double[] v, int fr, int to, int nCol, int rep) { - throw new RuntimeException(errMessage); - } - - @Override - public void addToEntryVectorized(double[] v, int f1, int f2, int f3, int f4, int f5, int f6, int f7, int f8, int t1, - int t2, int t3, int t4, int t5, int t6, int t7, int t8, int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary subtractTuple(double[] tuple) { - throw new RuntimeException(errMessage); - } - - @Override - public MatrixBlockDictionary getMBDict(int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary scaleTuples(int[] scaling, int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary preaggValuesFromDense(int numVals, IColIndex colIndexes, IColIndex aggregateColumns, double[] b, - int cut) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary replace(double pattern, double replace, int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary replaceWithReference(double pattern, double replace, double[] reference) { - throw new RuntimeException(errMessage); - } - - @Override - public void product(double[] ret, int[] counts, int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public void productWithDefault(double[] ret, int[] counts, double[] def, int defCount) { - throw new RuntimeException(errMessage); - } - - @Override - public void productWithReference(double[] ret, int[] counts, double[] reference, int refCount) { - throw new RuntimeException(errMessage); - } - - @Override - public void colProduct(double[] res, int[] counts, IColIndex colIndexes) { - throw new RuntimeException(errMessage); - } - - @Override - public void colProductWithReference(double[] res, int[] counts, IColIndex colIndexes, double[] reference) { - throw new RuntimeException(errMessage); - } - - @Override - public CM_COV_Object centralMoment(ValueFunction fn, int[] counts, int nRows) { - throw new RuntimeException(errMessage); - } - - @Override - public CM_COV_Object centralMoment(CM_COV_Object ret, ValueFunction fn, int[] counts, int nRows) { - throw new RuntimeException(errMessage); - } - - @Override - public CM_COV_Object centralMomentWithDefault(ValueFunction fn, int[] counts, double def, int nRows) { - throw new RuntimeException(errMessage); - } - - @Override - public CM_COV_Object centralMomentWithDefault(CM_COV_Object ret, ValueFunction fn, int[] counts, double def, - int nRows) { - throw new RuntimeException(errMessage); - } - - @Override - public CM_COV_Object centralMomentWithReference(ValueFunction fn, int[] counts, double reference, int nRows) { - throw new RuntimeException(errMessage); - } - - @Override - public CM_COV_Object centralMomentWithReference(CM_COV_Object ret, ValueFunction fn, int[] counts, double reference, - int nRows) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary rexpandCols(int max, boolean ignore, boolean cast, int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary rexpandColsWithReference(int max, boolean ignore, boolean cast, int reference) { - throw new RuntimeException(errMessage); - } - - @Override - public double getSparsity() { - throw new RuntimeException(errMessage); - } - - @Override - public void multiplyScalar(double v, double[] ret, int off, int dictIdx, IColIndex cols) { - throw new RuntimeException(errMessage); - } - - @Override - public void TSMMWithScaling(int[] counts, IColIndex rows, IColIndex cols, MatrixBlock ret) { - throw new RuntimeException(errMessage); - } - - @Override - public void MMDict(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { - throw new RuntimeException(errMessage); - } - - @Override - public void MMDictDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { - throw new RuntimeException(errMessage); - } - - @Override - public void MMDictSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { - throw new RuntimeException(errMessage); - } - - @Override - public void TSMMToUpperTriangle(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { - throw new RuntimeException(errMessage); - } - - @Override - public void TSMMToUpperTriangleDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { - throw new RuntimeException(errMessage); - } - - @Override - public void TSMMToUpperTriangleSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, - MatrixBlock result) { - throw new RuntimeException(errMessage); - } - - @Override - public void TSMMToUpperTriangleScaling(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, int[] scale, - MatrixBlock result) { - throw new RuntimeException(errMessage); - } - - @Override - public void TSMMToUpperTriangleDenseScaling(double[] left, IColIndex rowsLeft, IColIndex colsRight, int[] scale, - MatrixBlock result) { - throw new RuntimeException(errMessage); - } - - @Override - public void TSMMToUpperTriangleSparseScaling(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, int[] scale, - MatrixBlock result) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary cbind(IDictionary that, int nCol) { - throw new RuntimeException(errMessage); - } - @Override public boolean equals(IDictionary o) { return o instanceof PlaceHolderDict; } - @Override - public final boolean equals(double[] v) { - return false; - } - - @Override - public IDictionary reorder(int[] reorder) { - throw new RuntimeException(errMessage); - } - @Override public IDictionary clone() { return new PlaceHolderDict(nVal); } @Override - public void MMDictScaling(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, - int[] scaling) { - throw new RuntimeException(errMessage); - } - - @Override - public void MMDictScalingDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, - int[] scaling) { - throw new RuntimeException(errMessage); - } - - @Override - public void MMDictScalingSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, - int[] scaling) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary rightMMPreAggSparse(int numVals, SparseBlock b, IColIndex thisCols, IColIndex aggregateColumns, - int nColRight) { - throw new RuntimeException(errMessage); - } - - @Override - public void putSparse(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) { - throw new RuntimeException(errMessage); - } - - @Override - public void putDense(DenseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary append(double[] row) { - throw new RuntimeException(errMessage); + public DictType getDictType() { + throw new RuntimeException("invalid to get dictionary type for PlaceHolderDict"); } - @Override - public double[] getRow(int i, int nCol) { - throw new RuntimeException(errMessage); - } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java index 6d5f2aa6e04..35a08b8d14b 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java @@ -23,16 +23,8 @@ import java.io.DataOutput; import java.io.IOException; -import org.apache.commons.lang3.NotImplementedException; -import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; -import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.Builtin; -import org.apache.sysds.runtime.functionobjects.ValueFunction; -import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.matrix.operators.BinaryOperator; -import org.apache.sysds.runtime.matrix.operators.ScalarOperator; -import org.apache.sysds.runtime.matrix.operators.UnaryOperator; import org.apache.sysds.utils.MemoryEstimates; /** @@ -40,23 +32,38 @@ * group. The primary reason for its introduction was to provide an entry point for specialization such as shared * dictionaries, which require additional information. */ -public class QDictionary extends ADictionary { +public class QDictionary extends ACachingMBDictionary { private static final long serialVersionUID = 2100501253343438897L; protected double _scale; protected byte[] _values; + protected int _nCol; - protected QDictionary(byte[] values, double scale) { + protected QDictionary(byte[] values, double scale, int nCol) { _values = values; _scale = scale; + _nCol = nCol; + } + + public static QDictionary create(byte[] values, double scale, int nCol, boolean check) { + if(scale == 0) + return null; + if(check) { + boolean containsOnlyZero = true; + for(int i = 0; i < values.length && containsOnlyZero; i++) { + if(values[i] != 0) + containsOnlyZero = false; + } + if(containsOnlyZero) + return null; + } + return new QDictionary(values, scale, nCol); } @Override public double[] getValues() { - if(_values == null) { - return new double[0]; - } + double[] res = new double[_values.length]; for(int i = 0; i < _values.length; i++) { res[i] = getValue(i); @@ -74,18 +81,6 @@ public final double getValue(int r, int c, int nCol) { return _values[r * nCol + c] * _scale; } - public byte getValueByte(int i) { - return _values[i]; - } - - public byte[] getValuesByte() { - return _values; - } - - public double getScale() { - return _scale; - } - @Override public long getInMemorySize() { // object + values array + double @@ -107,73 +102,13 @@ public double aggregate(double init, Builtin fn) { return ret; } - @Override - public double aggregateWithReference(double init, Builtin fn, double[] reference, boolean def) { - throw new NotImplementedException(); - } - - @Override - public double[] aggregateRows(Builtin fn, final int nCol) { - if(nCol == 1) - return getValues(); - final int nRows = _values.length / nCol; - double[] res = new double[nRows]; - for(int i = 0; i < nRows; i++) { - final int off = i * nCol; - res[i] = _values[off]; - for(int j = off + 1; j < off + nCol; j++) - res[i] = fn.execute(res[i], _values[j] * _scale); - } - return res; - } - - @Override - public double[] aggregateRowsWithDefault(Builtin fn, double[] defaultTuple) { - throw new NotImplementedException(); - } - - @Override - public double[] aggregateRowsWithReference(Builtin fn, double[] reference) { - throw new NotImplementedException(); - } - - @Override - public QDictionary applyScalarOp(ScalarOperator op) { - throw new NotImplementedException(); - } - - @Override - public IDictionary applyScalarOpAndAppend(ScalarOperator op, double v0, int nCol) { - throw new NotImplementedException(); - } - - @Override - public IDictionary applyUnaryOp(UnaryOperator op) { - throw new NotImplementedException(); - } - - @Override - public IDictionary applyUnaryOpAndAppend(UnaryOperator op, double v0, int nCol) { - throw new NotImplementedException(); - } - - @Override - public IDictionary applyScalarOpWithReference(ScalarOperator op, double[] reference, double[] newReference) { - throw new NotImplementedException(); - } - - @Override - public IDictionary applyUnaryOpWithReference(UnaryOperator op, double[] reference, double[] newReference) { - throw new NotImplementedException(); - } - private int size() { return _values.length; } @Override public QDictionary clone() { - return new QDictionary(_values.clone(), _scale); + return new QDictionary(_values.clone(), _scale, _nCol); } @Override @@ -183,6 +118,7 @@ public void write(DataOutput out) throws IOException { out.writeInt(_values.length); for(int i = 0; i < _values.length; i++) out.writeByte(_values[i]); + out.writeInt(_nCol); } public static QDictionary read(DataInput in) throws IOException { @@ -192,17 +128,18 @@ public static QDictionary read(DataInput in) throws IOException { for(int i = 0; i < numVals; i++) { values[i] = in.readByte(); } - return new QDictionary(values, scale); + int nCol = in.readInt(); + return new QDictionary(values, scale, nCol); } @Override public long getExactSizeOnDisk() { - return 1 + 8 + 4 + size(); + return 1 + 8 + 4 + size() + 4; } @Override public int getNumberOfValues(int nCol) { - return (_values == null) ? 0 : _values.length / nCol; + return _values.length / nCol; } @Override @@ -218,16 +155,6 @@ public double[] sumAllRowsToDouble(int nrColumns) { return ret; } - @Override - public double[] sumAllRowsToDoubleWithDefault(double[] defaultTuple) { - throw new NotImplementedException(); - } - - @Override - public double[] sumAllRowsToDoubleWithReference(double[] reference) { - throw new NotImplementedException(); - } - @Override public double[] sumAllRowsToDoubleSq(int nrColumns) { final int numVals = getNumberOfValues(nrColumns); @@ -237,36 +164,8 @@ public double[] sumAllRowsToDoubleSq(int nrColumns) { return ret; } - @Override - public double[] sumAllRowsToDoubleSqWithDefault(double[] defaultTuple) { - throw new NotImplementedException(); - } - - @Override - public double[] sumAllRowsToDoubleSqWithReference(double[] reference) { - throw new NotImplementedException(); - } - - @Override - public double[] productAllRowsToDouble(int nCol) { - throw new NotImplementedException(); - } - - @Override - public double[] productAllRowsToDoubleWithDefault(double[] defaultTuple) { - throw new NotImplementedException(); - } - - @Override - public double[] productAllRowsToDoubleWithReference(double[] reference) { - throw new NotImplementedException(); - } - private double sumRow(int k, int nrColumns) { - if(_values == null) - return 0; int valOff = k * nrColumns; - int res = 0; for(int i = 0; i < nrColumns; i++) { res += _values[valOff + i]; @@ -275,8 +174,6 @@ private double sumRow(int k, int nrColumns) { } private double sumRowSq(int k, int nrColumns) { - if(_values == null) - return 0; int valOff = k * nrColumns; double res = 0.0; for(int i = 0; i < nrColumns; i++) @@ -284,46 +181,6 @@ private double sumRowSq(int k, int nrColumns) { return res; } - @Override - public void colSum(double[] c, int[] counts, IColIndex colIndexes) { - throw new NotImplementedException("Not Implemented"); - } - - @Override - public void colSumSq(double[] c, int[] counts, IColIndex colIndexes) { - throw new NotImplementedException("Not Implemented"); - } - - @Override - public void colProduct(double[] res, int[] counts, IColIndex colIndexes) { - throw new NotImplementedException("Not Implemented"); - } - - @Override - public void colProductWithReference(double[] res, int[] counts, IColIndex colIndexes, double[] reference) { - throw new NotImplementedException("Not Implemented"); - } - - @Override - public void colSumSqWithReference(double[] c, int[] counts, IColIndex colIndexes, double[] reference) { - throw new NotImplementedException(); - } - - @Override - public double sum(int[] counts, int ncol) { - throw new NotImplementedException("Not Implemented"); - } - - @Override - public double sumSq(int[] counts, int ncol) { - throw new NotImplementedException("Not Implemented"); - } - - @Override - public double sumSqWithReference(int[] counts, double[] reference) { - throw new NotImplementedException("Not Implemented"); - } - public String getString(int colIndexes) { StringBuilder sb = new StringBuilder(); for(int i = 0; i < size(); i++) { @@ -333,11 +190,6 @@ public String getString(int colIndexes) { return sb.toString(); } - public Dictionary makeDoubleDictionary() { - double[] doubleValues = getValues(); - return Dictionary.create(doubleValues); - } - public IDictionary sliceOutColumnRange(int idxStart, int idxEnd, int previousNumberOfColumns) { int numberTuples = getNumberOfValues(previousNumberOfColumns); int tupleLengthAfter = idxEnd - idxStart; @@ -350,19 +202,7 @@ public IDictionary sliceOutColumnRange(int idxStart, int idxEnd, int previousNum } orgOffset += previousNumberOfColumns - idxEnd + idxStart; } - return new QDictionary(newDictValues, _scale); - } - - @Override - public boolean containsValue(double pattern) { - if(Double.isNaN(pattern) || Double.isInfinite(pattern)) - return false; - throw new NotImplementedException("Not contains value on Q Dictionary"); - } - - @Override - public boolean containsValueWithReference(double pattern, double[] reference) { - throw new NotImplementedException(); + return new QDictionary(newDictValues, _scale, _nCol); } @Override @@ -397,259 +237,39 @@ public int[] countNNZZeroColumns(int[] counts) { return ret; } - @Override - public long getNumberNonZerosWithReference(int[] counts, double[] reference, int nRows) { - throw new NotImplementedException("not implemented yet"); - } - - @Override - public void addToEntry(double[] v, int fr, int to, int nCol) { - throw new NotImplementedException("Not implemented yet"); - } - - @Override - public void addToEntry(double[] v, int fr, int to, int nCol, int rep) { - throw new NotImplementedException("Not implemented yet"); - } - - @Override - public void addToEntryVectorized(double[] v, int f1, int f2, int f3, int f4, int f5, int f6, int f7, int f8, int t1, - int t2, int t3, int t4, int t5, int t6, int t7, int t8, int nCol) { - throw new NotImplementedException("Not implemented yet"); - } - @Override public DictType getDictType() { return DictType.UInt8; } - @Override - public IDictionary subtractTuple(double[] tuple) { - throw new NotImplementedException(); - } - - @Override - public MatrixBlockDictionary getMBDict(int nCol) { - throw new NotImplementedException(); - } - - @Override - public void aggregateCols(double[] c, Builtin fn, IColIndex colIndexes) { - throw new NotImplementedException(); - } - - @Override - public void aggregateColsWithReference(double[] c, Builtin fn, IColIndex colIndexes, double[] reference, - boolean def) { - throw new NotImplementedException(); - } - - @Override - public IDictionary scaleTuples(int[] scaling, int nCol) { - throw new NotImplementedException(); - } - - @Override - public IDictionary preaggValuesFromDense(int numVals, IColIndex colIndexes, IColIndex aggregateColumns, double[] b, - int cut) { - throw new NotImplementedException(); - } - - @Override - public IDictionary replace(double pattern, double replace, int nCol) { - throw new NotImplementedException(); - } - - @Override - public IDictionary replaceWithReference(double pattern, double replace, double[] reference) { - throw new NotImplementedException(); - } - - @Override - public void product(double[] ret, int[] counts, int nCol) { - throw new NotImplementedException(); - } - - @Override - public void productWithDefault(double[] ret, int[] counts, double[] def, int defCount) { - throw new NotImplementedException(); - } - - @Override - public void productWithReference(double[] ret, int[] counts, double[] reference, int refCount) { - throw new NotImplementedException(); - } - - @Override - public IDictionary binOpLeft(BinaryOperator op, double[] v, IColIndex colIndexes) { - throw new NotImplementedException(); - } - - @Override - public IDictionary binOpLeftAndAppend(BinaryOperator op, double[] v, IColIndex colIndexes) { - throw new NotImplementedException(); - } - - @Override - public IDictionary binOpRight(BinaryOperator op, double[] v, IColIndex colIndexes) { - throw new NotImplementedException(); - } - - @Override - public IDictionary binOpRightAndAppend(BinaryOperator op, double[] v, IColIndex colIndexes) { - throw new NotImplementedException(); - } - - @Override - public IDictionary binOpRight(BinaryOperator op, double[] v) { - throw new NotImplementedException(); - } - - @Override - public IDictionary binOpLeftWithReference(BinaryOperator op, double[] v, IColIndex colIndexes, double[] reference, - double[] newReference) { - throw new NotImplementedException(); - } - - @Override - public IDictionary binOpRightWithReference(BinaryOperator op, double[] v, IColIndex colIndexes, double[] reference, - double[] newReference) { - throw new NotImplementedException(); - } - - @Override - public CM_COV_Object centralMoment(CM_COV_Object ret, ValueFunction fn, int[] counts, int nRows) { - throw new NotImplementedException(); - } - - @Override - public CM_COV_Object centralMomentWithDefault(CM_COV_Object ret, ValueFunction fn, int[] counts, double def, - int nRows) { - throw new NotImplementedException(); - } - - @Override - public CM_COV_Object centralMomentWithReference(CM_COV_Object ret, ValueFunction fn, int[] counts, double reference, - int nRows) { - throw new NotImplementedException(); - } - - @Override - public IDictionary rexpandCols(int max, boolean ignore, boolean cast, int nCol) { - throw new NotImplementedException(); - // byte[] newDictValues = new byte[_values.length * max]; - // for(int i = 0, offset = 0; i < _values.length; i++, offset += max) { - // int val = _values[i] - 1; - // newDictValues[offset + val] = 1; - // } - - // return new QDictionary(newDictValues, 1.0); - } - - @Override - public IDictionary rexpandColsWithReference(int max, boolean ignore, boolean cast, int reference) { - throw new NotImplementedException(); - } - @Override public double getSparsity() { - return 1; - } - - @Override - public void multiplyScalar(double v, double[] ret, int off, int dictIdx, IColIndex cols) { - throw new NotImplementedException(); - } - - @Override - public void TSMMWithScaling(int[] counts, IColIndex rows, IColIndex cols, MatrixBlock ret) { - throw new NotImplementedException(); - } - - @Override - public void MMDict(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { - throw new NotImplementedException(); - } - - @Override - public void MMDictDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { - throw new NotImplementedException(); - } - - @Override - public void MMDictSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { - throw new NotImplementedException(); - } - - @Override - public void TSMMToUpperTriangle(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { - throw new NotImplementedException(); - } - - @Override - public void TSMMToUpperTriangleDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { - throw new NotImplementedException(); - } - - @Override - public void TSMMToUpperTriangleSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, - MatrixBlock result) { - throw new NotImplementedException(); - } - - @Override - public void TSMMToUpperTriangleScaling(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, int[] scale, - MatrixBlock result) { - throw new NotImplementedException(); - } - - @Override - public void TSMMToUpperTriangleDenseScaling(double[] left, IColIndex rowsLeft, IColIndex colsRight, int[] scale, - MatrixBlock result) { - throw new NotImplementedException(); - } - - @Override - public void TSMMToUpperTriangleSparseScaling(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, int[] scale, - MatrixBlock result) { - throw new NotImplementedException(); + int nnz = 0; + for(int i = 0; i < _values.length; i++) { + nnz += _values[i] == 0 ? 0 : 1; + } + return (double) nnz / _values.length; } @Override public boolean equals(IDictionary o) { - throw new NotImplementedException(); - } - - @Override - public IDictionary cbind(IDictionary that, int nCol) { - throw new NotImplementedException(); - } - - @Override - public IDictionary reorder(int[] reorder) { - throw new NotImplementedException(); - } - - @Override - public void MMDictScaling(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, - int[] scaling) { - throw new NotImplementedException(); + return getMBDict().equals(o); } @Override - public void MMDictScalingDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, - int[] scaling) { - throw new NotImplementedException(); + public MatrixBlockDictionary getMBDict() { + return getMBDict(_nCol); } @Override - public void MMDictScalingSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, - int[] scaling) { - throw new NotImplementedException(); + public MatrixBlockDictionary createMBDict(int nCol) { + MatrixBlock mb = new MatrixBlock(_values.length / nCol, nCol, false); + mb.allocateDenseBlock(); + double[] dbv = mb.getDenseBlockValues(); + for(int i = 0; i < _values.length; i++) + dbv[i] = _values[i] * _scale; + mb.recomputeNonZeros(); + return new MatrixBlockDictionary(mb); } - @Override - public IDictionary append(double[] row) { - throw new NotImplementedException(); - } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java index 5315d1a3b53..f16c88080a3 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java @@ -220,6 +220,7 @@ private static CompressedMatrixBlock binaryMVRow(CompressedMatrixBlock m1, doubl binaryMVRowMultiThread(oldColGroups, v, op, left, newColGroups, isRowSafe, k); ret.allocateColGroupList(newColGroups); + ret.setOverlapping(m1.isOverlapping()); ret.examSparsity(op.getNumThreads()); return ret; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java index 8cfe4639e45..99693635a9b 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java @@ -551,16 +551,8 @@ private static void aggRow(AggregateUnaryOperator op, List groups, do } private static void fillStart(MatrixBlock in, MatrixBlock ret, AggregateUnaryOperator op) { - final ValueFunction fn = op.aggOp.increOp.fn; - if(fn instanceof Builtin) { - ret.getDenseBlock().set(op.aggOp.initialValue); - } - else if(fn instanceof Multiply && op.indexFn instanceof ReduceAll) { - long nnz = in.getNonZeros(); - long nc = (long) in.getNumRows() * in.getNumColumns(); - boolean containsZero = nnz != nc; - ret.getDenseBlock().set(0, 0, containsZero ? 0 : 1); - } + if(op.aggOp.initialValue != 0) + ret.reset(ret.getNumRows(), ret.getNumColumns(), op.aggOp.initialValue); } protected static MatrixBlock genTmpReduceAllOrRow(MatrixBlock ret, AggregateUnaryOperator op) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java index e754ee6b1e3..3dc3c9a04f4 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java @@ -295,7 +295,6 @@ private static void decompressSparseSingleThread(MatrixBlock ret, List filteredGroups, int rlen, int blklen, double[] constV, double eps, boolean overlapping) { - final DenseBlock db = ret.getDenseBlock(); final int nCol = ret.getNumColumns(); for(int i = 0; i < rlen; i += blklen) { @@ -303,7 +302,7 @@ private static void decompressDenseSingleThread(MatrixBlock ret, List final int ru = Math.min(i + blklen, rlen); for(AColGroup grp : filteredGroups) grp.decompressToDenseBlock(db, rl, ru); - if(constV != null && !ret.isInSparseFormat()) + if(constV != null) addVector(db, nCol, constV, eps, rl, ru); } } @@ -389,9 +388,9 @@ private static double getEps(double[] constV) { double max = -Double.MAX_VALUE; double min = Double.MAX_VALUE; for(double v : constV) { - if(v > max) + if(v > max && Double.isFinite(v)) max = v; - if(v < min) + if(v < min && Double.isFinite(v)) min = v; } final double eps = (max + 1e-4 - min) * 1e-10; diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java index 92470886281..39594b03376 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java @@ -33,6 +33,7 @@ import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.APreAgg; +import org.apache.sysds.runtime.compress.colgroup.dictionary.AIdentityDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; @@ -776,7 +777,7 @@ private static void LMMWithPreAggDense(final List preAggCGs, final Matr // Multiply out the PreAggregate to the output matrix. for(int j = gl, p = 0; j < gu; j += skip, p++) { final APreAgg cg = preAggCGs.get(j); - if(cg.getDictionary() instanceof IdentityDictionary) + if(cg.getDictionary() instanceof AIdentityDictionary) continue; allocateOrResetTmpRes(ret, tmpRes, rowBlockSize); @@ -790,7 +791,7 @@ private static void LMMWithPreAggDense(final List preAggCGs, final Matr private static void preAllocate(List preAggCGs, int j, int rut, int rlt, double[][] preAgg, int p) { final APreAgg cg = preAggCGs.get(j); - if(cg.getDictionary() instanceof IdentityDictionary) + if(cg.getDictionary() instanceof AIdentityDictionary) return; final int preAggNCol = cg.getPreAggregateSize(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRexpand.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRexpand.java index 10c4ee1ab36..34f22441112 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRexpand.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRexpand.java @@ -31,7 +31,7 @@ import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC; -import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; @@ -118,7 +118,7 @@ private static CompressedMatrixBlock createCompressedReturn(int[] map, int nColO boolean containsNull, int k) throws Exception { // create a single DDC Column group. final IColIndex i = ColIndexFactory.create(0, nColOut); - final ADictionary d = new IdentityDictionary(nColOut, containsNull); + final IDictionary d = IdentityDictionary.create(nColOut, containsNull); final AMapToData m = MapToFactory.create(seqHeight, map, nColOut + (containsNull ? 1 : 0), k); final AColGroup g = ColGroupDDC.create(i, d, m, null); diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java index b397825d5af..fc910106145 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java @@ -148,7 +148,7 @@ protected static double[] filterGroupsAndSplitPreAggOneConst(List gro */ protected static boolean shouldPreFilterMorphOrRef(List groups) { for(AColGroup g : groups) - if(g instanceof AMorphingMMColGroup || g instanceof IFrameOfReferenceGroup) + if(g instanceof AMorphingMMColGroup || g instanceof IFrameOfReferenceGroup || g instanceof ColGroupConst) return true; return false; } diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index f88f737e6da..c3903cef006 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -41,6 +41,7 @@ import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressedArray; import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; @@ -248,7 +249,7 @@ private AColGroup recodeToDummy(ColumnEncoderComposite c) throws Exception { return ColGroupConst.create(colIndexes, new double[] {1}); } - ADictionary d = new IdentityDictionary(colIndexes.size(), containsNull); + IDictionary d = IdentityDictionary.create(colIndexes.size(), containsNull); AMapToData m = createMappingAMapToData(a, map, containsNull); AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null); nnz.addAndGet(ret.getNumberNonZeros(in.getNumRows())); @@ -377,7 +378,7 @@ private AColGroup binToDummy(ColumnEncoderComposite c) throws InterruptedExcepti b.build(in); // build first since we figure out if it contains null here. final boolean containsNull = b.containsNull; IColIndex colIndexes = ColIndexFactory.create(0, b._numBin); - ADictionary d = new IdentityDictionary(colIndexes.size(), containsNull); + IDictionary d = IdentityDictionary.create(colIndexes.size(), containsNull); final AMapToData m; m = binEncode(a, b, containsNull); AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null); @@ -653,7 +654,7 @@ private AColGroup hashToDummy(ColumnEncoderComposite c) { nnz.addAndGet(in.getNumRows()); return ColGroupConst.create(colIndexes, new double[] {1}); } - ADictionary d = new IdentityDictionary(colIndexes.size(), nulls); + IDictionary d = IdentityDictionary.create(colIndexes.size(), nulls); AMapToData m = createHashMappingAMapToData(a, domain, nulls); AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null); nnz.addAndGet(ret.getNumberNonZeros(in.getNumRows())); diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java b/src/test/java/org/apache/sysds/test/TestUtils.java index 26dff3a12bb..aa869a29e35 100644 --- a/src/test/java/org/apache/sysds/test/TestUtils.java +++ b/src/test/java/org/apache/sysds/test/TestUtils.java @@ -810,6 +810,10 @@ public static void compareMatrices(double[] expectedMatrix, double[] actualMatri new double[][]{actualMatrix}, 1, expectedMatrix.length, epsilon); } + public static void compareMatrices(double[] expectedMatrix, double[] actualMatrix, double epsilon, String message) { + compareMatrices(new double[][]{expectedMatrix}, + new double[][]{actualMatrix}, 1, expectedMatrix.length, epsilon, message); + } public static void compareMatrices(double[][] expectedMatrix, double[][] actualMatrix, int rows, int cols, double epsilon) { @@ -904,8 +908,9 @@ else if(ac.containsNull()) { } } - for(int i = 0; i < rows; i++) { - for(int j = 0; j < cols; j++) { + + for(int j = 0; j < cols; j++) { + for(int i = 0; i < rows; i++) { final Object a = expected.get(i, j); final Object b = actual.get(i, j); if(a == null) @@ -2132,6 +2137,17 @@ public static double[] generateTestVector(int cols, double min, double max, doub return vector; } + public static int[] generateTestIntVector(int cols, int min, int max, double sparsity, long seed) { + int[] vector = new int[cols]; + Random random = (seed == -1) ? TestUtils.random : new Random(seed); + for(int j = 0; j < cols; j++) { + if(random.nextDouble() > sparsity) + continue; + vector[j] = (random.nextInt(max - min) + min); + } + return vector; + } + /** * * Generates a test matrix with the specified parameters as a MatrixBlock. diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java index 139350432eb..de968a483f4 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java +++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java @@ -1163,7 +1163,7 @@ else if(OverLapping.effectOnOutput(overlappingType)) TestUtils.compareMatricesBitAvgDistance(expected, result, (long) (27000 * toleranceMultiplier), (long) (1024 * toleranceMultiplier), bufferedToString); - if(result.getNonZeros() < expected.getNonZeros()) + if(result.getNonZeros() != -1 && expected.getNonZeros() != -1 && result.getNonZeros() < expected.getNonZeros()) fail("Nonzero is to low guarantee at least equal or higher " + result.getNonZeros() + " vs " + expected.getNonZeros()); diff --git a/src/test/java/org/apache/sysds/test/component/compress/ExtendedMatrixTests.java b/src/test/java/org/apache/sysds/test/component/compress/ExtendedMatrixTests.java index 3176586065d..ed27dfb68e1 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/ExtendedMatrixTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/ExtendedMatrixTests.java @@ -200,6 +200,8 @@ public void testProd() { if(!(cmb instanceof CompressedMatrixBlock)) return; double ret1 = cmb.prod(); + LOG.error(ret1); + LOG.error(cmb); double ret2 = mb.prod(); boolean res; if(_cs != null && (_cs.lossy || overlappingType == OverLapping.SQUASH)) diff --git a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java index 6afb08335ae..e9f713610c1 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java @@ -730,7 +730,14 @@ public void UA_MIN_ROW() { @Test public void UA_PRODUCT_ROW() { - UA_ROW(InstructionUtils.parseBasicAggregateUnaryOperator("uar*", 1)); + try { + UA_ROW(InstructionUtils.parseBasicAggregateUnaryOperator("uar*", 1)); + } + catch(AssertionError e) { + LOG.error(base); + LOG.error(other); + throw e; + } } @Test @@ -1303,7 +1310,6 @@ public void denseSelection() { selection(mb, ret); } - @Test public void sparseSelectionEmptyRows() { MatrixBlock mb = CLALibSelectionMultTest.createSelectionMatrix(nRow, 50, true); @@ -1330,23 +1336,22 @@ public void selection(MatrixBlock selection, MatrixBlock ret) { MatrixBlock ret1 = new MatrixBlock(ret.getNumRows(), ret.getNumColumns(), ret.isInSparseFormat()); ret1.allocateBlock(); - MatrixBlock ret2 = new MatrixBlock(ret.getNumRows(), ret.getNumColumns(), ret.isInSparseFormat()); ret2.allocateBlock(); try { base.selectionMultiply(selection, points, ret1, 0, selection.getNumRows()); - other.selectionMultiply(selection, points, ret2, 0, selection.getNumRows()); + other.selectionMultiply(selection, points, ret2, 0, selection.getNumRows()); + + TestUtils.compareMatricesBitAvgDistance(ret1, ret2, 0, 0, + base.getClass().getSimpleName() + " vs " + other.getClass().getSimpleName()); - TestUtils.compareMatricesBitAvgDistance(ret1, ret2, 0, 0, base.getClass().getSimpleName() + " vs " + other.getClass().getSimpleName()); - - } catch(NotImplementedException e) { // okay } - catch(Exception e){ + catch(Exception e) { e.printStackTrace(); fail(e.getMessage()); } diff --git a/src/test/java/org/apache/sysds/test/component/compress/dictionary/CustomDictionaryTest.java b/src/test/java/org/apache/sysds/test/component/compress/dictionary/CustomDictionaryTest.java index bfacec50e16..edeadecbf10 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/dictionary/CustomDictionaryTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/dictionary/CustomDictionaryTest.java @@ -19,9 +19,13 @@ package org.apache.sysds.test.component.compress.dictionary; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import java.util.Arrays; @@ -35,7 +39,11 @@ import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionarySlice; import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.PlaceHolderDict; +import org.apache.sysds.runtime.compress.colgroup.dictionary.QDictionary; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.utils.DblArray; import org.apache.sysds.runtime.compress.utils.DblArrayCountHashMap; @@ -472,4 +480,172 @@ public void createDoubleCountHashMap() { assertEquals(Dictionary.create(new double[] {// 1, 2, 4, 6,}), d); } + + public void IdentityDictionaryEquals() { + IDictionary a = IdentityDictionary.create(10); + IDictionary b = IdentityDictionary.create(10); + assertTrue(a.equals(b)); + } + + @Test + public void IdentityDictionaryNotEquals() { + IDictionary a = IdentityDictionary.create(10); + IDictionary b = IdentityDictionary.create(11); + assertFalse(a.equals(b)); + } + + @Test + public void IdentityDictionaryNotEquals2() { + IDictionary a = IdentityDictionary.create(10); + IDictionary b = IdentityDictionary.create(11, false); + assertFalse(a.equals(b)); + } + + @Test + public void IdentityDictionaryEquals2() { + IDictionary a = IdentityDictionary.create(11, false); + IDictionary b = IdentityDictionary.create(11, false); + assertTrue(a.equals(b)); + } + + @Test + public void IdentityDictionaryEquals2v() { + IDictionary a = IdentityDictionary.create(11); + IDictionary b = IdentityDictionary.create(11, false); + assertTrue(a.equals(b)); + } + + @Test + public void IdentityDictionaryNotEquals3() { + IDictionary a = IdentityDictionary.create(11, true); + IDictionary b = IdentityDictionary.create(11, false); + assertFalse(a.equals(b)); + } + + @Test(expected = Exception.class) + public void invalidIdentity() { + IdentityDictionary.create(-1); + } + + @Test(expected = Exception.class) + public void invalidIdentity2() { + IdentityDictionary.create(-1, true); + } + + @Test + public void withEmpty() { + assertTrue(((IdentityDictionary) IdentityDictionary.create(10, true)).withEmpty()); + assertFalse(((IdentityDictionary) IdentityDictionary.create(10, false)).withEmpty()); + } + + @Test + public void memorySizeIdentitySameAtDifferentSizes() { + assertTrue(IdentityDictionary.create(10, true).getInMemorySize()// + == IdentityDictionary.create(1000, true).getInMemorySize()); + } + + @Test + public void replaceNan() { + IDictionary a = Dictionary.create(new double[] {1, 2, Double.NaN, Double.NaN}); + IDictionary b = ((Dictionary) a).getMBDict(2); + a = a.replace(Double.NaN, -1, 2); + b = b.replace(Double.NaN, -1, 2); + DictionaryTests.compare(a, b, 2, 2); + } + + @Test + public void replaceNanWithReference() { + IDictionary a = Dictionary.create(new double[] {1, 2, Double.NaN, Double.NaN}); + IDictionary b = ((Dictionary) a).getMBDict(2); + double[] ref1 = new double[] {3, Double.NaN}; + a = a.replaceWithReference(Double.NaN, -1, ref1); + double[] ref2 = new double[] {3, Double.NaN}; + b = b.replaceWithReference(Double.NaN, -1, ref2); + DictionaryTests.compare(a, b, 2, 2); + assertArrayEquals(ref1, ref2, 0.01); + } + + @Test + public void replaceNanWithReference2() { + IDictionary a = Dictionary.create(new double[] {1, 2, Double.NaN, Double.NaN}); + IDictionary b = ((Dictionary) a).getMBDict(2); + double[] ref1 = new double[] {3, 52}; + a = a.replaceWithReference(Double.NaN, -1, ref1); + double[] ref2 = new double[] {3, 52}; + b = b.replaceWithReference(Double.NaN, -1, ref2); + DictionaryTests.compare(a, b, 2, 2); + double[] ref3 = new double[] {3, 52}; + IDictionary ab = a.replaceWithReference(Double.NaN, -1, ref3); + DictionaryTests.compare(a, ab, 2, 2); + assertArrayEquals(ref1, ref2, 0.01); + assertArrayEquals(ref1, ref3, 0.01); + } + + @Test + public void equalsNot() { + IDictionary a = Dictionary.create(new double[] {1}); + assertFalse(a.equals(new PlaceHolderDict(1))); + } + + @Test + public void equalsNotEmptyDict() { + IDictionary a = Dictionary.create(new double[] {1}); + IDictionary b = MatrixBlockDictionary.create(new MatrixBlock(1, 1, 0.0), false); + assertFalse(a.equals(b)); + } + + @Test + public void equalsNotEmptyDictDifferentSize() { + IDictionary a = Dictionary.createNoCheck(new double[] {0}); + IDictionary b = MatrixBlockDictionary.create(new MatrixBlock(100, 100, 0.0), false); + assertFalse(a.equals(b)); + } + + @Test + public void zeroScale() { + assertNull(QDictionary.create(null, 0, 10, true)); + assertNull(QDictionary.create(new byte[] {1, 2, 3}, 0, 10, true)); + assertNull(QDictionary.create(new byte[] {0, 0, 0}, 2.3, 1, true)); + assertNotNull(QDictionary.create(new byte[] {0, 0, 0}, 2.3, 1, false)); + } + + @Test + public void notEqualsSlice() { + + assertNotEquals(// + IdentityDictionarySlice.create(10, true, 1, 4), // + IdentityDictionarySlice.create(10, true, 2, 4)); + + assertNotEquals(// + IdentityDictionarySlice.create(10, true, 1, 4), // + IdentityDictionarySlice.create(10, true, 1, 5)); + + assertNotEquals(// + IdentityDictionarySlice.create(10, true, 1, 4), // + IdentityDictionarySlice.create(10, false, 1, 4)); + + assertNotEquals(// + IdentityDictionarySlice.create(10, true, 1, 4), // + IdentityDictionarySlice.create(9, true, 1, 4)); + + assertNotEquals(// + IdentityDictionarySlice.create(10, true, 1, 4), // + IdentityDictionarySlice.create(9, true, 0, 9)); + } + + @Test + public void createDictionary() { + assertTrue(IdentityDictionarySlice.create(1, true, 0, 1) instanceof Dictionary); + assertTrue(IdentityDictionarySlice.create(1, false, 0, 1) instanceof Dictionary); + assertThrows(RuntimeException.class, () -> IdentityDictionarySlice.create(1, true, 1, 0)); + assertThrows(RuntimeException.class, () -> IdentityDictionarySlice.create(1, true, 0, 0)); + assertThrows(RuntimeException.class, () -> IdentityDictionarySlice.create(1, true, -13, 0)); + assertThrows(RuntimeException.class, () -> IdentityDictionarySlice.create(10, true, 4, 11)); + } + + + @Test + public void notEqualsObject(){ + assertNotEquals(Dictionary.create(new double[]{1.1,2.2,3.3}), new Object()); + } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java b/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java index 71a04832ed8..8ef384e45ae 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java @@ -19,7 +19,9 @@ package org.apache.sysds.test.component.compress.dictionary; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -33,31 +35,42 @@ import java.util.Arrays; import java.util.Collection; import java.util.List; +import java.util.Random; +import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.DMLCompressionException; +import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.QDictionary; +import org.apache.sysds.runtime.compress.colgroup.indexes.ArrayIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.colgroup.indexes.RangeIndex; +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.data.SparseBlockFactory; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode; import org.apache.sysds.runtime.functionobjects.Divide; import org.apache.sysds.runtime.functionobjects.Minus; +import org.apache.sysds.runtime.functionobjects.Multiply; +import org.apache.sysds.runtime.functionobjects.Plus; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.RightScalarOperator; +import org.apache.sysds.runtime.matrix.operators.ScalarOperator; +import org.apache.sysds.runtime.matrix.operators.UnaryOperator; import org.apache.sysds.test.TestUtils; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameters; -import scala.util.Random; - @RunWith(value = Parameterized.class) public class DictionaryTests { @@ -67,12 +80,18 @@ public class DictionaryTests { private final int nCol; private final IDictionary a; private final IDictionary b; + private final double[] ref; public DictionaryTests(IDictionary a, IDictionary b, int nRow, int nCol) { this.nRow = nRow; this.nCol = nCol; this.a = a; this.b = b; + + ref = new double[nCol]; + for(int i = 0; i < ref.length; i++) { + ref[i] = 0.232415 * i; + } } @Parameters @@ -85,16 +104,25 @@ public static Collection data() { addAll(tests, new double[] {1, 2, 3, 4, 5}, 1); addAll(tests, new double[] {1, 2, 3, 4, 5, 6}, 2); addAll(tests, new double[] {1, 2.2, 3.3, 4.4, 5.5, 6.6}, 3); + addAll(tests, new double[] {0, 0, 1, 1, 0, 0}, 2); + + addQDict(tests, new byte[] {2, 4, 6, 8}, 2.0, 1); + addQDict(tests, new byte[] {44, 44, 110, 12, 32, 14, 25, 2}, 2.0, 2); + addQDict(tests, new byte[] {44, 44, 0, 12, 32, 0, 25, 2}, 2.0, 2); - tests.add(new Object[] {new IdentityDictionary(2), Dictionary.create(new double[] {1, 0, 0, 1}), 2, 2}); - tests.add(new Object[] {new IdentityDictionary(2, true), // + addSparse(tests, -10, 10, 10, 100, 0.1, 321); + addSparse(tests, -10, 10, 2, 100, 0.04, 321); + addSparseWithNan(tests, 1, 10, 100, 100, 0.1, 321); + + tests.add(new Object[] {IdentityDictionary.create(2), Dictionary.create(new double[] {1, 0, 0, 1}), 2, 2}); + tests.add(new Object[] {IdentityDictionary.create(2, true), // Dictionary.create(new double[] {1, 0, 0, 1, 0, 0}), 3, 2}); - tests.add(new Object[] {new IdentityDictionary(3), // + tests.add(new Object[] {IdentityDictionary.create(3), // Dictionary.create(new double[] {1, 0, 0, 0, 1, 0, 0, 0, 1}), 3, 3}); - tests.add(new Object[] {new IdentityDictionary(3, true), // + tests.add(new Object[] {IdentityDictionary.create(3, true), // Dictionary.create(new double[] {1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0}), 4, 3}); - tests.add(new Object[] {new IdentityDictionary(4), // + tests.add(new Object[] {IdentityDictionary.create(4), // Dictionary.create(new double[] {// 1, 0, 0, 0, // 0, 1, 0, 0, // @@ -102,14 +130,34 @@ public static Collection data() { 0, 0, 0, 1,// }), 4, 4}); - tests.add(new Object[] {new IdentityDictionary(4).sliceOutColumnRange(1, 4, 4), // + tests.add(new Object[] {IdentityDictionary.create(4)// + .sliceOutColumnRange(1, 4, 4), // Dictionary.create(new double[] {// 0, 0, 0, // 1, 0, 0, // 0, 1, 0, // 0, 0, 1,// }), 4, 3}); - tests.add(new Object[] {new IdentityDictionary(4, true), // + + tests.add(new Object[] {IdentityDictionary.create(4)// + .sliceOutColumnRange(1, 3, 4), // + Dictionary.create(new double[] {// + 0, 0, // + 1, 0, // + 0, 1, // + 0, 0,// + }), 4, 2}); + + tests.add(new Object[] {IdentityDictionary.create(4)// + .sliceOutColumnRange(1, 2, 4), // + Dictionary.create(new double[] {// + 0, // + 1, // + 0, // + 0, // + }), 4, 1}); + + tests.add(new Object[] {IdentityDictionary.create(4, true), // Dictionary.create(new double[] {// 1, 0, 0, 0, // 0, 1, 0, 0, // @@ -118,7 +166,18 @@ public static Collection data() { 0, 0, 0, 0}), 5, 4}); - tests.add(new Object[] {new IdentityDictionary(4, true).sliceOutColumnRange(1, 4, 4), // + tests.add(new Object[] {IdentityDictionary.create(4, true)// + .sliceOutColumnRange(0, 2, 4), + Dictionary.create(new double[] {// + 1, 0, // + 0, 1, // + 0, 0, // + 0, 0, // + 0, 0}), + 5, 2}); + + tests.add(new Object[] {IdentityDictionary.create(4, true)// + .sliceOutColumnRange(1, 4, 4), // Dictionary.create(new double[] {// 0, 0, 0, // 1, 0, 0, // @@ -127,6 +186,92 @@ public static Collection data() { 0, 0, 0}), 5, 3}); + tests.add(new Object[] {IdentityDictionary.create(4, true)// + .sliceOutColumnRange(1, 2, 4), // + Dictionary.create(new double[] {// + 0, // + 1, // + 0, // + 0, // + 0,}), + 5, 1}); + + tests.add(new Object[] {IdentityDictionary.create(4, true)// + .sliceOutColumnRange(1, 3, 4), // + Dictionary.create(new double[] {// + 0, 0, // + 1, 0, // + 0, 1, // + 0, 0, // + 0, 0,}), + 5, 2}); + tests.add(new Object[] {IdentityDictionary.create(4, true), // + Dictionary.create(new double[] {// + 1, 0, 0, 0, // + 0, 1, 0, 0, // + 0, 0, 1, 0, // + 0, 0, 0, 1, // + 0, 0, 0, 0}).getMBDict(4), + 5, 4}); + + tests.add(new Object[] {IdentityDictionary.create(20, false), // + MatrixBlockDictionary.create(// + new MatrixBlock(20, 20, 20L, // + SparseBlockFactory.createIdentityMatrix(20)), + false), + 20, 20}); + + tests.add(new Object[] {IdentityDictionary.create(20, false), // + Dictionary.create(new double[] {// + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, // + }), // + 20, 20}); + + tests.add(new Object[] {IdentityDictionary.create(20, true), // + Dictionary.create(new double[] {// + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + }).getMBDict(20), // + 21, 20}); + create(tests, 30, 300, 0.2); } catch(Exception e) { @@ -164,6 +309,38 @@ private static void addAll(List tests, double[] vals, int cols) { vals.length / cols, cols}); } + private static void addQDict(List tests, byte[] is, double d, int i) { + ADictionary qd = QDictionary.create(is, d, i, true); + tests.add(new Object[] {qd, qd.getMBDict(i), is.length / i, i}); + } + + private static void addSparse(List tests, double min, double max, int rows, int cols, double sparsity, + int seed) { + + MatrixBlock mb = TestUtils.generateTestMatrixBlock(rows, cols, min, max, sparsity, seed); + + MatrixBlock mb2 = new MatrixBlock(); + mb2.copy(mb); + mb2.sparseToDense(); + double[] dbv = mb2.getDenseBlockValues(); + + tests.add(new Object[] {MatrixBlockDictionary.create(mb), Dictionary.create(dbv), rows, cols}); + } + + private static void addSparseWithNan(List tests, double min, double max, int rows, int cols, + double sparsity, int seed) { + + MatrixBlock mb = TestUtils.generateTestMatrixBlock(rows, cols, min, max, sparsity, seed); + + mb = TestUtils.floor(mb); + mb = mb.replaceOperations(null, min, Double.NaN); + MatrixBlock mb2 = new MatrixBlock(); + mb2.copy(mb); + mb2.sparseToDense(); + double[] dbv = mb2.getDenseBlockValues(); + tests.add(new Object[] {MatrixBlockDictionary.create(mb), Dictionary.create(dbv), rows, cols}); + } + @Test public void sum() { int[] counts = getCounts(nRow, 1324); @@ -172,6 +349,22 @@ public void sum() { assertEquals(as, bs, 0.0000001); } + @Test + public void sum2() { + int[] counts = getCounts(nRow, 124); + double as = a.sum(counts, nCol); + double bs = b.sum(counts, nCol); + assertEquals(as, bs, 0.0000001); + } + + @Test + public void sum3() { + int[] counts = getCounts(nRow, 124444); + double as = a.sum(counts, nCol); + double bs = b.sum(counts, nCol); + assertEquals(as, bs, 0.0000001); + } + @Test public void getValues() { try { @@ -250,19 +443,26 @@ public void productWithDoctoredReference2() { } public void productWithReference(double retV, double[] reference) { - // Shared - final int[] counts = getCounts(nRow, 1324); + try { - // A - final double[] aRet = new double[] {retV}; - a.productWithReference(aRet, counts, reference, nCol); + // Shared + final int[] counts = getCounts(nRow, 1324); - // B - final double[] bRet = new double[] {retV}; - b.productWithReference(bRet, counts, reference, nCol); + // A + final double[] aRet = new double[] {retV}; + a.productWithReference(aRet, counts, reference, nCol); - TestUtils.compareMatricesBitAvgDistance(// - aRet, bRet, 10, 10, "Not Equivalent values from product"); + // B + final double[] bRet = new double[] {retV}; + b.productWithReference(bRet, counts, reference, nCol); + + TestUtils.compareMatricesBitAvgDistance(// + aRet, bRet, 10, 10, "Not Equivalent values from product"); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } } @Test @@ -317,6 +517,119 @@ public void replaceWitReference() { assertNotEquals(before, aRep.getValue(r, c, nCol), 0.00001); } + @Test + public void replaceNaN() { + IDictionary ar = a.replace(Double.NaN, 0, nCol); + IDictionary br = b.replace(Double.NaN, 0, nCol); + compare(ar, br, nCol); + } + + @Test + public void replaceNaNWithRef() { + double[] ref1 = new double[nCol]; + IDictionary ar = a.replaceWithReference(Double.NaN, 1, ref1); + double[] ref2 = new double[nCol]; + IDictionary br = b.replaceWithReference(Double.NaN, 1, ref2); + compare(ar, br, nCol); + } + + @Test + public void replaceNaNWithRef12() { + double[] ref1 = new double[nCol]; + Arrays.fill(ref1, 1.2); + IDictionary ar = a.replaceWithReference(Double.NaN, 1, ref1); + double[] ref2 = new double[nCol]; + Arrays.fill(ref2, 1.2); + IDictionary br = b.replaceWithReference(Double.NaN, 1, ref2); + compare(ar, br, nCol); + } + + @Test + public void replaceNaNWithRefNaN() { + double[] ref1 = new double[nCol]; + ref1[0] = Double.NaN; + IDictionary ar = a.replaceWithReference(Double.NaN, 1, ref1); + double[] ref2 = new double[nCol]; + ref2[0] = Double.NaN; + IDictionary br = b.replaceWithReference(Double.NaN, 1, ref2); + compare(ar, br, nCol); + } + + @Test + public void replaceNaNWithRefNaN12() { + try { + + double[] ref1 = new double[nCol]; + Arrays.fill(ref1, 1.2); + ref1[0] = Double.NaN; + IDictionary ar = a.replaceWithReference(Double.NaN, 1, ref1); + double[] ref2 = new double[nCol]; + Arrays.fill(ref2, 1.2); + ref2[0] = Double.NaN; + IDictionary br = b.replaceWithReference(Double.NaN, 1, ref2); + compare(ar, br, nCol); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void replaceNaNWithRefNaN12ColGT2() { + try { + if(nCol > 2) { + double[] ref1 = new double[nCol]; + Arrays.fill(ref1, 1.2); + ref1[1] = Double.NaN; + IDictionary ar = a.replaceWithReference(Double.NaN, 1, ref1); + double[] ref2 = new double[nCol]; + Arrays.fill(ref2, 1.2); + ref2[1] = Double.NaN; + IDictionary br = b.replaceWithReference(Double.NaN, 1, ref2); + compare(ar, br, nCol); + } + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void replaceNaNWithRefNaNAllRefNaN() { + try { + double[] ref1 = new double[nCol]; + Arrays.fill(ref1, Double.NaN); + IDictionary ar = a.replaceWithReference(Double.NaN, 1, ref1); + double[] ref2 = new double[nCol]; + Arrays.fill(ref2, Double.NaN); + IDictionary br = b.replaceWithReference(Double.NaN, 1, ref2); + compare(ar, br, nCol); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void replaceNaNWithRefNaNAllRefNaNToZero() { + try { + double[] ref1 = new double[nCol]; + Arrays.fill(ref1, Double.NaN); + IDictionary ar = a.replaceWithReference(Double.NaN, 0, ref1); + double[] ref2 = new double[nCol]; + Arrays.fill(ref2, Double.NaN); + IDictionary br = b.replaceWithReference(Double.NaN, 0, ref2); + compare(ar, br, nCol); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + @Test public void rexpandCols() { if(nCol == 1) { @@ -381,14 +694,14 @@ public void rexpandColsWithReference(int reference) { @Test public void sumSq() { - try{ + try { int[] counts = getCounts(nRow, 2323); double as = a.sumSq(counts, nCol); double bs = b.sumSq(counts, nCol); assertEquals(as, bs, 0.0001); } - catch(Exception e){ + catch(Exception e) { e.printStackTrace(); fail(e.getMessage()); } @@ -478,11 +791,16 @@ public void equalsEl() { } } + @Test + public void equalsElOp() { + assertEquals(b, a); + } + @Test public void opRightMinus() { BinaryOperator op = new BinaryOperator(Minus.getMinusFnObject()); double[] vals = TestUtils.generateTestVector(nCol, -1, 1, 1.0, 132L); - opRight(op, vals, ColIndexFactory.create(0, nCol)); + binOp(op, vals, ColIndexFactory.create(0, nCol)); } @Test @@ -496,7 +814,7 @@ public void opRightMinusNoCol() { public void opRightMinusZero() { BinaryOperator op = new BinaryOperator(Minus.getMinusFnObject()); double[] vals = new double[nCol]; - opRight(op, vals, ColIndexFactory.create(0, nCol)); + binOp(op, vals, ColIndexFactory.create(0, nCol)); } @Test @@ -504,20 +822,52 @@ public void opRightDivOne() { BinaryOperator op = new BinaryOperator(Divide.getDivideFnObject()); double[] vals = new double[nCol]; Arrays.fill(vals, 1); - opRight(op, vals, ColIndexFactory.create(0, nCol)); + binOp(op, vals, ColIndexFactory.create(0, nCol)); } @Test public void opRightDiv() { BinaryOperator op = new BinaryOperator(Divide.getDivideFnObject()); double[] vals = TestUtils.generateTestVector(nCol, -1, 1, 1.0, 232L); - opRight(op, vals, ColIndexFactory.create(0, nCol)); + binOp(op, vals, ColIndexFactory.create(0, nCol)); } - private void opRight(BinaryOperator op, double[] vals, IColIndex cols) { - IDictionary aa = a.binOpRight(op, vals, cols); - IDictionary bb = b.binOpRight(op, vals, cols); - compare(aa, bb, nRow, nCol); + private void binOp(BinaryOperator op, double[] vals, IColIndex cols) { + try { + + IDictionary aa = a.binOpRight(op, vals, cols); + IDictionary bb = b.binOpRight(op, vals, cols); + compare(aa, bb, nRow, nCol); + + double[] ref = TestUtils.generateTestVector(nCol, 0, 10, 1.0, 33); + double[] newRef = TestUtils.generateTestVector(nCol, 0, 10, 1.0, 321); + aa = a.binOpRightWithReference(op, vals, cols, ref, newRef); + bb = b.binOpRightWithReference(op, vals, cols, ref, newRef); + compare(aa, bb, nRow, nCol); + + aa = a.binOpLeftWithReference(op, vals, cols, ref, newRef); + bb = b.binOpLeftWithReference(op, vals, cols, ref, newRef); + compare(aa, bb, nRow, nCol); + + aa = a.binOpLeft(op, vals, cols); + bb = b.binOpLeft(op, vals, cols); + compare(aa, bb, nRow, nCol); + + double[] app = TestUtils.generateTestVector(nCol, 0, 10, 1.0, 33); + + aa = a.binOpLeftAndAppend(op, app, cols); + bb = b.binOpLeftAndAppend(op, app, cols); + compare(aa, bb, nRow + 1, nCol); + + aa = a.binOpRightAndAppend(op, app, cols); + bb = b.binOpRightAndAppend(op, app, cols); + compare(aa, bb, nRow + 1, nCol); + + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } } private void opRight(BinaryOperator op, double[] vals) { @@ -565,6 +915,44 @@ public void testAddToEntry4() { } } + @Test + public void testAddToEntryRep1() { + double[] ret1 = new double[nCol]; + a.addToEntry(ret1, 0, 0, nCol, 16); + double[] ret2 = new double[nCol]; + b.addToEntry(ret2, 0, 0, nCol, 16); + assertTrue(Arrays.equals(ret1, ret2)); + } + + @Test + public void testAddToEntryRep2() { + double[] ret1 = new double[nCol * 2]; + a.addToEntry(ret1, 0, 1, nCol, 3214); + double[] ret2 = new double[nCol * 2]; + b.addToEntry(ret2, 0, 1, nCol, 3214); + assertTrue(Arrays.equals(ret1, ret2)); + } + + @Test + public void testAddToEntryRep3() { + double[] ret1 = new double[nCol * 3]; + a.addToEntry(ret1, 0, 2, nCol, 222); + double[] ret2 = new double[nCol * 3]; + b.addToEntry(ret2, 0, 2, nCol, 222); + assertTrue(Arrays.equals(ret1, ret2)); + } + + @Test + public void testAddToEntryRep4() { + if(a.getNumberOfValues(nCol) > 2) { + double[] ret1 = new double[nCol * 3]; + a.addToEntry(ret1, 2, 2, nCol, 321); + double[] ret2 = new double[nCol * 3]; + b.addToEntry(ret2, 2, 2, nCol, 321); + assertTrue(Arrays.equals(ret1, ret2)); + } + } + @Test public void testAddToEntryVectorized1() { try { @@ -580,6 +968,72 @@ public void testAddToEntryVectorized1() { } } + @Test + public void max() { + aggregate(Builtin.getBuiltinFnObject(BuiltinCode.MAX)); + } + + @Test + public void min() { + aggregate(Builtin.getBuiltinFnObject(BuiltinCode.MIN)); + } + + @Test(expected = NotImplementedException.class) + public void cMax() { + aggregate(Builtin.getBuiltinFnObject(BuiltinCode.CUMMAX)); + throw new NotImplementedException(); + } + + private void aggregate(Builtin fn) { + try { + double aa = a.aggregate(0, fn); + double bb = b.aggregate(0, fn); + assertEquals(aa, bb, 0.0); + } + catch(NotImplementedException ee) { + throw ee; + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void maxWithRef() { + aggregateWithRef(Builtin.getBuiltinFnObject(BuiltinCode.MAX)); + } + + @Test + public void minWithRef() { + aggregateWithRef(Builtin.getBuiltinFnObject(BuiltinCode.MIN)); + } + + @Test(expected = NotImplementedException.class) + public void cMaxWithRef() { + aggregateWithRef(Builtin.getBuiltinFnObject(BuiltinCode.CUMMAX)); + throw new NotImplementedException(); + } + + private void aggregateWithRef(Builtin fn) { + try { + + double aa = a.aggregateWithReference(0, fn, ref, true); + double bb = b.aggregateWithReference(0, fn, ref, true); + assertEquals(aa, bb, 0.0); + double aa2 = a.aggregateWithReference(0, fn, ref, false); + double bb2 = b.aggregateWithReference(0, fn, ref, false); + assertEquals(aa2, bb2, 0.0); + } + catch(NotImplementedException ee) { + throw ee; + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + @Test public void testAddToEntryVectorized2() { try { @@ -643,7 +1097,21 @@ public void containsValueWithReference(double value, double[] reference) { b.containsValueWithReference(value, reference)); } - private static void compare(IDictionary a, IDictionary b, int nRow, int nCol) { + private static void compare(IDictionary a, IDictionary b, int nCol) { + try { + + if(a == null && b == null) { + return; // all good. + } + assertEquals(a.getNumberOfValues(nCol), b.getNumberOfValues(nCol)); + compare(a, b, a.getNumberOfValues(nCol), nCol); + } + catch(NullPointerException e) { + fail("both outputs are not null: " + a + " vs " + b); + } + } + + protected static void compare(IDictionary a, IDictionary b, int nRow, int nCol) { try { if(a == null && b == null) return; @@ -652,8 +1120,15 @@ else if(a == null || b == null) else { String errorM = a.getClass().getSimpleName() + " " + b.getClass().getSimpleName(); for(int i = 0; i < nRow; i++) - for(int j = 0; j < nCol; j++) - assertEquals(errorM, a.getValue(i, j, nCol), b.getValue(i, j, nCol), 0.0001); + for(int j = 0; j < nCol; j++) { + double aa = a.getValue(i, j, nCol); + double bb = b.getValue(i, j, nCol); + boolean eq = Math.abs(aa - bb) < 0.0001; + if(!eq) { + assertEquals(errorM + " cell:<" + i + "," + j + ">", a.getValue(i, j, nCol), + b.getValue(i, j, nCol), 0.0001); + } + } } } catch(Exception e) { @@ -682,6 +1157,388 @@ public void preaggValuesFromDense() { } } + @Test + public void rightMMPreAggSparse() { + final int nColsOut = 30; + MatrixBlock sparse = TestUtils.generateTestMatrixBlock(1000, nColsOut, -10, 10, 0.1, 100); + sparse = TestUtils.ceil(sparse); + sparse.denseToSparse(true); + SparseBlock sb = sparse.getSparseBlock(); + if(sb == null) + throw new NotImplementedException(); + + IColIndex agCols = new RangeIndex(nColsOut); + IColIndex thisCols = new RangeIndex(0, nCol); + + int nVals = a.getNumberOfValues(nCol); + try { + + IDictionary aa = a.rightMMPreAggSparse(nVals, sb, thisCols, agCols, nColsOut); + IDictionary bb = b.rightMMPreAggSparse(nVals, sb, thisCols, agCols, nColsOut); + compare(aa, bb, nColsOut); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + + } + + @Test + public void rightMMPreAggSparse2() { + final int nColsOut = 1000; + MatrixBlock sparse = TestUtils.generateTestMatrixBlock(1000, nColsOut, -10, 10, 0.01, 100); + sparse = TestUtils.ceil(sparse); + sparse.denseToSparse(true); + SparseBlock sb = sparse.getSparseBlock(); + if(sb == null) + throw new NotImplementedException(); + + IColIndex agCols = new RangeIndex(nColsOut); + IColIndex thisCols = new RangeIndex(0, nCol); + + int nVals = a.getNumberOfValues(nCol); + try { + + IDictionary aa = a.rightMMPreAggSparse(nVals, sb, thisCols, agCols, nColsOut); + IDictionary bb = b.rightMMPreAggSparse(nVals, sb, thisCols, agCols, nColsOut); + compare(aa, bb, nColsOut); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + + } + + @Test + public void rightMMPreAggSparseDifferentColumns() { + final int nColsOut = 3; + MatrixBlock sparse = TestUtils.generateTestMatrixBlock(1000, 50, -10, 10, 0.1, 100); + sparse = TestUtils.ceil(sparse); + sparse.denseToSparse(true); + SparseBlock sb = sparse.getSparseBlock(); + if(sb == null) + throw new NotImplementedException(); + + IColIndex agCols = new ArrayIndex(new int[] {4, 10, 38}); + IColIndex thisCols = new RangeIndex(0, nCol); + + int nVals = a.getNumberOfValues(nCol); + try { + + IDictionary aa = a.rightMMPreAggSparse(nVals, sb, thisCols, agCols, 50); + IDictionary bb = b.rightMMPreAggSparse(nVals, sb, thisCols, agCols, 50); + compare(aa, bb, nColsOut); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + + } + + @Test + public void MMDictScalingDense() { + double[] left = TestUtils.ceil(TestUtils.generateTestVector(a.getNumberOfValues(nCol) * 3, -10, 10, 1.0, 3214)); + IColIndex rowsLeft = ColIndexFactory.createI(1, 2, 3); + IColIndex colsRight = ColIndexFactory.create(0, nCol); + int[] scaling = new int[a.getNumberOfValues(nCol)]; + for(int i = 0; i < a.getNumberOfValues(nCol); i++) + scaling[i] = i + 1; + + try { + + MatrixBlock retA = new MatrixBlock(5, nCol, 0); + retA.allocateDenseBlock(); + a.MMDictScalingDense(left, rowsLeft, colsRight, retA, scaling); + + MatrixBlock retB = new MatrixBlock(5, nCol, 0); + retB.allocateDenseBlock(); + b.MMDictScalingDense(left, rowsLeft, colsRight, retB, scaling); + + TestUtils.compareMatricesBitAvgDistance(retA, retB, 10, 10); + } + catch(Exception e) { + + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void MMDictScalingDenseOffset() { + double[] left = TestUtils.generateTestVector(a.getNumberOfValues(nCol) * 3, -10, 10, 1.0, 3214); + IColIndex rowsLeft = ColIndexFactory.createI(1, 2, 3); + IColIndex colsRight = ColIndexFactory.create(3, nCol + 3); + int[] scaling = new int[a.getNumberOfValues(nCol)]; + for(int i = 0; i < a.getNumberOfValues(nCol); i++) + scaling[i] = i; + + try { + + MatrixBlock retA = new MatrixBlock(5, nCol + 3, 0); + retA.allocateDenseBlock(); + a.MMDictScalingDense(left, rowsLeft, colsRight, retA, scaling); + + MatrixBlock retB = new MatrixBlock(5, nCol + 3, 0); + retB.allocateDenseBlock(); + b.MMDictScalingDense(left, rowsLeft, colsRight, retB, scaling); + + TestUtils.compareMatricesBitAvgDistance(retA, retB, 10, 10); + } + catch(Exception e) { + + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void MMDictDense() { + double[] left = TestUtils.ceil(TestUtils.generateTestVector(a.getNumberOfValues(nCol) * 3, -10, 10, 1.0, 3214)); + IColIndex rowsLeft = ColIndexFactory.createI(1, 2, 3); + IColIndex colsRight = ColIndexFactory.create(0, nCol); + + try { + + MatrixBlock retA = new MatrixBlock(5, nCol, 0); + retA.allocateDenseBlock(); + a.MMDictDense(left, rowsLeft, colsRight, retA); + + MatrixBlock retB = new MatrixBlock(5, nCol, 0); + retB.allocateDenseBlock(); + b.MMDictDense(left, rowsLeft, colsRight, retB); + + TestUtils.compareMatricesBitAvgDistance(retA, retB, 10, 10); + } + catch(Exception e) { + + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void MMDictDenseOffset() { + double[] left = TestUtils.generateTestVector(a.getNumberOfValues(nCol) * 3, -10, 10, 1.0, 3214); + IColIndex rowsLeft = ColIndexFactory.createI(1, 2, 3); + IColIndex colsRight = ColIndexFactory.create(3, nCol + 3); + + try { + + MatrixBlock retA = new MatrixBlock(5, nCol + 3, 0); + retA.allocateDenseBlock(); + a.MMDictDense(left, rowsLeft, colsRight, retA); + + MatrixBlock retB = new MatrixBlock(5, nCol + 3, 0); + retB.allocateDenseBlock(); + b.MMDictDense(left, rowsLeft, colsRight, retB); + + TestUtils.compareMatricesBitAvgDistance(retA, retB, 10, 10); + } + catch(Exception e) { + + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void sumAllRowsToDouble() { + double[] aa = a.sumAllRowsToDouble(nCol); + double[] bb = b.sumAllRowsToDouble(nCol); + TestUtils.compareMatrices(aa, bb, 0.001); + } + + @Test + public void sumAllRowsToDoubleWithDefault() { + double[] def = TestUtils.generateTestVector(nCol, 1, 10, 1.0, 3215213); + double[] aa = a.sumAllRowsToDoubleWithDefault(def); + double[] bb = b.sumAllRowsToDoubleWithDefault(def); + + String errm = a.getClass().getSimpleName() + " " + b.getClass().getSimpleName(); + TestUtils.compareMatrices(aa, bb, 0.001, errm); + } + + @Test + public void sumAllRowsToDoubleWithReference() { + double[] def = TestUtils.generateTestVector(nCol, 1, 10, 1.0, 3215213); + double[] aa = a.sumAllRowsToDoubleWithReference(def); + double[] bb = b.sumAllRowsToDoubleWithReference(def); + TestUtils.compareMatrices(aa, bb, 0.001, "\n" + a + "\n" + b); + } + + @Test + public void sumAllRowsToDoubleSq() { + double[] aa = a.sumAllRowsToDoubleSq(nCol); + double[] bb = b.sumAllRowsToDoubleSq(nCol); + TestUtils.compareMatrices(aa, bb, 0.001); + } + + @Test + public void sumAllRowsToDoubleSqWithDefault() { + double[] def = TestUtils.generateTestVector(nCol, 1, 10, 1.0, 3215213); + double[] aa = a.sumAllRowsToDoubleSqWithDefault(def); + double[] bb = b.sumAllRowsToDoubleSqWithDefault(def); + TestUtils.compareMatrices(aa, bb, 0.001); + } + + @Test + public void sumAllRowsToDoubleSqWithReference() { + double[] def = TestUtils.generateTestVector(nCol, 1, 10, 1.0, 3215213); + double[] aa = a.sumAllRowsToDoubleSqWithReference(def); + double[] bb = b.sumAllRowsToDoubleSqWithReference(def); + TestUtils.compareMatrices(aa, bb, 0.001); + } + + @Test + public void sumAllColsSqWithReference() { + double[] def = TestUtils.generateTestVector(nCol, 1, 10, 1.0, 3215213); + final int[] counts = getCounts(nRow, 1324); + + double[] aa = new double[nCol]; + double[] bb = new double[nCol]; + + a.colSumSqWithReference(aa, counts, ColIndexFactory.create(nCol), def); + b.colSumSqWithReference(bb, counts, ColIndexFactory.create(nCol), def); + TestUtils.compareMatrices(aa, bb, 0.001); + } + + @Test + public void aggColsMin() { + IColIndex cols = ColIndexFactory.create(2, nCol + 2); + Builtin m = Builtin.getBuiltinFnObject(BuiltinCode.MIN); + + double[] aa = new double[nCol + 3]; + a.aggregateCols(aa, m, cols); + double[] bb = new double[nCol + 3]; + b.aggregateCols(bb, m, cols); + + TestUtils.compareMatrices(aa, bb, 0.001); + } + + @Test + public void aggRows() { + Builtin m = Builtin.getBuiltinFnObject(BuiltinCode.MIN); + + double[] aa = a.aggregateRows(m, nCol); + double[] bb = b.aggregateRows(m, nCol); + + TestUtils.compareMatrices(aa, bb, 0.001); + + aa = a.aggregateRowsWithDefault(m, ref); + bb = b.aggregateRowsWithDefault(m, ref); + + TestUtils.compareMatrices(aa, bb, 0.001); + aa = a.aggregateRowsWithReference(m, ref); + bb = b.aggregateRowsWithReference(m, ref); + + TestUtils.compareMatrices(aa, bb, 0.001); + } + + @Test + public void getInMemorySize() { + a.getInMemorySize(); + b.getInMemorySize(); + } + + @Test + public void aggColsMax() { + IColIndex cols = ColIndexFactory.create(2, nCol + 2); + Builtin m = Builtin.getBuiltinFnObject(BuiltinCode.MAX); + + double[] aa; + double[] bb; + + aa = new double[nCol + 3]; + bb = new double[nCol + 3]; + a.aggregateCols(aa, m, cols); + b.aggregateCols(bb, m, cols); + TestUtils.compareMatrices(aa, bb, 0.001); + + aa = new double[nCol + 3]; + bb = new double[nCol + 3]; + a.aggregateColsWithReference(aa, m, cols, ref, true); + b.aggregateColsWithReference(bb, m, cols, ref, true); + TestUtils.compareMatrices(aa, bb, 0.001); + + aa = new double[nCol + 3]; + bb = new double[nCol + 3]; + a.aggregateColsWithReference(aa, m, cols, ref, false); + b.aggregateColsWithReference(bb, m, cols, ref, false); + TestUtils.compareMatrices(aa, bb, 0.001); + } + + @Test + public void getValue1() { + try { + int nCell = nCol * a.getNumberOfValues(nCol); + for(int i = 0; i < nCell; i++) + assertEquals(a.getValue(i), b.getValue(i), 0.0000); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void getValue2() { + try { + + String errm = a.getClass().getSimpleName() + " " + b.getClass().getSimpleName(); + for(int i = 0; i < nRow; i++) { + for(int j = 0; j < nCol; j++) { + assertEquals(errm, a.getValue(i, j, nCol), b.getValue(i, j, nCol), 0.0000); + } + } + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void colSum() { + IColIndex cols = ColIndexFactory.create(2, nCol + 2); + int[] counts = new int[a.getNumberOfValues(nCol)]; + for(int i = 0; i < counts.length; i++) { + counts[i] = i + 1; + } + + double[] aa = new double[nCol + 3]; + a.colSum(aa, counts, cols); + double[] bb = new double[nCol + 3]; + b.colSum(bb, counts, cols); + + String errm = a.getClass().getSimpleName() + " vs " + b.getClass().getSimpleName(); + TestUtils.compareMatrices(aa, bb, 0.001, errm); + } + + @Test + public void colProduct() { + IColIndex cols = ColIndexFactory.create(2, nCol + 2); + int[] counts = new int[a.getNumberOfValues(nCol)]; + for(int i = 0; i < counts.length; i++) { + counts[i] = i + 1; + } + + double[] aa = new double[nCol + 3]; + a.colProduct(aa, counts, cols); + double[] bb = new double[nCol + 3]; + b.colProduct(bb, counts, cols); + TestUtils.compareMatrices(aa, bb, 0.001); + + double[] ref = TestUtils.generateTestVector(nCol, 0, 10, 1, 3215555); + aa = new double[nCol + 3]; + a.colProductWithReference(aa, counts, cols, ref); + bb = new double[nCol + 3]; + b.colProductWithReference(bb, counts, cols, ref); + TestUtils.compareMatrices(aa, bb, 0.001); + } + public void productWithDefault(double retV, double[] def) { // Shared final int[] counts = getCounts(nRow, 1324); @@ -696,6 +1553,7 @@ public void productWithDefault(double retV, double[] def) { TestUtils.compareMatricesBitAvgDistance(// aRet, bRet, 10, 10, "Not Equivalent values from product"); + } private static int[] getCounts(int nRows, int seed) { @@ -726,7 +1584,7 @@ public void testSerialization() { ByteArrayOutputStream bos = new ByteArrayOutputStream(); DataOutputStream fos = new DataOutputStream(bos); a.write(fos); - + assertEquals(a.getExactSizeOnDisk(), fos.size()); // Serialize in ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray()); DataInputStream fis = new DataInputStream(bis); @@ -752,6 +1610,7 @@ public void testSerializationB() { DataOutputStream fos = new DataOutputStream(bos); b.write(fos); + assertEquals(b.getExactSizeOnDisk(), fos.size()); // Serialize in ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray()); DataInputStream fis = new DataInputStream(bis); @@ -768,4 +1627,237 @@ public void testSerializationB() { throw e; } } + + @Test + public void replaceNan() { + compare(a.replace(Double.NaN, 0, nCol), b.replace(Double.NaN, 0, nCol), nRow, nCol); + } + + @Test + public void getNNzCounts() { + int counts[] = new int[nRow]; + Random r = new Random(321); + for(int i = 0; i < nRow; i++) { + counts[i] = r.nextInt(100); + } + long annz = a.getNumberNonZeros(counts, nCol); + long bnnz = b.getNumberNonZeros(counts, nCol); + + long annzR = a.getNumberNonZerosWithReference(counts, new double[nCol], nRow); + long bnnzR = a.getNumberNonZerosWithReference(counts, new double[nCol], nRow); + assertEquals(annz, bnnz); + assertEquals(annzR, bnnz); + assertEquals(annzR, bnnzR); + } + + @Test + public void getNNzCountsWithRef() { + int counts[] = getCounts(nRow, 231); + double[] ref = TestUtils.generateTestVector(nCol, -1, -1, 0.5, 23); + long annzR = a.getNumberNonZerosWithReference(counts, ref, nRow); + long bnnzR = a.getNumberNonZerosWithReference(counts, ref, nRow); + assertEquals(annzR, bnnzR); + } + + @Test + public void getNNzCountsColumns() { + try { + int counts[] = new int[nRow]; + Random r = new Random(3213); + for(int i = 0; i < nRow; i++) { + counts[i] = r.nextInt(100); + } + int[] annz = a.countNNZZeroColumns(counts); + int[] bnnz = b.countNNZZeroColumns(counts); + assertArrayEquals(annz, bnnz); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testGetString() { + // get strings only criteria is not to crash. + a.getString(nCol); + b.getString(nCol); + } + + @Test + public void testClone() { + IDictionary ca = a.clone(); + assertFalse(ca == a); + assertEquals(ca, a); + + IDictionary cb = b.clone(); + assertFalse(cb == b); + assertEquals(cb, b); + } + + @Test + public void round() { + unaryOp(new UnaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.ROUND))); + } + + @Test + public void mod() { + unaryOp(new UnaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.ABS))); + } + + @Test + public void floor() { + unaryOp(new UnaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.FLOOR))); + } + + @Test + public void sin() { + unaryOp(new UnaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.SIN))); + } + + @Test + public void cos() { + unaryOp(new UnaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.COS))); + } + + public void unaryOp(UnaryOperator op) { + IDictionary aa; + IDictionary bb; + + aa = a.applyUnaryOp(op); + bb = b.applyUnaryOp(op); + compare(aa, bb, nCol); + + aa = a.applyUnaryOpAndAppend(op, 32, nCol); + bb = b.applyUnaryOpAndAppend(op, 32, nCol); + compare(aa, bb, nCol); + + double[] ref1 = TestUtils.generateTestVector(nCol, 0, 10, 1, 333); + double[] ref2 = TestUtils.generateTestVector(nCol, 0, 10, 1, 32); + aa = a.applyUnaryOpWithReference(op, ref1, ref2); + bb = b.applyUnaryOpWithReference(op, ref1, ref2); + compare(aa, bb, nCol); + } + + @Test + public void plus() { + scalarOp(new RightScalarOperator(Plus.getPlusFnObject(), 1)); + } + + @Test + public void mult() { + scalarOp(new RightScalarOperator(Multiply.getMultiplyFnObject(), 1)); + } + + @Test + public void div() { + scalarOp(new RightScalarOperator(Divide.getDivideFnObject(), 1)); + } + + public void scalarOp(ScalarOperator op) { + IDictionary aa; + IDictionary bb; + + aa = a.applyScalarOp(op); + bb = b.applyScalarOp(op); + compare(aa, bb, nCol); + + aa = a.applyScalarOpAndAppend(op, 32, nCol); + bb = b.applyScalarOpAndAppend(op, 32, nCol); + compare(aa, bb, nCol); + + double[] ref1 = TestUtils.generateTestVector(nCol, 0, 10, 1, 3213); + double[] ref2 = TestUtils.generateTestVector(nCol, 0, 10, 1, 23232); + aa = a.applyScalarOpWithReference(op, ref1, ref2); + bb = b.applyScalarOpWithReference(op, ref1, ref2); + compare(aa, bb, nCol); + } + + @Test + public void scaleTuples() { + IDictionary aa; + IDictionary bb; + + int[] scale = TestUtils.generateTestIntVector(nRow, 1, 10, 1, 3213); + aa = a.scaleTuples(scale, nCol); + bb = b.scaleTuples(scale, nCol); + compare(aa, bb, nCol); + } + + // productAllRowsToDouble + @Test + public void productRows() { + double[] aa; + double[] bb; + String err = a.getClass().getSimpleName() + " " + b.getClass().getSimpleName(); + // int[] scale = TestUtils.generateTestIntVector(nRow, 1, 10, 1, 3213); + aa = a.productAllRowsToDouble(nCol); + bb = b.productAllRowsToDouble(nCol); + assertArrayEquals(err, aa, bb, 0.0000001); + + double[] def = TestUtils.generateTestVector(nCol, 1, 10, 1, 3216245); + aa = a.productAllRowsToDoubleWithDefault(def); + bb = b.productAllRowsToDoubleWithDefault(def); + assertArrayEquals(err, aa, bb, 0.0000001); + + double[] ref = TestUtils.generateTestVector(nCol, 1, 10, 1, 3216245); + aa = a.productAllRowsToDoubleWithReference(ref); + bb = b.productAllRowsToDoubleWithReference(ref); + assertArrayEquals(err, aa, bb, 0.0000001); + } + + @Test + public void appendRow() { + double[] r = TestUtils.generateTestVector(nCol, 1, 10, 0.9, 2222); + IDictionary aa = a.append(r); + IDictionary bb = b.append(r); + + compare(aa, bb, nCol); + + for(int i = 0; i < nCol; i++) { + assertEquals(r[i], aa.getValue(nRow, i, nCol), 0.0); + assertEquals(r[i], bb.getValue(nRow, i, nCol), 0.0); + } + } + + @Test + public void colSumSq() { + double[] aa = new double[nCol + 2]; + double[] bb = new double[nCol + 2]; + int[] counts = getCounts(nRow, 321652); + a.colSumSq(aa, counts, ColIndexFactory.create(nCol)); + b.colSumSq(bb, counts, ColIndexFactory.create(nCol)); + assertArrayEquals(aa, bb, 0.0000001); + } + + @Test + public void multiplyScalar() { + double[] aa = new double[(nCol + 1) * 4]; + double[] bb = new double[(nCol + 1) * 4]; + Random r = new Random(3222); + for(int i = 0; i < 10; i++) { + int di = r.nextInt(nRow); + int ur = r.nextInt(4); + a.multiplyScalar(32, aa, ur, di, ColIndexFactory.create(nCol).shift(1)); + b.multiplyScalar(32, bb, ur, di, ColIndexFactory.create(nCol).shift(1)); + } + assertArrayEquals(aa, bb, 0.0000001); + + } + + @Test + public void subtractTuple() { + double[] r = TestUtils.generateTestVector(nCol, 1, 10, 0.9, 222); + IDictionary aa = a.subtractTuple(r); + IDictionary bb = b.subtractTuple(r); + + compare(aa, bb, nCol); + } + + @Test + public void cbind() { + IDictionary aa = a.cbind(b, nCol); + IDictionary bb = b.cbind(a, nCol); + compare(aa, bb, nCol * 2); + } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/indexes/CustomIndexTest.java b/src/test/java/org/apache/sysds/test/component/compress/indexes/CustomIndexTest.java index 3286a3eed61..9fa404ca77f 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/indexes/CustomIndexTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/indexes/CustomIndexTest.java @@ -19,6 +19,7 @@ package org.apache.sysds.test.component.compress.indexes; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotEquals; @@ -41,6 +42,7 @@ import org.apache.sysds.runtime.compress.colgroup.indexes.TwoIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.TwoRangesIndex; import org.apache.sysds.runtime.compress.utils.Util; +import org.apache.sysds.runtime.matrix.data.Pair; import org.junit.Test; import org.mockito.Mockito; @@ -1027,4 +1029,64 @@ public void containsAnyArray2() { IColIndex b = new RangeIndex(3, 11); assertTrue(a.containsAny(b)); } + + @Test + public void reordering1(){ + IColIndex a = ColIndexFactory.createI(1,3,5); + IColIndex b = ColIndexFactory.createI(2); + + assertFalse(IColIndex.inOrder(a, b)); + Pair r = IColIndex.reorderingIndexes(a, b); + + int[] ra = r.getKey(); + int[] rb = r.getValue(); + + assertArrayEquals(new int[]{0,2,3}, ra); + assertArrayEquals(new int[]{1}, rb); + } + + @Test + public void reordering2(){ + IColIndex a = ColIndexFactory.createI(1,3,5); + IColIndex b = ColIndexFactory.createI(2,4); + + assertFalse(IColIndex.inOrder(a, b)); + Pair r = IColIndex.reorderingIndexes(a, b); + + int[] ra = r.getKey(); + int[] rb = r.getValue(); + + assertArrayEquals(new int[]{0,2,4}, ra); + assertArrayEquals(new int[]{1,3}, rb); + } + + @Test + public void reordering3(){ + IColIndex a = ColIndexFactory.createI(1,3,5); + IColIndex b = ColIndexFactory.createI(0, 2,4); + + assertFalse(IColIndex.inOrder(a, b)); + Pair r = IColIndex.reorderingIndexes(a, b); + + int[] ra = r.getKey(); + int[] rb = r.getValue(); + + assertArrayEquals(new int[]{1,3,5}, ra); + assertArrayEquals(new int[]{0,2,4}, rb); + } + + @Test + public void reordering4(){ + IColIndex a = ColIndexFactory.createI(1,5); + IColIndex b = ColIndexFactory.createI(0,2,3,4); + + assertFalse(IColIndex.inOrder(a, b)); + Pair r = IColIndex.reorderingIndexes(a, b); + + int[] ra = r.getKey(); + int[] rb = r.getValue(); + + assertArrayEquals(new int[]{1,5}, ra); + assertArrayEquals(new int[]{0,2,3,4}, rb); + } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/indexes/IndexesTest.java b/src/test/java/org/apache/sysds/test/component/compress/indexes/IndexesTest.java index 871636ed477..1f5deccf779 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/indexes/IndexesTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/indexes/IndexesTest.java @@ -41,6 +41,7 @@ import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.indexes.ArrayIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; +import org.apache.sysds.runtime.compress.colgroup.indexes.CombinedIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex.SliceResult; import org.apache.sysds.runtime.compress.colgroup.indexes.IIterate; @@ -145,6 +146,7 @@ public static Collection data() { tests.add(createTwoRange(1, 10, 22, 30)); tests.add(createTwoRange(9, 11, 22, 30)); tests.add(createTwoRange(9, 11, 22, 60)); + tests.add(createCombined(9, 11, 22)); } catch(Exception e) { e.printStackTrace(); @@ -349,6 +351,19 @@ public void equalsSizeDiff_twoRanges2() { assertNotEquals(actual, c); } + @Test + public void equalsCombine(){ + RangeIndex a = new RangeIndex(9, 11); + SingleIndex b = new SingleIndex(22); + IColIndex c = a.combine(b); + if(eq(expected, c)){ + LOG.error(c.size()); + compare(expected, c); + compare(c, actual); + } + + } + @Test public void equalsItself() { assertEquals(actual, actual); @@ -395,10 +410,16 @@ public void combineTwoAbove() { @Test public void combineTwoAround() { - IColIndex b = new TwoIndex(expected[0] - 1, expected[expected.length - 1] + 1); - IColIndex c = actual.combine(b); - assertTrue(c.containsStrict(actual, b)); - assertTrue(c.containsStrict(b, actual)); + try { + IColIndex b = new TwoIndex(expected[0] - 1, expected[expected.length - 1] + 1); + IColIndex c = actual.combine(b); + assertTrue(c.containsStrict(actual, b)); + assertTrue(c.containsStrict(b, actual)); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } } @Test @@ -417,7 +438,10 @@ public void hashCodeEquals() { @Test public void estimateInMemorySizeIsNotToBig() { - assertTrue(MemoryEstimates.intArrayCost(expected.length) >= actual.estimateInMemorySize() - 16); + if(actual instanceof CombinedIndex) + assertTrue(MemoryEstimates.intArrayCost(expected.length) >= actual.estimateInMemorySize() - 64); + else + assertTrue(MemoryEstimates.intArrayCost(expected.length) >= actual.estimateInMemorySize() - 16); } @Test @@ -594,6 +618,17 @@ private void shift(int i) { compare(expected, actual.shift(i), i); } + private static boolean eq(int[] expected, IColIndex actual) { + if(expected.length == actual.size()) { + for(int i = 0; i < expected.length; i++) + if(expected[i] != actual.get(i)) + return false; + return true; + } + else + return false; + } + public static void compare(int[] expected, IColIndex actual) { assertEquals(expected.length, actual.size()); for(int i = 0; i < expected.length; i++) @@ -673,4 +708,19 @@ private static Object[] createTwoRange(int l1, int u1, int l2, int u2) { exp[j] = i; return new Object[] {exp, c}; } + + private static Object[] createCombined(int l1, int u1, int o) { + RangeIndex a = new RangeIndex(l1, u1); + SingleIndex b = new SingleIndex(o); + IColIndex c = a.combine(b); + int[] exp = new int[u1 - l1 + 1]; + + for(int i = l1, j = 0; i < u1; i++, j++) + exp[j] = i; + + exp[exp.length - 1] = o; + + return new Object[] {exp, c}; + + } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/io/IOTest.java b/src/test/java/org/apache/sysds/test/component/compress/io/IOTest.java index 9ec4aee5267..787d457f802 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/io/IOTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/io/IOTest.java @@ -159,10 +159,11 @@ protected static void writeR(MatrixBlock src, String path, int rep) throws Excep } protected static void writeAndReadR(MatrixBlock mb, int blen, int rep) throws Exception { + String filename = getName(); try { - String filename = getName(); File f = new File(filename); - f.delete(); + if(f.isFile() || f.isDirectory()) + f.delete(); WriterCompressed.writeCompressedMatrixToHDFS(mb, filename, blen); File f2 = new File(filename); assertTrue(f2.isFile() || f2.isDirectory()); @@ -170,15 +171,21 @@ protected static void writeAndReadR(MatrixBlock mb, int blen, int rep) throws Ex IOCompressionTestUtils.verifyEquivalence(mb, mbr); } catch(Exception e) { - + File f = new File(filename); + if(f.isFile() || f.isDirectory()) + f.delete(); if(rep < 3) { Thread.sleep(1000); writeAndReadR(mb, blen, rep + 1); return; } - e.printStackTrace(); throw e; } + finally{ + File f = new File(filename); + if(f.isFile() || f.isDirectory()) + f.delete(); + } } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibBinaryCellOpTest.java b/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibBinaryCellOpTest.java index 074e1153535..69305cf5b24 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibBinaryCellOpTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibBinaryCellOpTest.java @@ -35,34 +35,11 @@ import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.lib.CLALibBinaryCellOp; -import org.apache.sysds.runtime.functionobjects.And; -import org.apache.sysds.runtime.functionobjects.BitwAnd; -import org.apache.sysds.runtime.functionobjects.BitwOr; -import org.apache.sysds.runtime.functionobjects.BitwShiftL; -import org.apache.sysds.runtime.functionobjects.BitwShiftR; -import org.apache.sysds.runtime.functionobjects.BitwXor; -import org.apache.sysds.runtime.functionobjects.Builtin; -import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode; import org.apache.sysds.runtime.functionobjects.Divide; -import org.apache.sysds.runtime.functionobjects.Equals; -import org.apache.sysds.runtime.functionobjects.GreaterThan; -import org.apache.sysds.runtime.functionobjects.GreaterThanEquals; -import org.apache.sysds.runtime.functionobjects.IntegerDivide; -import org.apache.sysds.runtime.functionobjects.LessThan; -import org.apache.sysds.runtime.functionobjects.LessThanEquals; -import org.apache.sysds.runtime.functionobjects.Minus; -import org.apache.sysds.runtime.functionobjects.Minus1Multiply; -import org.apache.sysds.runtime.functionobjects.MinusMultiply; -import org.apache.sysds.runtime.functionobjects.MinusNz; -import org.apache.sysds.runtime.functionobjects.Modulus; import org.apache.sysds.runtime.functionobjects.Multiply; -import org.apache.sysds.runtime.functionobjects.NotEquals; -import org.apache.sysds.runtime.functionobjects.Or; import org.apache.sysds.runtime.functionobjects.Plus; -import org.apache.sysds.runtime.functionobjects.PlusMultiply; import org.apache.sysds.runtime.functionobjects.Power; import org.apache.sysds.runtime.functionobjects.ValueFunction; -import org.apache.sysds.runtime.functionobjects.Xor; import org.apache.sysds.runtime.matrix.data.LibMatrixBincell; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; @@ -77,42 +54,42 @@ public class CLALibBinaryCellOpTest { protected static final Log LOG = LogFactory.getLog(CombineGroupsTest.class.getName()); public final static ValueFunction[] vf = {// - (Plus.getPlusFnObject()), // - (Minus.getMinusFnObject()), // + // (Plus.getPlusFnObject()), // + // (Minus.getMinusFnObject()), // Divide.getDivideFnObject(), // - (Or.getOrFnObject()), // - (LessThan.getLessThanFnObject()), // - (LessThanEquals.getLessThanEqualsFnObject()), // - (GreaterThan.getGreaterThanFnObject()), // - (GreaterThanEquals.getGreaterThanEqualsFnObject()), // - (Multiply.getMultiplyFnObject()), // - (Modulus.getFnObject()), // - (IntegerDivide.getFnObject()), // - (Equals.getEqualsFnObject()), // - (NotEquals.getNotEqualsFnObject()), // - (And.getAndFnObject()), // - (Xor.getXorFnObject()), // - (BitwAnd.getBitwAndFnObject()), // - (BitwOr.getBitwOrFnObject()), // - (BitwXor.getBitwXorFnObject()), // - (BitwShiftL.getBitwShiftLFnObject()), // - (BitwShiftR.getBitwShiftRFnObject()), // - (Power.getPowerFnObject()), // - (MinusNz.getMinusNzFnObject()), // - (new PlusMultiply(32)), // - (new PlusMultiply(2)), // - (new PlusMultiply(0)), // - (new MinusMultiply(32)), // - Minus1Multiply.getMinus1MultiplyFnObject(), - // // Builtin - (Builtin.getBuiltinFnObject(BuiltinCode.MIN)), // - (Builtin.getBuiltinFnObject(BuiltinCode.MAX)), // - (Builtin.getBuiltinFnObject(BuiltinCode.LOG)), // - (Builtin.getBuiltinFnObject(BuiltinCode.LOG_NZ)), // - (Builtin.getBuiltinFnObject(BuiltinCode.MAXINDEX)), // - (Builtin.getBuiltinFnObject(BuiltinCode.MININDEX)), // - (Builtin.getBuiltinFnObject(BuiltinCode.CUMMAX)), // - (Builtin.getBuiltinFnObject(BuiltinCode.CUMMIN)),// + // (Or.getOrFnObject()), // + // (LessThan.getLessThanFnObject()), // + // (LessThanEquals.getLessThanEqualsFnObject()), // + // (GreaterThan.getGreaterThanFnObject()), // + // (GreaterThanEquals.getGreaterThanEqualsFnObject()), // + // (Multiply.getMultiplyFnObject()), // + // (Modulus.getFnObject()), // + // (IntegerDivide.getFnObject()), // + // (Equals.getEqualsFnObject()), // + // (NotEquals.getNotEqualsFnObject()), // + // (And.getAndFnObject()), // + // (Xor.getXorFnObject()), // + // (BitwAnd.getBitwAndFnObject()), // + // (BitwOr.getBitwOrFnObject()), // + // (BitwXor.getBitwXorFnObject()), // + // (BitwShiftL.getBitwShiftLFnObject()), // + // (BitwShiftR.getBitwShiftRFnObject()), // + // (Power.getPowerFnObject()), // + // (MinusNz.getMinusNzFnObject()), // + // (new PlusMultiply(32)), // + // (new PlusMultiply(2)), // + // (new PlusMultiply(0)), // + // (new MinusMultiply(32)), // + // Minus1Multiply.getMinus1MultiplyFnObject(), + // // // Builtin + // (Builtin.getBuiltinFnObject(BuiltinCode.MIN)), // + // (Builtin.getBuiltinFnObject(BuiltinCode.MAX)), // + // (Builtin.getBuiltinFnObject(BuiltinCode.LOG)), // + // (Builtin.getBuiltinFnObject(BuiltinCode.LOG_NZ)), // + // (Builtin.getBuiltinFnObject(BuiltinCode.MAXINDEX)), // + // (Builtin.getBuiltinFnObject(BuiltinCode.MININDEX)), // + // (Builtin.getBuiltinFnObject(BuiltinCode.CUMMAX)), // + // (Builtin.getBuiltinFnObject(BuiltinCode.CUMMIN)),// }; private final MatrixBlock mb; @@ -375,7 +352,6 @@ public void binRightMrV() { @Test public void binRightMrV_noCache() { try { - CompressedMatrixBlock spy = spy(cmb); when(spy.getCachedDecompressed()).thenReturn(null); exec(op, mb, spy, mrv2); @@ -499,9 +475,17 @@ public void binLeftMS() { private static void exec(BinaryOperator op, MatrixBlock mb1, CompressedMatrixBlock cmb1, MatrixBlock mb2) { if(mb2 != null) { - MatrixBlock cRet = CLALibBinaryCellOp.binaryOperationsRight(op, cmb1, mb2); - MatrixBlock uRet = LibMatrixBincell.bincellOp(mb1, CompressedMatrixBlock.getUncompressed(mb2), null, op); - compare(op, cRet, uRet); + MatrixBlock cRet = null; + MatrixBlock uRet = null; + try{ + cRet = CLALibBinaryCellOp.binaryOperationsRight(op, cmb1, mb2); + uRet = LibMatrixBincell.bincellOp(mb1, CompressedMatrixBlock.getUncompressed(mb2), null, op); + compare(op, cRet, uRet); + } + catch(AssertionError e ){ + fail(e.getMessage()); + throw new RuntimeException(e); + } } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibLMMTest.java b/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibLMMTest.java index 383d31cab95..7a227481c0c 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibLMMTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibLMMTest.java @@ -38,6 +38,7 @@ import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC; import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; @@ -146,7 +147,7 @@ public static Collection data() { cmb = (CompressedMatrixBlock) CompressedMatrixBlockFactory.compress(mb, 1).getLeft(); genTests(tests, mb, cmb, "Sparse2"); - IdentityDictionary id = new IdentityDictionary(10); + IDictionary id = IdentityDictionary.create(10); AMapToData d = MappingTestUtil.createRandomMap(100, 10, new Random(23)); AColGroup idg = ColGroupDDC.create(ColIndexFactory.create(10), id, d, null); cmb = new CompressedMatrixBlock(100, 10); @@ -162,7 +163,7 @@ public static Collection data() { mb = CompressedMatrixBlock.getUncompressed(cmb); genTests(tests, mb, cmb, "Identity2"); - id = new IdentityDictionary(10, true); + id = IdentityDictionary.create(10, true); // continuous index range d = MappingTestUtil.createRandomMap(100, 11, new Random(33)); @@ -259,7 +260,7 @@ private static void genTests(List tests, MatrixBlock mb, MatrixBlock c private static MatrixBlock createSelectionMatrix(final int nRow, final int nRowLeft, boolean emptyRows) { MatrixBlock tcmb; - IdentityDictionary id = new IdentityDictionary(nRow, emptyRows); + IDictionary id = IdentityDictionary.create(nRow, emptyRows); AMapToData d = MappingTestUtil.createRandomMap(nRowLeft, nRow + (emptyRows ? 1 : 0), new Random(33)); AColGroup idg = ColGroupDDC.create(ColIndexFactory.create(nRow), id, d, null); tcmb = new CompressedMatrixBlock(nRowLeft, nRow); diff --git a/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibSelectionMultTest.java b/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibSelectionMultTest.java index 33a4ecbd8da..794864e6e0a 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibSelectionMultTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibSelectionMultTest.java @@ -38,6 +38,7 @@ import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC; import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; @@ -125,7 +126,7 @@ public static Collection data() { gs = new ArrayList<>(); gs.add(ColGroupConst.create(ColIndexFactory.create(10), new double[] {13.0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); - IdentityDictionary id = new IdentityDictionary(10); + IDictionary id = IdentityDictionary.create(10); AMapToData d = MappingTestUtil.createRandomMap(100, 10, new Random(23)); AColGroup idg = ColGroupDDC.create(ColIndexFactory.create(10, 100), id, d, null); gs.add(idg); @@ -161,7 +162,7 @@ public static Collection data() { mb = CompressedMatrixBlock.getUncompressed(cmb); genTests(tests, mb, cmb, "Identity2"); - id = new IdentityDictionary(10, true); + id = IdentityDictionary.create(10, true); d = MappingTestUtil.createRandomMap(100, 11, new Random(33)); idg = ColGroupDDC.create(ColIndexFactory.createI(0,1,2,3,4,6,7,8,9,10), id, d, null); @@ -217,7 +218,7 @@ private static void genTests(List tests, MatrixBlock mb, MatrixBlock c public static MatrixBlock createSelectionMatrix(final int nRow, final int nRowLeft, boolean emptyRows) { MatrixBlock tcmb; - IdentityDictionary id = new IdentityDictionary(nRow, emptyRows); + IDictionary id = IdentityDictionary.create(nRow, emptyRows); AMapToData d = MappingTestUtil.createRandomMap(nRowLeft, nRow + (emptyRows ? 1 : 0), new Random(33)); AColGroup idg = ColGroupDDC.create(ColIndexFactory.create(nRow), id, d, null); tcmb = new CompressedMatrixBlock(nRowLeft, nRow); diff --git a/src/test/java/org/apache/sysds/test/component/frame/FrameCustomTest.java b/src/test/java/org/apache/sysds/test/component/frame/FrameCustomTest.java index 047b2da3b25..3af635a3189 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/FrameCustomTest.java +++ b/src/test/java/org/apache/sysds/test/component/frame/FrameCustomTest.java @@ -19,6 +19,7 @@ package org.apache.sysds.test.component.frame; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.anyInt; @@ -30,6 +31,8 @@ import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.columns.StringArray; import org.apache.sysds.runtime.frame.data.lib.FrameLibAppend; import org.apache.sysds.runtime.frame.data.lib.FrameLibDetectSchema; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -97,4 +100,23 @@ public void appendUniqueColNames(){ assertTrue(c.getColumnName(0).equals("Hi")); assertTrue(c.getColumnName(1).equals("There")); } + + + @Test + public void detectSchema(){ + FrameBlock f = new FrameBlock(new Array[]{new StringArray(new String[]{"00000001", "e013af63"})}); + assertEquals("HASH32", FrameLibDetectSchema.detectSchema(f, 1).get(0,0)); + } + + @Test + public void detectSchema2(){ + FrameBlock f = new FrameBlock(new Array[]{new StringArray(new String[]{"10000001", "e013af63"})}); + assertEquals("HASH32", FrameLibDetectSchema.detectSchema(f, 1).get(0,0)); + } + + @Test + public void detectSchema3(){ + FrameBlock f = new FrameBlock(new Array[]{new StringArray(new String[]{"e013af63","10000001"})}); + assertEquals("HASH32", FrameLibDetectSchema.detectSchema(f, 1).get(0,0)); + } } diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java index 2b51a77b705..832aac51f09 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java @@ -19,11 +19,15 @@ package org.apache.sysds.test.component.frame.transform; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Map; +import java.util.Map.Entry; import java.util.logging.Level; import java.util.logging.Logger; @@ -119,8 +123,8 @@ public void testDummyCode() { test("{dummycode:[C1,C2,C3]}"); } - @Test - public void testDummyCodeV2(){ + @Test + public void testDummyCodeV2() { test("{ids:true, dummycode:[1,2,3]}"); } @@ -169,16 +173,19 @@ public void test(String spec) { TestUtils.compareMatrices(outNormal, outCompressed, 0, "Not Equal after apply"); - + meta = encoderNormal.getMetaData(meta); MultiColumnEncoder ec2 = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), encoderNormal.getMetaData(null)); - + + FrameBlock metaBack = ec2.getMetaData(null); + compareMeta(metaBack, meta); MatrixBlock outMeta12 = ec2.apply(data, k); + TestUtils.compareMatrices(outNormal, outMeta12, 0, "Not Equal after apply2"); MultiColumnEncoder ec = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), encoderCompressed.getMetaData(null)); - + MatrixBlock outMeta1 = ec.apply(data, k); TestUtils.compareMatrices(outNormal, outMeta1, 0, "Not Equal after apply"); @@ -188,4 +195,23 @@ public void test(String spec) { fail(e.getMessage()); } } + + private void compareMeta(FrameBlock e, FrameBlock a){ + try{ + assertEquals(e.getNumRows(), a.getNumRows()); + if(e.getNumRows()>0){ + for(int i = 0; i < e.getNumColumns(); i++){ + Map em = e.getColumn(i).getRecodeMap(); + Map am = a.getColumn(i).getRecodeMap(); + for(Entry eme : em.entrySet()){ + assertTrue(am.containsKey(eme.getKey())); + assertEquals(eme.getValue(), am.get(eme.getKey())); + } + } + } + } + catch(Exception ex){ + throw new RuntimeException(e.toString(), ex); + } + } } diff --git a/src/test/java/org/apache/sysds/test/component/matrix/MatrixBlockSerializationTest.java b/src/test/java/org/apache/sysds/test/component/matrix/MatrixBlockSerializationTest.java new file mode 100644 index 00000000000..40e7143fbb3 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/matrix/MatrixBlockSerializationTest.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.matrix; + +import static org.junit.Assert.fail; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(value = Parameterized.class) +public class MatrixBlockSerializationTest { + + private MatrixBlock mb; + + @Parameters + public static Collection data() { + List tests = new ArrayList<>(); + + try { + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(100, 100, 0, 10, 1.0, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(1000, 100, 0, 10, 1.0, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(100, 1, 0, 10, 1.0, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(1, 100, 0, 10, 1.0, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(100, 10, 0, 10, 1.0, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(100, 1000, 0, 10, 1.0, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(1000, 1000, 0, 10, 1.0, 3)}); + + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(100, 100, 0, 10, 0.1, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(1000, 100, 0, 10, 0.1, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(100, 1, 0, 10, 0.1, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(1, 100, 0, 10, 0.1, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(100, 10, 0, 10, 0.1, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(100, 1000, 0, 10, 0.1, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(1000, 1000, 0, 10, 0.1, 3)}); + + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(100, 100, 0, 10, 0.001, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(1000, 100, 0, 10, 0.001, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(100, 1, 0, 10, 0.001, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(1, 100, 0, 10, 0.001, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(100, 10, 0, 10, 0.001, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(100, 1000, 0, 10, 0.001, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(1000, 1000, 0, 10, 0.001, 3)}); + tests.add(new Object[] {new MatrixBlock()}); + + } + catch(Exception e) { + e.printStackTrace(); + fail("failed constructing tests"); + } + + return tests; + } + + public MatrixBlockSerializationTest(MatrixBlock mb) { + this.mb = mb; + } + + @Test + public void testSerialization() { + try { + // serialize compressed matrix block + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + DataOutputStream fos = new DataOutputStream(bos); + mb.write(fos); + + // deserialize compressed matrix block + ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray()); + DataInputStream fis = new DataInputStream(bis); + MatrixBlock in = new MatrixBlock(); + in.readFields(fis); + TestUtils.compareMatrices(mb, in, 0); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/compress/matrixByBin/CompressByBinTest.java b/src/test/java/org/apache/sysds/test/functions/compress/matrixByBin/CompressByBinTest.java index 1fe40002c29..14b6b5f787e 100644 --- a/src/test/java/org/apache/sysds/test/functions/compress/matrixByBin/CompressByBinTest.java +++ b/src/test/java/org/apache/sysds/test/functions/compress/matrixByBin/CompressByBinTest.java @@ -23,6 +23,8 @@ import java.util.Arrays; import java.util.Random; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types; import org.apache.sysds.runtime.controlprogram.caching.CacheBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; @@ -38,9 +40,9 @@ import org.junit.Assert; import org.junit.Test; - public class CompressByBinTest extends AutomatedTestBase { + protected static final Log LOG = LogFactory.getLog(CompressByBinTest.class.getName()); private final static String TEST_NAME = "compressByBins"; private final static String TEST_DIR = "functions/compress/matrixByBin/"; @@ -52,41 +54,48 @@ public class CompressByBinTest extends AutomatedTestBase { private final static int nbins = 10; - //private final static int[] dVector = new int[cols]; + // private final static int[] dVector = new int[cols]; @Override public void setUp() { - addTestConfiguration(TEST_NAME,new TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"X"})); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"X"})); } @Test - public void testCompressBinsMatrixWidthCP() { runCompress(Types.ExecType.CP, ColumnEncoderBin.BinMethod.EQUI_WIDTH); } + public void testCompressBinsMatrixWidthCP() { + runCompress(Types.ExecType.CP, ColumnEncoderBin.BinMethod.EQUI_WIDTH); + } @Test - public void testCompressBinsMatrixHeightCP() { runCompress(Types.ExecType.CP, ColumnEncoderBin.BinMethod.EQUI_HEIGHT); } + public void testCompressBinsMatrixHeightCP() { + runCompress(Types.ExecType.CP, ColumnEncoderBin.BinMethod.EQUI_HEIGHT); + } @Test - public void testCompressBinsFrameWidthCP() { runCompressFrame(Types.ExecType.CP, ColumnEncoderBin.BinMethod.EQUI_WIDTH); } + public void testCompressBinsFrameWidthCP() { + runCompressFrame(Types.ExecType.CP, ColumnEncoderBin.BinMethod.EQUI_WIDTH); + } @Test - public void testCompressBinsFrameHeightCP() { runCompressFrame(Types.ExecType.CP, ColumnEncoderBin.BinMethod.EQUI_HEIGHT); } + public void testCompressBinsFrameHeightCP() { + runCompressFrame(Types.ExecType.CP, ColumnEncoderBin.BinMethod.EQUI_HEIGHT); + } - private void runCompress(Types.ExecType instType, ColumnEncoderBin.BinMethod binMethod) - { + private void runCompress(Types.ExecType instType, ColumnEncoderBin.BinMethod binMethod) { Types.ExecMode platformOld = setExecMode(instType); - try - { + try { loadTestConfiguration(getTestConfiguration(TEST_NAME)); String HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = HOME + TEST_NAME + ".dml"; - programArgs = new String[]{"-args", input("X"), Boolean.toString(binMethod == ColumnEncoderBin.BinMethod.EQUI_WIDTH),output("meta"), output("res")}; + programArgs = new String[] {"-stats","-args", input("X"), + Boolean.toString(binMethod == ColumnEncoderBin.BinMethod.EQUI_WIDTH), output("meta"), output("res")}; double[][] X = generateMatrixData(binMethod); writeInputMatrixWithMTD("X", X, true); - runTest(true, false, null, -1); + runTest(null); checkMetaFile(DataConverter.convertToMatrixBlock(X), binMethod); @@ -99,24 +108,23 @@ private void runCompress(Types.ExecType instType, ColumnEncoderBin.BinMethod bin } } - private void runCompressFrame(Types.ExecType instType, ColumnEncoderBin.BinMethod binMethod) - { + private void runCompressFrame(Types.ExecType instType, ColumnEncoderBin.BinMethod binMethod) { Types.ExecMode platformOld = setExecMode(instType); - try - { + try { loadTestConfiguration(getTestConfiguration(TEST_NAME)); String HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = HOME + TEST_NAME + ".dml"; - programArgs = new String[]{"-explain", "-args", input("X"), Boolean.toString(binMethod == ColumnEncoderBin.BinMethod.EQUI_WIDTH) , output("meta"), output("res")}; + programArgs = new String[] {"-explain", "-args", input("X"), + Boolean.toString(binMethod == ColumnEncoderBin.BinMethod.EQUI_WIDTH), output("meta"), output("res")}; Types.ValueType[] schema = new Types.ValueType[cols]; Arrays.fill(schema, Types.ValueType.FP32); FrameBlock Xf = generateFrameData(binMethod, schema); writeInputFrameWithMTD("X", Xf, false, schema, Types.FileFormat.CSV); - runTest(true, false, null, -1); + runTest(null); checkMetaFile(Xf, binMethod); @@ -132,14 +140,15 @@ private void runCompressFrame(Types.ExecType instType, ColumnEncoderBin.BinMetho private double[][] generateMatrixData(ColumnEncoderBin.BinMethod binMethod) { double[][] X; if(binMethod == ColumnEncoderBin.BinMethod.EQUI_WIDTH) { - //generate actual dataset + // generate actual dataset X = getRandomMatrix(rows, cols, -100, 100, 1, 7); // make sure that bins in [-100, 100] for(int i = 0; i < cols; i++) { X[0][i] = -100; X[1][i] = 100; } - } else if(binMethod == ColumnEncoderBin.BinMethod.EQUI_HEIGHT) { + } + else if(binMethod == ColumnEncoderBin.BinMethod.EQUI_HEIGHT) { X = new double[rows][cols]; for(int c = 0; c < cols; c++) { double[] vals = new Random().doubles(nbins).toArray(); @@ -150,7 +159,8 @@ private double[][] generateMatrixData(ColumnEncoderBin.BinMethod binMethod) { j++; } } - } else + } + else throw new RuntimeException("Invalid binning method."); return X; @@ -164,9 +174,10 @@ private FrameBlock generateFrameData(ColumnEncoderBin.BinMethod binMethod, Types for(int i = 0; i < cols; i++) { Xf.set(0, i, -100); - Xf.set(rows-1, i, 100); + Xf.set(rows - 1, i, 100); } - } else if(binMethod == ColumnEncoderBin.BinMethod.EQUI_HEIGHT) { + } + else if(binMethod == ColumnEncoderBin.BinMethod.EQUI_HEIGHT) { Xf = new FrameBlock(); for(int c = 0; c < schema.length; c++) { double[] vals = new Random().doubles(nbins).toArray(); @@ -180,14 +191,16 @@ private FrameBlock generateFrameData(ColumnEncoderBin.BinMethod binMethod, Types Xf.appendColumn(f); } - } else + } + else throw new RuntimeException("Invalid binning method."); return Xf; } - private void checkMetaFile(CacheBlock X, ColumnEncoderBin.BinMethod binningType) throws IOException{ + private void checkMetaFile(CacheBlock X, ColumnEncoderBin.BinMethod binningType) throws IOException { FrameBlock outputMeta = readDMLFrameFromHDFS("meta", Types.FileFormat.CSV); + Assert.assertEquals(nbins, outputMeta.getNumRows()); double[] binStarts = new double[nbins]; @@ -201,9 +214,10 @@ private void checkMetaFile(CacheBlock X, ColumnEncoderBin.BinMethod binningTy Assert.assertEquals(i, binStart, 0.0); j++; } - } else { + } + else { binStarts[c] = Double.parseDouble(((String) outputMeta.getColumn(c).get(0)).split("·")[0]); - binEnds[c] = Double.parseDouble(((String) outputMeta.getColumn(c).get(nbins-1)).split("·")[1]); + binEnds[c] = Double.parseDouble(((String) outputMeta.getColumn(c).get(nbins - 1)).split("·")[1]); } } diff --git a/src/test/java/org/apache/sysds/test/functions/compress/reshape/CompressedReshapeTest.java b/src/test/java/org/apache/sysds/test/functions/compress/reshape/CompressedReshapeTest.java new file mode 100644 index 00000000000..4ba86392956 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/compress/reshape/CompressedReshapeTest.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.compress.reshape; + +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types; +import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +public class CompressedReshapeTest extends AutomatedTestBase { + protected static final Log LOG = LogFactory.getLog(CompressedReshapeTest.class.getName()); + + private final static String TEST_DIR = "functions/compress/reshape/"; + + protected String getTestClassDir() { + return getTestDir() + this.getClass().getSimpleName() + "/"; + } + + protected String getTestName() { + return "reshape1"; + } + + protected String getTestDir() { + return TEST_DIR; + } + + @Test + public void testReshape_01_1to2_sparse() { + reshapeTest(1, 1000, 2, 500, 0.2, ExecType.CP, 0, 5, "01"); + } + + @Test + public void testReshape_01_2to4_sparse() { + reshapeTest(2, 500, 4, 250, 0.2, ExecType.CP, 0, 5, "01"); + } + + @Test + public void testReshape_01_1to10_sparse() { + reshapeTest(1, 10000, 10, 1000, 0.2, ExecType.CP, 0, 5, "01"); + } + + @Test + public void testReshape_01_1to2_dense() { + reshapeTest(1, 1000, 2, 500, 1.0, ExecType.CP, 0, 5, "01"); + } + + @Test + public void testReshape_01_2to4_dense() { + reshapeTest(2, 500, 4, 250, 1.0, ExecType.CP, 0, 5, "01"); + } + + @Test + public void testReshape_01_1to10_dense() { + reshapeTest(1, 10000, 10, 1000, 1.0, ExecType.CP, 0, 5, "01"); + } + + @Test + public void testReshape_02_1to2_sparse() { + reshapeTest(1, 1000, 2, 500, 0.2, ExecType.CP, 0, 10, "02"); + } + + @Test + public void testReshape_02_1to2_dense() { + reshapeTest(1, 1000, 2, 500, 1.0, ExecType.CP, 0, 10, "02"); + } + + @Test + public void testReshape_03_1to2_sparse() { + reshapeTest(1, 1000, 2, 500, 0.2, ExecType.CP, 0, 10, "03"); + } + + @Test + public void testReshape_03_1to2_dense() { + reshapeTest(1, 1000, 2, 500, 1.0, ExecType.CP, 0, 10, "03"); + } + + public void reshapeTest(int cols, int rows, int reCol, int reRows, double sparsity, ExecType instType, int min, + int max, String name) { + + OptimizerUtils.ALLOW_SCRIPT_LEVEL_COMPRESS_COMMAND = true; + Types.ExecMode platformOld = setExecMode(instType); + + CompressedMatrixBlock.debug = true; + CompressedMatrixBlock.allowCachingUncompressed = false; + try { + + super.setOutputBuffering(true); + loadTestConfiguration(getTestConfiguration(getTestName())); + + fullDMLScriptName = SCRIPT_DIR + "/" + getTestClassDir() + name + ".dml"; + + programArgs = new String[] {"-stats", "100", "-nvargs", "cols=" + cols, "rows=" + rows, "reCols=" + reCol, + "reRows=" + reRows, "sparsity=" + sparsity, "min=" + min, "max= " + max}; + String s = runTest(null).toString(); + + if(s.contains("Failed")) + fail(s); + else + LOG.debug(s); + + } + catch(Exception e) { + e.printStackTrace(); + assertTrue("Exception in execution: " + e.getMessage(), false); + } + finally { + rtplatform = platformOld; + } + } + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(getTestName(), new TestConfiguration(getTestClassDir(), getTestName())); + } + +} diff --git a/src/test/scripts/functions/compress/reshape/CompressedReshapeTest/01.dml b/src/test/scripts/functions/compress/reshape/CompressedReshapeTest/01.dml new file mode 100644 index 00000000000..33b1baff130 --- /dev/null +++ b/src/test/scripts/functions/compress/reshape/CompressedReshapeTest/01.dml @@ -0,0 +1,54 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +cols=$cols +rows=$rows +reCols=$reCols +reRows=$reRows +sparsity=$sparsity +min=$min +max=$max + +X = rand(cols=cols, rows=rows, min=min, max=max, sparsity=$sparsity) +X = ceil(X) + +X_C = compress(X) + +while(FALSE){} # force a break + +X_r = matrix(X, rows = reRows, cols=reCols) +X_Cr = matrix(X_C, rows = reRows, cols=reCols) + +while(FALSE){} # force a second break + +same = X == X_C +same2 = X_r == X_Cr + +print(sum(same)) +print(sum(same2)) + +nCells = cols * rows + +if(nCells == sum(same) & sum(same) == sum(same2)) + print("Success, the output contained the same values after reshaping") +else + print("Failed, the output did not contain the same values after reshaping") diff --git a/src/test/scripts/functions/compress/reshape/CompressedReshapeTest/02.dml b/src/test/scripts/functions/compress/reshape/CompressedReshapeTest/02.dml new file mode 100644 index 00000000000..f213a9b9e29 --- /dev/null +++ b/src/test/scripts/functions/compress/reshape/CompressedReshapeTest/02.dml @@ -0,0 +1,57 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +cols=$cols +rows=$rows +reCols=$reCols +reRows=$reRows +sparsity=$sparsity +min=$min +max=$max + +X = rand(cols=cols, rows=rows, min=min, max=max, sparsity=$sparsity) +X = sqrt(X) + +X = ceil(X) + +X_C = compress(X) + + +while(FALSE){} # force a break + +X_r = matrix(X, rows = reRows, cols=reCols) +X_Cr = matrix(X_C, rows = reRows, cols=reCols) + +while(FALSE){} # force a second break + +same = X == X_C +same2 = X_r == X_Cr + +print(sum(same)) +print(sum(same2)) + +nCells = cols * rows + +if(nCells == sum(same) & sum(same) == sum(same2)) + print("Success, the output contained the same values after reshaping") +else + print("Failed, the output did not contain the same values after reshaping") diff --git a/src/test/scripts/functions/compress/reshape/CompressedReshapeTest/03.dml b/src/test/scripts/functions/compress/reshape/CompressedReshapeTest/03.dml new file mode 100644 index 00000000000..154f2069582 --- /dev/null +++ b/src/test/scripts/functions/compress/reshape/CompressedReshapeTest/03.dml @@ -0,0 +1,60 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +cols=$cols +rows=$rows +reCols=$reCols +reRows=$reRows +sparsity=$sparsity +min=$min +max=$max + +X = rand(cols=cols, rows=rows, min=min, max=max, sparsity=$sparsity) +X = sqrt(X) + +X = ceil(X) + + +X_C = compress(X) + +X = X + 1 +X_C = X_C + 1 + +while(FALSE){} # force a break + +X_r = matrix(X, rows = reRows, cols=reCols) +X_Cr = matrix(X_C, rows = reRows, cols=reCols) + +while(FALSE){} # force a second break + +same = X == X_C +same2 = X_r == X_Cr + +print(sum(same)) +print(sum(same2)) + +nCells = cols * rows + +if(nCells == sum(same) & sum(same) == sum(same2)) + print("Success, the output contained the same values after reshaping") +else + print("Failed, the output did not contain the same values after reshaping")