diff --git a/bin/systemds b/bin/systemds index efa47c1abcb..788c487e659 100755 --- a/bin/systemds +++ b/bin/systemds @@ -402,28 +402,22 @@ NATIVE_LIBS="$SYSTEMDS_ROOT${DIR_SEP}target${DIR_SEP}classes${DIR_SEP}lib" export PATH=${HADOOP_REL}${DIR_SEP}bin${PATH_SEP}${PATH}${PATH_SEP}$NATIVE_LIBS export LD_LIBRARY_PATH=${HADOOP_REL}${DIR_SEP}bin${PATH_SEP}${LD_LIBRARY_PATH} -# set java class path -CLASSPATH="${SYSTEMDS_JAR_FILE}${PATH_SEP} \ - ${SYSTEMDS_ROOT}${DIR_SEP}lib${DIR_SEP}*${PATH_SEP} \ - ${SYSTEMDS_ROOT}${DIR_SEP}target${DIR_SEP}lib${DIR_SEP}*" -# trim whitespace (introduced by the line breaks above) -CLASSPATH=$(echo "${CLASSPATH}" | tr -d '[:space:]') - if [ $PRINT_SYSDS_HELP == 1 ]; then echo "----------------------------------------------------------------------" echo "Further help on SystemDS arguments:" - java -cp "$CLASSPATH" org.apache.sysds.api.DMLScript -help + java -jar $SYSTEMDS_JAR_FILE -help exit 1 fi -print_out "###############################################################################" -print_out "# SYSTEMDS_ROOT= $SYSTEMDS_ROOT" -print_out "# SYSTEMDS_JAR_FILE= $SYSTEMDS_JAR_FILE" -print_out "# SYSDS_EXEC_MODE= $SYSDS_EXEC_MODE" -print_out "# CONFIG_FILE= $CONFIG_FILE" -print_out "# LOG4JPROP= $LOG4JPROP" -print_out "# CLASSPATH= $CLASSPATH" -print_out "# HADOOP_HOME= $HADOOP_HOME" +if [ $SYSDS_QUIET == 0 ]; then + print_out "###############################################################################" + print_out "# SYSTEMDS_ROOT= $SYSTEMDS_ROOT" + print_out "# SYSTEMDS_JAR_FILE= $SYSTEMDS_JAR_FILE" + print_out "# SYSDS_EXEC_MODE= $SYSDS_EXEC_MODE" + print_out "# CONFIG_FILE= $CONFIG_FILE" + print_out "# LOG4JPROP= $LOG4JPROP" + print_out "# HADOOP_HOME= $HADOOP_HOME" +fi #build the command to run if [ $WORKER == 1 ]; then @@ -432,7 +426,7 @@ if [ $WORKER == 1 ]; then print_out "###############################################################################" CMD=" \ java $SYSTEMDS_STANDALONE_OPTS \ - -cp $CLASSPATH \ + -jar $SYSTEMDS_JAR_FILE \ $LOG4JPROPFULL \ org.apache.sysds.api.DMLScript \ -w $PORT \ @@ -447,9 +441,8 @@ elif [ "$FEDMONITORING" == 1 ]; then print_out "###############################################################################" CMD=" \ java $SYSTEMDS_STANDALONE_OPTS \ - -cp $CLASSPATH \ $LOG4JPROPFULL \ - org.apache.sysds.api.DMLScript \ + -jar $SYSTEMDS_JAR_FILE \ -fedMonitoring $PORT \ $CONFIG_FILE \ $*" @@ -462,9 +455,8 @@ elif [ $SYSDS_DISTRIBUTED == 0 ]; then print_out "###############################################################################" CMD=" \ java $SYSTEMDS_STANDALONE_OPTS \ - -cp $CLASSPATH \ $LOG4JPROPFULL \ - org.apache.sysds.api.DMLScript \ + -jar $SYSTEMDS_JAR_FILE \ -f $SCRIPT_FILE \ -exec $SYSDS_EXEC_MODE \ $CONFIG_FILE \ diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java index 14ac2c5efc8..9f313d8cd17 100644 --- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java @@ -446,8 +446,7 @@ private boolean isApplicableForTransitiveSparkExecType(boolean left) || (left && !isLeftTransposeRewriteApplicable(true))) && getInput(index).getParent().size()==1 //bagg is only parent && !getInput(index).areDimsBelowThreshold() - && (getInput(index).optFindExecType() == ExecType.SPARK - || (getInput(index) instanceof DataOp && ((DataOp)getInput(index)).hasOnlyRDD())) + && getInput(index).hasSparkOutput() && getInput(index).getOutputMemEstimate()>getOutputMemEstimate(); } diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java index 74740f10ce4..097e3d1fa91 100644 --- a/src/main/java/org/apache/sysds/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java @@ -751,8 +751,8 @@ protected ExecType optFindExecType(boolean transitive) { checkAndSetForcedPlatform(); - DataType dt1 = getInput().get(0).getDataType(); - DataType dt2 = getInput().get(1).getDataType(); + final DataType dt1 = getInput(0).getDataType(); + final DataType dt2 = getInput(1).getDataType(); if( _etypeForced != null ) { _etype = _etypeForced; @@ -801,18 +801,28 @@ else if ( dt1 == DataType.SCALAR && dt2 == DataType.MATRIX ) { checkAndSetInvalidCPDimsAndSize(); } - //spark-specific decision refinement (execute unary scalar w/ spark input and + // spark-specific decision refinement (execute unary scalar w/ spark input and // single parent also in spark because it's likely cheap and reduces intermediates) - if(transitive && _etype == ExecType.CP && _etypeForced != ExecType.CP && _etypeForced != ExecType.FED && - getDataType().isMatrix() // output should be a matrix - && (dt1.isScalar() || dt2.isScalar()) // one side should be scalar - && supportsMatrixScalarOperations() // scalar operations - && !(getInput().get(dt1.isScalar() ? 1 : 0) instanceof DataOp) // input is not checkpoint - && getInput().get(dt1.isScalar() ? 1 : 0).getParent().size() == 1 // unary scalar is only parent - && !HopRewriteUtils.isSingleBlock(getInput().get(dt1.isScalar() ? 1 : 0)) // single block triggered exec - && getInput().get(dt1.isScalar() ? 1 : 0).optFindExecType() == ExecType.SPARK) { - // pull unary scalar operation into spark - _etype = ExecType.SPARK; + if(transitive // we allow transitive Spark operations. continue sequences of spark operations + && _etype == ExecType.CP // The instruction is currently in CP + && _etypeForced != ExecType.CP // not forced CP + && _etypeForced != ExecType.FED // not federated + && (getDataType().isMatrix() || getDataType().isFrame()) // output should be a matrix or frame + ) { + final boolean v1 = getInput(0).isScalarOrVectorBellowBlockSize(); + final boolean v2 = getInput(1).isScalarOrVectorBellowBlockSize(); + final boolean left = v1 == true; // left side is the vector or scalar + final Hop sparkIn = getInput(left ? 1 : 0); + if((v1 ^ v2) // XOR only one side is allowed to be a vector or a scalar. + && (supportsMatrixScalarOperations() || op == OpOp2.APPLY_SCHEMA) // supported operation + && sparkIn.getParent().size() == 1 // only one parent + && !HopRewriteUtils.isSingleBlock(sparkIn) // single block triggered exec + && sparkIn.optFindExecType() == ExecType.SPARK // input was spark op. + && !(sparkIn instanceof DataOp) // input is not checkpoint + ) { + // pull operation into spark + _etype = ExecType.SPARK; + } } if( OptimizerUtils.ALLOW_BINARY_UPDATE_IN_PLACE && @@ -842,7 +852,7 @@ else if( (op == OpOp2.CBIND && getDataType().isList()) || (op == OpOp2.RBIND && getDataType().isList())) { _etype = ExecType.CP; } - + //mark for recompile (forever) setRequiresRecompileIfNecessary(); @@ -1157,17 +1167,35 @@ && getInput().get(0) == that2.getInput().get(0) } public boolean supportsMatrixScalarOperations() { - return ( op==OpOp2.PLUS ||op==OpOp2.MINUS - ||op==OpOp2.MULT ||op==OpOp2.DIV - ||op==OpOp2.MODULUS ||op==OpOp2.INTDIV - ||op==OpOp2.LESS ||op==OpOp2.LESSEQUAL - ||op==OpOp2.GREATER ||op==OpOp2.GREATEREQUAL - ||op==OpOp2.EQUAL ||op==OpOp2.NOTEQUAL - ||op==OpOp2.MIN ||op==OpOp2.MAX - ||op==OpOp2.LOG ||op==OpOp2.POW - ||op==OpOp2.AND ||op==OpOp2.OR ||op==OpOp2.XOR - ||op==OpOp2.BITWAND ||op==OpOp2.BITWOR ||op==OpOp2.BITWXOR - ||op==OpOp2.BITWSHIFTL ||op==OpOp2.BITWSHIFTR); + switch(op) { + case PLUS: + case MINUS: + case MULT: + case DIV: + case MODULUS: + case INTDIV: + case LESS: + case LESSEQUAL: + case GREATER: + case GREATEREQUAL: + case EQUAL: + case NOTEQUAL: + case MIN: + case MAX: + case LOG: + case POW: + case AND: + case OR: + case XOR: + case BITWAND: + case BITWOR: + case BITWXOR: + case BITWSHIFTL: + case BITWSHIFTR: + return true; + default: + return false; + } } public boolean isPPredOperation() { diff --git a/src/main/java/org/apache/sysds/hops/DataOp.java b/src/main/java/org/apache/sysds/hops/DataOp.java index 42c51e452b7..bb4e5a9dedb 100644 --- a/src/main/java/org/apache/sysds/hops/DataOp.java +++ b/src/main/java/org/apache/sysds/hops/DataOp.java @@ -387,8 +387,8 @@ public boolean allowsAllExecTypes() protected double computeOutputMemEstimate( long dim1, long dim2, long nnz ) { double ret = 0; - - if ( getDataType() == DataType.SCALAR ) + final DataType dt = getDataType(); + if ( dt == DataType.SCALAR ) { switch( getValueType() ) { @@ -407,6 +407,11 @@ protected double computeOutputMemEstimate( long dim1, long dim2, long nnz ) ret = 0; } } + else if(dt == DataType.FRAME) { + if(_op == OpOpData.PERSISTENTREAD || _op == OpOpData.TRANSIENTREAD) { + ret = OptimizerUtils.estimateSizeExactFrame(dim1, dim2); + } + } else //MATRIX / FRAME { if( _op == OpOpData.PERSISTENTREAD diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java index 265ba672e96..276c5dc6479 100644 --- a/src/main/java/org/apache/sysds/hops/Hop.java +++ b/src/main/java/org/apache/sysds/hops/Hop.java @@ -1099,6 +1099,12 @@ public final String toString() { // ======================================================================================== + protected boolean isScalarOrVectorBellowBlockSize(){ + return getDataType().isScalar() || (dimsKnown() && + (( _dc.getRows() == 1 && _dc.getCols() < ConfigurationManager.getBlocksize()) + || _dc.getCols() == 1 && _dc.getRows() < ConfigurationManager.getBlocksize())); + } + protected boolean isVector() { return (dimsKnown() && (_dc.getRows() == 1 || _dc.getCols() == 1) ); } @@ -1702,6 +1708,11 @@ protected void setMemoryAndComputeEstimates(Lop lop) { lop.setComputeEstimate(ComputeCost.getHOPComputeCost(this)); } + protected boolean hasSparkOutput(){ + return (this.optFindExecType() == ExecType.SPARK + || (this instanceof DataOp && ((DataOp)this).hasOnlyRDD())); + } + /** * Set parse information. * diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java index 8953cba3782..bb640711827 100644 --- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java +++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java @@ -64,6 +64,7 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics; import org.apache.sysds.runtime.util.IndexRange; import org.apache.sysds.runtime.util.UtilFunctions; +import org.apache.sysds.utils.MemoryEstimates; public class OptimizerUtils { @@ -788,6 +789,15 @@ public static long estimateSizeExactSparsity(long nrows, long ncols, long nnz) double sp = getSparsity(nrows, ncols, nnz); return estimateSizeExactSparsity(nrows, ncols, sp); } + + + public static long estimateSizeExactFrame(long nRows, long nCols){ + if(nRows > Integer.MAX_VALUE) + return Long.MAX_VALUE; + + // assuming String arrays and on average 8 characters per value. + return (long)MemoryEstimates.stringArrayCost((int)nRows, 8) * nCols; + } /** * Estimates the footprint (in bytes) for an in-memory representation of a diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java b/src/main/java/org/apache/sysds/hops/UnaryOp.java index f046ffe85c2..9833586275d 100644 --- a/src/main/java/org/apache/sysds/hops/UnaryOp.java +++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java @@ -371,7 +371,11 @@ protected double computeOutputMemEstimate( long dim1, long dim2, long nnz ) } else { sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz); } - return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity); + + if(getDataType() == DataType.FRAME) + return OptimizerUtils.estimateSizeExactFrame(dim1, dim2); + else + return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity); } @Override @@ -468,6 +472,13 @@ public boolean isMetadataOperation() { || _op == OpOp1.CAST_AS_LIST; } + private boolean isDisallowedSparkOps(){ + return isCumulativeUnaryOperation() + || isCastUnaryOperation() + || _op==OpOp1.MEDIAN + || _op==OpOp1.IQM; + } + @Override protected ExecType optFindExecType(boolean transitive) { @@ -498,19 +509,22 @@ else if ( getInput().get(0).areDimsBelowThreshold() || getInput().get(0).isVecto checkAndSetInvalidCPDimsAndSize(); } + //spark-specific decision refinement (execute unary w/ spark input and //single parent also in spark because it's likely cheap and reduces intermediates) - if( _etype == ExecType.CP && _etypeForced != ExecType.CP - && getInput().get(0).optFindExecType() == ExecType.SPARK - && getDataType().isMatrix() - && !isCumulativeUnaryOperation() && !isCastUnaryOperation() - && _op!=OpOp1.MEDIAN && _op!=OpOp1.IQM - && !(getInput().get(0) instanceof DataOp) //input is not checkpoint - && getInput().get(0).getParent().size()==1 ) //unary is only parent - { + if(_etype == ExecType.CP // currently CP instruction + && _etype != ExecType.SPARK /// currently not SP. + && _etypeForced != ExecType.CP // not forced as CP instruction + && getInput(0).hasSparkOutput() // input is a spark instruction + && (getDataType().isMatrix() || getDataType().isFrame()) // output is a matrix or frame + && !isDisallowedSparkOps() // is invalid spark instruction + // && !(getInput().get(0) instanceof DataOp) // input is not checkpoint + // && getInput(0).getParent().size() <= 1// unary is only parent + ) { //pull unary operation into spark _etype = ExecType.SPARK; } + //mark for recompile (forever) setRequiresRecompileIfNecessary(); @@ -524,7 +538,7 @@ && getInput().get(0).getParent().size()==1 ) //unary is only parent } else { setRequiresRecompileIfNecessary(); } - + return _etype; } diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java index b2460e7697c..19a0785bda0 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java @@ -1206,6 +1206,11 @@ public static boolean isSumSq(Hop hop) { public static boolean isParameterBuiltinOp(Hop hop, ParamBuiltinOp type) { return hop instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp) hop).getOp().equals(type); } + + public static boolean isParameterBuiltinOp(Hop hop, ParamBuiltinOp... types) { + return hop instanceof ParameterizedBuiltinOp && + ArrayUtils.contains(types, ((ParameterizedBuiltinOp) hop).getOp()); + } public static boolean isRemoveEmpty(Hop hop, boolean rows) { return isParameterBuiltinOp(hop, ParamBuiltinOp.RMEMPTY) diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index e181c60a78f..cbf4de94a72 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -1986,35 +1986,35 @@ private static Hop simplifyWeightedDivMM(Hop parent, Hop hi, int pos) { } } - //Pattern 7) (W*(U%*%t(V))) - if( !appliedPattern - && HopRewriteUtils.isBinary(hi, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT - && HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) //prevent mv - && hi.getDim2() > 1 //not applied for vector-vector mult - && hi.getInput().get(0).getDataType() == DataType.MATRIX - && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getBlocksize() - && HopRewriteUtils.isOuterProductLikeMM(hi.getInput().get(1)) - && (((AggBinaryOp) hi.getInput().get(1)).checkMapMultChain() == ChainType.NONE || hi.getInput().get(1).getInput().get(1).getDim2() > 1) //no mmchain - && HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT - { - Hop W = hi.getInput().get(0); - Hop U = hi.getInput().get(1).getInput().get(0); - Hop V = hi.getInput().get(1).getInput().get(1); - - //for this basic pattern, we're more conservative and only apply wdivmm if - //W is sparse and U/V unknown or dense; or if U/V are dense - if( (HopRewriteUtils.isSparse(W) && !HopRewriteUtils.isSparse(U) && !HopRewriteUtils.isSparse(V)) - || (HopRewriteUtils.isDense(U) && HopRewriteUtils.isDense(V)) ) { - V = !HopRewriteUtils.isTransposeOperation(V) ? - HopRewriteUtils.createTranspose(V) : V.getInput().get(0); - hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, - OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 0, true, false); - hnew.setBlocksize(W.getBlocksize()); - hnew.refreshSizeInformation(); - appliedPattern = true; - LOG.debug("Applied simplifyWeightedDivMM7 (line "+hi.getBeginLine()+")"); - } - } + // //Pattern 7) (W*(U%*%t(V))) + // if( !appliedPattern + // && HopRewriteUtils.isBinary(hi, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT + // && HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) //prevent mv + // && hi.getDim2() > 1 //not applied for vector-vector mult + // && hi.getInput().get(0).getDataType() == DataType.MATRIX + // && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getBlocksize() + // && HopRewriteUtils.isOuterProductLikeMM(hi.getInput().get(1)) + // && (((AggBinaryOp) hi.getInput().get(1)).checkMapMultChain() == ChainType.NONE || hi.getInput().get(1).getInput().get(1).getDim2() > 1) //no mmchain + // && HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT + // { + // Hop W = hi.getInput().get(0); + // Hop U = hi.getInput().get(1).getInput().get(0); + // Hop V = hi.getInput().get(1).getInput().get(1); + + // //for this basic pattern, we're more conservative and only apply wdivmm if + // //W is sparse and U/V unknown or dense; or if U/V are dense + // if( (HopRewriteUtils.isSparse(W) && !HopRewriteUtils.isSparse(U) && !HopRewriteUtils.isSparse(V)) + // || ( HopRewriteUtils.isDense(U) && HopRewriteUtils.isDense(V)) ) { + // V = !HopRewriteUtils.isTransposeOperation(V) ? + // HopRewriteUtils.createTranspose(V) : V.getInput().get(0); + // hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, + // OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 0, true, false); + // hnew.setBlocksize(W.getBlocksize()); + // hnew.refreshSizeInformation(); + // appliedPattern = true; + // LOG.debug("Applied simplifyWeightedDivMM7 (line "+hi.getBeginLine()+")"); + // } + // } //relink new hop into original position if( hnew != null ) { diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index fa1a163036b..57fe615f020 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -1642,8 +1642,8 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV case DECOMPRESS: if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_COMPRESS_COMMAND){ checkNumParameters(1); - checkMatrixParam(getFirstExpr()); - output.setDataType(DataType.MATRIX); + checkMatrixFrameParam(getFirstExpr()); + output.setDataType(getFirstExpr().getOutput().getDataType()); output.setDimensions(id.getDim1(), id.getDim2()); output.setBlocksize (id.getBlocksize()); output.setValueType(id.getValueType()); diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java index 92200d4384b..ccba4d107cf 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java @@ -44,14 +44,15 @@ import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; import org.apache.sysds.runtime.compress.colgroup.ColGroupIO; import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; -import org.apache.sysds.runtime.compress.lib.CLALibAppend; import org.apache.sysds.runtime.compress.lib.CLALibBinaryCellOp; +import org.apache.sysds.runtime.compress.lib.CLALibCBind; import org.apache.sysds.runtime.compress.lib.CLALibCMOps; import org.apache.sysds.runtime.compress.lib.CLALibCompAgg; import org.apache.sysds.runtime.compress.lib.CLALibDecompress; import org.apache.sysds.runtime.compress.lib.CLALibMMChain; import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult; import org.apache.sysds.runtime.compress.lib.CLALibMerge; +import org.apache.sysds.runtime.compress.lib.CLALibReorg; import org.apache.sysds.runtime.compress.lib.CLALibRexpand; import org.apache.sysds.runtime.compress.lib.CLALibScalar; import org.apache.sysds.runtime.compress.lib.CLALibSlice; @@ -65,7 +66,6 @@ import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseRow; -import org.apache.sysds.runtime.functionobjects.SwapIndex; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.instructions.cp.ScalarObject; @@ -166,7 +166,9 @@ protected CompressedMatrixBlock(MatrixBlock uncompressedMatrixBlock) { clen = uncompressedMatrixBlock.getNumColumns(); sparse = false; nonZeros = uncompressedMatrixBlock.getNonZeros(); - decompressedVersion = new SoftReference<>(uncompressedMatrixBlock); + if(!(uncompressedMatrixBlock instanceof CompressedMatrixBlock)) { + decompressedVersion = new SoftReference<>(uncompressedMatrixBlock); + } } /** @@ -267,6 +269,9 @@ public synchronized MatrixBlock decompress(int k) { ret = CLALibDecompress.decompress(this, k); + ret.recomputeNonZeros(k); + ret.examSparsity(k); + // Set soft reference to the decompressed version decompressedVersion = new SoftReference<>(ret); @@ -319,6 +324,11 @@ public long recomputeNonZeros() { return nonZeros; } + @Override + public long recomputeNonZeros(int k) { + return recomputeNonZeros(); + } + @Override public long recomputeNonZeros(int rl, int ru) { throw new NotImplementedException(); @@ -491,8 +501,8 @@ public MatrixBlock binaryOperationsLeft(BinaryOperator op, MatrixValue thatValue @Override public MatrixBlock append(MatrixBlock[] that, MatrixBlock ret, boolean cbind) { - if(cbind && that.length == 1) - return CLALibAppend.append(this, that[0], InfrastructureAnalyzer.getLocalParallelism()); + if(cbind) + return CLALibCBind.cbind(this, that, InfrastructureAnalyzer.getLocalParallelism()); else { MatrixBlock left = getUncompressed("append list or r-bind not supported in compressed"); MatrixBlock[] thatUC = new MatrixBlock[that.length]; @@ -511,8 +521,7 @@ public void append(MatrixValue v2, ArrayList outlist, int bl } @Override - public MatrixBlock chainMatrixMultOperations(MatrixBlock v, MatrixBlock w, MatrixBlock out, ChainType ctype, - int k) { + public MatrixBlock chainMatrixMultOperations(MatrixBlock v, MatrixBlock w, MatrixBlock out, ChainType ctype, int k) { checkMMChain(ctype, v, w); // multi-threaded MMChain of single uncompressed ColGroup @@ -589,21 +598,7 @@ else if(isOverlapping()) { @Override public MatrixBlock reorgOperations(ReorgOperator op, MatrixValue ret, int startRow, int startColumn, int length) { - if(op.fn instanceof SwapIndex && this.getNumColumns() == 1) { - MatrixBlock tmp = decompress(op.getNumThreads()); - long nz = tmp.setNonZeros(tmp.getNonZeros()); - tmp = new MatrixBlock(tmp.getNumColumns(), tmp.getNumRows(), tmp.getDenseBlockValues()); - tmp.setNonZeros(nz); - return tmp; - } - else { - // Allow transpose to be compressed output. In general we need to have a transposed flag on - // the compressed matrix. https://issues.apache.org/jira/browse/SYSTEMDS-3025 - String message = op.getClass().getSimpleName() + " -- " + op.fn.getClass().getSimpleName(); - MatrixBlock tmp = getUncompressed(message, op.getNumThreads()); - return tmp.reorgOperations(op, ret, startRow, startColumn, length); - } - + return CLALibReorg.reorg(this, op, (MatrixBlock) ret, startRow, startColumn, length); } public boolean isOverlapping() { @@ -1101,8 +1096,7 @@ public void appendRow(int r, SparseRow row, boolean deep) { } @Override - public void appendRowToSparse(SparseBlock dest, MatrixBlock src, int i, int rowoffset, int coloffset, - boolean deep) { + public void appendRowToSparse(SparseBlock dest, MatrixBlock src, int i, int rowoffset, int coloffset, boolean deep) { throw new DMLCompressionException("Can't append row to compressed Matrix"); } @@ -1157,12 +1151,12 @@ public void examSparsity(boolean allowCSR, int k) { } @Override - public void sparseToDense(int k) { - // do nothing + public MatrixBlock sparseToDense(int k) { + return this; // do nothing } @Override - public void denseToSparse(boolean allowCSR, int k){ + public void denseToSparse(boolean allowCSR, int k) { // do nothing } diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java index cc5f5465fb4..b737ed4a3af 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java @@ -283,7 +283,7 @@ else if(mb instanceof CompressedMatrixBlock && ((CompressedMatrixBlock) mb).isOv return createEmpty(); res = new CompressedMatrixBlock(mb); // copy metadata and allocate soft reference - + logInit(); classifyPhase(); if(compressionGroups == null) return abortCompression(); @@ -396,7 +396,7 @@ private void transposeHeuristics() { compSettings.transposed = false; break; default: - compSettings.transposed = transposeHeuristics(compressionGroups.getNumberColGroups() , mb); + compSettings.transposed = transposeHeuristics(compressionGroups.getNumberColGroups(), mb); } } @@ -465,6 +465,16 @@ private Pair abortCompression() { return new ImmutablePair<>(mb, _stats); } + private void logInit() { + if(LOG.isDebugEnabled()) { + LOG.debug("--Seed used for comp : " + compSettings.seed); + LOG.debug(String.format("--number columns to compress: %10d", mb.getNumColumns())); + LOG.debug(String.format("--number rows to compress : %10d", mb.getNumRows())); + LOG.debug(String.format("--sparsity : %10.5f", mb.getSparsity())); + LOG.debug(String.format("--nonZeros : %10d", mb.getNonZeros())); + } + } + private void logPhase() { setNextTimePhase(time.stop()); DMLCompressionStatistics.addCompressionTime(getLastTimePhase(), phase); @@ -476,7 +486,6 @@ private void logPhase() { else { switch(phase) { case 0: - LOG.debug("--Seed used for comp : " + compSettings.seed); LOG.debug("--compression phase " + phase + " Classify : " + getLastTimePhase()); LOG.debug("--Individual Columns Estimated Compression: " + _stats.estimatedSizeCols); if(mb instanceof CompressedMatrixBlock) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java index ec5512266e8..dc0908dc9bf 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java @@ -35,10 +35,7 @@ */ public class CompressionSettingsBuilder { private double samplingRatio; - // private double samplePower = 0.6; private double samplePower = 0.65; - // private double samplePower = 0.68; - // private double samplePower = 0.7; private boolean allowSharedDictionary = false; private String transposeInput; private int seed = -1; diff --git a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java index f8fe0287542..449d42e3789 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java +++ b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java @@ -92,16 +92,22 @@ private List coCodeBruteForce(List changeInCost) + if(-Math.min(costC1, costC2) > changeInCost // change in cost cannot possibly be better. + || (maxCombined < 0) // int overflow + || (maxCombined > c1i.getNumRows() * 2)) // higher combined number of rows. continue; // Combine the two column groups. @@ -202,10 +208,20 @@ protected CombineTask(ColIndexes c1, ColIndexes c2) { } @Override - public Object call() { - final IColIndex c = _c1._indexes.combine(_c2._indexes); - final ColIndexes cI = new ColIndexes(c); - mem.getOrCreate(cI, _c1, _c2); + public Object call() throws Exception { + final CompressedSizeInfoColGroup c1i = mem.get(_c1); + final CompressedSizeInfoColGroup c2i = mem.get(_c2); + if(c1i != null && c2i != null) { + final int maxCombined = c1i.getNumVals() * c2i.getNumVals(); + + if(maxCombined < 0 // int overflow + || maxCombined > c1i.getNumRows() * 2) // higher combined than number of rows. + return null; + + final IColIndex c = _c1._indexes.combine(_c2._indexes); + final ColIndexes cI = new ColIndexes(c); + mem.getOrCreate(cI, _c1, _c2); + } return null; } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeHybrid.java b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeHybrid.java index 6dc53739d24..27a64a0c5b2 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeHybrid.java +++ b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeHybrid.java @@ -40,13 +40,13 @@ protected CompressedSizeInfo coCodeColumns(CompressedSizeInfo colInfos, int k) { if(startSize == 1) return colInfos; // nothing to join when there only is one column else if(startSize <= 16) {// Greedy all compare all if small number of columns - LOG.debug("Hybrid chose to do greedy cocode because of few columns"); + LOG.debug("Hybrid chose to do greedy CoCode because of few columns"); CoCodeGreedy gd = new CoCodeGreedy(_sest, _cest, _cs); return colInfos.setInfo(gd.combine(colInfos.getInfo(), k)); } else if(startSize > 1000) return colInfos.setInfo(CoCodePriorityQue.join(colInfos.getInfo(), _sest, _cest, 1, k)); - LOG.debug("Using Hybrid Cocode Strategy: "); + LOG.debug("Using Hybrid CoCode Strategy: "); final int PriorityQueGoal = startSize / 5; if(PriorityQueGoal > 30) { // hybrid if there is a large number of columns to begin with @@ -62,8 +62,11 @@ else if(startSize > 1000) } return colInfos; } - else // If somewhere in between use the que based approach only. - return colInfos.setInfo(CoCodePriorityQue.join(colInfos.getInfo(), _sest, _cest, 1, k)); - + else { + LOG.debug("Using only Greedy based since Nr Column groups: " + startSize + " is not large enough"); + CoCodeGreedy gd = new CoCodeGreedy(_sest, _cest, _cs); + colInfos.setInfo(gd.combine(colInfos.getInfo(), k)); + return colInfos; + } } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCoderFactory.java b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCoderFactory.java index abd12d3f6a8..12fee3c50b4 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCoderFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCoderFactory.java @@ -91,7 +91,7 @@ else if(g.isConst()) // overwrite groups. colInfos.compressionInfo = groups; - + // cocode remaining groups if(!groups.isEmpty()) { colInfos = co.coCodeColumns(colInfos, k); @@ -135,7 +135,7 @@ private static AColumnCoCoder createColumnGroupPartitioner(PartitionerType type, case PRIORITY_QUE: return new CoCodePriorityQue(est, costEstimator, cs); default: - throw new RuntimeException("Unsupported column group partitioner: " + type.toString()); + throw new RuntimeException("Unsupported column group partition technique: " + type.toString()); } } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/cocode/Memorizer.java b/src/main/java/org/apache/sysds/runtime/compress/cocode/Memorizer.java index db77a32bf68..b81694a54db 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/cocode/Memorizer.java +++ b/src/main/java/org/apache/sysds/runtime/compress/cocode/Memorizer.java @@ -24,6 +24,7 @@ import java.util.Map; import java.util.Map.Entry; +import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.estim.AComEst; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; @@ -60,7 +61,7 @@ public void remove(ColIndexes c1, ColIndexes c2) { } } - public CompressedSizeInfoColGroup getOrCreate(ColIndexes cI, ColIndexes c1, ColIndexes c2){ + public CompressedSizeInfoColGroup getOrCreate(ColIndexes cI, ColIndexes c1, ColIndexes c2) { CompressedSizeInfoColGroup g = mem.get(cI); st2++; if(g == null) { @@ -69,7 +70,11 @@ public CompressedSizeInfoColGroup getOrCreate(ColIndexes cI, ColIndexes c1, ColI if(left != null && right != null) { st3++; g = _sEst.combine(cI._indexes, left, right); - + if(g != null) { + if(g.getNumVals() < 0) + throw new DMLCompressionException( + "Combination returned less distinct values on: \n" + left + "\nand\n" + right + "\nEq\n" + g); + } synchronized(this) { mem.put(cI, g); } @@ -88,7 +93,7 @@ public void incst4() { } public String stats() { - return " possible: " + st1 + " requests: " + st2 + " combined: " + st3 + " outSecond: "+ st4; + return " possible: " + st1 + " requests: " + st2 + " combined: " + st3 + " outSecond: " + st4; } public void resetStats() { diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java index a4030d95612..26233d3fff2 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java @@ -163,7 +163,7 @@ public final void decompressToSparseBlock(SparseBlock sb, int rl, int ru) { /** * Decompress a range of rows into a dense block * - * @param db Sparse Target block + * @param db Dense target block * @param rl Row to start at * @param ru Row to end at */ @@ -171,6 +171,15 @@ public final void decompressToDenseBlock(DenseBlock db, int rl, int ru) { decompressToDenseBlock(db, rl, ru, 0, 0); } + /** + * Decopress a range of rows into a dense transposed block. + * + * @param db Dense target block + * @param rl Row in this column group to start at. + * @param ru Row in this column group to end at. + */ + public abstract void decompressToDenseBlockTransposed(DenseBlock db, int rl, int ru); + /** * Serializes column group to data output. * @@ -606,8 +615,8 @@ public AColGroup addVector(double[] v) { public abstract boolean isEmpty(); /** - * Append the other column group to this column group. This method tries to combine them to return a new column - * group containing both. In some cases it is possible in reasonable time, in others it is not. + * Append the other column group to this column group. This method tries to combine them to return a new column group + * containing both. In some cases it is possible in reasonable time, in others it is not. * * The result is first this column group followed by the other column group in higher row values. * @@ -670,11 +679,18 @@ public void clear() { /** * Recompress this column group into a new column group of the given type. * - * @param ct The compressionType that the column group should morph into + * @param ct The compressionType that the column group should morph into + * @param nRow The number of rows in this columngroup. * @return A new column group */ - public AColGroup morph(CompressionType ct) { - throw new NotImplementedException(); + public AColGroup morph(CompressionType ct, int nRow) { + if(ct == getCompType()) + return this; + else if (ct == CompressionType.DDCFOR) + return this; // it does not make sense to change to FOR. + else{ + throw new NotImplementedException("Morphing from : " + getCompType() + " to " + ct + " is not implemented"); + } } /** @@ -716,6 +732,34 @@ public AColGroup sortColumnIndexes() { protected abstract AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering); + /** + * Get an approximate sparsity of this column group + * + * @return the approximate sparsity of this columngroup + */ + public abstract double getSparsity(); + + /** + * Sparse selection (left matrix multiply) + * + * @param selection A sparse matrix with "max" a single one in each row all other values are zero. + * @param ret The Sparse MatrixBlock to decompress the selected rows into + * @param rl The row to start at in the selection matrix + * @param ru the row to end at in the selection matrix (not inclusive) + */ + public abstract void sparseSelection(MatrixBlock selection, MatrixBlock ret, int rl, int ru); + + /** + * Method to determine if the columnGroup have the same index structure as another. Note that the column indexes and + * dictionaries are allowed to be different. + * + * @param that the other column group + * @return if the index is the same. + */ + public boolean sameIndexStructure(AColGroup that) { + return false; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupCompressed.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupCompressed.java index 97f0d8058a8..81433a20450 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupCompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupCompressed.java @@ -19,6 +19,9 @@ package org.apache.sysds.runtime.compress.colgroup; +import java.util.List; + +import org.apache.commons.lang.NotImplementedException; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.DMLScriptException; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; @@ -86,6 +89,14 @@ protected AColGroupCompressed(IColIndex colIndices) { protected abstract double[] preAggBuiltinRows(Builtin builtin); + @Override + public boolean sameIndexStructure(AColGroup that) { + if(that instanceof AColGroupCompressed) + return sameIndexStructure((AColGroupCompressed) that); + else + return false; + } + public abstract boolean sameIndexStructure(AColGroupCompressed that); public double[] preAggRows(ValueFunction fn) { @@ -215,7 +226,8 @@ protected static void tsmm(double[] result, int numColumns, int[] counts, IDicti } - protected static void tsmmDense(double[] result, int numColumns, double[] values, int[] counts, IColIndex colIndexes) { + protected static void tsmmDense(double[] result, int numColumns, double[] values, int[] counts, + IColIndex colIndexes) { final int nCol = colIndexes.size(); final int nRow = counts.length; for(int k = 0; k < nRow; k++) { @@ -231,7 +243,8 @@ protected static void tsmmDense(double[] result, int numColumns, double[] values } } - protected static void tsmmSparse(double[] result, int numColumns, SparseBlock sb, int[] counts, IColIndex colIndexes) { + protected static void tsmmSparse(double[] result, int numColumns, SparseBlock sb, int[] counts, + IColIndex colIndexes) { for(int row = 0; row < counts.length; row++) { if(sb.isEmpty(row)) continue; @@ -253,4 +266,20 @@ protected static void tsmmSparse(double[] result, int numColumns, SparseBlock sb public boolean isEmpty() { return false; } + + /** + * C bind the list of column groups with this column group. the list of elements provided in the index of each list + * is guaranteed to have the same index structures + * + * @param index The index to look up in each list of the right argument to find the correct column groups to combine. + * @param nCol The number of columns to shift the right hand side column groups over when combining, this should + * only effect the column indexes + * @param right The right hand side column groups to combine. NOTE only the index offset of the second nested list + * should be used. The reason for providing this nested list is to avoid redundant allocations in + * calling methods. + * @return A combined compressed column group of the same type as this!. + */ + public AColGroupCompressed combineWithSameIndex(int index, int nCol, List> right) { + throw new NotImplementedException(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java index 753bef2619f..4116a0ae82d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java @@ -24,7 +24,6 @@ import java.util.HashSet; import java.util.Set; -import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary; @@ -58,9 +57,38 @@ public IDictionary getDictionary() { } @Override - public final void decompressToDenseBlock(DenseBlock db, int rl, int ru, int offR, int offC) { + public final void decompressToDenseBlockTransposed(DenseBlock db, int rl, int ru) { if(_dict instanceof IdentityDictionary) { + final MatrixBlockDictionary md = ((IdentityDictionary) _dict).getMBDict(); + final MatrixBlock mb = md.getMatrixBlock(); + // The dictionary is never empty. + if(mb.isInSparseFormat()) + decompressToDenseBlockTransposedSparseDictionary(db, rl, ru, mb.getSparseBlock()); + else + decompressToDenseBlockTransposedDenseDictionary(db, rl, ru, mb.getDenseBlockValues()); + } + else if(_dict instanceof MatrixBlockDictionary) { + final MatrixBlockDictionary md = (MatrixBlockDictionary) _dict; + final MatrixBlock mb = md.getMatrixBlock(); + // The dictionary is never empty. + if(mb.isInSparseFormat()) + decompressToDenseBlockTransposedSparseDictionary(db, rl, ru, mb.getSparseBlock()); + else + decompressToDenseBlockTransposedDenseDictionary(db, rl, ru, mb.getDenseBlockValues()); + } + else + decompressToDenseBlockTransposedDenseDictionary(db, rl, ru, _dict.getValues()); + } + + protected abstract void decompressToDenseBlockTransposedSparseDictionary(DenseBlock db, int rl, int ru, + SparseBlock dict); + + protected abstract void decompressToDenseBlockTransposedDenseDictionary(DenseBlock db, int rl, int ru, + double[] dict); + @Override + public final void decompressToDenseBlock(DenseBlock db, int rl, int ru, int offR, int offC) { + if(_dict instanceof IdentityDictionary) { final MatrixBlockDictionary md = ((IdentityDictionary) _dict).getMBDict(); final MatrixBlock mb = md.getMatrixBlock(); // The dictionary is never empty. @@ -198,7 +226,7 @@ public final AColGroup rightMultByMatrix(MatrixBlock right, IColIndex allCols) { final int nVals = getNumValues(); final IDictionary preAgg = (right.isInSparseFormat()) ? // Chose Sparse or Dense - rightMMPreAggSparse(nVals, right.getSparseBlock(), agCols, 0, nCol) : // sparse + rightMMPreAggSparse(nVals, right.getSparseBlock(), agCols, nCol) : // sparse _dict.preaggValuesFromDense(nVals, _colIndexes, agCols, right.getDenseBlockValues(), nCol); // dense return allocateRightMultiplication(right, agCols, preAgg); } @@ -269,30 +297,8 @@ protected IColIndex rightMMGetColsSparse(SparseBlock b, int retCols, IColIndex a return ColIndexFactory.create(aggregateColumns); } - private IDictionary rightMMPreAggSparse(int numVals, SparseBlock b, IColIndex aggregateColumns, int cl, int cu) { - final double[] ret = new double[numVals * aggregateColumns.size()]; - for(int h = 0; h < _colIndexes.size(); h++) { - final int colIdx = _colIndexes.get(h); - if(b.isEmpty(colIdx)) - continue; - - final double[] sValues = b.values(colIdx); - final int[] sIndexes = b.indexes(colIdx); - int retIdx = 0; - for(int i = b.pos(colIdx); i < b.size(colIdx) + b.pos(colIdx); i++) { - while(aggregateColumns.get(retIdx) < sIndexes[i]) - retIdx++; - // It is known in this case that the sIndex always correspond to the aggregateColumns. - // if(sIndexes[i] == aggregateColumns[retIdx]) - for(int j = 0, offOrg = h; - j < numVals * aggregateColumns.size(); - j += aggregateColumns.size(), offOrg += _colIndexes.size()) { - ret[j + retIdx] += _dict.getValue(offOrg) * sValues[i]; - } - } - - } - return Dictionary.create(ret); + private IDictionary rightMMPreAggSparse(int numVals, SparseBlock b, IColIndex aggregateColumns, int nColRight) { + return _dict.rightMMPreAggSparse(numVals, b, this._colIndexes, aggregateColumns, nColRight); } @Override @@ -315,4 +321,8 @@ public final AColGroup copyAndSet(IDictionary newDictionary) { protected abstract AColGroup copyAndSet(IColIndex colIndexes, IDictionary newDictionary); + @Override + public double getSparsity() { + return _dict.getSparsity(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AMorphingMMColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AMorphingMMColGroup.java index fc2c3642015..ff8e334c73b 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AMorphingMMColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AMorphingMMColGroup.java @@ -68,6 +68,24 @@ protected final void decompressToDenseBlockDenseDictionary(DenseBlock db, int rl decompressToDenseBlockCommonVector(db, rl, ru, offR, offC, cv); } + @Override + protected void decompressToDenseBlockTransposedSparseDictionary(DenseBlock db, int rl, int ru, SparseBlock sb) { + LOG.warn("Should never call decompress on morphing group instead extract common values and combine all commons"); + double[] cv = new double[db.getDim(1)]; + AColGroup b = extractCommon(cv); + b.decompressToDenseBlockTransposed(db, rl, ru); + decompressToDenseBlockTransposedCommonVector(db, rl, ru, cv); + } + + @Override + protected void decompressToDenseBlockTransposedDenseDictionary(DenseBlock db, int rl, int ru, double[] dict) { + LOG.warn("Should never call decompress on morphing group instead extract common values and combine all commons"); + double[] cv = new double[db.getDim(1)]; + AColGroup b = extractCommon(cv); + b.decompressToDenseBlockTransposed(db, rl, ru); + decompressToDenseBlockTransposedCommonVector(db, rl, ru, cv); + } + private final void decompressToDenseBlockCommonVector(DenseBlock db, int rl, int ru, int offR, int offC, double[] common) { for(int i = rl, offT = rl + offR; i < ru; i++, offT++) { @@ -78,6 +96,18 @@ private final void decompressToDenseBlockCommonVector(DenseBlock db, int rl, int } } + private final void decompressToDenseBlockTransposedCommonVector(DenseBlock db, int rl, int ru, double[] common) { + for(int j = 0; j < _colIndexes.size(); j++){ + final int rowOut = _colIndexes.get(j); + final double[] c = db.values(rowOut); + final int off = db.pos(rowOut); + double v = common[j]; + for(int i = rl; i < ru; i++) { + c[off + i] += v; + } + } + } + @Override protected final void decompressToSparseBlockSparseDictionary(SparseBlock ret, int rl, int ru, int offR, int offC, SparseBlock sb) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java index 9fe286ddada..d13008239e7 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java @@ -313,7 +313,7 @@ public void mmWithDictionary(MatrixBlock preAgg, MatrixBlock tmpRes, MatrixBlock final MatrixBlock preAggCopy = new MatrixBlock(); preAggCopy.copy(preAgg); final MatrixBlock tmpResCopy = new MatrixBlock(); - tmpResCopy.copy(tmpRes); + tmpResCopy.copyShallow(tmpRes); // Get dictionary matrixBlock final MatrixBlock dict = getDictionary().getMBDict(_colIndexes.size()).getMatrixBlock(); if(dict != null) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDC.java index 633adb3d01b..daa535e383c 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDC.java @@ -69,4 +69,14 @@ public final CompressedSizeInfoColGroup getCompressionInfo(int nRow) { public ICLAScheme getCompressionScheme() { return SDCScheme.create(this); } + + @Override + public AColGroup morph(CompressionType ct, int nRow) { + if(ct == getCompType()) + return this; + else if (ct == CompressionType.SDCFOR) + return this; // it does not make sense to change to FOR. + else + return super.morph(ct, nRow); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java index 77fb11e77ea..e99460027d7 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java @@ -65,9 +65,12 @@ else if(rl == ru - 1) private final void leftMultByMatrixNoPreAggSingleRow(MatrixBlock mb, MatrixBlock result, int r, int cl, int cu, AIterator it) { - final double[] resV = result.getDenseBlockValues(); - final int nCols = result.getNumColumns(); - final int offRet = nCols * r; + if(mb.isEmpty()) // early abort. + return; + + final DenseBlock res = result.getDenseBlock(); + final double[] resV = res.values(r); + final int offRet = res.pos(r); if(mb.isInSparseFormat()) { final SparseBlock sb = mb.getSparseBlock(); leftMultByMatrixNoPreAggSingleRowSparse(sb, resV, offRet, r, cu, it); @@ -102,50 +105,62 @@ private final void leftMultByMatrixNoPreAggSingleRowSparse(final SparseBlock sb, final int alen = sb.size(r) + apos; final int[] aix = sb.indexes(r); final double[] aval = sb.values(r); - int v = it.value(); + final int v = it.value(); while(apos < alen && aix[apos] < v) apos++; // go though sparse block until offset start. - if(cu < last) { - while(v < cu && apos < alen) { - if(aix[apos] == v) { - multiplyScalar(aval[apos++], resV, offRet, it); - v = it.next(); - } - else if(aix[apos] < v) - apos++; - else - v = it.next(); + if(cu < last) + leftMultByMatrixNoPreAggSingleRowSparseInside(v, it, apos, alen, aix, aval, resV, offRet, cu); + else if(aix[alen - 1] < last) + leftMultByMatrixNoPreAggSingleRowSparseLessThan(v, it, apos, alen, aix, aval, resV, offRet); + else + leftMultByMatrixNoPreAggSingleRowSparseTail(v, it, apos, alen, aix, aval, resV, offRet, cu, last); + } + + private final void leftMultByMatrixNoPreAggSingleRowSparseInside(int v, AIterator it, int apos, int alen, int[] aix, + double[] aval, double[] resV, int offRet, int cu) { + while(v < cu && apos < alen) { + if(aix[apos] == v) { + multiplyScalar(aval[apos++], resV, offRet, it); + v = it.next(); } + else if(aix[apos] < v) + apos++; + else + v = it.next(); } - else if(aix[alen - 1] < last) { - while(apos < alen) { - if(aix[apos] == v) { - multiplyScalar(aval[apos++], resV, offRet, it); - v = it.next(); - } - else if(aix[apos] < v) - apos++; - else - v = it.next(); + } + + private final void leftMultByMatrixNoPreAggSingleRowSparseLessThan(int v, AIterator it, int apos, int alen, + int[] aix, double[] aval, double[] resV, int offRet) { + while(apos < alen) { + if(aix[apos] == v) { + multiplyScalar(aval[apos++], resV, offRet, it); + v = it.next(); } + else if(aix[apos] < v) + apos++; + else + v = it.next(); } - else { - while(v < last) { - if(aix[apos] == v) { - multiplyScalar(aval[apos++], resV, offRet, it); - v = it.next(); - } - else if(aix[apos] < v) - apos++; - else - v = it.next(); + } + + private final void leftMultByMatrixNoPreAggSingleRowSparseTail(int v, AIterator it, int apos, int alen, int[] aix, + double[] aval, double[] resV, int offRet, int cu, int last) { + while(v < last) { + if(aix[apos] == v) { + multiplyScalar(aval[apos++], resV, offRet, it); + v = it.next(); } - while(aix[apos] < last && apos < alen) + else if(aix[apos] < v) apos++; - - if(last == aix[apos]) - multiplyScalar(aval[apos], resV, offRet, it); + else + v = it.next(); } + while(aix[apos] < last && apos < alen) + apos++; + + if(last == aix[apos]) + multiplyScalar(aval[apos], resV, offRet, it); } private final void leftMultByMatrixNoPreAggRows(MatrixBlock mb, MatrixBlock result, int rl, int ru, int cl, int cu, @@ -230,11 +245,21 @@ public double[] getDefaultTuple() { @Override public final CompressedSizeInfoColGroup getCompressionInfo(int nRow) { EstimationFactors ef = new EstimationFactors(getNumValues(), _numRows, getNumberOffsets(), _dict.getSparsity()); - return new CompressedSizeInfoColGroup(_colIndexes, ef, nRow, getCompType(),getEncoding()); + return new CompressedSizeInfoColGroup(_colIndexes, ef, nRow, getCompType(), getEncoding()); } - @Override + @Override public ICLAScheme getCompressionScheme() { return SDCScheme.create(this); } + + @Override + public AColGroup morph(CompressionType ct, int nRow) { + if(ct == getCompType()) + return this; + else if (ct == CompressionType.SDCFOR) + return this; // it does not make sense to change to FOR. + else + return super.morph(ct, nRow); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java index 0ef7a423503..95802725e9e 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java @@ -647,4 +647,44 @@ public AMapToData getMapToData() { return MapToFactory.create(0, 0); } + @Override + public double getSparsity(){ + return 1.0; + } + + @Override + public void sparseSelection(MatrixBlock selection, MatrixBlock ret, int rl, int ru){ + throw new NotImplementedException(); + } + + @Override + protected void decompressToDenseBlockTransposedSparseDictionary(DenseBlock db, int rl, int ru, SparseBlock sb) { + // guaranteed to be containing some values therefore no check for empty. + final int apos = sb.pos(0); + final int alen = sb.size(0); + final int[] aix = sb.indexes(0); + final double[] avals = sb.values(0); + + for(int j = apos; j < alen; j++){ + final int rowOut = _colIndexes.get(aix[j]); + final double[] c = db.values(rowOut); + final int off = db.pos(rowOut); // row offset out. + final double v = avals[j]; + for(int i = rl; i < ru; i++) + c[off + i] += v; + } + } + + @Override + protected void decompressToDenseBlockTransposedDenseDictionary(DenseBlock db, int rl, int ru, double[] dict) { + for(int j = 0; j < _colIndexes.size(); j++){ + final int rowOut = _colIndexes.get(j); + final double[] c = db.values(rowOut); + final int off = db.pos(rowOut); + double v = dict[j]; + for(int i = rl; i < ru; i++) { + c[off + i] += v; + } + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java index 6340affede3..929a7fc509b 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java @@ -22,9 +22,11 @@ import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; -import org.apache.commons.lang3.NotImplementedException; +import org.apache.commons.lang.NotImplementedException; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.DMLCompressionException; @@ -39,6 +41,7 @@ import org.apache.sysds.runtime.compress.colgroup.mapping.MapToChar; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory; import org.apache.sysds.runtime.compress.colgroup.offset.AOffsetIterator; +import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; import org.apache.sysds.runtime.compress.colgroup.scheme.DDCScheme; import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; @@ -51,7 +54,9 @@ import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.Minus; import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.matrix.data.LibMatrixMult; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.Pair; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.runtime.matrix.operators.ScalarOperator; import org.apache.sysds.runtime.matrix.operators.UnaryOperator; @@ -110,35 +115,31 @@ protected void decompressToDenseBlockSparseDictionary(DenseBlock db, int rl, int continue; final double[] c = db.values(offT); final int off = db.pos(offT) + offC; - final int apos = sb.pos(vr); - final int alen = sb.size(vr) + apos; - final int[] aix = sb.indexes(vr); - final double[] aval = sb.values(vr); - for(int j = apos; j < alen; j++) - c[off + _colIndexes.get(aix[j])] += aval[j]; + _colIndexes.decompressToDenseFromSparse(sb, vr, off, c); } } @Override protected void decompressToDenseBlockDenseDictionary(DenseBlock db, int rl, int ru, int offR, int offC, double[] values) { + final int idxSize = _colIndexes.size(); if(db.isContiguous()) { - final int nCol = db.getDim(1); - if(_colIndexes.size() == 1 && nCol == 1) + final int nColOut = db.getDim(1); + if(idxSize == 1 && nColOut == 1) decompressToDenseBlockDenseDictSingleColOutContiguous(db, rl, ru, offR, offC, values); - else if(_colIndexes.size() == 1) + else if(idxSize == 1) decompressToDenseBlockDenseDictSingleColContiguous(db, rl, ru, offR, offC, values); - else if(_colIndexes.size() == nCol) // offC == 0 implied - decompressToDenseBlockDenseDictAllColumnsContiguous(db, rl, ru, offR, values); + else if(idxSize == nColOut) // offC == 0 implied + decompressToDenseBlockDenseDictAllColumnsContiguous(db, rl, ru, offR, values, idxSize); else if(offC == 0 && offR == 0) decompressToDenseBlockDenseDictNoOff(db, rl, ru, values); else if(offC == 0) - decompressToDenseBlockDenseDictNoColOffset(db, rl, ru, offR, values); + decompressToDenseBlockDenseDictNoColOffset(db, rl, ru, offR, values, idxSize, nColOut); else - decompressToDenseBlockDenseDictGeneric(db, rl, ru, offR, offC, values); + decompressToDenseBlockDenseDictGeneric(db, rl, ru, offR, offC, values, idxSize); } else - decompressToDenseBlockDenseDictGeneric(db, rl, ru, offR, offC, values); + decompressToDenseBlockDenseDictGeneric(db, rl, ru, offR, offC, values, idxSize); } private final void decompressToDenseBlockDenseDictSingleColContiguous(DenseBlock db, int rl, int ru, int offR, @@ -171,7 +172,6 @@ else if(data instanceof MapToChar) decompressToDenseBlockDenseDictSingleColOutContiguousCharM(c, rl, ru, offR, values, (MapToChar) data); else decompressToDenseBlockDenseDictSingleColOutContiguousGenM(c, rl, ru, offR, values, data); - } private final static void decompressToDenseBlockDenseDictSingleColOutContiguousByteM(double[] c, int rl, int ru, @@ -193,28 +193,22 @@ private final static void decompressToDenseBlockDenseDictSingleColOutContiguousG } private final void decompressToDenseBlockDenseDictAllColumnsContiguous(DenseBlock db, int rl, int ru, int offR, - double[] values) { + double[] values, int nCol) { final double[] c = db.values(0); - final int nCol = _colIndexes.size(); for(int r = rl; r < ru; r++) { final int start = _data.getIndex(r) * nCol; - final int end = start + nCol; final int offStart = (offR + r) * nCol; - for(int vOff = start, off = offStart; vOff < end; vOff++, off++) - c[off] += values[vOff]; + LibMatrixMult.vectAdd(values, c, start, offStart, nCol); } } private final void decompressToDenseBlockDenseDictNoColOffset(DenseBlock db, int rl, int ru, int offR, - double[] values) { - final int nCol = _colIndexes.size(); - final int colOut = db.getDim(1); + double[] values, int nCol, int colOut) { int off = (rl + offR) * colOut; for(int i = rl, offT = rl + offR; i < ru; i++, off += colOut) { final double[] c = db.values(offT); final int rowIndex = _data.getIndex(i) * nCol; - for(int j = 0; j < nCol; j++) - c[off + _colIndexes.get(j)] += values[rowIndex + j]; + _colIndexes.decompressVec(nCol, c, off, values, rowIndex); } } @@ -225,20 +219,17 @@ private final void decompressToDenseBlockDenseDictNoOff(DenseBlock db, int rl, i for(int i = rl; i < ru; i++) { final int off = i * nColU; final int rowIndex = _data.getIndex(i) * nCol; - for(int j = 0; j < nCol; j++) - c[off + _colIndexes.get(j)] += values[rowIndex + j]; + _colIndexes.decompressVec(nCol, c, off, values, rowIndex); } } private final void decompressToDenseBlockDenseDictGeneric(DenseBlock db, int rl, int ru, int offR, int offC, - double[] values) { - final int nCol = _colIndexes.size(); + double[] values, int nCol) { for(int i = rl, offT = rl + offR; i < ru; i++, offT++) { final double[] c = db.values(offT); final int off = db.pos(offT) + offC; final int rowIndex = _data.getIndex(i) * nCol; - for(int j = 0; j < nCol; j++) - c[off + _colIndexes.get(j)] += values[rowIndex + j]; + _colIndexes.decompressVec(nCol, c, off, values, rowIndex); } } @@ -261,7 +252,11 @@ protected void decompressToSparseBlockSparseDictionary(SparseBlock ret, int rl, @Override protected void decompressToSparseBlockDenseDictionary(SparseBlock ret, int rl, int ru, int offR, int offC, double[] values) { - final int nCol = _colIndexes.size(); + decompressToSparseBlockDenseDictionary(ret, rl, ru, offR, offC, values, _colIndexes.size()); + } + + protected void decompressToSparseBlockDenseDictionary(SparseBlock ret, int rl, int ru, int offR, int offC, + double[] values, int nCol) { for(int i = rl, offT = rl + offR; i < ru; i++, offT++) { final int rowIndex = _data.getIndex(i) * nCol; for(int j = 0; j < nCol; j++) @@ -269,6 +264,26 @@ protected void decompressToSparseBlockDenseDictionary(SparseBlock ret, int rl, i } } + + @Override + protected void decompressToDenseBlockTransposedSparseDictionary(DenseBlock db, int rl, int ru, SparseBlock sb) { + throw new NotImplementedException(); + } + + @Override + protected void decompressToDenseBlockTransposedDenseDictionary(DenseBlock db, int rl, int ru, double[] dict) { + final int nCol = _colIndexes.size(); + for(int j = 0; j < nCol; j++){ + final int rowOut = _colIndexes.get(j); + final double[] c = db.values(rowOut); + final int off = db.pos(rowOut); + for(int i = rl; i < ru; i++) { + final double v = dict[_data.getIndex(i) * nCol + j]; + c[off + i] += v; + } + } + } + @Override public double getIdx(int r, int colIdx) { return _dict.getValue(_data.getIndex(r), colIdx, _colIndexes.size()); @@ -307,22 +322,34 @@ public void leftMultByMatrixNoPreAgg(MatrixBlock matrix, MatrixBlock result, int private void leftMultByMatrixNoPreAggSingleCol(MatrixBlock matrix, MatrixBlock result, int rl, int ru, int cl, int cu) { - final double[] retV = result.getDenseBlockValues(); + final DenseBlock retV = result.getDenseBlock(); final int nColM = matrix.getNumColumns(); final int nColRet = result.getNumColumns(); final double[] dictVals = _dict.getValues(); // guaranteed dense double since we only have one column. - if(matrix.isInSparseFormat()) { + if(matrix.isEmpty()) + return; + else if(matrix.isInSparseFormat()) { if(cl != 0 || cu != _data.size()) - throw new NotImplementedException(); - lmSparseMatrixNoPreAggSingleCol(matrix.getSparseBlock(), nColM, retV, nColRet, dictVals, rl, ru); + lmSparseMatrixNoPreAggSingleCol(matrix.getSparseBlock(), nColM, retV, nColRet, dictVals, rl, ru, cl, cu); + else + lmSparseMatrixNoPreAggSingleCol(matrix.getSparseBlock(), nColM, retV, nColRet, dictVals, rl, ru); } else lmDenseMatrixNoPreAggSingleCol(matrix.getDenseBlockValues(), nColM, retV, nColRet, dictVals, rl, ru, cl, cu); } - private void lmSparseMatrixNoPreAggSingleCol(SparseBlock sb, int nColM, double[] retV, int nColRet, double[] vals, + private void lmSparseMatrixNoPreAggSingleCol(SparseBlock sb, int nColM, DenseBlock retV, int nColRet, double[] vals, int rl, int ru) { + + if(retV.isContiguous()) + lmSparseMatrixNoPreAggSingleColContiguous(sb, nColM, retV.valuesAt(0), nColRet, vals, rl, ru); + else + lmSparseMatrixNoPreAggSingleColGeneric(sb, nColM, retV, nColRet, vals, rl, ru); + } + + private void lmSparseMatrixNoPreAggSingleColGeneric(SparseBlock sb, int nColM, DenseBlock ret, int nColRet, + double[] vals, int rl, int ru) { final int colOut = _colIndexes.get(0); for(int r = rl; r < ru; r++) { @@ -332,52 +359,164 @@ private void lmSparseMatrixNoPreAggSingleCol(SparseBlock sb, int nColM, double[] final int alen = sb.size(r) + apos; final int[] aix = sb.indexes(r); final double[] aval = sb.values(r); - final int offR = r * nColRet; + final int offR = ret.pos(r); + final double[] retV = ret.values(r); + for(int i = apos; i < alen; i++) retV[offR + colOut] += aval[i] * vals[_data.getIndex(aix[i])]; } } - private void lmDenseMatrixNoPreAggSingleCol(double[] mV, int nColM, double[] retV, int nColRet, double[] vals, - int rl, int ru, int cl, int cu) { + private void lmSparseMatrixNoPreAggSingleColContiguous(SparseBlock sb, int nColM, double[] retV, int nColRet, + double[] vals, int rl, int ru) { final int colOut = _colIndexes.get(0); + for(int r = rl; r < ru; r++) { - final int offL = r * nColM; + if(sb.isEmpty(r)) + continue; + final int apos = sb.pos(r); + final int alen = sb.size(r) + apos; + final int[] aix = sb.indexes(r); + final double[] aval = sb.values(r); final int offR = r * nColRet; - for(int c = cl; c < cu; c++) - retV[offR + colOut] += mV[offL + c] * vals[_data.getIndex(c)]; + for(int i = apos; i < alen; i++) + retV[offR + colOut] += aval[i] * vals[_data.getIndex(aix[i])]; } } - private void lmMatrixNoPreAggMultiCol(MatrixBlock matrix, MatrixBlock result, int rl, int ru, int cl, int cu) { - if(matrix.isInSparseFormat()) { - if(cl != 0 || cu != _data.size()) - throw new NotImplementedException( - "Not implemented left multiplication on sparse without it being entire input"); - lmSparseMatrixNoPreAggMultiCol(matrix, result, rl, ru); - } + private void lmSparseMatrixNoPreAggSingleCol(SparseBlock sb, int nColM, DenseBlock retV, int nColRet, double[] vals, + int rl, int ru, int cl, int cu) { + if(retV.isContiguous()) + lmSparseMatrixNoPreAggSingleColContiguous(sb, nColM, retV.valuesAt(0), nColRet, vals, rl, ru, cl, cu); else - lmDenseMatrixNoPreAggMultiCol(matrix, result, rl, ru, cl, cu); + lmSparseMatrixNoPreAggSingleColGeneric(sb, nColM, retV, nColRet, vals, rl, ru, cl, cu); } - private void lmSparseMatrixNoPreAggMultiCol(MatrixBlock matrix, MatrixBlock result, int rl, int ru) { - final double[] retV = result.getDenseBlockValues(); - final int nColRet = result.getNumColumns(); - final SparseBlock sb = matrix.getSparseBlock(); + private void lmSparseMatrixNoPreAggSingleColGeneric(SparseBlock sb, int nColM, DenseBlock ret, int nColRet, + double[] vals, int rl, int ru, int cl, int cu) { + final int colOut = _colIndexes.get(0); for(int r = rl; r < ru; r++) { if(sb.isEmpty(r)) continue; final int apos = sb.pos(r); + final int aposSkip = sb.posFIndexGTE(r, cl); + final int[] aix = sb.indexes(r); + if(aposSkip <= -1 || aix[apos + aposSkip] >= cu) + continue; final int alen = sb.size(r) + apos; + final double[] aval = sb.values(r); + final int offR = ret.pos(r); + final double[] retV = ret.values(r); + // final int offR = r * nColRet; + for(int i = apos + aposSkip; i < alen && aix[i] < cu; i++) + retV[offR + colOut] += aval[i] * vals[_data.getIndex(aix[i])]; + } + } + + private void lmSparseMatrixNoPreAggSingleColContiguous(SparseBlock sb, int nColM, double[] retV, int nColRet, + double[] vals, int rl, int ru, int cl, int cu) { + final int colOut = _colIndexes.get(0); + + for(int r = rl; r < ru; r++) { + if(sb.isEmpty(r)) + continue; + final int apos = sb.pos(r); + final int aposSkip = sb.posFIndexGTE(r, cl); final int[] aix = sb.indexes(r); + if(aposSkip <= -1 || aix[apos + aposSkip] >= cu) + continue; + final int alen = sb.size(r) + apos; final double[] aval = sb.values(r); final int offR = r * nColRet; - for(int i = apos; i < alen; i++) - _dict.multiplyScalar(aval[i], retV, offR, _data.getIndex(aix[i]), _colIndexes); + for(int i = apos + aposSkip; i < alen && aix[i] < cu; i++) + retV[offR + colOut] += aval[i] * vals[_data.getIndex(aix[i])]; } } + private void lmDenseMatrixNoPreAggSingleCol(double[] mV, int nColM, DenseBlock retV, int nColRet, double[] vals, + int rl, int ru, int cl, int cu) { + if(retV.isContiguous()) + lmDenseMatrixNoPreAggSingleColContiguous(mV, nColM, retV.valuesAt(0), nColRet, vals, rl, ru, cl, cu); + else + lmDenseMatrixNoPreAggSingleColGeneric(mV, nColM, retV, nColRet, vals, rl, ru, cl, cu); + } + + private void lmDenseMatrixNoPreAggSingleColGeneric(double[] mV, int nColM, DenseBlock ret, int nColRet, + double[] vals, int rl, int ru, int cl, int cu) { + final int colOut = _colIndexes.get(0); + for(int r = rl; r < ru; r++) { + final int offL = r * nColM; + final int offR = ret.pos(r); + final double[] retV = ret.values(r); + for(int c = cl; c < cu; c++) + retV[offR + colOut] += mV[offL + c] * vals[_data.getIndex(c)]; + } + } + + private void lmDenseMatrixNoPreAggSingleColContiguous(double[] mV, int nColM, double[] retV, int nColRet, + double[] vals, int rl, int ru, int cl, int cu) { + final int colOut = _colIndexes.get(0); + for(int r = rl; r < ru; r++) { + final int offL = r * nColM; + final int offR = r * nColRet; + for(int c = cl; c < cu; c++) + retV[offR + colOut] += mV[offL + c] * vals[_data.getIndex(c)]; + } + } + + private void lmMatrixNoPreAggMultiCol(MatrixBlock matrix, MatrixBlock result, int rl, int ru, int cl, int cu) { + if(matrix.isInSparseFormat()) + lmSparseMatrixNoPreAggMultiCol(matrix, result, rl, ru, cl, cu); + else + lmDenseMatrixNoPreAggMultiCol(matrix, result, rl, ru, cl, cu); + } + + private void lmSparseMatrixNoPreAggMultiCol(MatrixBlock matrix, MatrixBlock result, int rl, int ru, int cl, int cu) { + final DenseBlock db = result.getDenseBlock(); + final SparseBlock sb = matrix.getSparseBlock(); + + if(cl != 0 || cu != _data.size()) { + for(int r = rl; r < ru; r++) { + if(sb.isEmpty(r)) + continue; + final double[] retV = db.values(r); + final int pos = db.pos(r); + lmSparseMatrixRowColRange(sb, r, pos, retV, cl, cu); + } + } + else { + for(int r = rl; r < ru; r++) { + if(sb.isEmpty(r)) + continue; + final double[] retV = db.values(r); + final int pos = db.pos(r); + lmSparseMatrixRow(sb, r, pos, retV); + } + } + } + + private final void lmSparseMatrixRowColRange(SparseBlock sb, int r, int offR, double[] retV, int cl, int cu) { + final int apos = sb.pos(r); + final int aposSkip = sb.posFIndexGTE(r, cl); + final int[] aix = sb.indexes(r); + if(aposSkip <= -1 || aix[apos + aposSkip] >= cu) + return; + final int alen = sb.size(r) + apos; + final double[] aval = sb.values(r); + for(int i = apos + aposSkip; i < alen && aix[i] < cu; i++) + _dict.multiplyScalar(aval[i], retV, offR, _data.getIndex(aix[i]), _colIndexes); + } + + private final void lmSparseMatrixRow(SparseBlock sb, int r, int offR, double[] retV) { + final int apos = sb.pos(r); + final int alen = sb.size(r) + apos; + final int[] aix = sb.indexes(r); + final double[] aval = sb.values(r); + + _data.lmSparseMatrixRow(apos, alen, aix, aval, r, offR, retV, _colIndexes, _dict); + } + private void lmDenseMatrixNoPreAggMultiCol(MatrixBlock matrix, MatrixBlock result, int rl, int ru, int cl, int cu) { final double[] retV = result.getDenseBlockValues(); final int nColM = matrix.getNumColumns(); @@ -608,9 +747,16 @@ public AColGroup recompress() { @Override public CompressedSizeInfoColGroup getCompressionInfo(int nRow) { - IEncode enc = getEncoding(); - EstimationFactors ef = new EstimationFactors(getNumValues(), _data.size(), _data.size(), _dict.getSparsity()); - return new CompressedSizeInfoColGroup(_colIndexes, ef, estimateInMemorySize(), getCompType(), enc); + try { + + IEncode enc = getEncoding(); + EstimationFactors ef = new EstimationFactors(_data.getUnique(), _data.size(), _data.size(), + _dict.getSparsity()); + return new CompressedSizeInfoColGroup(_colIndexes, ef, estimateInMemorySize(), getCompType(), enc); + } + catch(Exception e) { + throw new DMLCompressionException(this.toString(), e); + } } @Override @@ -623,6 +769,90 @@ protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) { return ColGroupDDC.create(newColIndex, _dict.reorder(reordering), _data, getCachedCounts()); } + @Override + public void sparseSelection(MatrixBlock selection, MatrixBlock ret, int rl, int ru) { + final SparseBlock sb = selection.getSparseBlock(); + final SparseBlock retB = ret.getSparseBlock(); + for(int r = rl; r < ru; r++) { + if(sb.isEmpty(r)) + continue; + + final int sPos = sb.pos(r); + final int rowCompressed = sb.indexes(r)[sPos]; + decompressToSparseBlock(retB, rowCompressed, rowCompressed + 1, r - rowCompressed, 0); + } + } + + @Override + public AColGroup morph(CompressionType ct, int nRow) { + if(ct == getCompType()) + return this; + else if(ct == CompressionType.SDC) { + int[] counts = getCounts(); + int maxId = maxIndex(counts); + double[] def = _dict.getRow(maxId, _colIndexes.size()); + + int offsetSize = nRow - counts[maxId]; + int[] offsets = new int[offsetSize]; + AMapToData reducedData = MapToFactory.create(offsetSize, _data.getUnique()); + int o = 0; + for(int i = 0; i < nRow; i++) { + int v = _data.getIndex(i); + if(v != maxId) { + offsets[o] = i; + reducedData.set(o, v); + o++; + } + } + + return ColGroupSDC.create(_colIndexes, _data.size(), _dict, def, OffsetFactory.createOffset(offsets), + reducedData, null); + } + else if(ct == CompressionType.CONST) { + // if(1 < getNumValues()) { + String thisS = this.toString(); + if(thisS.length() > 10000) + thisS = thisS.substring(0, 10000) + "..."; + LOG.warn("Tried to morph to const from DDC but impossible: " + thisS); + return this; + // } + } + else if (ct == CompressionType.DDCFOR) + return this; // it does not make sense to change to FOR. + else + return super.morph(ct, nRow); + } + + private static int maxIndex(int[] counts) { + int id = 0; + for(int i = 1; i < counts.length; i++) { + if(counts[i] > counts[id]) { + id = i; + } + } + return id; + } + + @Override + public AColGroupCompressed combineWithSameIndex(int index, int nCol, List> right) { + List> dicts = new ArrayList<>(right.size() +1); + dicts.add(new Pair<>(_colIndexes.size(), getDictionary())); + for(int i = 0; i < right.size(); i++){ + ColGroupDDC a = ((ColGroupDDC)right.get(i).get(index)); + dicts.add(new Pair<>(a._colIndexes.size(),a.getDictionary())); + } + IDictionary combined = DictionaryFactory.cBindDictionaries(dicts); + + + IColIndex combinedColIndex = _colIndexes; + for(int i = 0; i < right.size(); i++){ + int off = nCol * i + nCol; + combinedColIndex = combinedColIndex.combine(right.get(i).get(index).getColIndices().shift(off)); + } + + return new ColGroupDDC(combinedColIndex, combined, _data, getCachedCounts()); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java index d09ba4e624d..964ba4e1c67 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java @@ -40,6 +40,7 @@ import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; import org.apache.sysds.runtime.compress.utils.Util; +import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.Divide; import org.apache.sysds.runtime.functionobjects.Minus; @@ -252,7 +253,7 @@ public AColGroup replace(double pattern, double replace) { if(patternInReference) { double[] nRef = new double[_reference.length]; for(int i = 0; i < _reference.length; i++) - if(Util.eq(pattern ,_reference[i])) + if(Util.eq(pattern, _reference[i])) nRef[i] = replace; else nRef[i] = _reference[i]; @@ -489,6 +490,20 @@ protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) { throw new NotImplementedException(); } + @Override + public void sparseSelection(MatrixBlock selection, MatrixBlock ret, int rl, int ru) { + final SparseBlock sb = selection.getSparseBlock(); + final SparseBlock retB = ret.getSparseBlock(); + for(int r = rl; r < ru; r++) { + if(sb.isEmpty(r)) + continue; + + final int sPos = sb.pos(r); + final int rowCompressed = sb.indexes(r)[sPos]; + decompressToSparseBlock(retB, rowCompressed, rowCompressed + 1, r - rowCompressed, 0); + } + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java index a8d8e6840e3..75d4e04cc6e 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java @@ -23,6 +23,7 @@ import java.io.IOException; import java.util.Arrays; +import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; @@ -53,7 +54,7 @@ import org.apache.sysds.runtime.matrix.operators.UnaryOperator; public class ColGroupEmpty extends AColGroupCompressed - implements IContainADictionary, IContainDefaultTuple, AOffsetsGroup ,IMapToDataGroup{ + implements IContainADictionary, IContainDefaultTuple, AOffsetsGroup, IMapToDataGroup { private static final long serialVersionUID = -2307677253622099958L; /** @@ -89,6 +90,11 @@ public void decompressToSparseBlock(SparseBlock sb, int rl, int ru, int offR, in // do nothing. } + @Override + public void decompressToDenseBlockTransposed(DenseBlock db, int rl, int ru) { + // do nothing. + } + @Override public double getIdx(int r, int colIdx) { return 0; @@ -403,4 +409,13 @@ public AMapToData getMapToData() { return MapToFactory.create(0, 0); } + @Override + public double getSparsity() { + return 0.0; + } + + @Override + public void sparseSelection(MatrixBlock selection, MatrixBlock ret, int rl, int ru) { + throw new NotImplementedException(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java index 23ba7d6fc4c..4b04568da2d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java @@ -258,7 +258,9 @@ private AColGroup compress(CompressedSizeInfoColGroup cg) throws Exception { final boolean t = cs.transposed; // Fast path compressions - if(ct == CompressionType.EMPTY && !t) + if((ct == CompressionType.EMPTY && !t) || // + (t && colIndexes.size() == 1 && in.isInSparseFormat() // Empty Column + && in.getSparseBlock().isEmpty(colIndexes.get(0)))) return new ColGroupEmpty(colIndexes); else if(ct == CompressionType.UNCOMPRESSED) // don't construct mapping if uncompressed return ColGroupUncompressed.create(colIndexes, in, t); @@ -470,9 +472,13 @@ private AColGroup directCompressDDCSingleCol(IColIndex colIndexes, CompressedSiz if(map.size() == 0) return new ColGroupEmpty(colIndexes); IDictionary dict = DictionaryFactory.create(map); + final int nUnique = map.size(); final AMapToData resData = MapToFactory.resize(d, nUnique); - return ColGroupDDC.create(colIndexes, dict, resData, null); + AColGroup g = ColGroupDDC.create(colIndexes, dict, resData, null); + if(g instanceof ColGroupConst) + throw new DMLCompressionException("Invalid ddc should not have been const" + g + "\n\n" + map + " " + dict); + return g; } private AColGroup directCompressDDCMultiCol(IColIndex colIndexes, CompressedSizeInfoColGroup cg) throws Exception { @@ -569,15 +575,14 @@ private void readToMapDDCTransposed(int col, DoubleCountHashMap map, AMapToData if(in.isInSparseFormat()) { final SparseBlock sb = in.getSparseBlock(); if(sb.isEmpty(col)) - // It should never be empty here. - return; + throw new DMLCompressionException("Empty column in DDC compression"); final int apos = sb.pos(col); final int alen = sb.size(col) + apos; final int[] aix = sb.indexes(col); final double[] aval = sb.values(col); // count zeros - if(nRow - apos - alen > 0) + if(nRow > alen - apos) map.increment(0.0, nRow - apos - alen); // insert all other counts for(int j = apos; j < alen; j++) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java index 708d3512f53..858269cd7b4 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java @@ -703,4 +703,18 @@ protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) { throw new NotImplementedException(); } + @Override + public double getSparsity() { + return 1.0; + } + + @Override + public void sparseSelection(MatrixBlock selection, MatrixBlock ret, int rl, int ru) { + throw new NotImplementedException(); + } + + @Override + public void decompressToDenseBlockTransposed(DenseBlock db, int rl, int ru) { + throw new NotImplementedException(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java index 8af0f959e0c..090449b5e49 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java @@ -689,4 +689,19 @@ protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) { throw new NotImplementedException(); } + @Override + public void sparseSelection(MatrixBlock selection, MatrixBlock ret, int rl, int ru) { + throw new NotImplementedException(); + } + + @Override + protected void decompressToDenseBlockTransposedSparseDictionary(DenseBlock db, int rl, int ru, SparseBlock sb) { + throw new NotImplementedException(); + } + + @Override + protected void decompressToDenseBlockTransposedDenseDictionary(DenseBlock db, int rl, int ru, double[] dict) { + throw new NotImplementedException(); + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java index 23596c1e190..0072098dbe2 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java @@ -756,10 +756,6 @@ private String pair(char[] d, int off, int sum) { public void preAggregateDense(MatrixBlock m, double[] preAgg, final int rl, final int ru, final int cl, final int cu) { final DenseBlock db = m.getDenseBlock(); - if(!db.isContiguous()) - throw new NotImplementedException("Not implemented support for preAggregate non contiguous dense matrix"); - final double[] mV = m.getDenseBlockValues(); - final int nCol = m.getNumColumns(); final int nv = getNumValues(); for(int k = 0; k < nv; k++) { // for each run in RLE @@ -774,8 +770,9 @@ public void preAggregateDense(MatrixBlock m, double[] preAgg, final int rl, fina if(re >= cu) { for(int r = rl; r < ru; r++) { + final double[] mV = db.values(r); + final int offI = db.pos(r); final int off = (r - rl) * nv + k; - final int offI = nCol * r; for(int rix = rsc + offI; rix < cu + offI; rix++) { preAgg[off] += mV[rix]; } @@ -784,8 +781,9 @@ public void preAggregateDense(MatrixBlock m, double[] preAgg, final int rl, fina } else { for(int r = rl; r < ru; r++) { + final double[] mV = db.values(r); + final int offI = db.pos(r); final int off = (r - rl) * nv + k; - final int offI = nCol * r; for(int rix = rsc + offI; rix < re + offI; rix++) preAgg[off] += mV[rix]; } @@ -1146,4 +1144,20 @@ public static char[] genRLEBitmap(int[] offsets, int len) { return ret; } + + @Override + public void sparseSelection(MatrixBlock selection, MatrixBlock ret, int rl, int ru) { + throw new NotImplementedException(); + } + + @Override + protected void decompressToDenseBlockTransposedSparseDictionary(DenseBlock db, int rl, int ru, SparseBlock sb) { + throw new NotImplementedException(); + } + + @Override + protected void decompressToDenseBlockTransposedDenseDictionary(DenseBlock db, int rl, int ru, double[] dict) { + throw new NotImplementedException(); + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java index a905e401e42..8c696dc64fa 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java @@ -24,8 +24,11 @@ import java.io.IOException; import java.util.Arrays; +import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.DMLCompressionException; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; @@ -43,6 +46,7 @@ import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; import org.apache.sysds.runtime.compress.utils.Util; +import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -71,15 +75,24 @@ protected ColGroupSDC(IColIndex colIndices, int numRows, IDictionary dict, doubl super(colIndices, numRows, dict, offsets, cachedCounts); _data = data; _defaultTuple = defaultTuple; - if(data.getUnique() != dict.getNumberOfValues(colIndices.size())) { - if(data.getUnique() != data.getMax()) - throw new DMLCompressionException( - "Invalid unique count compared to actual: " + data.getUnique() + " " + data.getMax()); - throw new DMLCompressionException("Invalid construction of SDC group: number uniques: " + data.getUnique() - + " vs." + dict.getNumberOfValues(colIndices.size())); + if(CompressedMatrixBlock.debug) { + + if(data.getUnique() != dict.getNumberOfValues(colIndices.size())) { + if(data.getUnique() != data.getMax()) + throw new DMLCompressionException( + "Invalid unique count compared to actual: " + data.getUnique() + " " + data.getMax()); + throw new DMLCompressionException("Invalid construction of SDC group: number uniques: " + data.getUnique() + + " vs." + dict.getNumberOfValues(colIndices.size())); + } + if(_indexes.getSize() == numRows) { + throw new DMLCompressionException("Invalid SDC group that contains index with size == numRows"); + } + if(defaultTuple.length != colIndices.size()) + throw new DMLCompressionException("Invalid construction of SDC group"); + + _data.verify(); + _indexes.verify(_data.size()); } - if(defaultTuple.length != colIndices.size()) - throw new DMLCompressionException("Invalid construction of SDC group"); } @@ -459,7 +472,7 @@ public AColGroup replace(double pattern, double replace) { IDictionary replaced = _dict.replace(pattern, replace, _colIndexes.size()); double[] newDefaultTuple = new double[_defaultTuple.length]; for(int i = 0; i < _defaultTuple.length; i++) - newDefaultTuple[i] = Util.eq(_defaultTuple[i],pattern) ? replace : _defaultTuple[i]; + newDefaultTuple[i] = Util.eq(_defaultTuple[i], pattern) ? replace : _defaultTuple[i]; return create(_colIndexes, _numRows, replaced, newDefaultTuple, _indexes, _data, getCachedCounts()); } @@ -662,6 +675,86 @@ public int getNumberOffsets() { return _data.size(); } + @Override + public void sparseSelection(MatrixBlock selection, MatrixBlock ret, int rl, int ru) { + final SparseBlock sb = selection.getSparseBlock(); + final SparseBlock sr = ret.getSparseBlock(); + final int nCol = _colIndexes.size(); + final AIterator it = _indexes.getIterator(rl); + if(it == null) + throw new NotImplementedException("Not Implemented fill with default"); + + P[] points = ColGroupUtils.getSortedSelection(sb, rl, ru); + + final int last = Math.min(_indexes.getOffsetToLast(), ru); + int c = 0; + + while(it.value() < last && c < points.length) { + while(it.value() < last && it.value() < points[c].o) { + it.next(); + } + if(it.value() == last) { + break; + } + final int of = it.value(); + if(points[c].o < of) { + for(int i = 0; i < nCol; i++) + sr.add(points[c].r, _colIndexes.get(i), _defaultTuple[i]); + } + else { + _dict.put(sr, _data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes); + it.next(); + } + c++; + } + if(it.value() == ru) { + _dict.put(sr, _data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes); + c++; + } + + // set default in tail. + for(; c < points.length; c++) { + for(int i = 0; i < nCol; i++) + sr.add(points[c].r, _colIndexes.get(i), _defaultTuple[i]); + } + + } + + @Override + public AColGroup morph(CompressionType ct, int nRow) { + if(ct == getCompType()) + return this; + else if(ct == CompressionType.DDC) { + + AMapToData nMap = MapToFactory.create(nRow, _data.getUnique() + 1); + IDictionary nDict = _dict.append(_defaultTuple); + + final AIterator it = _indexes.getIterator(); + final int last = _indexes.getOffsetToLast(); + int r = 0; + int def = _data.getUnique(); + while(it.value() < last) { + final int iv = it.value(); + while(r < iv) { + nMap.set(r++, def); + } + nMap.set(r++, _data.getIndex(it.getDataIndex())); + it.next(); + } + nMap.set(r++, _data.getIndex(it.getDataIndex())); + while(r < nRow) { + nMap.set(r++, def); + } + + return ColGroupDDC.create(_colIndexes, nDict, nMap, null); + } + + + else { + return super.morph(ct, nRow); + } + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java index dfb9a605118..f7205bc8bda 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java @@ -25,10 +25,11 @@ import java.util.Arrays; import org.apache.commons.lang3.NotImplementedException; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.DMLCompressionException; -import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; @@ -48,6 +49,7 @@ import org.apache.sysds.runtime.functionobjects.Multiply; import org.apache.sysds.runtime.functionobjects.Plus; import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.runtime.matrix.operators.CMOperator; import org.apache.sysds.runtime.matrix.operators.ScalarOperator; @@ -77,12 +79,19 @@ private ColGroupSDCFOR(IColIndex colIndices, int numRows, IDictionary dict, AOff int[] cachedCounts, double[] reference) { super(colIndices, numRows, dict, indexes, cachedCounts); // allow for now 1 data unique. - if(data.getUnique() == 1) - LOG.warn("SDCFor unique is 1, indicate it should have been SDCSingle please add support"); - else if(data.getUnique() != dict.getNumberOfValues(colIndices.size())) - throw new DMLCompressionException("Invalid construction of SDCZero group"); _data = data; _reference = reference; + if(CompressedMatrixBlock.debug) { + + if(data.getUnique() == 1) + LOG.warn("SDCFor unique is 1, indicate it should have been SDCSingle please add support"); + else if(data.getUnique() != dict.getNumberOfValues(colIndices.size())) + throw new DMLCompressionException("Invalid construction of SDCZero group"); + + _data.verify(); + _indexes.verify(_data.size()); + } + } public static AColGroup create(IColIndex colIndexes, int numRows, IDictionary dict, AOffset offsets, AMapToData data, @@ -520,6 +529,11 @@ public ICLAScheme getCompressionScheme() { throw new NotImplementedException(); } + @Override + public void sparseSelection(MatrixBlock selection, MatrixBlock ret, int rl, int ru) { + throw new NotImplementedException(); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java index a13150c12c6..b0532825d97 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java @@ -24,9 +24,11 @@ import java.io.IOException; import java.util.Arrays; +import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.PlaceHolderDict; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; @@ -44,6 +46,7 @@ import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.runtime.matrix.operators.CMOperator; import org.apache.sysds.runtime.matrix.operators.ScalarOperator; @@ -68,6 +71,9 @@ private ColGroupSDCSingle(IColIndex colIndices, int numRows, IDictionary dict, d super(colIndices, numRows, dict == null ? Dictionary.createNoCheck(new double[colIndices.size()]) : dict, offsets, cachedCounts); _defaultTuple = defaultTuple; + if(CompressedMatrixBlock.debug) { + _indexes.verify(_indexes.getSize()); + } } public static AColGroup create(IColIndex colIndexes, int numRows, IDictionary dict, double[] defaultTuple, @@ -620,6 +626,11 @@ protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) { ColGroupUtils.reorderDefault(_defaultTuple, reordering), _indexes, getCachedCounts()); } + @Override + public void sparseSelection(MatrixBlock selection, MatrixBlock ret, int rl, int ru) { + throw new NotImplementedException(); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java index 7a3309aafa0..c0cfc450eef 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java @@ -65,10 +65,12 @@ public class ColGroupSDCSingleZeros extends ASDCZero { private ColGroupSDCSingleZeros(IColIndex colIndices, int numRows, IDictionary dict, AOffset offsets, int[] cachedCounts) { super(colIndices, numRows, dict, offsets, cachedCounts); - if(CompressedMatrixBlock.debug) + if(CompressedMatrixBlock.debug) { if(offsets.getSize() * 2 > numRows + 2 && !(dict instanceof PlaceHolderDict)) throw new DMLCompressionException("Wrong direction of SDCSingleZero compression should be other way " + numRows + " vs " + _indexes + "\n" + this); + _indexes.verify(_indexes.getSize()); + } } public static AColGroup create(IColIndex colIndices, int numRows, IDictionary dict, AOffset offsets, @@ -357,8 +359,62 @@ protected void multiplyScalar(double v, double[] resV, int offRet, AIterator it) @Override public void preAggregateDense(MatrixBlock m, double[] preAgg, int rl, int ru, int cl, int cu) { - if(!m.getDenseBlock().isContiguous()) - throw new NotImplementedException("Not implemented support for preAggregate non contiguous dense matrix"); + if(m.getDenseBlock().isContiguous()) + preAggregateDenseContiguous(m, preAgg, rl, ru, cl, cu); + else + preAggregateDenseGeneric(m, preAgg, rl, ru, cl, cu); + } + + private void preAggregateDenseGeneric(MatrixBlock m, double[] preAgg, int rl, int ru, int cl, int cu) { + final AIterator it = _indexes.getIterator(cl); + final DenseBlock db = m.getDenseBlock(); + final int nCol = m.getNumColumns(); + if(it == null) + return; + else if(it.value() > cu) + _indexes.cacheIterator(it, cu); + else if(cu < _indexes.getOffsetToLast() + 1) { + if(db.isContiguous(rl, ru)) { + while(it.value() < cu) { + final double[] vals = db.values(rl); + final int start = it.value() + db.pos(rl); + final int end = it.value() + db.pos(ru); + for(int offOut = 0, off = start; off < end; offOut++, off += nCol) + preAgg[offOut] += vals[off]; + it.next(); + } + } + else { + throw new NotImplementedException(); + } + _indexes.cacheIterator(it, cu); + } + else { + if(db.isContiguous(rl, ru)) { + final double[] vals = db.values(rl); + final int rlPos = db.pos(rl); + final int ruPos = db.pos(ru); + int of = it.value(); + int start = of + rlPos; + int end = of + ruPos; + for(int offOut = 0, off = start; off < end; offOut++, off += nCol) + preAgg[offOut] += vals[off]; + while(of < _indexes.getOffsetToLast()) { + it.next(); + of = it.value(); + start = of + rlPos; + end = of + ruPos; + for(int offOut = 0, off = start; off < end; offOut++, off += nCol) + preAgg[offOut] += vals[off]; + } + } + else { + throw new NotImplementedException(); + } + } + } + + private void preAggregateDenseContiguous(MatrixBlock m, double[] preAgg, int rl, int ru, int cl, int cu) { final AIterator it = _indexes.getIterator(cl); final double[] vals = m.getDenseBlockValues(); final int nCol = m.getNumColumns(); @@ -826,8 +882,8 @@ public AColGroup sliceRows(int rl, int ru) { OffsetSliceInfo off = _indexes.slice(rl, ru); if(off.lIndex == -1) return null; - if(CompressedMatrixBlock.debug){ - if(off.offsetSlice.getOffsetToFirst() < 0 || off.offsetSlice.getOffsetToLast() > ru-rl) + if(CompressedMatrixBlock.debug) { + if(off.offsetSlice.getOffsetToFirst() < 0 || off.offsetSlice.getOffsetToLast() > ru - rl) throw new DMLCompressionException("Failed to slice : " + rl + " " + ru + " in: " + this); } return create(_colIndexes, ru - rl, _dict, off.offsetSlice, null); @@ -853,12 +909,12 @@ public AColGroup appendNInternal(AColGroup[] g, int blen, int rlen) { return null; } - if(!(gs instanceof AOffsetsGroup )) { + if(!(gs instanceof AOffsetsGroup)) { LOG.warn("Not SDCFOR but " + gs.getClass().getSimpleName()); return null; } - if( gs instanceof ColGroupSDCSingleZeros){ + if(gs instanceof ColGroupSDCSingleZeros) { final ColGroupSDCSingleZeros gc = (ColGroupSDCSingleZeros) gs; if(!gc._dict.equals(_dict)) { LOG.warn("Not same Dictionaries therefore not appending \n" + _dict + "\n\n" + gc._dict); @@ -885,6 +941,23 @@ public int getNumberOffsets() { return getCounts()[0]; } + @Override + public void sparseSelection(MatrixBlock selection, MatrixBlock ret, int rl, int ru) { + throw new NotImplementedException(); + } + + @Override + protected void decompressToDenseBlockTransposedSparseDictionary(DenseBlock db, int rl, int ru, SparseBlock sb) { + throw new NotImplementedException(); + } + + @Override + protected void decompressToDenseBlockTransposedDenseDictionary(DenseBlock db, int rl, int ru, double[] dict) { + throw new NotImplementedException(); + } + + + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java index f4c7c6f615b..21b16023067 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java @@ -24,11 +24,14 @@ import java.io.IOException; import java.util.Arrays; +import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.DMLCompressionException; -import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; @@ -70,10 +73,14 @@ public class ColGroupSDCZeros extends ASDCZero implements IMapToDataGroup { private ColGroupSDCZeros(IColIndex colIndices, int numRows, IDictionary dict, AOffset indexes, AMapToData data, int[] cachedCounts) { super(colIndices, numRows, dict, indexes, cachedCounts); - if(data.getUnique() != dict.getNumberOfValues(colIndices.size())) - throw new DMLCompressionException("Invalid construction of SDCZero group: number uniques: " + data.getUnique() - + " vs." + dict.getNumberOfValues(colIndices.size())); _data = data; + if(CompressedMatrixBlock.debug) { + if(data.getUnique() != dict.getNumberOfValues(colIndices.size())) + throw new DMLCompressionException("Invalid construction of SDCZero group: number uniques: " + + data.getUnique() + " vs." + dict.getNumberOfValues(colIndices.size())); + _data.verify(); + _indexes.verify(_data.size()); + } } public static AColGroup create(IColIndex colIndices, int numRows, IDictionary dict, AOffset offsets, AMapToData data, @@ -227,21 +234,25 @@ private void decompressToDenseBlockDenseDictionaryPreSingleColContiguous(DenseBl it.setOff(it.value() - offR); } - private void decompressToDenseBlockDenseDictionaryPreGeneric(DenseBlock db, int rl, int ru, int offR, int offC, + private final void decompressToDenseBlockDenseDictionaryPreGeneric(DenseBlock db, int rl, int ru, int offR, int offC, double[] values, AIterator it) { final int nCol = _colIndexes.size(); while(it.isNotOver(ru)) { - final int idx = offR + it.value(); - final double[] c = db.values(idx); - final int off = db.pos(idx) + offC; - final int offDict = _data.getIndex(it.getDataIndex()) * nCol; - for(int j = 0; j < nCol; j++) - c[off + _colIndexes.get(j)] += values[offDict + j]; - + decompressRowDenseDictionaryPreGeneric(db, nCol, offR, offC, values, it); it.next(); } } + private final void decompressRowDenseDictionaryPreGeneric(DenseBlock db, int nCol, int offR, int offC, + double[] values, AIterator it) { + final int idx = offR + it.value(); + final double[] c = db.values(idx); + final int off = db.pos(idx) + offC; + final int offDict = _data.getIndex(it.getDataIndex()) * nCol; + for(int j = 0; j < nCol; j++) + c[off + _colIndexes.get(j)] += values[offDict + j]; + } + private void decompressToDenseBlockDenseDictionaryPreAllCols(DenseBlock db, int rl, int ru, int offR, int offC, double[] values, AIterator it) { final int nCol = _colIndexes.size(); @@ -767,7 +778,7 @@ public AColGroup append(AColGroup g) { @Override public AColGroup appendNInternal(AColGroup[] g, int blen, int rlen) { - + for(int i = 1; i < g.length; i++) { final AColGroup gs = g[i]; if(!_colIndexes.equals(gs._colIndexes)) { @@ -775,12 +786,12 @@ public AColGroup appendNInternal(AColGroup[] g, int blen, int rlen) { return null; } - if(!(gs instanceof AOffsetsGroup )) { + if(!(gs instanceof AOffsetsGroup)) { LOG.warn("Not valid OffsetGroup but " + gs.getClass().getSimpleName()); return null; } - if( gs instanceof ColGroupSDCZeros){ + if(gs instanceof ColGroupSDCZeros) { final ColGroupSDCZeros gc = (ColGroupSDCZeros) gs; if(!gc._dict.equals(_dict)) { LOG.warn("Not same Dictionaries therefore not appending \n" + _dict + "\n\n" + gc._dict); @@ -815,6 +826,59 @@ protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) { getCachedCounts()); } + @Override + public void sparseSelection(MatrixBlock selection, MatrixBlock ret, int rl, int ru) { + final SparseBlock sb = selection.getSparseBlock(); + final SparseBlock sr = ret.getSparseBlock(); + final int nCol = _colIndexes.size(); + final AIterator it = _indexes.getIterator(rl); + if(it == null) + throw new NotImplementedException("Not Implemented fill with default"); + + P[] points = ColGroupUtils.getSortedSelection(sb, rl, ru); + + _data.verify(); + + // LOG.error(this); + final int last = Math.min(_indexes.getOffsetToLast(), ru); + int c = 0; + while(it.value() < last && c < points.length) { + while(it.value() < last && it.value() < points[c].o) { + it.next(); + } + if(it.value() >= last) { + break; + } + final int of = it.value(); + if(points[c].o == of) { + try { + + _dict.put(sr, _data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes); + it.next(); + } + catch(Exception e) { + throw new DMLCompressionException(it + " " + points[c] + " fail", e); + } + } + c++; + } + if(it.value() == ru) { + _dict.put(sr, _data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes); + c++; + } + + } + + @Override + protected void decompressToDenseBlockTransposedSparseDictionary(DenseBlock db, int rl, int ru, SparseBlock sb) { + throw new NotImplementedException(); + } + + @Override + protected void decompressToDenseBlockTransposedDenseDictionary(DenseBlock db, int rl, int ru, double[] dict) { + throw new NotImplementedException(); + } + public String toString() { StringBuilder sb = new StringBuilder(); sb.append(super.toString()); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java index d5553deb41f..9d9e7ffa830 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java @@ -22,13 +22,12 @@ import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.DMLRuntimeException; -import org.apache.sysds.runtime.compress.CompressedMatrixBlock; -import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory; import org.apache.sysds.runtime.compress.CompressionSettings; import org.apache.sysds.runtime.compress.CompressionSettingsBuilder; import org.apache.sysds.runtime.compress.DMLCompressionException; @@ -40,6 +39,7 @@ import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; import org.apache.sysds.runtime.compress.colgroup.scheme.SchemeFactory; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; +import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; import org.apache.sysds.runtime.compress.estim.EstimationFactors; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; @@ -208,6 +208,27 @@ private void decompressToDenseBlockDenseData(DenseBlock db, int rl, int ru, int } private void decompressToDenseBlockDenseDataAllColumns(DenseBlock db, int rl, int ru, int offR) { + if(db.isContiguous() && _data.getDenseBlock().isContiguous()) + decompressToDenseBlockDenseDataAllColumnsContiguous(db, rl, ru, offR); + else + decompressToDenseBlockDenseDataAllColumnsGeneric(db, rl, ru, offR); + } + + private void decompressToDenseBlockDenseDataAllColumnsContiguous(DenseBlock db, int rl, int ru, int offR) { + + final int nCol = _data.getNumColumns(); + final double[] a = _data.getDenseBlockValues(); + final double[] c = db.values(0); + final int as = rl * nCol; + final int cs = (rl + offR) * nCol; + final int sz = ru * nCol - rl * nCol; + for(int i = 0; i < sz; i += 64) { + LibMatrixMult.vectAdd(a, c, as + i, cs + i, Math.min(64, sz - i)); + } + + } + + private void decompressToDenseBlockDenseDataAllColumnsGeneric(DenseBlock db, int rl, int ru, int offR) { int offT = rl + offR; final int nCol = _colIndexes.size(); DenseBlock tb = _data.getDenseBlock(); @@ -532,7 +553,7 @@ public final void tsmm(MatrixBlock ret, int nRows) { // tsmm but only upper triangle. LibMatrixMult.matrixMultTransposeSelf(_data, tmp, true, false); - if(tmp.isInSparseFormat()){ + if(tmp.isInSparseFormat()) { final int numColumns = ret.getNumColumns(); final double[] result = ret.getDenseBlockValues(); final SparseBlock sb = tmp.getSparseBlock(); @@ -546,10 +567,10 @@ public final void tsmm(MatrixBlock ret, int nRows) { double[] aval = sb.values(row); for(int j = apos; j < alen; j++) result[offRet + _colIndexes.get(aix[j])] += aval[j]; - + } } - else{ + else { // copy that upper triangle part to ret final int numColumns = ret.getNumColumns(); final double[] result = ret.getDenseBlockValues(); @@ -629,8 +650,8 @@ private void leftMultByAPreAggColGroup(APreAgg paCG, MatrixBlock result) { private void leftMultByAColGroupUncompressed(ColGroupUncompressed lhs, MatrixBlock result) { final MatrixBlock tmpRet = new MatrixBlock(lhs.getNumCols(), _colIndexes.size(), 0); final int k = InfrastructureAnalyzer.getLocalParallelism(); - - if(lhs._data.getNumColumns() != 1){ + + if(lhs._data.getNumColumns() != 1) { LOG.warn("Inefficient Left Matrix Multiplication with transpose of left hand side : t(l) %*% r"); } // multiply to temp @@ -866,30 +887,43 @@ public ICLAScheme getCompressionScheme() { @Override public AColGroup recompress() { - MatrixBlock mb = CompressedMatrixBlockFactory.compress(_data).getLeft(); - if(mb instanceof CompressedMatrixBlock) { - CompressedMatrixBlock cmb = (CompressedMatrixBlock) mb; - List gs = cmb.getColGroups(); - if(gs.size() > 1) { - LOG.error("The uncompressed column group did compress into multiple groups"); - return this; - } - else { - return gs.get(0).copyAndSet(_colIndexes); - } - } - else - return this; + + final List es = new ArrayList<>(); + final CompressionSettings cs = new CompressionSettingsBuilder().create(); + final EstimationFactors f = new EstimationFactors(_data.getNumRows(), _data.getNumRows(), _data.getSparsity()); + es.add(new CompressedSizeInfoColGroup( // + ColIndexFactory.create(_data.getNumColumns()), f, 312152, CompressionType.DDC)); + final CompressedSizeInfo csi = new CompressedSizeInfo(es); + final List comp = ColGroupFactory.compressColGroups(_data, csi, cs); + + return comp.get(0).copyAndSet(_colIndexes); + + // MatrixBlock mb = CompressedMatrixBlockFactory.compress(_data).getLeft(); + // if(mb instanceof CompressedMatrixBlock) { + // CompressedMatrixBlock cmb = (CompressedMatrixBlock) mb; + // List gs = cmb.getColGroups(); + // if(gs.size() > 1) { + // LOG.error("The uncompressed column group did compress into multiple groups"); + // return this; + // } + // else { + // return gs.get(0).copyAndSet(_colIndexes); + // } + // } + // else + // return this; } @Override public CompressedSizeInfoColGroup getCompressionInfo(int nRow) { - final IEncode map = EncodingFactory.createFromMatrixBlock(_data, false, - ColIndexFactory.create(_data.getNumColumns())); + // final IEncode map = EncodingFactory.createFromMatrixBlock(_data, false, + // ColIndexFactory.create(_data.getNumColumns())); final int _numRows = _data.getNumRows(); final CompressionSettings _cs = new CompressionSettingsBuilder().create();// default settings - final EstimationFactors em = map.extractFacts(_numRows, _data.getSparsity(), _data.getSparsity(), _cs); - return new CompressedSizeInfoColGroup(_colIndexes, em, _cs.validCompressions, map); + final EstimationFactors em = + new EstimationFactors(_numRows, _numRows, 1, null, _numRows, _numRows, _numRows, false, false, (double) _numRows / _data.getNonZeros(), (double) _numRows / _data.getNonZeros()); + // map.extractFacts(_numRows, _data.getSparsity(), _data.getSparsity(), _cs); + return new CompressedSizeInfoColGroup(_colIndexes, em, _cs.validCompressions, null); } @Override @@ -907,6 +941,95 @@ protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) { return create(newColIndex, ret, false); } + @Override + public double getSparsity() { + return _data.getSparsity(); + } + + @Override + public AColGroup morph(CompressionType ct, int nRow) { + if(ct == getCompType()) + return this; + + final List es = new ArrayList<>(); + final CompressionSettings cs = new CompressionSettingsBuilder().create(); + final EstimationFactors f = new EstimationFactors(_data.getNumRows(), _data.getNumRows(), _data.getSparsity()); + es.add(new CompressedSizeInfoColGroup(ColIndexFactory.create(_data.getNumColumns()), f, 312152, ct)); + final CompressedSizeInfo csi = new CompressedSizeInfo(es); + final List comp = ColGroupFactory.compressColGroups(_data, csi, cs); + + return comp.get(0).copyAndSet(_colIndexes); + } + + @Override + public void sparseSelection(MatrixBlock selection, MatrixBlock ret, int rl, int ru) { + if(_data.isEmpty()) + return; + else if(_data.isInSparseFormat()) + sparseSelectionSparseColumnGroup(selection, ret, rl, ru); + else + sparseSelectionDenseColumnGroup(selection, ret, rl, ru); + } + + private void sparseSelectionSparseColumnGroup(MatrixBlock selection, MatrixBlock ret, int rl, int ru) { + + final SparseBlock sb = selection.getSparseBlock(); + final SparseBlock retB = ret.getSparseBlock(); + final SparseBlock tb = _data.getSparseBlock(); + for(int r = rl; r < ru; r++) { + if(sb.isEmpty(r)) + continue; + + final int sPos = sb.pos(r); + final int rowCompressed = sb.indexes(r)[sPos]; + if(tb.isEmpty(rowCompressed)) + continue; + final int tPos = tb.pos(rowCompressed); + final int tEnd = tb.size(rowCompressed) + tPos; + final int[] tIx = tb.indexes(rowCompressed); + final double[] tVal = tb.values(rowCompressed); + for(int j = tPos; j < tEnd; j++) + retB.append(r, _colIndexes.get(tIx[j]), tVal[j]); + } + + } + + private void sparseSelectionDenseColumnGroup(MatrixBlock selection, MatrixBlock ret, int rl, int ru) { + final SparseBlock sb = selection.getSparseBlock(); + final SparseBlock retB = ret.getSparseBlock(); + final DenseBlock tb = _data.getDenseBlock(); + final int nCol = _colIndexes.size(); + for(int r = rl; r < ru; r++) { + if(sb.isEmpty(r)) + continue; + + final int sPos = sb.pos(r); + final int rowCompressed = sb.indexes(r)[sPos]; + + double[] tVal = tb.values(rowCompressed); + int tPos = tb.pos(rowCompressed); + for(int j = 0; j < nCol; j++) + retB.append(r, _colIndexes.get(j), tVal[tPos + j]); + } + } + + @Override + public void decompressToDenseBlockTransposed(DenseBlock db, int rl, int ru) { + if(_data.isInSparseFormat()) + decompressToDenseBlockTransposedSparse(db, rl, ru); + else + decompressToDenseBlockTransposedDense(db, rl, ru); + + } + + private void decompressToDenseBlockTransposedSparse(DenseBlock db, int rl, int ru) { + throw new NotImplementedException(); + } + + private void decompressToDenseBlockTransposedDense(DenseBlock db, int rl, int ru) { + throw new NotImplementedException(); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUtils.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUtils.java index c67a40b34c1..ca42856a62d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUtils.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUtils.java @@ -19,6 +19,8 @@ package org.apache.sysds.runtime.compress.colgroup; +import java.util.Arrays; + import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.utils.DoubleCountHashMap; import org.apache.sysds.runtime.data.SparseBlock; @@ -305,11 +307,51 @@ public static void addMatrixToResult(MatrixBlock tmp, MatrixBlock result, IColIn } } - public static double[] reorderDefault(double[] vals, int[] reordering){ + public static double[] reorderDefault(double[] vals, int[] reordering) { double[] ret = new double[vals.length]; for(int i = 0; i < vals.length; i++) ret[i] = vals[reordering[i]]; - return ret; + return ret; + } + + public static P[] getSortedSelection(SparseBlock sb, int rl, int ru) { + + int c = 0; + // count loop + for(int i = rl; i < ru; i++) { + if(sb.isEmpty(i)) + continue; + c++; + } + + P[] points = new P[c]; + c = 0; + for(int i = rl; i < ru; i++) { + if(sb.isEmpty(i)) + continue; + final int sPos = sb.pos(i); + points[c++] = new P(i, sb.indexes(i)[sPos]); + } + + Arrays.sort(points, (a, b) -> { + return a.o < b.o ? -1 : 1; + }); + return points; + } + + public static class P { + public final int r; + public final int o; + + private P(int r, int o) { + this.r = r; + this.o = o; + } + + @Override + public String toString() { + return "P(" + r + "," + o + ")"; + } } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java index 67f546c6ac5..d25868c2b2e 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java @@ -22,6 +22,7 @@ import java.io.Serializable; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.ValueFunction; import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; @@ -84,4 +85,107 @@ public static void correctNan(double[] res, IColIndex colIndexes) { res[cix] = Double.isNaN(res[cix]) ? 0 : res[cix]; } } + + @Override + public IDictionary rightMMPreAggSparse(int numVals, SparseBlock b, IColIndex thisCols, IColIndex aggregateColumns, + int nColRight) { + if(aggregateColumns.size() < nColRight) + return rightMMPreAggSparseSelectedCols(numVals, b, thisCols, aggregateColumns); + else + return rightMMPreAggSparseAllColsRight(numVals, b, thisCols, nColRight); + } + + protected IDictionary rightMMPreAggSparseSelectedCols(int numVals, SparseBlock b, IColIndex thisCols, + IColIndex aggregateColumns) { + + final int thisColsSize = thisCols.size(); + final int aggColSize = aggregateColumns.size(); + final double[] ret = new double[numVals * aggColSize]; + + for(int h = 0; h < thisColsSize; h++) { + // chose row in right side matrix via column index of the dictionary + final int colIdx = thisCols.get(h); + if(b.isEmpty(colIdx)) + continue; + + // extract the row values on the right side. + final double[] sValues = b.values(colIdx); + final int[] sIndexes = b.indexes(colIdx); + final int sPos = b.pos(colIdx); + final int sEnd = b.size(colIdx) + sPos; + + for(int j = 0; j < numVals; j++) { // rows left + final int offOut = j * aggColSize; + final double v = getValue(j, h, thisColsSize); + sparseAddSelected(sPos, sEnd, aggColSize, aggregateColumns, sIndexes, sValues, ret, offOut, v); + } + + } + return Dictionary.create(ret); + } + + private final void sparseAddSelected(int sPos, int sEnd, int aggColSize, IColIndex aggregateColumns, int[] sIndexes, + double[] sValues, double[] ret, int offOut, double v) { + + int retIdx = 0; + for(int i = sPos; i < sEnd; i++) { + // skip through the retIdx. + while(retIdx < aggColSize && aggregateColumns.get(retIdx) < sIndexes[i]) + retIdx++; + if(retIdx == aggColSize) + break; + ret[offOut + retIdx] += v * sValues[i]; + } + retIdx = 0; + } + + protected IDictionary rightMMPreAggSparseAllColsRight(int numVals, SparseBlock b, IColIndex thisCols, + int nColRight) { + final int thisColsSize = thisCols.size(); + final double[] ret = new double[numVals * nColRight]; + + for(int h = 0; h < thisColsSize; h++) { // common dim + // chose row in right side matrix via column index of the dictionary + final int colIdx = thisCols.get(h); + if(b.isEmpty(colIdx)) + continue; + + // extract the row values on the right side. + final double[] sValues = b.values(colIdx); + final int[] sIndexes = b.indexes(colIdx); + final int sPos = b.pos(colIdx); + final int sEnd = b.size(colIdx) + sPos; + + for(int i = 0; i < numVals; i++) { // rows left + final int offOut = i * nColRight; + final double v = getValue(i, h, thisColsSize); + SparseAdd(sPos, sEnd, ret, offOut, sIndexes, sValues, v); + } + } + return Dictionary.create(ret); + } + + private final void SparseAdd(int sPos, int sEnd, double[] ret, int offOut, int[] sIdx, double[] sVals, double v) { + if(v != 0) { + for(int k = sPos; k < sEnd; k++) { // cols right with value + ret[offOut + sIdx[k]] += v * sVals[k]; + } + } + } + + @Override + public void put(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) { + for(int i = 0; i < nCol; i++) { + sb.append(rowOut, columns.get(i), getValue(idx, i, nCol)); + } + } + + @Override + public double[] getRow(int i, int nCol) { + double[] ret = new double[nCol]; + for(int c = 0; c < nCol; c++) { + ret[c] = getValue(i, c, nCol); + } + return ret; + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictLibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictLibMatrixMult.java index 9aba711a30e..6305063ab2a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictLibMatrixMult.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictLibMatrixMult.java @@ -56,11 +56,11 @@ else if(row > col) // swap because in lower triangle /** * Matrix multiply with scaling (left side transposed) * - * @param left Left side dictionary - * @param right Right side dictionary + * @param left Left side dictionary that is not physically transposed but should be treated if it is. + * @param right Right side dictionary that is not transposed and should be used as is. * @param leftRows Left side row offsets * @param rightColumns Right side column offsets - * @param result The result matrix + * @param result The result matrix, normal allocation. * @param counts The scaling factors */ public static void MMDictsWithScaling(IDictionary left, IDictionary right, IColIndex leftRows, @@ -221,7 +221,6 @@ protected static void MMDictsScalingDenseDense(double[] left, double[] right, IC final int commonDim = Math.min(left.length / leftSide, right.length / rightSide); final int resCols = result.getNumColumns(); final double[] resV = result.getDenseBlockValues(); - for(int k = 0; k < commonDim; k++) { final int offL = k * leftSide; final int offR = k * rightSide; @@ -305,8 +304,8 @@ protected static void MMDictsDenseSparse(double[] left, SparseBlock right, IColI } } - protected static void MMDictsScalingDenseSparse(double[] left, SparseBlock right, IColIndex rowsLeft, IColIndex colsRight, - MatrixBlock result, int[] scaling) { + protected static void MMDictsScalingDenseSparse(double[] left, SparseBlock right, IColIndex rowsLeft, + IColIndex colsRight, MatrixBlock result, int[] scaling) { final double[] resV = result.getDenseBlockValues(); final int leftSize = rowsLeft.size(); final int commonDim = Math.min(left.length / leftSize, right.numRows()); @@ -538,19 +537,27 @@ else if(loc > 0) protected static void MMToUpperTriangleDenseDenseAllUpperTriangle(double[] left, double[] right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { - final int commonDim = Math.min(left.length / rowsLeft.size(), right.length / colsRight.size()); + final int lSize = rowsLeft.size(); + final int rSize = colsRight.size(); + final int commonDim = Math.min(left.length / lSize, right.length / rSize); final int resCols = result.getNumColumns(); final double[] resV = result.getDenseBlockValues(); + for(int i = 0; i < lSize; i++) { + MMToUpperTriangleDenseDenseAllUpperTriangleRow(left, right, rowsLeft.get(i), colsRight, commonDim, lSize, + rSize, i, resV, resCols); + } + } + + protected static void MMToUpperTriangleDenseDenseAllUpperTriangleRow(final double[] left, final double[] right, + final int rowOut, final IColIndex colsRight, final int commonDim, final int lSize, final int rSize, final int i, + final double[] resV, final int resCols) { for(int k = 0; k < commonDim; k++) { - final int offL = k * rowsLeft.size(); - final int offR = k * colsRight.size(); - for(int i = 0; i < rowsLeft.size(); i++) { - final int rowOut = rowsLeft.get(i); - final double vl = left[offL + i]; - if(vl != 0) { - for(int j = 0; j < colsRight.size(); j++) - resV[colsRight.get(j) * resCols + rowOut] += vl * right[offR + j]; - } + final int offL = k * lSize; + final double vl = left[offL + i]; + if(vl != 0) { + final int offR = k * rSize; + for(int j = 0; j < rSize; j++) + resV[colsRight.get(j) * resCols + rowOut] += vl * right[offR + j]; } } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java index 4f0bbfbee14..9fea6952890 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java @@ -34,9 +34,11 @@ import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.Builtin; +import org.apache.sysds.runtime.functionobjects.Multiply; import org.apache.sysds.runtime.functionobjects.Plus; import org.apache.sysds.runtime.functionobjects.ValueFunction; import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; +import org.apache.sysds.runtime.matrix.data.LibMatrixMult; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator; @@ -181,13 +183,26 @@ public double[] aggregateRowsWithReference(Builtin fn, double[] reference) { } @Override - public Dictionary applyScalarOp(ScalarOperator op) { + public IDictionary applyScalarOp(ScalarOperator op) { + if(op.fn instanceof Multiply) + return applyScalarMultOp(op.getConstant()); + else + return applyScalarGeneric(op); + } + + private IDictionary applyScalarGeneric(ScalarOperator op) { final double[] retV = new double[_values.length]; for(int i = 0; i < _values.length; i++) retV[i] = op.executeScalar(_values[i]); return create(retV); } + private IDictionary applyScalarMultOp(double v) { + final double[] retV = new double[_values.length]; + LibMatrixMult.vectMultiplyAdd(v, _values, retV, 0, 0, _values.length); + return create(retV); + } + @Override public IDictionary applyScalarOpAndAppend(ScalarOperator op, double v0, int nCol) { final double[] retV = new double[_values.length + nCol]; @@ -649,10 +664,12 @@ public String toString() { private static void stringArray(StringBuilder sb, double[] val) { sb.append("["); - sb.append(doubleToString(val[0])); - for(int i = 1; i < val.length; i++) { - sb.append(", "); - sb.append(doubleToString(val[i])); + if(val.length > 0) { + sb.append(doubleToString(val[0])); + for(int i = 1; i < val.length; i++) { + sb.append(", "); + sb.append(doubleToString(val[i])); + } } sb.append("]"); } @@ -1120,12 +1137,20 @@ public void TSMMToUpperTriangleSparseScaling(SparseBlock left, IColIndex rowsLef @Override public boolean equals(IDictionary o) { - if(o instanceof Dictionary) { + if(o instanceof Dictionary) return Arrays.equals(_values, ((Dictionary) o)._values); - } + else if(o instanceof IdentityDictionary) + return ((IdentityDictionary) o).equals(this); else if(o instanceof MatrixBlockDictionary) { final MatrixBlock mb = ((MatrixBlockDictionary) o).getMatrixBlock(); - if(mb.isInSparseFormat()) + if(mb.isEmpty()) { + for(int i = 0; i < _values.length; i++) { + if(_values[i] != 0) + return false; + } + return true; + } + else if(mb.isInSparseFormat()) return mb.getSparseBlock().equals(_values, mb.getNumColumns()); final double[] dv = mb.getDenseBlockValues(); return Arrays.equals(_values, dv); @@ -1154,4 +1179,12 @@ public IDictionary reorder(int[] reorder) { } return ret; } + + @Override + public IDictionary append(double[] row) { + double[] retV = new double[_values.length + row.length]; + System.arraycopy(_values, 0, retV, 0, _values.length); + System.arraycopy(row, 0, retV, _values.length, row.length); + return new Dictionary(retV); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java index a75ca3f865e..8f7c8e85091 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java @@ -21,6 +21,7 @@ import java.io.DataInput; import java.io.IOException; +import java.util.List; import java.util.Map; import org.apache.commons.lang3.NotImplementedException; @@ -35,6 +36,7 @@ import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; import org.apache.sysds.runtime.compress.colgroup.IContainADictionary; import org.apache.sysds.runtime.compress.colgroup.IContainDefaultTuple; +import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.lib.CLALibCombineGroups; import org.apache.sysds.runtime.compress.utils.ACount; import org.apache.sysds.runtime.compress.utils.DblArray; @@ -43,6 +45,7 @@ import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseRowVector; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.Pair; public interface DictionaryFactory { static final Log LOG = LogFactory.getLog(DictionaryFactory.class.getName()); @@ -149,8 +152,7 @@ else if(ubm instanceof MultiColBitmap) { return Dictionary.create(resValues); } - throw new NotImplementedException( - "Not implemented creation of bitmap type : " + ubm.getClass().getSimpleName()); + throw new NotImplementedException("Not implemented creation of bitmap type : " + ubm.getClass().getSimpleName()); } public static IDictionary create(ABitmap ubm, int defaultIndex, double[] defaultTuple, double sparsity, @@ -257,34 +259,37 @@ public static IDictionary combineDictionaries(AColGroupCompressed a, AColGroupCo if(ae && be) { - IDictionary ad = ((IContainADictionary) a).getDictionary(); - IDictionary bd = ((IContainADictionary) b).getDictionary(); + final IDictionary ad = ((IContainADictionary) a).getDictionary(); + final IDictionary bd = ((IContainADictionary) b).getDictionary(); if(ac.isConst()) { if(bc.isConst()) { return Dictionary.create(CLALibCombineGroups.constructDefaultTuple(a, b)); } else if(bc.isDense()) { final double[] at = ((IContainDefaultTuple) a).getDefaultTuple(); - return combineConstSparseSparseRet(at, bd, b.getNumCols(), filter); + Pair r = IColIndex.reorderingIndexes(a.getColIndices(), b.getColIndices()); + return combineConstLeft(at, bd, b.getNumCols(), r.getKey(), r.getValue(), filter); } } else if(ac.isDense()) { if(bc.isConst()) { + final Pair r = IColIndex.reorderingIndexes(a.getColIndices(), b.getColIndices()); final double[] bt = ((IContainDefaultTuple) b).getDefaultTuple(); - return combineSparseConstSparseRet(ad, a.getNumCols(), bt, filter); + return combineSparseConstSparseRet(ad, a.getNumCols(), bt, r.getKey(), r.getValue(), filter); + } + else if(bc.isDense()) { + return combineFullDictionaries(ad, a.getColIndices(), bd, b.getColIndices(), filter); } - else if(bc.isDense()) - return combineFullDictionaries(ad, a.getNumCols(), bd, b.getNumCols(), filter); else if(bc.isSDC()) { double[] tuple = ((IContainDefaultTuple) b).getDefaultTuple(); - return combineSDCRight(ad, a.getNumCols(), bd, tuple, filter); + return combineSDCRight(ad, a.getColIndices(), bd, tuple, b.getColIndices(), filter); } } else if(ac.isSDC()) { if(bc.isSDC()) { final double[] at = ((IContainDefaultTuple) a).getDefaultTuple(); final double[] bt = ((IContainDefaultTuple) b).getDefaultTuple(); - return combineSDC(ad, at, bd, bt, filter); + return combineSDCFilter(ad, at, a.getColIndices(), bd, bt, b.getColIndices(), filter); } } } @@ -300,34 +305,83 @@ else if(ac.isSDC()) { * @return The combined dictionary */ public static IDictionary combineDictionariesSparse(AColGroupCompressed a, AColGroupCompressed b) { + return combineDictionariesSparse(a, b, null); + } + + /** + * Combine the dictionaries assuming a sparse combination where each dictionary can be a SDC containing a default + * element that have to be introduced into the combined dictionary. + * + * @param a A Dictionary can be SDC or const + * @param b A Dictionary can be Const or SDC + * @param filter A filter to remove elements in the combined dictionary + * @return The combined dictionary + */ + public static IDictionary combineDictionariesSparse(AColGroupCompressed a, AColGroupCompressed b, + Map filter) { CompressionType ac = a.getCompType(); CompressionType bc = b.getCompType(); + if(filter != null) + throw new NotImplementedException("Not supported filter for sparse join yet!"); + if(ac.isSDC()) { - IDictionary ad = ((IContainADictionary) a).getDictionary(); + final IDictionary ad = ((IContainADictionary) a).getDictionary(); if(bc.isConst()) { + final Pair r = IColIndex.reorderingIndexes(a.getColIndices(), b.getColIndices()); double[] bt = ((IContainDefaultTuple) b).getDefaultTuple(); - return combineSparseConstSparseRet(ad, a.getNumCols(), bt); + return combineSparseConstSparseRet(ad, a.getNumCols(), bt, r.getKey(), r.getValue()); } else if(bc.isSDC()) { - IDictionary bd = ((IContainADictionary) b).getDictionary(); + final IDictionary bd = ((IContainADictionary) b).getDictionary(); if(a.sameIndexStructure(b)) { - return ad.cbind(bd, b.getNumCols()); + // in order or other order.. + if(IColIndex.inOrder(a.getColIndices(), b.getColIndices())) + return ad.cbind(bd, b.getNumCols()); + else if(IColIndex.inOrder(b.getColIndices(), a.getColIndices())) + return bd.cbind(ad, b.getNumCols()); + else { + final Pair r = IColIndex.reorderingIndexes(a.getColIndices(), b.getColIndices()); + return cbindReorder(ad, bd, r.getKey(), r.getValue()); + } } - // real combine extract default and combine like dense but with default before. } } else if(ac.isConst()) { - double[] at = ((IContainDefaultTuple) a).getDefaultTuple(); + final double[] at = ((IContainDefaultTuple) a).getDefaultTuple(); if(bc.isSDC()) { - IDictionary bd = ((IContainADictionary) b).getDictionary(); - return combineConstSparseSparseRet(at, bd, b.getNumCols()); + final IDictionary bd = ((IContainADictionary) b).getDictionary(); + final Pair r = IColIndex.reorderingIndexes(a.getColIndices(), b.getColIndices()); + return combineConstLeftAll(at, bd, b.getNumCols(), r.getKey(), r.getValue()); } } throw new NotImplementedException("Not supporting combining dense: " + a + " " + b); } + private static IDictionary cbindReorder(IDictionary a, IDictionary b, int[] ai, int[] bi) { + final int nca = ai.length; + final int ncb = bi.length; + final int ra = a.getNumberOfValues(nca); + final int rb = b.getNumberOfValues(ncb); + final MatrixBlock ma = a.getMBDict(nca).getMatrixBlock(); + final MatrixBlock mb = b.getMBDict(ncb).getMatrixBlock(); + if(ra != rb) + throw new DMLCompressionException("Invalid cbind reorder, different sizes of dictionaries"); + final MatrixBlock out = new MatrixBlock(ra, nca + ncb, false); + + for(int r = 0; r < ra; r++) {// each row + // + for(int c = 0; c < nca; c++) + out.quickSetValue(r, ai[c], ma.quickGetValue(r, c)); + + for(int c = 0; c < ncb; c++) + out.quickSetValue(r, bi[c], mb.quickGetValue(r, c)); + } + + return new MatrixBlockDictionary(out); + } + /** * Combine the dictionaries as if the dictionaries contain the full spectrum of the combined data. * @@ -341,6 +395,13 @@ public static IDictionary combineFullDictionaries(IDictionary a, int nca, IDicti return combineFullDictionaries(a, nca, b, ncb, null); } + public static IDictionary combineFullDictionaries(IDictionary a, IColIndex ai, IDictionary b, IColIndex bi, + Map filter) { + final int nca = ai.size(); + final int ncb = bi.size(); + return combineFullDictionaries(a, ai, nca, b, bi, ncb, filter); + } + /** * Combine the dictionaries as if the dictionaries only contain the values in the specified filter. * @@ -348,64 +409,112 @@ public static IDictionary combineFullDictionaries(IDictionary a, int nca, IDicti * @param nca Number of columns left dictionary * @param b Right side dictionary * @param ncb Number of columns right dictionary - * @param filter The mapping filter to not include all possible combinations in the output, this filter is allowed - * to be null, that means the output is defaulting back to a full combine + * @param filter The mapping filter to not include all possible combinations in the output, this filter is allowed to + * be null, that means the output is defaulting back to a full combine * @return A combined dictionary */ public static IDictionary combineFullDictionaries(IDictionary a, int nca, IDictionary b, int ncb, Map filter) { + return combineFullDictionaries(a, null, nca, b, null, ncb, filter); + } + + public static IDictionary combineFullDictionaries(IDictionary a, IColIndex ai, int nca, IDictionary b, IColIndex bi, + int ncb, Map filter) { final int ra = a.getNumberOfValues(nca); final int rb = b.getNumberOfValues(ncb); + final MatrixBlock ma = a.getMBDict(nca).getMatrixBlock(); + final MatrixBlock mb = b.getMBDict(ncb).getMatrixBlock(); + final MatrixBlock out = new MatrixBlock(filter != null ? filter.size() : ra * rb, nca + ncb, false); + out.allocateBlock(); - MatrixBlock ma = a.getMBDict(nca).getMatrixBlock(); - MatrixBlock mb = b.getMBDict(ncb).getMatrixBlock(); - - if(ra == 1 && rb == 1) - return new MatrixBlockDictionary(ma.append(mb)); + if(ai != null && bi != null && !IColIndex.inOrder(ai, bi)) { - MatrixBlock out = new MatrixBlock(filter != null ? filter.size() : ra * rb, nca + ncb, false); + Pair reordering = IColIndex.reorderingIndexes(ai, bi); + if(filter != null) + // throw new NotImplementedException(); + combineFullDictionariesOOOFilter(out, filter, ra, rb, nca, ncb, reordering.getKey(), reordering.getValue(), + ma, mb); + else + combineFullDictionariesOOONoFilter(out, ra, rb, nca, ncb, reordering.getKey(), reordering.getValue(), ma, + mb); - out.allocateBlock(); + } + else { + if(filter != null) + combineFullDictionariesFilter(out, filter, ra, rb, nca, ncb, ma, mb); + else + combineFullDictionariesNoFilter(out, ra, rb, nca, ncb, ma, mb); + } - if(filter != null) { - for(int r : filter.keySet()) { - int o = filter.get(r); - int ia = r % ra; - int ib = r / ra; - for(int c = 0; c < nca; c++) - out.quickSetValue(o, c, ma.quickGetValue(ia, c)); + out.examSparsity(true); - for(int c = 0; c < ncb; c++) - out.quickSetValue(o, c + nca, mb.quickGetValue(ib, c)); + return new MatrixBlockDictionary(out); + } - } + private static void combineFullDictionariesFilter(MatrixBlock out, Map filter, int ra, int rb, + int nca, int ncb, MatrixBlock ma, MatrixBlock mb) { + for(int r : filter.keySet()) { + int o = filter.get(r); + int ia = r % ra; + int ib = r / ra; + for(int c = 0; c < nca; c++) + out.quickSetValue(o, c, ma.quickGetValue(ia, c)); + for(int c = 0; c < ncb; c++) + out.quickSetValue(o, c + nca, mb.quickGetValue(ib, c)); } - else { + } - for(int r = 0; r < out.getNumRows(); r++) { - int ia = r % ra; - int ib = r / ra; - for(int c = 0; c < nca; c++) - out.quickSetValue(r, c, ma.quickGetValue(ia, c)); + private static void combineFullDictionariesOOOFilter(MatrixBlock out, Map filter, int ra, int rb, + int nca, int ncb, int[] ai, int[] bi, MatrixBlock ma, MatrixBlock mb) { + for(int r : filter.keySet()) { + int o = filter.get(r); + int ia = r % ra; + int ib = r / ra; + for(int c = 0; c < nca; c++) + out.quickSetValue(o, ai[c], ma.quickGetValue(ia, c)); + for(int c = 0; c < ncb; c++) + out.quickSetValue(o, bi[c], mb.quickGetValue(ib, c)); + } + } - for(int c = 0; c < ncb; c++) - out.quickSetValue(r, c + nca, mb.quickGetValue(ib, c)); + private static void combineFullDictionariesOOONoFilter(MatrixBlock out, int ra, int rb, int nca, int ncb, int[] ai, + int[] bi, MatrixBlock ma, MatrixBlock mb) { + for(int r = 0; r < out.getNumRows(); r++) { + int ia = r % ra; + int ib = r / ra; + for(int c = 0; c < nca; c++) + out.quickSetValue(r, ai[c], ma.quickGetValue(ia, c)); + for(int c = 0; c < ncb; c++) + out.quickSetValue(r, bi[c], mb.quickGetValue(ib, c)); + } + } - } + private static void combineFullDictionariesNoFilter(MatrixBlock out, int ra, int rb, int nca, int ncb, + MatrixBlock ma, MatrixBlock mb) { + for(int r = 0; r < out.getNumRows(); r++) { + int ia = r % ra; + int ib = r / ra; + for(int c = 0; c < nca; c++) + out.quickSetValue(r, c, ma.quickGetValue(ia, c)); + for(int c = 0; c < ncb; c++) + out.quickSetValue(r, c + nca, mb.quickGetValue(ib, c)); } - return new MatrixBlockDictionary(out); } - public static IDictionary combineSDCRight(IDictionary a, int nca, IDictionary b, double[] tub) { + public static IDictionary combineSDCRightNoFilter(IDictionary a, int nca, IDictionary b, double[] tub) { + return combineSDCRightNoFilter(a, null, nca, b, tub, null); + } + public static IDictionary combineSDCRightNoFilter(IDictionary a, IColIndex ai, int nca, IDictionary b, double[] tub, + IColIndex bi) { + if(ai != null || bi != null) + throw new NotImplementedException(); final int ncb = tub.length; final int ra = a.getNumberOfValues(nca); final int rb = b.getNumberOfValues(ncb); - - MatrixBlock ma = a.getMBDict(nca).getMatrixBlock(); - MatrixBlock mb = b.getMBDict(ncb).getMatrixBlock(); - - MatrixBlock out = new MatrixBlock(ra * (rb + 1), nca + ncb, false); + final MatrixBlock ma = a.getMBDict(nca).getMatrixBlock(); + final MatrixBlock mb = b.getMBDict(ncb).getMatrixBlock(); + final MatrixBlock out = new MatrixBlock(ra * (rb + 1), nca + ncb, false); out.allocateBlock(); @@ -430,65 +539,116 @@ public static IDictionary combineSDCRight(IDictionary a, int nca, IDictionary b, return new MatrixBlockDictionary(out); } + public static IDictionary combineSDCRight(IDictionary a, IColIndex ai, IDictionary b, double[] tub, IColIndex bi, + Map filter) { + return combineSDCRight(a, ai, ai.size(), b, tub, bi, filter); + } + public static IDictionary combineSDCRight(IDictionary a, int nca, IDictionary b, double[] tub, Map filter) { + return combineSDCRight(a, null, nca, b, tub, null, filter); + } + + public static IDictionary combineSDCRight(IDictionary a, IColIndex ai, int nca, IDictionary b, double[] tub, + IColIndex bi, Map filter) { if(filter == null) - return combineSDCRight(a, nca, b, tub); + return combineSDCRightNoFilter(a, ai, nca, b, tub, bi); + final int ncb = tub.length; final int ra = a.getNumberOfValues(nca); final int rb = b.getNumberOfValues(ncb); - - MatrixBlock ma = a.getMBDict(nca).getMatrixBlock(); - MatrixBlock mb = b.getMBDict(ncb).getMatrixBlock(); - - MatrixBlock out = new MatrixBlock(filter.size(), nca + ncb, false); - + final MatrixBlock ma = a.getMBDict(nca).getMatrixBlock(); + final MatrixBlock mb = b.getMBDict(ncb).getMatrixBlock(); + final MatrixBlock out = new MatrixBlock(filter.size(), nca + ncb, false); out.allocateBlock(); + if(ai != null && bi != null) { + Pair re = IColIndex.reorderingIndexes(ai, bi); + combineSDCRightOOOFilter(out, nca, ncb, tub, ra, rb, ma, mb, re.getKey(), re.getValue(), filter); + } + else { + combineSDCRightFilter(out, nca, ncb, tub, ra, rb, ma, mb, filter); + } + return new MatrixBlockDictionary(out); + } + + private static void combineSDCRightFilter(MatrixBlock out, int nca, int ncb, double[] tub, int ra, int rb, + MatrixBlock ma, MatrixBlock mb, Map filter) { for(int r = 0; r < ra; r++) { if(filter.containsKey(r)) { - int o = filter.get(r); for(int c = 0; c < nca; c++) out.quickSetValue(o, c, ma.quickGetValue(r, c)); for(int c = 0; c < ncb; c++) out.quickSetValue(o, c + nca, tub[c]); } - } - - for(int r = ra; r < ra * rb; r++) { + for(int r = ra; r < ra * rb + ra; r++) { if(filter.containsKey(r)) { int o = filter.get(r); - int ia = r % ra; int ib = r / ra - 1; for(int c = 0; c < nca; c++) // all good. out.quickSetValue(o, c, ma.quickGetValue(ia, c)); - for(int c = 0; c < ncb; c++) out.quickSetValue(o, c + nca, mb.quickGetValue(ib, c)); + } + } + } + private static void combineSDCRightOOOFilter(MatrixBlock out, int nca, int ncb, double[] tub, int ra, int rb, + MatrixBlock ma, MatrixBlock mb, int[] ai, int[] bi, Map filter) { + for(int r = 0; r < ra; r++) { + if(filter.containsKey(r)) { + int o = filter.get(r); + for(int c = 0; c < nca; c++) + out.quickSetValue(o, ai[c], ma.quickGetValue(r, c)); + for(int c = 0; c < ncb; c++) + out.quickSetValue(o, bi[c], tub[c]); } } - return new MatrixBlockDictionary(out); + for(int r = ra; r < ra * rb + ra; r++) { + if(filter.containsKey(r)) { + int o = filter.get(r); + int ia = r % ra; + int ib = r / ra - 1; + for(int c = 0; c < nca; c++) // all good. + out.quickSetValue(o, ai[c], ma.quickGetValue(ia, c)); + for(int c = 0; c < ncb; c++) + out.quickSetValue(o, bi[c], mb.quickGetValue(ib, c)); + } + } + } + + public static IDictionary combineSDCNoFilter(IDictionary a, double[] tua, IDictionary b, double[] tub) { + return combineSDCNoFilter(a, tua, null, b, tub, null); } - public static IDictionary combineSDC(IDictionary a, double[] tua, IDictionary b, double[] tub) { + public static IDictionary combineSDCNoFilter(IDictionary a, double[] tua, IColIndex ai, IDictionary b, double[] tub, + IColIndex bi) { final int nca = tua.length; final int ncb = tub.length; final int ra = a.getNumberOfValues(nca); final int rb = b.getNumberOfValues(ncb); + final MatrixBlock ma = a.getMBDict(nca).getMatrixBlock(); + final MatrixBlock mb = b.getMBDict(ncb).getMatrixBlock(); + final MatrixBlock out = new MatrixBlock((ra + 1) * (rb + 1), nca + ncb, false); - MatrixBlock ma = a.getMBDict(nca).getMatrixBlock(); - MatrixBlock mb = b.getMBDict(ncb).getMatrixBlock(); + out.allocateBlock(); - MatrixBlock out = new MatrixBlock((ra + 1) * (rb + 1), nca + ncb, false); + if(ai != null || bi != null) { + final Pair re = IColIndex.reorderingIndexes(ai, bi); + combineSDCNoFilterOOO(nca, ncb, tua, tub, out, ma, mb, ra, rb, re.getKey(), re.getValue()); + } + else + combineSDCNoFilter(nca, ncb, tua, tub, out, ma, mb, ra, rb); + return new MatrixBlockDictionary(out); + } - out.allocateBlock(); + private static void combineSDCNoFilter(int nca, int ncb, double[] tua, double[] tub, MatrixBlock out, MatrixBlock ma, + MatrixBlock mb, int ra, int rb) { // 0 row both default tuples - for(int c = 0; c < nca; c++) out.quickSetValue(0, c, tua[c]); @@ -504,8 +664,8 @@ public static IDictionary combineSDC(IDictionary a, double[] tua, IDictionary b, } for(int r = ra + 1; r < out.getNumRows(); r++) { - int ia = r % (ra + 1) - 1; - int ib = r / (ra + 1) - 1; + final int ia = r % (ra + 1) - 1; + final int ib = r / (ra + 1) - 1; if(ia == -1) for(int c = 0; c < nca; c++) @@ -516,27 +676,74 @@ public static IDictionary combineSDC(IDictionary a, double[] tua, IDictionary b, for(int c = 0; c < ncb; c++) // all good here. out.quickSetValue(r, c + nca, mb.quickGetValue(ib, c)); + } + } + + private static void combineSDCNoFilterOOO(int nca, int ncb, double[] tua, double[] tub, MatrixBlock out, + MatrixBlock ma, MatrixBlock mb, int ra, int rb, int[] ai, int[] bi) { + + // 0 row both default tuples + for(int c = 0; c < nca; c++) + out.quickSetValue(0, ai[c], tua[c]); + for(int c = 0; c < ncb; c++) + out.quickSetValue(0, bi[c], tub[c]); + + // default case for b and all cases for a. + for(int r = 1; r < ra + 1; r++) { + for(int c = 0; c < nca; c++) + out.quickSetValue(r, ai[c], ma.quickGetValue(r - 1, c)); + for(int c = 0; c < ncb; c++) + out.quickSetValue(r, bi[c], tub[c]); } - return new MatrixBlockDictionary(out); + for(int r = ra + 1; r < out.getNumRows(); r++) { + final int ia = r % (ra + 1) - 1; + final int ib = r / (ra + 1) - 1; + + if(ia == -1) + for(int c = 0; c < nca; c++) + out.quickSetValue(r, ai[c], tua[c]); + else + for(int c = 0; c < nca; c++) + out.quickSetValue(r, ai[c], ma.quickGetValue(ia, c)); + + for(int c = 0; c < ncb; c++) // all good here. + out.quickSetValue(r, bi[c], mb.quickGetValue(ib, c)); + } } - public static IDictionary combineSDC(IDictionary a, double[] tua, IDictionary b, double[] tub, + public static IDictionary combineSDCFilter(IDictionary a, double[] tua, IDictionary b, double[] tub, Map filter) { + return combineSDCFilter(a, tua, null, b, tub, null, filter); + } + + public static IDictionary combineSDCFilter(IDictionary a, double[] tua, IColIndex ai, IDictionary b, double[] tub, + IColIndex bi, Map filter) { if(filter == null) - return combineSDC(a, tua, b, tub); + return combineSDCNoFilter(a, tua, ai, b, tub, bi); + final int nca = tua.length; final int ncb = tub.length; final int ra = a.getNumberOfValues(nca); - final int rb = b.getNumberOfValues(nca); + final int rb = b.getNumberOfValues(ncb); + final MatrixBlock ma = a.getMBDict(nca).getMatrixBlock(); + final MatrixBlock mb = b.getMBDict(ncb).getMatrixBlock(); + final MatrixBlock out = new MatrixBlock(filter.size(), nca + ncb, false); + out.allocateBlock(); - MatrixBlock ma = a.getMBDict(nca).getMatrixBlock(); - MatrixBlock mb = b.getMBDict(ncb).getMatrixBlock(); + if(ai != null && bi != null) { + Pair re = IColIndex.reorderingIndexes(ai, bi); + combineSDCFilterOOO(filter, nca, ncb, tua, tub, out, ma, mb, ra, rb, re.getKey(), re.getValue()); + } + else + combineSDCFilter(filter, nca, ncb, tua, tub, out, ma, mb, ra, rb); - MatrixBlock out = new MatrixBlock(filter.size(), nca + ncb, false); + return new MatrixBlockDictionary(out); + } - out.allocateBlock(); + private static void combineSDCFilter(Map filter, int nca, int ncb, double[] tua, double[] tub, + MatrixBlock out, MatrixBlock ma, MatrixBlock mb, int ra, int rb) { // 0 row both default tuples if(filter.containsKey(0)) { @@ -559,13 +766,12 @@ public static IDictionary combineSDC(IDictionary a, double[] tua, IDictionary b, } } - for(int r = ra + 1; r < ra * rb; r++) { + for(int r = ra + 1; r < ra * rb + ra + rb + 1; r++) { if(filter.containsKey(r)) { - int o = filter.get(r); - - int ia = r % (ra + 1) - 1; - int ib = r / (ra + 1) - 1; + final int o = filter.get(r); + final int ia = r % (ra + 1) - 1; + final int ib = r / (ra + 1) - 1; if(ia == -1) for(int c = 0; c < nca; c++) @@ -578,12 +784,50 @@ public static IDictionary combineSDC(IDictionary a, double[] tua, IDictionary b, out.quickSetValue(o, c + nca, mb.quickGetValue(ib, c)); } } + } - return new MatrixBlockDictionary(out); + private static void combineSDCFilterOOO(Map filter, int nca, int ncb, double[] tua, double[] tub, + MatrixBlock out, MatrixBlock ma, MatrixBlock mb, int ra, int rb, int[] ai, int[] bi) { + // 0 row both default tuples + if(filter.containsKey(0)) { + final int o = filter.get(0); + for(int c = 0; c < nca; c++) + out.quickSetValue(o, ai[c], tua[c]); + for(int c = 0; c < ncb; c++) + out.quickSetValue(o, bi[c], tub[c]); + } + + // default case for b and all cases for a. + for(int r = 1; r < ra + 1; r++) { + if(filter.containsKey(r)) { + final int o = filter.get(r); + for(int c = 0; c < nca; c++) + out.quickSetValue(o, ai[c], ma.quickGetValue(r - 1, c)); + for(int c = 0; c < ncb; c++) + out.quickSetValue(o, bi[c], tub[c]); + } + } + + for(int r = ra + 1; r < ra * rb + ra + rb + 1; r++) { + if(filter.containsKey(r)) { + final int o = filter.get(r); + final int ia = r % (ra + 1) - 1; + final int ib = r / (ra + 1) - 1; + + if(ia == -1) + for(int c = 0; c < nca; c++) + out.quickSetValue(o, ai[c], tua[c]); + else + for(int c = 0; c < nca; c++) + out.quickSetValue(o, ai[c], ma.quickGetValue(ia, c)); + for(int c = 0; c < ncb; c++) // all good here. + out.quickSetValue(o, bi[c], mb.quickGetValue(ib, c)); + } + } } - public static IDictionary combineSparseConstSparseRet(IDictionary a, int nca, double[] tub) { + private static IDictionary combineSparseConstSparseRet(IDictionary a, int nca, double[] tub, int[] ai, int[] bi) { final int ncb = tub.length; final int ra = a.getNumberOfValues(nca); @@ -596,19 +840,19 @@ public static IDictionary combineSparseConstSparseRet(IDictionary a, int nca, do // default case for b and all cases for a. for(int r = 0; r < ra; r++) { for(int c = 0; c < nca; c++) - out.quickSetValue(r, c, ma.quickGetValue(r, c)); + out.quickSetValue(r, ai[c], ma.quickGetValue(r, c)); for(int c = 0; c < ncb; c++) - out.quickSetValue(r, c + nca, tub[c]); + out.quickSetValue(r, bi[c], tub[c]); } return new MatrixBlockDictionary(out); } - private static IDictionary combineSparseConstSparseRet(IDictionary a, int nca, double[] tub, + private static IDictionary combineSparseConstSparseRet(IDictionary a, int nca, double[] tub, int[] ai, int[] bi, Map filter) { if(filter == null) - return combineSparseConstSparseRet(a, nca, tub); + return combineSparseConstSparseRet(a, nca, tub, ai, bi); else throw new NotImplementedException(); // final int ncb = tub.length; @@ -632,7 +876,7 @@ private static IDictionary combineSparseConstSparseRet(IDictionary a, int nca, d } - public static IDictionary combineConstSparseSparseRet(double[] tua, IDictionary b, int ncb) { + private static IDictionary combineConstLeftAll(double[] tua, IDictionary b, int ncb, int[] ai, int[] bi) { final int nca = tua.length; final int rb = b.getNumberOfValues(ncb); @@ -645,19 +889,19 @@ public static IDictionary combineConstSparseSparseRet(double[] tua, IDictionary // default case for b and all cases for a. for(int r = 0; r < rb; r++) { for(int c = 0; c < nca; c++) - out.quickSetValue(r, c, tua[c]); + out.quickSetValue(r, ai[c], tua[c]); for(int c = 0; c < ncb; c++) - out.quickSetValue(r, c + nca, mb.quickGetValue(r, c)); + out.quickSetValue(r, bi[c], mb.quickGetValue(r, c)); } return new MatrixBlockDictionary(out); } - private static IDictionary combineConstSparseSparseRet(double[] tua, IDictionary b, int ncb, + private static IDictionary combineConstLeft(double[] tua, IDictionary b, int ncb, int[] ai, int[] bi, Map filter) { if(filter == null) - return combineConstSparseSparseRet(tua, b, ncb); + return combineConstLeftAll(tua, b, ncb, ai, bi); else throw new NotImplementedException(); // final int nca = tua.length; @@ -680,4 +924,15 @@ private static IDictionary combineConstSparseSparseRet(double[] tua, IDictionary // return new MatrixBlockDictionary(out); } + + public static IDictionary cBindDictionaries(List> dicts) { + MatrixBlock base = dicts.get(0).getValue().getMBDict(dicts.get(0).getKey()).getMatrixBlock(); + MatrixBlock[] others = new MatrixBlock[dicts.size() - 1]; + for(int i = 1; i < dicts.size(); i++) { + Pair p = dicts.get(i); + others[i - 1] = p.getValue().getMBDict(p.getKey()).getMatrixBlock(); + } + MatrixBlock ret = base.append(others, null, true); + return new MatrixBlockDictionary(ret); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java index 1047692f509..2e543eeaefa 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java @@ -809,9 +809,8 @@ public CM_COV_Object centralMomentWithReference(CM_COV_Object ret, ValueFunction public void TSMMWithScaling(int[] counts, IColIndex rows, IColIndex cols, MatrixBlock ret); /** - * Matrix multiplication of dictionaries - * - * Note the left is this, and it is transposed + * Matrix multiplication of dictionaries note the left is this, and it is supposed to be treated as if it is + * transposed while it is not physically transposed. * * @param right Right hand side of multiplication * @param rowsLeft Offset rows on the left @@ -821,11 +820,10 @@ public CM_COV_Object centralMomentWithReference(CM_COV_Object ret, ValueFunction public void MMDict(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result); /** - * Matrix multiplication of dictionaries - * - * Note the left is this, and it is transposed + * Matrix multiplication of dictionaries, note the left is this, and it is supposed to be treated as if it is + * transposed while it is not physically transposed. * - * @param right Right hand side of multiplication + * @param right Right hand side of multiplication (not transposed) * @param rowsLeft Offset rows on the left * @param colsRight Offset cols on the right * @param result The output matrix block @@ -835,9 +833,9 @@ public void MMDictScaling(IDictionary right, IColIndex rowsLeft, IColIndex colsR int[] scaling); /** - * Matrix multiplication of dictionaries left side dense and transposed right side is this. + * Matrix multiplication of dictionaries left side dense and transposed. The right side is this. * - * @param left Dense left side + * @param left Dense left side (treat as if it is transposed but it is physically not) * @param rowsLeft Offset rows on the left * @param colsRight Offset cols on the right * @param result The output matrix block @@ -845,13 +843,14 @@ public void MMDictScaling(IDictionary right, IColIndex rowsLeft, IColIndex colsR public void MMDictDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result); /** - * Matrix multiplication of dictionaries left side dense and transposed right side is this. + * Matrix multiplication of dictionaries left side dense and transposed. The Right side is this, the scaling factor + * is used to multiply each element with. * - * @param left Dense left side - * @param rowsLeft Offset rows on the left - * @param colsRight Offset cols on the right - * @param result The output matrix block - * @param scaling The scaling + * @param left Dense left side (Dense dictionary) + * @param rowsLeft Offset rows on the left (That dictionaries column indexes) + * @param colsRight Offset cols on the right (This dictionaries column indexes) + * @param result The output matrix block, guaranteed to be allocated as dense. + * @param scaling The scaling factor to multiply each entry with. */ public void MMDictScalingDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, int[] scaling); @@ -865,8 +864,8 @@ public void MMDictScalingDense(double[] left, IColIndex rowsLeft, IColIndex cols * @param result The output matrix block */ public void MMDictSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result); - -/** + + /** * Matrix multiplication of dictionaries left side sparse and transposed right side is this. * * @param left Sparse left side @@ -985,4 +984,48 @@ public void TSMMToUpperTriangleSparseScaling(SparseBlock left, IColIndex rowsLef */ public IDictionary reorder(int[] reorder); + /** + * Pre-aggregate the given sparse block for right multiplication. The returned dictionary is the new column groups + * dictionary. + * + * @param numVals The number of values in this dictionary and in the output dictionary. + * @param b The sparse block to pre aggregate, note this contains the entire right side matrix, not a + * reduced or sliced version. + * @param thisCols The column indexes of this dictionary, these correspond to the rows to extract from the + * right side matrix + * @param aggregateColumns The reduced column indexes of the right side, these are the number of columns in the + * output dictionary, and the columns to multiply this dictionary with. + * @param nColRight The number of columns in the b sparse matrix. + * @return The pre-aggregate dictionary that can be used as the output dictionary for the right matrix multiplication + */ + public IDictionary rightMMPreAggSparse(int numVals, SparseBlock b, IColIndex thisCols, IColIndex aggregateColumns, + int nColRight); + + /** + * Put the row specified into the sparse block, via append calls. + * + * @param sb The sparse block to put into + * @param idx The dictionary index to put in. + * @param rowOut The row in the sparse block to put it into + * @param nCol The number of columns in the dictionary + * @param columns The columns to output into. + */ + public void put(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns); + + /** + * Return a new dictionary with the given row appended. If possible reuse as much of the old dictionary as possible. + * + * @param row The new row to append. + * @return A new dictionary with the appended row + */ + public IDictionary append(double[] row); + + /** + * Extract the values on a given row as a dense double array. + * + * @param i The row index to extract + * @param nCol The number of columns in this columngroup + * @return The row extracted + */ + public double[] getRow(int i, int nCol); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java index 74f5e5b0991..78f646fdb7c 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java @@ -30,6 +30,7 @@ import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockFactory; +import org.apache.sysds.runtime.data.SparseBlockMCSR; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode; import org.apache.sysds.runtime.functionobjects.Divide; @@ -85,6 +86,14 @@ public IdentityDictionary(int nRowCol, boolean withEmpty) { @Override public double[] getValues() { + if(nRowCol < 3) { + // lets live with it if we call it on 3 columns. + double[] ret = new double[nRowCol * nRowCol]; + for(int i = 0; i < nRowCol; i++) { + ret[(i * nRowCol) + i] = 1; + } + return ret; + } throw new DMLCompressionException("Invalid to materialize identity Matrix Please Implement alternative"); // LOG.warn("Should not call getValues on Identity Dictionary"); // double[] ret = new double[nRowCol * nRowCol]; @@ -218,21 +227,15 @@ public IDictionary binOpRight(BinaryOperator op, double[] v, IColIndex colIndexe boolean same = false; if(op.fn instanceof Plus || op.fn instanceof Minus) { same = true; - for(int i = 0; i < colIndexes.size(); i++) { - if(v[colIndexes.get(i)] != 0.0) { - same = false; - break; - } - } + for(int i = 0; i < colIndexes.size() && same; i++) + same = v[colIndexes.get(i)] == 0.0; + } if(op.fn instanceof Divide) { same = true; - for(int i = 0; i < colIndexes.size(); i++) { - if(v[colIndexes.get(i)] != 1.0) { - same = false; - break; - } - } + for(int i = 0; i < colIndexes.size() && same; i++) + same = v[colIndexes.get(i)] == 1.0; + } if(same) return this; @@ -290,20 +293,28 @@ public double[] sumAllRowsToDouble(int nrColumns) { @Override public double[] sumAllRowsToDoubleWithDefault(double[] defaultTuple) { - double[] ret = new double[defaultTuple.length]; + double[] ret = new double[getNumberOfValues(defaultTuple.length) + 1]; + for(int i = 0; i < nRowCol; i++) + ret[i] = 1; + for(int i = 0; i < defaultTuple.length; i++) - ret[i] += 1 + defaultTuple[i]; - if(withEmpty) - ret[ret.length - 1] += -1; + ret[ret.length - 1] += defaultTuple[i]; return ret; } @Override public double[] sumAllRowsToDoubleWithReference(double[] reference) { - double[] ret = new double[nRowCol]; - Arrays.fill(ret, 1); + double[] ret = new double[getNumberOfValues(reference.length)]; + double refSum = 0; for(int i = 0; i < reference.length; i++) - ret[i] += reference[i] * nRowCol; + refSum += reference[i]; + Arrays.fill(ret, 1); + for(int i = 0; i < ret.length; i++) + ret[i] += refSum; + + if(withEmpty) + ret[ret.length - 1] += -1; + return ret; } @@ -341,11 +352,9 @@ public double[] productAllRowsToDoubleWithReference(double[] reference) { @Override public void colSum(double[] c, int[] counts, IColIndex colIndexes) { - for(int i = 0; i < colIndexes.size(); i++) { - // very nice... - final int idx = colIndexes.get(i); - c[idx] = counts[i]; - } + for(int i = 0; i < colIndexes.size(); i++) + c[colIndexes.get(i)] += counts[i]; + } @Override @@ -422,18 +431,19 @@ public long getNumberNonZerosWithReference(int[] counts, double[] reference, int @Override public void addToEntry(final double[] v, final int fr, final int to, final int nCol) { - getMBDict().addToEntry(v, fr, to, nCol); + // getMBDict().addToEntry(v, fr, to, nCol); + if(!withEmpty) + v[to * nCol + fr] += 1; + else if(fr < nRowCol) + v[to * nCol + fr] += 1; } @Override public void addToEntry(final double[] v, final int fr, final int to, final int nCol, int rep) { - if(withEmpty) { - if(fr < nRowCol) - v[to * nCol + fr] += rep; - } - else { + if(!withEmpty) + v[to * nCol + fr] += rep; + else if(fr < nRowCol) v[to * nCol + fr] += rep; - } } @Override @@ -512,16 +522,6 @@ private MatrixBlockDictionary createMBDict() { } } - @Override - public String getString(int colIndexes) { - return "IdentityMatrix of size: " + nRowCol; - } - - @Override - public String toString() { - return "IdentityMatrix of size: " + nRowCol; - } - @Override public IDictionary scaleTuples(int[] scaling, int nCol) { return getMBDict().scaleTuples(scaling, nCol); @@ -545,12 +545,13 @@ public long getExactSizeOnDisk() { @Override public IDictionary preaggValuesFromDense(final int numVals, final IColIndex colIndexes, final IColIndex aggregateColumns, final double[] b, final int cut) { + // return getMBDict().preaggValuesFromDense(numVals, colIndexes, aggregateColumns, b, cut); /** * This operations is Essentially a Identity matrix multiplication with a right hand side dense matrix, but we * need to slice out the right hand side from the input. - * + * * ColIndexes specify the rows to slice out of the right matrix. - * + * * aggregate columns specify the columns to slice out from the right. */ final int cs = colIndexes.size(); @@ -637,7 +638,8 @@ public double getSparsity() { @Override public void multiplyScalar(double v, double[] ret, int off, int dictIdx, IColIndex cols) { - getMBDict().multiplyScalar(v, ret, off, dictIdx, cols); + if(!withEmpty || dictIdx < nRowCol) + ret[off + cols.get(dictIdx)] += v; } @Override @@ -665,9 +667,9 @@ public void MMDictDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, final double[] resV = result.getDenseBlockValues(); for(int i = 0; i < leftSide; i++) {// rows in left side final int offOut = rowsLeft.get(i) * resCols; - final int leftOff = i * leftSide; + final int leftOff = i; for(int j = 0; j < commonDim; j++) { // cols in left side skipping empty from identity - resV[offOut + colsRight.get(j)] += left[leftOff + j]; + resV[offOut + colsRight.get(j)] += left[leftOff + j * leftSide]; } } } @@ -675,15 +677,14 @@ public void MMDictDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, @Override public void MMDictScalingDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, int[] scaling) { + // getMBDict().MMDictScalingDense(left, rowsLeft, colsRight, result, scaling); final int leftSide = rowsLeft.size(); final int resCols = result.getNumColumns(); - final int commonDim = Math.min(left.length / leftSide, nRowCol); final double[] resV = result.getDenseBlockValues(); - for(int i = 0; i < leftSide; i++) {// rows in left side + for(int i = 0; i < leftSide; i++) { // rows in left side final int offOut = rowsLeft.get(i) * resCols; - final int leftOff = i * leftSide; - for(int j = 0; j < commonDim; j++) { // cols in left side skipping empty from identity - resV[offOut + colsRight.get(j)] += left[leftOff + j] * scaling[j]; + for(int j = 0; j < nRowCol; j++) { // cols in right side + resV[offOut + colsRight.get(j)] += left[i + j * leftSide] * scaling[j]; } } } @@ -737,19 +738,8 @@ public void TSMMToUpperTriangleSparseScaling(SparseBlock left, IColIndex rowsLef @Override public boolean equals(IDictionary o) { if(o instanceof IdentityDictionary) - return ((IdentityDictionary) o).nRowCol == nRowCol; - - MatrixBlock mb = getMBDict().getMatrixBlock(); - if(o instanceof MatrixBlockDictionary) - return mb.equals(((MatrixBlockDictionary) o).getMatrixBlock()); - else if(o instanceof Dictionary) { - if(mb.isInSparseFormat()) - return mb.getSparseBlock().equals(((Dictionary) o)._values, nRowCol); - final double[] dv = mb.getDenseBlockValues(); - return Arrays.equals(dv, ((Dictionary) o)._values); - } - - return false; + return ((IdentityDictionary) o).nRowCol == nRowCol && ((IdentityDictionary) o).withEmpty == withEmpty; + return getMBDict().equals(o); } @Override @@ -762,4 +752,81 @@ public IDictionary reorder(int[] reorder) { return getMBDict().reorder(reorder); } + @Override + protected IDictionary rightMMPreAggSparseAllColsRight(int numVals, SparseBlock b, IColIndex thisCols, + int nColRight) { + final int thisColsSize = thisCols.size(); + final SparseBlockMCSR ret = new SparseBlockMCSR(numVals); + + for(int h = 0; h < thisColsSize; h++) { + final int colIdx = thisCols.get(h); + if(b.isEmpty(colIdx)) + continue; + + final double[] sValues = b.values(colIdx); + final int[] sIndexes = b.indexes(colIdx); + + final int sPos = b.pos(colIdx); + final int sEnd = b.size(colIdx) + sPos; + for(int i = sPos; i < sEnd; i++) { + ret.add(h, sIndexes[i], sValues[i]); + } + + } + + final MatrixBlock retB = new MatrixBlock(numVals, nColRight, -1, ret); + retB.recomputeNonZeros(); + return MatrixBlockDictionary.create(retB, false); + } + + @Override + protected IDictionary rightMMPreAggSparseSelectedCols(int numVals, SparseBlock b, IColIndex thisCols, + IColIndex aggregateColumns) { + + final int thisColsSize = thisCols.size(); + final int aggColSize = aggregateColumns.size(); + final SparseBlockMCSR ret = new SparseBlockMCSR(numVals); + + for(int h = 0; h < thisColsSize; h++) { + final int colIdx = thisCols.get(h); + if(b.isEmpty(colIdx)) + continue; + + final double[] sValues = b.values(colIdx); + final int[] sIndexes = b.indexes(colIdx); + + final int sPos = b.pos(colIdx); + final int sEnd = b.size(colIdx) + sPos; + int retIdx = 0; + for(int i = sPos; i < sEnd; i++) { + while(retIdx < aggColSize && aggregateColumns.get(retIdx) < sIndexes[i]) + retIdx++; + + if(retIdx == aggColSize) + break; + ret.add(h, retIdx, sValues[i]); + } + + } + + final MatrixBlock retB = new MatrixBlock(numVals, aggregateColumns.size(), -1, ret); + retB.recomputeNonZeros(); + return MatrixBlockDictionary.create(retB, false); + } + + @Override + public IDictionary append(double[] row) { + return getMBDict().append(row); + } + + @Override + public String getString(int colIndexes) { + return "IdentityMatrix of size: " + nRowCol + " with empty: " + withEmpty; + } + + @Override + public String toString() { + return "IdentityMatrix of size: " + nRowCol + " with empty: " + withEmpty; + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java index 167328871b4..36ed4251eaa 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java @@ -206,6 +206,11 @@ public long getNumberNonZeros(int[] counts, int nCol) { return (long) sum(counts, nCol); } + @Override + public int getNumberOfValues(int ncol) { + return nRowCol + (withEmpty ? 1 : 0); + } + @Override public MatrixBlockDictionary getMBDict(int nCol) { if(cache != null) { @@ -316,4 +321,9 @@ else if(o instanceof Dictionary) { return false; } + @Override + public void multiplyScalar(double v, double[] ret, int off, int dictIdx, IColIndex cols) { + getMBDict().multiplyScalar(v, ret, off, dictIdx, cols); + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java index 2a800837c7c..e6c0eb201e5 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java @@ -38,11 +38,13 @@ import org.apache.sysds.runtime.data.SparseBlockMCSR; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode; +import org.apache.sysds.runtime.functionobjects.Divide; import org.apache.sysds.runtime.functionobjects.Minus; import org.apache.sysds.runtime.functionobjects.Plus; import org.apache.sysds.runtime.functionobjects.ValueFunction; import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.matrix.data.LibMatrixAgg; +import org.apache.sysds.runtime.matrix.data.LibMatrixBincell; import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; @@ -63,6 +65,7 @@ public class MatrixBlockDictionary extends ADictionary { */ protected MatrixBlockDictionary(MatrixBlock data) { _data = data; + _data.examSparsity(); } public static MatrixBlockDictionary create(MatrixBlock mb) { @@ -420,7 +423,7 @@ public void aggregateColsWithReference(double[] c, Builtin fn, IColIndex colInde @Override public IDictionary applyScalarOp(ScalarOperator op) { - MatrixBlock res = _data.scalarOperations(op, new MatrixBlock()); + MatrixBlock res = LibMatrixBincell.bincellOpScalar(_data, null, op, 1); return MatrixBlockDictionary.create(res); } @@ -732,7 +735,16 @@ public MatrixBlockDictionary binOpLeftWithReference(BinaryOperator op, double[] @Override public MatrixBlockDictionary binOpRight(BinaryOperator op, double[] v, IColIndex colIndexes) { + if(op.fn instanceof Divide) { + boolean all1 = true; + for(int i = 0; i < colIndexes.size() && all1; i++) { + all1 = v[colIndexes.get(i)] == 1; + } + if(all1) + return this; + } final MatrixBlock rowVector = Util.extractValues(v, colIndexes); + rowVector.examSparsity(); final MatrixBlock ret = _data.binaryOperations(op, rowVector, null); return MatrixBlockDictionary.create(ret); } @@ -2015,7 +2027,9 @@ public double getSparsity() { @Override public void multiplyScalar(double v, double[] ret, int off, int dictIdx, IColIndex cols) { - if(_data.isInSparseFormat()) + if(v == 0) + return; + else if(_data.isInSparseFormat()) multiplyScalarSparse(v, ret, off, dictIdx, cols); else multiplyScalarDense(v, ret, off, dictIdx, cols); @@ -2077,9 +2091,11 @@ public void MMDictDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, public void MMDictScalingDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, int[] scaling) { if(_data.isInSparseFormat()) - DictLibMatrixMult.MMDictsScalingDenseSparse(left, _data.getSparseBlock(), rowsLeft, colsRight, result, scaling); + DictLibMatrixMult.MMDictsScalingDenseSparse(left, _data.getSparseBlock(), rowsLeft, colsRight, result, + scaling); else - DictLibMatrixMult.MMDictsScalingDenseDense(left, _data.getDenseBlockValues(), rowsLeft, colsRight, result,scaling); + DictLibMatrixMult.MMDictsScalingDenseDense(left, _data.getDenseBlockValues(), rowsLeft, colsRight, result, + scaling); } @Override @@ -2095,9 +2111,11 @@ public void MMDictSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRig public void MMDictScalingSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, int[] scaling) { if(_data.isInSparseFormat()) - DictLibMatrixMult.MMDictsScalingSparseSparse(left, _data.getSparseBlock(), rowsLeft, colsRight, result, scaling); + DictLibMatrixMult.MMDictsScalingSparseSparse(left, _data.getSparseBlock(), rowsLeft, colsRight, result, + scaling); else - DictLibMatrixMult.MMDictsScalingSparseDense(left, _data.getDenseBlockValues(), rowsLeft, colsRight, result, scaling); + DictLibMatrixMult.MMDictsScalingSparseDense(left, _data.getDenseBlockValues(), rowsLeft, colsRight, result, + scaling); } @Override @@ -2160,11 +2178,22 @@ public void TSMMToUpperTriangleSparseScaling(SparseBlock left, IColIndex rowsLef public boolean equals(IDictionary o) { if(o instanceof MatrixBlockDictionary) return _data.equals(((MatrixBlockDictionary) o)._data); + + else if(o instanceof IdentityDictionary) + return ((IdentityDictionary) o).equals(this); else if(o instanceof Dictionary) { - if(_data.isInSparseFormat()) - return _data.getSparseBlock().equals(((Dictionary) o)._values, _data.getNumColumns()); + double[] dVals = ((Dictionary) o)._values; + if(_data.isEmpty()) { + for(int i = 0; i < dVals.length; i++) { + if(dVals[i] != 0) + return false; + } + return true; + } + else if(_data.isInSparseFormat()) + return _data.getSparseBlock().equals(dVals, _data.getNumColumns()); final double[] dv = _data.getDenseBlockValues(); - return Arrays.equals(dv, ((Dictionary) o)._values); + return Arrays.equals(dv, dVals); } return false; @@ -2190,4 +2219,44 @@ public IDictionary reorder(int[] reorder) { return create(ret, false); } + + @Override + public IDictionary append(double[] row) { + if(_data.isEmpty()) { + throw new NotImplementedException(); + } + else if(_data.isInSparseFormat()) { + final int nRow = _data.getNumRows(); + if(_data.getSparseBlock() instanceof SparseBlockMCSR) { + MatrixBlock mb = new MatrixBlock(_data.getNumRows() + 1, _data.getNumColumns(), true); + mb.allocateBlock(); + SparseBlock sb = mb.getSparseBlock(); + SparseBlockMCSR s = (SparseBlockMCSR) _data.getSparseBlock(); + + for(int i = 0; i < _data.getNumRows(); i++) + sb.set(i, s.get(i), false); + + for(int i = 0; i < row.length; i++) + sb.set(nRow, i, row[i]); + + mb.examSparsity(); + return new MatrixBlockDictionary(mb); + + } + else { + throw new NotImplementedException("Not implemented append for CSR"); + } + + } + else { + // dense + double[] _values = _data.getDenseBlockValues(); + double[] retV = new double[_values.length + row.length]; + System.arraycopy(_values, 0, retV, 0, _values.length); + System.arraycopy(row, 0, retV, _values.length, row.length); + + MatrixBlock mb = new MatrixBlock(_data.getNumRows() + 1, _data.getNumColumns(), retV); + return new MatrixBlockDictionary(mb); + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java index 51c41ffeec6..23b4f161d58 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java @@ -195,7 +195,7 @@ public DictType getDictType() { } @Override - public int getNumberOfValues(int ncol) { + public int getNumberOfValues(int nCol) { return nVal; } @@ -525,4 +525,24 @@ public void MMDictScalingSparse(SparseBlock left, IColIndex rowsLeft, IColIndex throw new RuntimeException(errMessage); } + @Override + public IDictionary rightMMPreAggSparse(int numVals, SparseBlock b, IColIndex thisCols, IColIndex aggregateColumns, + int nColRight) { + throw new RuntimeException(errMessage); + } + + @Override + public void put(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) { + throw new RuntimeException(errMessage); + } + + @Override + public IDictionary append(double[] row) { + throw new RuntimeException(errMessage); + } + + @Override + public double[] getRow(int i, int nCol) { + throw new RuntimeException(errMessage); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java index ae833dd7a9f..5527a7d0894 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java @@ -631,4 +631,9 @@ public void MMDictScalingSparse(SparseBlock left, IColIndex rowsLeft, IColIndex int[] scaling) { throw new NotImplementedException(); } + + @Override + public IDictionary append(double[] row) { + throw new NotImplementedException(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java index b66c7ddb877..a3930bca011 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java @@ -32,6 +32,7 @@ import org.apache.sysds.runtime.compress.colgroup.IMapToDataGroup; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; +import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory.MAP_TYPE; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.AOffsetIterator; @@ -239,7 +240,7 @@ protected void preAggregateDenseToRowBy8(double[] mV, double[] preAV, int cl, in preAggregateDenseToRowVec8(mV, preAV, rc, off); } - protected void preAggregateDenseToRowVec8(double[] mV, double[] preAV, int rc, int off){ + protected void preAggregateDenseToRowVec8(double[] mV, double[] preAV, int rc, int off) { preAV[getIndex(rc)] += mV[off]; preAV[getIndex(rc + 1)] += mV[off + 1]; preAV[getIndex(rc + 2)] += mV[off + 2]; @@ -900,6 +901,12 @@ public void verify() { } } + public void lmSparseMatrixRow(final int apos, final int alen, final int[] aix, final double[] aval, final int r, + final int offR, final double[] retV, final IColIndex colIndexes, final IDictionary dict) { + for(int i = apos; i < alen; i++) + dict.multiplyScalar(aval[i], retV, offR, getIndex(aix[i]), colIndexes); + } + @Override public String toString() { final int sz = size(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java index fcbc84ce984..456d6f5b551 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java @@ -27,6 +27,8 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.colgroup.IMapToDataGroup; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; +import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory.MAP_TYPE; import org.apache.sysds.utils.MemoryEstimates; @@ -266,6 +268,13 @@ public int getMaxPossible() { return 256; } + @Override + public void lmSparseMatrixRow(final int apos, final int alen, final int[] aix, final double[] aval, final int r, + final int offR, final double[] retV, final IColIndex colIndexes, final IDictionary dict) { + for(int i = apos; i < alen; i++) + dict.multiplyScalar(aval[i], retV, offR, getIndex(aix[i]), colIndexes); + } + @Override public boolean equals(AMapToData e) { return e instanceof MapToByte && // diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java index 1f46cc3886f..a23a659672f 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java @@ -27,6 +27,8 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.colgroup.IMapToDataGroup; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; +import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory.MAP_TYPE; import org.apache.sysds.runtime.io.IOUtilFunctions; import org.apache.sysds.utils.MemoryEstimates; @@ -159,7 +161,7 @@ protected void preAggregateDenseToRowBy8(double[] mV, double[] preAV, int cl, in } @Override - protected void preAggregateDenseToRowVec8(double[] mV, double[] preAV, int rc, int off){ + protected void preAggregateDenseToRowVec8(double[] mV, double[] preAV, int rc, int off) { preAV[getIndex(rc)] += mV[off]; preAV[getIndex(rc + 1)] += mV[off + 1]; preAV[getIndex(rc + 2)] += mV[off + 2]; @@ -318,6 +320,13 @@ protected void preAggregateDDC_DDCSingleCol_vec(AMapToData tm, double[] td, doub super.preAggregateDDC_DDCSingleCol_vec(tm, td, v, r); } + @Override + public final void lmSparseMatrixRow(final int apos, final int alen, final int[] aix, final double[] aval, final int r, + final int offR, final double[] retV, final IColIndex colIndexes, final IDictionary dict) { + for(int i = apos; i < alen; i++) + dict.multiplyScalar(aval[i], retV, offR, getIndex(aix[i]), colIndexes); + } + protected final void preAggregateDDC_DDCSingleCol_vecChar(MapToChar tm, double[] td, double[] v, int r) { final int r2 = r + 1, r3 = r + 2, r4 = r + 3, r5 = r + 4, r6 = r + 5, r7 = r + 6, r8 = r + 7; v[getIndex(r)] += td[tm.getIndex(r)]; diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java index a37d5bc75aa..b78f4e77d2a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java @@ -27,6 +27,8 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.colgroup.IMapToDataGroup; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; +import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory.MAP_TYPE; import org.apache.sysds.utils.MemoryEstimates; @@ -291,7 +293,14 @@ protected void preAggregateDenseToRowBy8(double[] mV, double[] preAV, int cl, in } @Override - protected void preAggregateDenseToRowVec8(double[] mV, double[] preAV, int rc, int off){ + public void lmSparseMatrixRow(final int apos, final int alen, final int[] aix, final double[] aval, final int r, + final int offR, final double[] retV, final IColIndex colIndexes, final IDictionary dict) { + for(int i = apos; i < alen; i++) + dict.multiplyScalar(aval[i], retV, offR, getIndex(aix[i]), colIndexes); + } + + @Override + protected void preAggregateDenseToRowVec8(double[] mV, double[] preAV, int rc, int off) { preAV[getIndex(rc)] += mV[off]; preAV[getIndex(rc + 1)] += mV[off + 1]; preAV[getIndex(rc + 2)] += mV[off + 2]; diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java index adefe49a528..3d4ff985e44 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java @@ -50,6 +50,8 @@ public abstract class AOffset implements Serializable { protected static final Log LOG = LogFactory.getLog(AOffset.class.getName()); + protected static final OffsetSliceInfo emptySlice = new OffsetSliceInfo(-1, -1, new OffsetEmpty()); + /** The skip list stride size, aka how many indexes skipped for each index. */ protected static final int skipStride = 1000; @@ -104,7 +106,13 @@ else if(getLength() < skipStride) return getIteratorLargeOffset(row); } - private AIterator getIteratorSkipCache(int row){ + /** + * Get an iterator that is pointing to a specific offset, this method skips looking at our cache of iterators. + * + * @param row The row to look at + * @return The iterator associated with the row. + */ + private AIterator getIteratorSkipCache(int row) { if(row <= getOffsetToFirst()) return getIterator(); else if(row > getOffsetToLast()) @@ -478,37 +486,46 @@ public boolean equals(AOffset b) { public abstract int getLength(); public OffsetSliceInfo slice(int l, int u) { - AIterator it = getIteratorSkipCache(l); - if(it == null || it.value() >= u) - return new OffsetSliceInfo(-1, -1, new OffsetEmpty()); - else if(l <= getOffsetToFirst() && u > getOffsetToLast()) { + final int first = getOffsetToFirst(); + final int last = getOffsetToLast(); + final int s = getSize(); + + if(l <= first && u > last) { if(l == 0) - return new OffsetSliceInfo(0, getSize(), this); + return new OffsetSliceInfo(0, s, this); else - return new OffsetSliceInfo(0, getSize(), moveIndex(l)); + return new OffsetSliceInfo(0, s, moveIndex(l)); } + + final AIterator it = getIteratorSkipCache(l); + if(it == null || it.value() >= u) + return emptySlice; + + if(u >= last) // If including the last do not iterate. + return constructSliceReturn(l, it.getDataIndex(), s - 1, it.getOffsetsIndex(), getLength(), it.value(), last); + else // Have to iterate through until we find last. + return genericSlice(l, u, it); + } + + protected OffsetSliceInfo genericSlice(int l, int u, AIterator it) { final int low = it.getDataIndex(); final int lowOff = it.getOffsetsIndex(); final int lowValue = it.value(); - int high = low; int highOff = lowOff; int highValue = lowValue; - if(u >= getOffsetToLast()) { // If including the last do not iterate. - high = getSize() - 1; - highOff = getLength(); - highValue = getOffsetToLast(); - } - else { // Have to iterate through until we find last. - while(it.value() < u) { - // TODO add previous command that would allow us to simplify this loop. - high = it.getDataIndex(); - highOff = it.getOffsetsIndex(); - highValue = it.value(); - it.next(); - } + while(it.value() < u) { + // TODO add previous command that would allow us to simplify this loop. + high = it.getDataIndex(); + highOff = it.getOffsetsIndex(); + highValue = it.value(); + it.next(); } - + return constructSliceReturn(l, low, high, lowOff, highOff, lowValue, highValue); + } + + protected final OffsetSliceInfo constructSliceReturn(int l, int low, int high, int lowOff, int highOff, int lowValue, + int highValue) { if(low == high) return new OffsetSliceInfo(low, high + 1, new OffsetSingle(lowValue - l)); else if(low + 1 == high) @@ -582,6 +599,24 @@ public AOffset appendN(AOffsetsGroup[] g, int s) { } } + public void verify(int size) { + AIterator it = getIterator(); + if(it != null) { + final int last = getOffsetToLast(); + while(it.value() < last) { + it.next(); + if(it.getDataIndex() > size) + throw new DMLCompressionException("Invalid index"); + } + if(it.getDataIndex() > size) + throw new DMLCompressionException("Invalid index"); + } + else { + if(size != 0) + throw new DMLCompressionException("Invalid index"); + } + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetEmpty.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetEmpty.java index 67a6ad55d26..a1238ee928f 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetEmpty.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetEmpty.java @@ -86,7 +86,7 @@ public int getSize() { @Override public OffsetSliceInfo slice(int l, int u) { - return new OffsetSliceInfo(-1, -1, this); + return emptySlice; } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetFactory.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetFactory.java index 319b7ce89f9..8f7bef84331 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetFactory.java @@ -245,7 +245,7 @@ private static AOffset createByte(int[] indexes, int apos, int alen) { final int nv = indexes[i]; final int offsetSize = nv - ov; if(offsetSize <= 0) - throw new DMLCompressionException("invalid offset construction with negative sequences"); + throw new DMLCompressionException("invalid offset construction with negative sequences Byte"); final byte mod = (byte) (offsetSize % mp1); offsets[p++] = mod; ov = nv; @@ -304,7 +304,7 @@ private static AOffset createChar(int[] indexes, int apos, int alen) { final int nv = indexes[i]; final int offsetSize = (nv - ov); if(offsetSize <= 0) - throw new DMLCompressionException("invalid offset construction with negative sequences"); + throw new DMLCompressionException("invalid offset construction with negative sequences Char"); final int mod = offsetSize % mp1; offsets[p++] = (char) (mod); ov = nv; diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetSingle.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetSingle.java index d77fec4a254..2ce25e013ab 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetSingle.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetSingle.java @@ -90,12 +90,10 @@ public static OffsetSingle readFields(DataInput in) throws IOException { @Override public OffsetSliceInfo slice(int l, int u) { - if(l <= off && u > off) return new OffsetSliceInfo(0, 1, new OffsetSingle(off - l)); else - return new OffsetSliceInfo(-1, -1, new OffsetEmpty()); - + return emptySlice; } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetTwo.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetTwo.java index a76fc2112d0..f8b573fde56 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetTwo.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetTwo.java @@ -36,7 +36,7 @@ public OffsetTwo(int first, int last) { this.first = first; this.last = last; if(last <= first) - throw new DMLCompressionException("Invalid offsets last should be greater than first"); + throw new DMLCompressionException("Invalid offsets last should be greater than first: " + first + "->" + last); } @Override @@ -98,7 +98,7 @@ public static OffsetTwo readFields(DataInput in) throws IOException { public OffsetSliceInfo slice(int l, int u) { if(l <= first) { if(u < first) - return new OffsetSliceInfo(-1, -1, new OffsetEmpty()); + return emptySlice; else if(u > last) return new OffsetSliceInfo(0, 2, moveIndex(l)); else @@ -107,7 +107,7 @@ else if(u > last) else if(l <= last && u > last) return new OffsetSliceInfo(1, 2, new OffsetSingle(last - l)); else - return new OffsetSliceInfo(-1, -1, new OffsetEmpty()); + return emptySlice; } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java index 5df202fcbb1..ab078480686 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java @@ -50,8 +50,6 @@ protected List CompressedSizeInfoColGroup(int clen, @Override public CompressedSizeInfoColGroup getColGroupInfo(IColIndex colIndexes, int estimate, int nrUniqueUpperBound) { - - // final IEncode map = throw new UnsupportedOperationException("Unimplemented method 'getColGroupInfo'"); } @@ -69,11 +67,11 @@ protected int worstCaseUpperBound(IColIndex columns) { } else { List groups = CLALibCombineGroups.findGroupsInIndex(columns, cData.getColGroups()); - int nVals = 1; + long nVals = 1; for(AColGroup g : groups) nVals *= g.getNumValues(); - return Math.min(_data.getNumRows(), nVals); + return Math.min(_data.getNumRows(), (int) Math.min(nVals, Integer.MAX_VALUE)); } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressedSample.java b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressedSample.java new file mode 100644 index 00000000000..521b0485fd1 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressedSample.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.estim; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.CompressionSettings; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.lib.CLALibCombineGroups; + +public class ComEstCompressedSample extends ComEstSample { + + public ComEstCompressedSample(CompressedMatrixBlock sample, CompressionSettings cs, CompressedMatrixBlock full, + int k) { + super(sample, cs, full, k); + // cData = sample; + } + + @Override + protected List CompressedSizeInfoColGroup(int clen, int k) { + List ret = new ArrayList<>(); + final int nRow = _data.getNumRows(); + final List fg = ((CompressedMatrixBlock) _data).getColGroups(); + final List sg = ((CompressedMatrixBlock) _sample).getColGroups(); + + for(int i = 0; i < fg.size(); i++) { + CompressedSizeInfoColGroup r = fg.get(i).getCompressionInfo(nRow); + r.setMap(sg.get(i).getCompressionInfo(_sampleSize).getMap()); + ret.add(r); + } + + return ret; + } + + @Override + public CompressedSizeInfoColGroup getColGroupInfo(IColIndex colIndexes, int estimate, int nrUniqueUpperBound) { + throw new UnsupportedOperationException("Unimplemented method 'getColGroupInfo'"); + } + + @Override + public CompressedSizeInfoColGroup getDeltaColGroupInfo(IColIndex colIndexes, int estimate, int nrUniqueUpperBound) { + throw new UnsupportedOperationException("Unimplemented method 'getDeltaColGroupInfo'"); + } + + @Override + protected int worstCaseUpperBound(IColIndex columns) { + CompressedMatrixBlock cData = ((CompressedMatrixBlock) _data); + if(columns.size() == 1) { + int id = columns.get(0); + AColGroup g = cData.getColGroupForColumn(id); + return g.getNumValues(); + } + else { + List groups = CLALibCombineGroups.findGroupsInIndex(columns, cData.getColGroups()); + long nVals = 1; + for(AColGroup g : groups) + nVals *= g.getNumValues(); + + return Math.min(_data.getNumRows(), (int) Math.min(nVals, Integer.MAX_VALUE)); + } + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstFactory.java b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstFactory.java index e497b580ce3..1fe97c7e1ee 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstFactory.java @@ -23,6 +23,7 @@ import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.CompressionSettings; +import org.apache.sysds.runtime.compress.lib.CLALibSlice; import org.apache.sysds.runtime.matrix.data.MatrixBlock; public interface ComEstFactory { @@ -37,13 +38,13 @@ public interface ComEstFactory { * @return A new CompressionSizeEstimator used to extract information of column groups */ public static AComEst createEstimator(MatrixBlock data, CompressionSettings cs, int k) { - if(data instanceof CompressedMatrixBlock) - return createCompressedEstimator((CompressedMatrixBlock) data, cs); - final int nRows = cs.transposed ? data.getNumColumns() : data.getNumRows(); final int nCols = cs.transposed ? data.getNumRows() : data.getNumColumns(); final double sparsity = data.getSparsity(); final int sampleSize = getSampleSize(cs, nRows, nCols, sparsity); + if(data instanceof CompressedMatrixBlock) + return createCompressedEstimator((CompressedMatrixBlock) data, cs, sampleSize, k); + if(data.isEmpty()) return createExactEstimator(data, cs); return createEstimator(data, cs, sampleSize, k, nRows); @@ -75,8 +76,17 @@ private static ComEstExact createExactEstimator(MatrixBlock data, CompressionSet return new ComEstExact(data, cs); } - private static ComEstCompressed createCompressedEstimator(CompressedMatrixBlock data, CompressionSettings cs) { - LOG.debug("Using Compressed Estimator"); + private static AComEst createCompressedEstimator(CompressedMatrixBlock data, CompressionSettings cs, int sampleSize, + int k) { + if(sampleSize < data.getNumRows()) { + LOG.debug("Trying to sample"); + final MatrixBlock slice = CLALibSlice.sliceRowsCompressed(data, 0, sampleSize); + if(slice instanceof CompressedMatrixBlock) { + LOG.debug("Using Sampled Compressed Estimator " + sampleSize); + return new ComEstCompressedSample((CompressedMatrixBlock) slice, cs, data, k); + } + } + LOG.debug("Using Full Compressed Estimator"); return new ComEstCompressed(data, cs); } @@ -114,7 +124,7 @@ private static int getSampleSize(CompressionSettings cs, int nRows, int nCols, d * @param maxSampleSize The maximum sample size * @return The sample size to use. */ - private static int getSampleSize(double samplePower, int nRows, int nCols, double sparsity, int minSampleSize, + public static int getSampleSize(double samplePower, int nRows, int nCols, double sparsity, int minSampleSize, int maxSampleSize) { // Start sample size at the min sample size as the basis sample. diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java index 6306b04e8c1..9557c5c94c8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java @@ -23,6 +23,7 @@ import java.util.Random; import org.apache.sysds.runtime.compress.CompressionSettings; +import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; @@ -42,13 +43,22 @@ public class ComEstSample extends AComEst { /** Sample extracted from the input data */ - private final MatrixBlock _sample; + protected final MatrixBlock _sample; /** Parallelization degree */ - private final int _k; + protected final int _k; /** Sample size */ - private final int _sampleSize; + protected final int _sampleSize; /** Boolean specifying if the sample is in transposed format. */ - private boolean _transposed; + protected boolean _transposed; + + public ComEstSample(MatrixBlock sample, CompressionSettings cs, MatrixBlock full, int k){ + super(full, cs); + _k = k; + _transposed = cs.transposed; + _sample = sample; + _sampleSize = sample.getNumRows(); + + } /** * CompressedSizeEstimatorSample, samples from the input data and estimates the size of the compressed matrix. @@ -95,7 +105,7 @@ public CompressedSizeInfoColGroup getDeltaColGroupInfo(IColIndex colIndexes, int @Override protected int worstCaseUpperBound(IColIndex columns) { if(getNumColumns() == columns.size()) - return Math.min(getNumRows(), (int) _data.getNonZeros()); + return Math.min(getNumRows(), (int) Math.min(_data.getNonZeros(),Integer.MAX_VALUE)); return getNumRows(); } @@ -107,10 +117,15 @@ protected CompressedSizeInfoColGroup combine(IColIndex combinedColumns, Compress } private CompressedSizeInfoColGroup extractInfo(IEncode map, IColIndex colIndexes, int maxDistinct) { - final double spar = _data.getSparsity(); - final EstimationFactors sampleFacts = map.extractFacts(_sampleSize, spar, spar, _cs); - final EstimationFactors em = scaleFactors(sampleFacts, colIndexes, maxDistinct, map.isDense()); - return new CompressedSizeInfoColGroup(colIndexes, em, _cs.validCompressions, map); + try{ + final double spar = _data.getSparsity(); + final EstimationFactors sampleFacts = map.extractFacts(_sampleSize, spar, spar, _cs); + final EstimationFactors em = scaleFactors(sampleFacts, colIndexes, maxDistinct, map.isDense()); + return new CompressedSizeInfoColGroup(colIndexes, em, _cs.validCompressions, map); + } + catch(Exception e){ + throw new DMLCompressionException(map + "", e); + } } private EstimationFactors scaleFactors(EstimationFactors sampleFacts, IColIndex colIndexes, int maxDistinct, @@ -125,6 +140,8 @@ private EstimationFactors scaleFactors(EstimationFactors sampleFacts, IColIndex final long nnz = calculateNNZ(colIndexes, scalingFactor); final int numOffs = calculateOffs(sampleFacts, numRows, scalingFactor, colIndexes, (int) nnz); final int estDistinct = distinctCountScale(sampleFacts, numOffs, numRows, maxDistinct, dense, nCol); + if(estDistinct < sampleFacts.numVals) + throw new DMLCompressionException("Failed estimating distinct: " + estDistinct ); // calculate the largest instance count. final int maxLargestInstanceCount = numRows - estDistinct + 1; @@ -137,7 +154,6 @@ private EstimationFactors scaleFactors(EstimationFactors sampleFacts, IColIndex sampleFacts.overAllSparsity); // For robustness safety add 10 percent more tuple sparsity final double tupleSparsity = Math.min(overallSparsity * 1.3, 1.0); // increase sparsity by 30%. - if(_cs.isRLEAllowed()) { final int scaledRuns = Math.max(estDistinct, calculateRuns(sampleFacts, scalingFactor, numOffs, estDistinct)); @@ -161,6 +177,7 @@ private int distinctCountScale(EstimationFactors sampleFacts, int numOffs, int n final int[] freq = sampleFacts.frequencies; if(freq == null || freq.length == 0) return numOffs; // very aggressive number of distinct + // sampled size is smaller than actual if there was empty rows. // and the more we can reduce this value the more accurate the estimation will become. final int sampledSize = sampleFacts.numOffs; diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java index 1168147b3d2..56d58cd8dc5 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java @@ -206,6 +206,10 @@ public IEncode getMap() { return _map; } + public void setMap(IEncode map){ + _map = map; + } + public boolean containsZeros() { return _facts.numOffs < _facts.numRows; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/EstimationFactors.java b/src/main/java/org/apache/sysds/runtime/compress/estim/EstimationFactors.java index 130d0f77f82..904a228ae71 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/EstimationFactors.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/EstimationFactors.java @@ -87,16 +87,14 @@ public EstimationFactors(int numVals, int numOffs, int largestOff, int[] frequen this.tupleSparsity = tupleSparsity; if(overAllSparsity > 1 || overAllSparsity < 0) - throw new DMLCompressionException("Invalid OverAllSparsity of: " + overAllSparsity); + overAllSparsity = Math.max(0, Math.min(1, overAllSparsity)); else if(tupleSparsity > 1 || tupleSparsity < 0) - throw new DMLCompressionException("Invalid TupleSparsity of:" + tupleSparsity); + tupleSparsity = Math.max(0, Math.min(1, tupleSparsity)); else if(largestOff > numRows) - throw new DMLCompressionException( - "Invalid number of instance of most common element should be lower than number of rows. " + largestOff - + " > numRows: " + numRows); + largestOff = numRows; else if(numVals > numOffs) - throw new DMLCompressionException( - "Num vals cannot be greater than num offs: vals: " + numVals + " offs: " + numOffs); + numVals = numOffs; + if(CompressedMatrixBlock.debug && frequencies != null) { for(int i = 0; i < frequencies.length; i++) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/DenseEncoding.java b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/DenseEncoding.java index 11e23000ba8..5ef600dbe02 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/DenseEncoding.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/DenseEncoding.java @@ -37,16 +37,24 @@ */ public class DenseEncoding extends AEncode { + private static boolean zeroWarn = false; + private final AMapToData map; public DenseEncoding(AMapToData map) { this.map = map; if(CompressedMatrixBlock.debug) { - int[] freq = map.getCounts(); - for(int i = 0; i < freq.length; i++) { - if(freq[i] == 0) - throw new DMLCompressionException("Invalid counts in fact contains 0"); + if(!zeroWarn) { + int[] freq = map.getCounts(); + for(int i = 0; i < freq.length; i++) { + if(freq[i] == 0) { + LOG.warn("Dense encoding contains zero encoding, indicating not all dictionary entries are in use"); + zeroWarn = true; + break; + } + // throw new DMLCompressionException("Invalid counts in fact contains 0"); + } } } } @@ -146,27 +154,40 @@ protected DenseEncoding combineDense(final DenseEncoding other) { final int nVL = lm.getUnique(); final int nVR = rm.getUnique(); final int size = map.size(); - final int maxUnique = nVL * nVR; - - final AMapToData ret = MapToFactory.create(size, maxUnique); - - if(maxUnique > size && maxUnique > 2048) { + int maxUnique = nVL * nVR; + DenseEncoding retE = null; + if(maxUnique < Math.max(nVL, nVR)) {// overflow + maxUnique = size; + final AMapToData ret = MapToFactory.create(size, maxUnique); + final Map m = new HashMap<>(size); + retE = combineDenseWithHashMapLong(lm, rm, size, nVL, ret, m); + } + else if(maxUnique > size && maxUnique > 2048) { + final AMapToData ret = MapToFactory.create(size, maxUnique); // aka there is more maxUnique than rows. final Map m = new HashMap<>(size); - return combineDenseWithHashMap(lm, rm, size, nVL, ret, m); + retE = combineDenseWithHashMap(lm, rm, size, nVL, ret, m); } else { + final AMapToData ret = MapToFactory.create(size, maxUnique); final AMapToData m = MapToFactory.create(maxUnique, maxUnique + 1); - return combineDenseWithMapToData(lm, rm, size, nVL, ret, maxUnique, m); + retE = combineDenseWithMapToData(lm, rm, size, nVL, ret, maxUnique, m); + } + + if(retE.getUnique() < 0) { + throw new DMLCompressionException( + "Failed to combine dense encodings correctly: Number unique values is lower than max input: \n\n" + this + + "\n\n" + other + "\n\n" + retE); } + return retE; } private Pair> combineDenseNoResize(final DenseEncoding other) { - if(map == other.map) { + if(map.equals(other.map)) { LOG.warn("Constructing perfect mapping, this could be optimized to skip hashmap"); final Map m = new HashMap<>(map.size()); for(int i = 0; i < map.getUnique(); i++) - m.put(i * i, i); + m.put(i * (map.getUnique() + 1) , i); return new ImmutablePair<>(this, m); // same object } @@ -176,16 +197,12 @@ private Pair> combineDenseNoResize(final DenseEnc final int nVL = lm.getUnique(); final int nVR = rm.getUnique(); final int size = map.size(); - final int maxUnique = nVL * nVR; + final int maxUnique = (int) Math.min((long) nVL * nVR, (long) size); final AMapToData ret = MapToFactory.create(size, maxUnique); - final Map m = new HashMap<>(Math.min(size, maxUnique)); + final Map m = new HashMap<>(maxUnique); return new ImmutablePair<>(combineDenseWithHashMap(lm, rm, size, nVL, ret, m), m); - - // there can be less unique. - - // return new DenseEncoding(ret); } private Pair> combineSparseNoResize(final SparseEncoding other) { @@ -193,6 +210,14 @@ private Pair> combineSparseNoResize(final SparseE return combineSparseHashMap(a); } + protected final DenseEncoding combineDenseWithHashMapLong(final AMapToData lm, final AMapToData rm, final int size, + final int nVL, final AMapToData ret, Map m) { + + for(int r = 0; r < size; r++) + addValHashMap((long) lm.getIndex(r) + (long) rm.getIndex(r) * (long) nVL, r, m, ret); + return new DenseEncoding(MapToFactory.resize(ret, m.size())); + } + protected final DenseEncoding combineDenseWithHashMap(final AMapToData lm, final AMapToData rm, final int size, final int nVL, final AMapToData ret, Map m) { @@ -218,8 +243,16 @@ protected static int addValMapToData(final int nv, final int r, final AMapToData return newId; } - protected static void addValHashMap(final int nv, final int r, final Map map, - final AMapToData d) { + protected static void addValHashMap(final int nv, final int r, final Map map, final AMapToData d) { + final int v = map.size(); + final Integer mv = map.putIfAbsent(nv, v); + if(mv == null) + d.set(r, v); + else + d.set(r, mv); + } + + protected static void addValHashMap(final long nv, final int r, final Map map, final AMapToData d) { final int v = map.size(); final Integer mv = map.putIfAbsent(nv, v); if(mv == null) diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java index d8ab0f0f7c3..257ddf6f3c2 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java @@ -229,8 +229,16 @@ else if(alen - apos > nCol / 4) { // return a dense encoding // Iteration 3 of non zero indexes, make a Offset Encoding to know what cells are zero and not. // not done yet - final AOffset o = OffsetFactory.createOffset(aix, apos, alen); - return new SparseEncoding(d, o, m.getNumColumns()); + try{ + + final AOffset o = OffsetFactory.createOffset(aix, apos, alen); + return new SparseEncoding(d, o, m.getNumColumns()); + } + catch(Exception e){ + String mes = Arrays.toString(Arrays.copyOfRange(aix, apos, alen)) + "\n" + apos + " " + alen; + mes += Arrays.toString(Arrays.copyOfRange(avals, apos, alen)); + throw new DMLRuntimeException(mes, e); + } } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/io/WriterCompressed.java b/src/main/java/org/apache/sysds/runtime/compress/io/WriterCompressed.java index 9c934592089..864d32b4f3e 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/io/WriterCompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/io/WriterCompressed.java @@ -57,6 +57,7 @@ import org.apache.sysds.runtime.meta.DataCharacteristics; import org.apache.sysds.runtime.meta.MatrixCharacteristics; import org.apache.sysds.runtime.util.CommonThreadPool; +import org.apache.sysds.runtime.util.HDFSTool; public final class WriterCompressed extends MatrixWriter { @@ -146,7 +147,7 @@ private void write(MatrixBlock src, final String fname, final int blen) throws I } fs = IOUtilFunctions.getFileSystem(new Path(fname), job); - + int k = OptimizerUtils.getParallelBinaryWriteParallelism(); k = Math.min(k, (int)(src.getInMemorySize() / InfrastructureAnalyzer.getBlockSize(fs))); @@ -213,8 +214,6 @@ private void writeMultiBlockCompressedSingleThread(MatrixBlock mb, final int rle throws IOException { try { final CompressedMatrixBlock cmb = (CompressedMatrixBlock) mb; - - setupWrite(); final Path path = new Path(fname); Writer w = generateWriter(job, path, fs); for(int bc = 0; bc * blen < clen; bc++) {// column blocks @@ -244,7 +243,6 @@ private void writeMultiBlockCompressedSingleThread(MatrixBlock mb, final int rle private void writeMultiBlockCompressedParallel(MatrixBlock b, final int rlen, final int clen, final int blen, int k) throws IOException { - setupWrite(); final ExecutorService pool = CommonThreadPool.get(k); final ArrayList> tasks = new ArrayList<>(); try { @@ -265,7 +263,8 @@ private void writeMultiBlockCompressedParallel(MatrixBlock b, final int rlen, fi final int colBlocks = (int) Math.ceil((double) clen / blen ); final int nBlocks = (int) Math.ceil((double) rlen / blen); final int blocksPerThread = Math.max(1, nBlocks * colBlocks / k ); - + HDFSTool.deleteFileIfExistOnHDFS(new Path(fname + ".dict"), job); + int i = 0; for(int bc = 0; bc * blen < clen; bc++) {// column blocks final int sC = bc * blen; @@ -307,13 +306,6 @@ private void writeMultiBlockCompressedParallel(MatrixBlock b, final int rlen, fi } } - private void setupWrite() throws IOException { - // final Path path = new Path(fname); - // final JobConf job = ConfigurationManager.getCachedJobConf(); - // HDFSTool.deleteFileIfExistOnHDFS(path, job); - // HDFSTool.createDirIfNotExistOnHDFS(path, DMLConfig.DEFAULT_SHARED_DIR_PERMISSION); - } - private Path getPath(int id) { return new Path(fname, IOUtilFunctions.getPartFileName(id)); } @@ -397,6 +389,7 @@ protected DictWriteTask(String fname, List dicts, int id) { public Object call() throws Exception { Path p = new Path(fname + ".dict", IOUtilFunctions.getPartFileName(id)); + HDFSTool.deleteFileIfExistOnHDFS(p, job); try(Writer w = SequenceFile.createWriter(job, Writer.file(p), // Writer.bufferSize(4096), // Writer.keyClass(DictWritable.K.class), // diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java index ede9ca46aad..a6fc0d51301 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java @@ -29,6 +29,7 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysds.api.DMLScript; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory; @@ -39,7 +40,9 @@ import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary; import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; +import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.data.DenseBlockFP64; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.Divide; import org.apache.sysds.runtime.functionobjects.Minus; @@ -59,6 +62,7 @@ import org.apache.sysds.runtime.matrix.operators.RightScalarOperator; import org.apache.sysds.runtime.matrix.operators.ScalarOperator; import org.apache.sysds.runtime.util.CommonThreadPool; +import org.apache.sysds.utils.DMLCompressionStatistics; public final class CLALibBinaryCellOp { private static final Log LOG = LogFactory.getLog(CLALibBinaryCellOp.class.getName()); @@ -191,7 +195,6 @@ private static MatrixBlock rowBinCellOp(CompressedMatrixBlock m1, MatrixBlock m2 overlappingBinaryCellOp(m1, m2, cRet, op, left); else nonOverlappingBinaryCellOp(m1, m2, cRet, op, left); - cRet.recomputeNonZeros(); return cRet; } @@ -234,7 +237,7 @@ private static CompressedMatrixBlock binaryMVRow(CompressedMatrixBlock m1, doubl binaryMVRowMultiThread(oldColGroups, v, op, left, newColGroups, isRowSafe, k); ret.allocateColGroupList(newColGroups); - ret.setNonZeros(m1.getNumColumns() * m1.getNumRows()); + ret.examSparsity(op.getNumThreads()); return ret; } @@ -350,62 +353,123 @@ private static MatrixBlock binaryMVCol(CompressedMatrixBlock m1, MatrixBlock m2, final int nRows = m1.getNumRows(); m1 = morph(m1); - MatrixBlock ret = new MatrixBlock(nRows, nCols, false, -1).allocateBlock(); - final int k = op.getNumThreads(); long nnz = 0; - if(k <= 1) - nnz = binaryMVColSingleThread(m1, m2, op, left, ret); - else - nnz = binaryMVColMultiThread(m1, m2, op, left, ret); + boolean shouldBeSparseOut = false; + if(op.fn.isBinary()) { + // maybe it is good if this is a sparse output. + // evaluate if it is good + double est = evaluateSparsityMVCol(m1, m2, op, left); + shouldBeSparseOut = MatrixBlock.evalSparseFormatInMemory(nRows, nCols, (long) (est * nRows * nCols)); + + } + MatrixBlock ret = new MatrixBlock(nRows, nCols, shouldBeSparseOut, -1).allocateBlock(); + + if(shouldBeSparseOut) { + if(k <= 1) + nnz = binaryMVColSingleThreadSparse(m1, m2, op, left, ret); + else + nnz = binaryMVColMultiThreadSparse(m1, m2, op, left, ret); + } + else { + if(k <= 1) + nnz = binaryMVColSingleThreadDense(m1, m2, op, left, ret); + else + nnz = binaryMVColMultiThreadDense(m1, m2, op, left, ret); + } + + // LOG.error(ret); if(op.fn instanceof ValueComparisonFunction) { - if(nnz == (long) nRows * nCols) + if(nnz == (long) nRows * nCols)// all was 1 return CompressedMatrixBlockFactory.createConstant(nRows, nCols, 1.0); - - else if(nnz == 0) + else if(nnz == 0) // all was 0 return CompressedMatrixBlockFactory.createConstant(nRows, nCols, 0.0); } + ret.setNonZeros(nnz); - ret.examSparsity(); + ret.examSparsity(op.getNumThreads()); + // throw new NotImplementedException(); return ret; } - private static long binaryMVColSingleThread(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, + private static long binaryMVColSingleThreadDense(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, boolean left, MatrixBlock ret) { final int nRows = m1.getNumRows(); long nnz = 0; if(left) - nnz += new BinaryMVColLeftTask(m1, m2, ret, 0, nRows, op).call(); + nnz += new BinaryMVColLeftTaskDense(m1, m2, ret, 0, nRows, op).call(); else - nnz += new BinaryMVColTask(m1, m2, ret, 0, nRows, op).call(); + nnz += new BinaryMVColTaskDense(m1, m2, ret, 0, nRows, op).call(); return nnz; } - private static long binaryMVColMultiThread(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, boolean left, - MatrixBlock ret) { + private static long binaryMVColSingleThreadSparse(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, + boolean left, MatrixBlock ret) { + final int nRows = m1.getNumRows(); + long nnz = 0; + if(left) + throw new NotImplementedException(); + // nnz += new BinaryMVColLeftTaskSparse(m1, m2, ret, 0, nRows, op).call(); + else + nnz += new BinaryMVColTaskSparse(m1, m2, ret, 0, nRows, op).call(); + return nnz; + } + + private static long binaryMVColMultiThreadDense(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, + boolean left, MatrixBlock ret) { final int nRows = m1.getNumRows(); final int k = op.getNumThreads(); final int blkz = ret.getNumRows() / k; long nnz = 0; final ExecutorService pool = CommonThreadPool.get(op.getNumThreads()); - final ArrayList> tasks = new ArrayList<>(); + final ArrayList> tasks = new ArrayList<>(); try { for(int i = 0; i < nRows; i += blkz) { if(left) - tasks.add(new BinaryMVColLeftTask(m1, m2, ret, i, Math.min(nRows, i + blkz), op)); + tasks.add(new BinaryMVColLeftTaskDense(m1, m2, ret, i, Math.min(nRows, i + blkz), op)); else - tasks.add(new BinaryMVColTask(m1, m2, ret, i, Math.min(nRows, i + blkz), op)); + tasks.add(new BinaryMVColTaskDense(m1, m2, ret, i, Math.min(nRows, i + blkz), op)); } - for(Future f : pool.invokeAll(tasks)) + for(Future f : pool.invokeAll(tasks)) nnz += f.get(); + } + catch(InterruptedException | ExecutionException e) { + throw new DMLRuntimeException(e); + } + finally { pool.shutdown(); } + return nnz; + } + + private static long binaryMVColMultiThreadSparse(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, + boolean left, MatrixBlock ret) { + final int nRows = m1.getNumRows(); + final int k = op.getNumThreads(); + final int blkz = Math.max(nRows / k, 64); + long nnz = 0; + final ExecutorService pool = CommonThreadPool.get(op.getNumThreads()); + final ArrayList> tasks = new ArrayList<>(); + try { + for(int i = 0; i < nRows; i += blkz) { + if(left) + throw new NotImplementedException(); + // tasks.add(new BinaryMVColLeftTaskDense(m1, m2, ret, i, Math.min(nRows, i + blkz), op)); + else + tasks.add(new BinaryMVColTaskSparse(m1, m2, ret, i, Math.min(nRows, i + blkz), op)); + } + for(Future f : pool.invokeAll(tasks)) + nnz += f.get(); + } catch(InterruptedException | ExecutionException e) { throw new DMLRuntimeException(e); } + finally { + pool.shutdown(); + } return nnz; } @@ -420,7 +484,7 @@ private static MatrixBlock binaryMM(CompressedMatrixBlock m1, MatrixBlock m2, Bi long nnz = binaryMMMultiThread(m1, m2, op, left, ret); ret.setNonZeros(nnz); - ret.examSparsity(); + ret.examSparsity(op.getNumThreads()); return ret; } @@ -462,7 +526,7 @@ private static CompressedMatrixBlock morph(CompressedMatrixBlock m) { return m; } - private static class BinaryMVColTask implements Callable { + private static class BinaryMVColTaskDense implements Callable { private final int _rl; private final int _ru; private final CompressedMatrixBlock _m1; @@ -470,7 +534,7 @@ private static class BinaryMVColTask implements Callable { private final MatrixBlock _ret; private final BinaryOperator _op; - protected BinaryMVColTask(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru, + protected BinaryMVColTaskDense(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru, BinaryOperator op) { _m1 = m1; _m2 = m2; @@ -481,7 +545,7 @@ protected BinaryMVColTask(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock } @Override - public Integer call() { + public Long call() { final int _blklen = Math.max(16384 / _ret.getNumColumns(), 64); final List groups = _m1.getColGroups(); @@ -490,18 +554,51 @@ public Integer call() { for(int r = _rl; r < _ru; r += _blklen) processBlock(r, Math.min(r + _blklen, _ru), groups, its); - return _ret.getNumColumns() * _ret.getNumRows(); + return _ret.recomputeNonZeros(_rl, _ru - 1); } private final void processBlock(final int rl, final int ru, final List groups, final AIterator[] its) { // unsafe decompress, since we count nonzeros afterwards. final DenseBlock db = _ret.getDenseBlock(); decompressToSubBlock(rl, ru, db, groups, its); + if(db.isContiguous()) { - if(_m2.isInSparseFormat()) - throw new NotImplementedException("Not Implemented sparse Format execution for MM."); - else - processDense(rl, ru); + if(_m2.isEmpty()) + processEmpty(rl, ru); + else if(_m2.isInSparseFormat()) + throw new NotImplementedException("Not implemented sparse format execution for mm."); + else + processDense(rl, ru); + } + else { + if(_m2.isEmpty()) { + processGenericEmpty(rl, ru); + } + else if(_m2.isInSparseFormat()) + throw new NotImplementedException("Not implemented sparse format execution for mm."); + else + processGenericDense(rl, ru); + } + } + + private final void processEmpty(final int rl, final int ru) { + final int nCol = _m1.getNumColumns(); + final double[] _retDense = _ret.getDenseBlockValues(); + for(int i = rl * nCol; i < ru * nCol; i++) { + _retDense[i] = _op.fn.execute(_retDense[i], 0); + } + } + + private final void processGenericEmpty(final int rl, final int ru) { + final int nCol = _m1.getNumColumns(); + final DenseBlock db = _ret.getDenseBlock(); + for(int r = rl; r < ru; r++) { + final double[] row = db.values(r); + final int pos = db.pos(r); + for(int c = pos; c < pos + nCol; c++) { + row[c] = _op.fn.execute(row[c], 0); + } + } } private final void processDense(final int rl, final int ru) { @@ -516,6 +613,100 @@ private final void processDense(final int rl, final int ru) { } } } + + private final void processGenericDense(final int rl, final int ru) { + final DenseBlock rd = _ret.getDenseBlock(); + final DenseBlock m2d = _m2.getDenseBlock(); + + for(int row = rl; row < ru; row++) { + final double[] _retDense = rd.values(row); + final double[] _m2Dense = m2d.values(row); + final int posR = rd.pos(row); + final int posM = m2d.pos(row); + final double vr = _m2Dense[posM]; + for(int col = 0; col < _m1.getNumColumns(); col++) { + _retDense[posR + col] = _op.fn.execute(_retDense[posR + col], vr); + } + } + } + + } + + private static class BinaryMVColTaskSparse implements Callable { + private final int _rl; + private final int _ru; + private final CompressedMatrixBlock _m1; + private final MatrixBlock _m2; + private final MatrixBlock _ret; + private final BinaryOperator _op; + + private MatrixBlock tmp; + + protected BinaryMVColTaskSparse(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru, + BinaryOperator op) { + _m1 = m1; + _m2 = m2; + _ret = ret; + _op = op; + _rl = rl; + _ru = ru; + } + + @Override + public Long call() { + final int _blklen = Math.max(16384 / _ret.getNumColumns(), 64); + final List groups = _m1.getColGroups(); + final AIterator[] its = getIterators(groups, _rl); + tmp = new MatrixBlock(_blklen, _m1.getNumColumns(), false); + tmp.allocateBlock(); + + for(int r = _rl; r < _ru; r += _blklen) + processBlock(r, Math.min(r + _blklen, _ru), groups, its); + + return _ret.recomputeNonZeros(_rl, _ru - 1); + } + + private final void processBlock(final int rl, final int ru, final List groups, final AIterator[] its) { + decompressToTmpBlock(rl, ru, tmp.getDenseBlock(), groups, its); + + if(_m2.isEmpty()) + processEmpty(rl, ru); + else if(_m2.isInSparseFormat()) + throw new NotImplementedException("Not implemented sparse format execution for mm."); + else + processDense(rl, ru); + tmp.reset(); + } + + private final void processEmpty(final int rl, final int ru) { + final int nCol = _m1.getNumColumns(); + final SparseBlock sb = _ret.getSparseBlock(); + final double[] _tmpDense = tmp.getDenseBlockValues(); + for(int i = rl; i < ru; i++) { + final int tmpOff = (i - rl) * nCol; + for(int j = 0; j < nCol; j++) { + double v = _op.fn.execute(_tmpDense[tmpOff + j], 0); + if(v != 0) + sb.append(i, j, v); + } + } + } + + private final void processDense(final int rl, final int ru) { + final int nCol = _m1.getNumColumns(); + final SparseBlock sb = _ret.getSparseBlock(); + final double[] _tmpDense = tmp.getDenseBlockValues(); + final double[] _m2Dense = _m2.getDenseBlockValues(); + for(int row = rl; row < ru; row++) { + final double vr = _m2Dense[row]; + final int tmpOff = (row - rl) * nCol; + for(int col = 0; col < nCol; col++) { + double v = _op.fn.execute(_tmpDense[tmpOff + col], vr); + if(v != 0) + sb.append(row, col, v); + } + } + } } private static class BinaryMMTask implements Callable { @@ -555,7 +746,6 @@ public Long call() { } private final void processBlock(final int rl, final int ru, final List groups, final AIterator[] its) { - // unsafe decompress, since we count nonzeros afterwards. final DenseBlock db = _ret.getDenseBlock(); decompressToSubBlock(rl, ru, db, groups, its); @@ -682,7 +872,7 @@ private final void processRightEmpty(final int rl, final int ru) { } } - private static class BinaryMVColLeftTask implements Callable { + private static class BinaryMVColLeftTaskDense implements Callable { private final int _rl; private final int _ru; private final CompressedMatrixBlock _m1; @@ -690,7 +880,7 @@ private static class BinaryMVColLeftTask implements Callable { private final MatrixBlock _ret; private final BinaryOperator _op; - protected BinaryMVColLeftTask(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru, + protected BinaryMVColLeftTaskDense(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru, BinaryOperator op) { _m1 = m1; _m2 = m2; @@ -701,8 +891,7 @@ protected BinaryMVColLeftTask(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBl } @Override - public Integer call() { - // unsafe decompress, since we count nonzeros afterwards. + public Long call() { for(AColGroup g : _m1.getColGroups()) g.decompressToDenseBlock(_ret.getDenseBlock(), _rl, _ru); @@ -721,7 +910,7 @@ public Integer call() { } } - return _ret.getNumColumns() * _ret.getNumRows(); + return _ret.recomputeNonZeros(_rl, _ru - 1); } } } @@ -764,6 +953,7 @@ public AColGroup call() { protected static void decompressToSubBlock(final int rl, final int ru, final DenseBlock db, final List groups, final AIterator[] its) { + Timing time = new Timing(true); for(int i = 0; i < groups.size(); i++) { final AColGroup g = groups.get(i); if(g.getCompType() == CompressionType.SDC) @@ -771,6 +961,33 @@ protected static void decompressToSubBlock(final int rl, final int ru, final Den else g.decompressToDenseBlock(db, rl, ru, 0, 0); } + + if(DMLScript.STATISTICS) { + final double t = time.stop(); + DMLCompressionStatistics.addDecompressToBlockTime(t, 1); + if(LOG.isTraceEnabled()) + LOG.trace("decompressed block w/ k=" + 1 + " in " + t + "ms."); + } + } + + protected static void decompressToTmpBlock(final int rl, final int ru, final DenseBlock db, + final List groups, final AIterator[] its) { + Timing time = new Timing(true); + // LOG.error(rl + " " + ru); + for(int i = 0; i < groups.size(); i++) { + final AColGroup g = groups.get(i); + if(g.getCompType() == CompressionType.SDC) + ((ASDCZero) g).decompressToDenseBlock(db, rl, ru, -rl, 0, its[i]); + else + g.decompressToDenseBlock(db, rl, ru, -rl, 0); + } + + if(DMLScript.STATISTICS) { + final double t = time.stop(); + DMLCompressionStatistics.addDecompressToBlockTime(t, 1); + if(LOG.isTraceEnabled()) + LOG.trace("decompressed block w/ k=" + 1 + " in " + t + "ms."); + } } protected static AIterator[] getIterators(final List groups, final int rl) { @@ -783,4 +1000,67 @@ protected static AIterator[] getIterators(final List groups, final in } return its; } + + private static double evaluateSparsityMVCol(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, + boolean left) { + final List groups = m1.getColGroups(); + final int nCol = m1.getNumColumns(); + final int nRow = m1.getNumRows(); + final int sampleRow = Math.min(nRow, 5); + final int sampleCol = nCol; + double[] dv = new double[sampleRow * sampleCol]; + + double[] m2v = m2.getDenseBlockValues(); + + DenseBlock db = new DenseBlockFP64(new int[] {sampleRow, sampleCol}, dv); + + for(int i = 0; i < groups.size(); i++) { + groups.get(i).decompressToDenseBlock(db, 0, sampleRow); + } + + int nnz = 0; + + if(m2v == null) { // right side is empty. + if(left) { + for(int r = 0; r < sampleRow; r++) { + int off = r * sampleCol; + for(int c = 0; c < sampleCol; c++) { + nnz += op.fn.execute(0, dv[off + c]) != 0 ? 1 : 0; + } + } + } + else { + for(int r = 0; r < sampleRow; r++) { + int off = r * sampleCol; + for(int c = 0; c < sampleCol; c++) { + nnz += op.fn.execute(dv[off + c], 0) != 0 ? 1 : 0; + } + } + } + } + else { + if(left) { + + for(int r = 0; r < sampleRow; r++) { + double m = m2v[r]; + int off = r * sampleCol; + for(int c = 0; c < sampleCol; c++) { + nnz += op.fn.execute(m, dv[off + c]) != 0 ? 1 : 0; + } + } + } + else { + for(int r = 0; r < sampleRow; r++) { + double m = m2v[r]; + int off = r * sampleCol; + for(int c = 0; c < sampleCol; c++) { + nnz += op.fn.execute(dv[off + c], m) != 0 ? 1 : 0; + } + } + } + } + + return (double) nnz / (sampleRow * sampleCol); + + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibAppend.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCBind.java similarity index 75% rename from src/main/java/org/apache/sysds/runtime/compress/lib/CLALibAppend.java rename to src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCBind.java index cedf98494c6..a8d37c93989 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibAppend.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCBind.java @@ -22,26 +22,44 @@ import java.util.ArrayList; import java.util.List; +import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory; import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed; import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.matrix.data.MatrixBlock; -public final class CLALibAppend { +public final class CLALibCBind { - private CLALibAppend(){ + private CLALibCBind() { // private constructor. } - private static final Log LOG = LogFactory.getLog(CLALibAppend.class.getName()); + private static final Log LOG = LogFactory.getLog(CLALibCBind.class.getName()); - public static MatrixBlock append(MatrixBlock left, MatrixBlock right, int k) { + public static MatrixBlock cbind(MatrixBlock left, MatrixBlock[] right, int k) { + if(right.length == 1) { + return cbind(left, right[0], k); + } + else { + boolean allCompressed = true; + for(int i = 0; i < right.length && allCompressed; i++) + allCompressed = right[i] instanceof CompressedMatrixBlock; + if(allCompressed) { + return cbindAllCompressed((CompressedMatrixBlock) left, right, k); + } + + } + throw new NotImplementedException(); + } + + public static MatrixBlock cbind(MatrixBlock left, MatrixBlock right, int k) { final int m = left.getNumRows(); final int n = left.getNumColumns() + right.getNumColumns(); @@ -77,6 +95,48 @@ else if(left instanceof CompressedMatrixBlock) return append((CompressedMatrixBlock) left, (CompressedMatrixBlock) right, m, n); } + private static CompressedMatrixBlock cbindAllCompressed(CompressedMatrixBlock left, MatrixBlock[] right, int k) { + boolean allSameColumnGroupIndex = true; + List gl = left.getColGroups(); + final int nCol = left.getNumColumns(); + for(int i = 0; i < right.length && allSameColumnGroupIndex; i++) { + allSameColumnGroupIndex = nCol == right[i].getNumColumns(); + List gr = ((CompressedMatrixBlock) right[i]).getColGroups(); + for(int j = 0; j < gl.size() && allSameColumnGroupIndex; j++) { + allSameColumnGroupIndex = gl.get(i).sameIndexStructure(gr.get(i)); + } + } + + if(allSameColumnGroupIndex) + return cbindAllCompressedAligned(left, right, k); + + throw new NotImplementedException(); + } + + private static CompressedMatrixBlock cbindAllCompressedAligned(CompressedMatrixBlock left, MatrixBlock[] right, + int k) { + + List gl = left.getColGroups(); + List> gr = new ArrayList<>(right.length); + + List rg = new ArrayList<>(gl.size()); + for(int i = 0; i < right.length; i++) { + gr.add(((CompressedMatrixBlock) right[i]).getColGroups()); + } + final int nCol = left.getNumColumns(); + + for(int j = 0; j < gl.size(); j++) { + rg.add(combine((AColGroupCompressed) gl.get(j), j, nCol, gr)); + } + + return new CompressedMatrixBlock(left.getNumRows(), nCol * right.length + nCol, -1, left.isOverlapping(), rg); + + } + + private static AColGroup combine(AColGroupCompressed cg, int index, int nCol, List> right) { + return cg.combineWithSameIndex(index, nCol, right); + } + private static MatrixBlock appendLeftUncompressed(MatrixBlock left, CompressedMatrixBlock right, final int m, final int n) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java index 32ec9c0f327..f581e61f705 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java @@ -22,7 +22,9 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.lang3.tuple.Pair; @@ -39,10 +41,11 @@ import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; import org.apache.sysds.runtime.compress.colgroup.IContainDefaultTuple; import org.apache.sysds.runtime.compress.colgroup.IFrameOfReferenceGroup; -import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.colgroup.indexes.IIterate; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; import org.apache.sysds.runtime.compress.estim.encoding.ConstEncoding; @@ -51,8 +54,8 @@ import org.apache.sysds.runtime.compress.estim.encoding.IEncode; import org.apache.sysds.runtime.compress.estim.encoding.SparseEncoding; import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.util.CommonThreadPool; /** * Library functions to combine column groups inside a compressed matrix. @@ -64,40 +67,68 @@ private CLALibCombineGroups() { // private constructor } - public static List combine(CompressedMatrixBlock cmb, int k) { - ExecutorService pool = null; - try { - pool = (k > 1) ? CommonThreadPool.get(k) : null; - return combine(cmb, null, pool); - } - catch(Exception e) { - throw new DMLCompressionException("Compression Failed", e); - } - finally { - if(pool != null) - pool.shutdown(); - } + public static List combine(CompressedMatrixBlock cmb, CompressedSizeInfo csi, ExecutorService pool) { + if(pool == null) + return combineSingleThread(cmb, csi); + else + return combineParallel(cmb, csi, pool); } - public static List combine(CompressedMatrixBlock cmb, CompressedSizeInfo csi, ExecutorService pool) { + private static List combineSingleThread(CompressedMatrixBlock cmb, CompressedSizeInfo csi) { List input = cmb.getColGroups(); - + final int nRow = cmb.getNumRows(); final boolean filterFor = CLALibUtils.shouldFilterFOR(input); double[] c = filterFor ? new double[cmb.getNumColumns()] : null; if(filterFor) input = CLALibUtils.filterFOR(input, c); - List> combinations = new ArrayList<>(); - for(CompressedSizeInfoColGroup gi : csi.getInfo()) - combinations.add(findGroupsInIndex(gi.getColumns(), input)); + final List csiI = csi.getInfo(); + final List ret = new ArrayList<>(csiI.size()); + for(CompressedSizeInfoColGroup gi : csiI) { + List groupsToCombine = findGroupsInIndex(gi.getColumns(), input); + AColGroup combined = combineN(groupsToCombine); + combined = combined.morph(gi.getBestCompressionType(), nRow); + combined = filterFor ? combined.addVector(c) : combined; + ret.add(combined); + } - List ret = new ArrayList<>(); + return ret; + } + + private static List combineParallel(CompressedMatrixBlock cmb, CompressedSizeInfo csi, + ExecutorService pool) { + List input = cmb.getColGroups(); + final int nRow = cmb.getNumRows(); + final boolean filterFor = CLALibUtils.shouldFilterFOR(input); + double[] c = filterFor ? new double[cmb.getNumColumns()] : null; if(filterFor) - for(List combine : combinations) - ret.add(combineN(combine).addVector(c)); - else - for(List combine : combinations) - ret.add(combineN(combine)); + input = CLALibUtils.filterFOR(input, c); + + final List filteredGroups = input; + final List csiI = csi.getInfo(); + final List> tasks = new ArrayList<>(); + for(CompressedSizeInfoColGroup gi : csiI) { + Future fcg = pool.submit(() -> { + List groupsToCombine = findGroupsInIndex(gi.getColumns(), filteredGroups); + AColGroup combined = combineN(groupsToCombine); + combined = combined.morph(gi.getBestCompressionType(), nRow); + combined = filterFor ? combined.addVector(c) : combined; + return combined; + }); + + tasks.add(fcg); + + } + final List ret = new ArrayList<>(csiI.size()); + try { + for(Future fcg : tasks) { + ret.add(fcg.get()); + } + } + catch(InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + return ret; } @@ -111,13 +142,10 @@ public static List findGroupsInIndex(IColIndex idx, List g } public static AColGroup combineN(List groups) { - AColGroup base = groups.get(0); // Inefficient combine N but base line - for(int i = 1; i < groups.size(); i++) { + for(int i = 1; i < groups.size(); i++) base = combine(base, groups.get(i)); - } - return base; } @@ -147,21 +175,37 @@ public static AColGroup combine(AColGroup a, AColGroup b) { if(b instanceof ColGroupUncompressed) b = b.recompress(); - if(a instanceof AColGroupCompressed && b instanceof AColGroupCompressed) + long maxEst = (long) a.getNumValues() * b.getNumValues(); + + if(a instanceof AColGroupCompressed && b instanceof AColGroupCompressed // + && (long) Integer.MAX_VALUE > maxEst) return combineCompressed(combinedColumns, (AColGroupCompressed) a, (AColGroupCompressed) b); - else if(a instanceof ColGroupUncompressed || b instanceof ColGroupUncompressed) - // either side is uncompressed + else return combineUC(combinedColumns, a, b); - - throw new NotImplementedException( - "Not implemented combine for " + a.getClass().getSimpleName() + " - " + b.getClass().getSimpleName()); + } + catch(NotImplementedException e) { + throw e; } catch(Exception e) { StringBuilder sb = new StringBuilder(); sb.append("Failed to combine:\n\n"); - sb.append(a); + String as = a.toString(); + if(as.length() < 10000) + sb.append(as); + else { + sb.append(as.substring(0, 10000)); + sb.append("..."); + } sb.append("\n\n"); - sb.append(b); + + String bs = b.toString(); + if(as.length() < 10000) + sb.append(bs); + else { + sb.append(bs.substring(0, 10000)); + sb.append("..."); + } + throw new DMLCompressionException(sb.toString(), e); } @@ -169,32 +213,37 @@ else if(a instanceof ColGroupUncompressed || b instanceof ColGroupUncompressed) private static AColGroup combineCompressed(IColIndex combinedColumns, AColGroupCompressed ac, AColGroupCompressed bc) { - IEncode ae = ac.getEncoding(); - IEncode be = bc.getEncoding(); + final IEncode ae = ac.getEncoding(); + final IEncode be = bc.getEncoding(); + + // if(ae.equals(be)) + // throw new NotImplementedException("Equivalent encodings combine"); + if(ae instanceof SparseEncoding && !(be instanceof SparseEncoding)) { // the order must be sparse second unless both sparse. return combineCompressed(combinedColumns, bc, ac); } + // add if encodings are equal make shortcut. + final Pair> cec = ae.combineWithMap(be); + final IEncode ce = cec.getLeft(); + final Map filter = cec.getRight(); - Pair> cec = ae.combineWithMap(be); - IEncode ce = cec.getLeft(); - Map filter = cec.getRight(); - if(ce instanceof DenseEncoding) { - DenseEncoding ced = (DenseEncoding) (ce); - IDictionary cd = DictionaryFactory.combineDictionaries(ac, bc, filter); - return ColGroupDDC.create(combinedColumns, cd, ced.getMap(), null); - } - else if(ce instanceof EmptyEncoding) { + if(ce instanceof EmptyEncoding) { return new ColGroupEmpty(combinedColumns); } else if(ce instanceof ConstEncoding) { IDictionary cd = DictionaryFactory.combineDictionaries(ac, bc, filter); return ColGroupConst.create(combinedColumns, cd); } + else if(ce instanceof DenseEncoding) { + DenseEncoding ced = (DenseEncoding) (ce); + IDictionary cd = DictionaryFactory.combineDictionaries(ac, bc, filter); + return ColGroupDDC.create(combinedColumns, cd, ced.getMap(), null); + } else if(ce instanceof SparseEncoding) { SparseEncoding sed = (SparseEncoding) ce; - IDictionary cd = DictionaryFactory.combineDictionariesSparse(ac, bc); + IDictionary cd = DictionaryFactory.combineDictionariesSparse(ac, bc, filter); double[] defaultTuple = constructDefaultTuple(ac, bc); return ColGroupSDC.create(combinedColumns, sed.getNumRows(), cd, defaultTuple, sed.getOffsets(), sed.getMap(), null); @@ -205,11 +254,53 @@ else if(ce instanceof SparseEncoding) { } - private static AColGroup combineUC(IColIndex combinedColumns, AColGroup a, AColGroup b) { - int nRow = a instanceof ColGroupUncompressed ? // - ((ColGroupUncompressed) a).getData().getNumRows() : // - ((ColGroupUncompressed) b).getData().getNumRows(); - // step 1 decompress both into target uncompressed MatrixBlock; + private static AColGroup combineUC(IColIndex combineColumns, AColGroup a, AColGroup b) { + int nRow = 0; + if(a instanceof ColGroupUncompressed) { + nRow = ((ColGroupUncompressed) a).getData().getNumRows(); + } + else if(b instanceof ColGroupUncompressed) { + nRow = ((ColGroupUncompressed) b).getData().getNumRows(); + } + else if(a instanceof ColGroupDDC) { + nRow = ((ColGroupDDC) a).getMapToData().size(); + } + else if(b instanceof ColGroupDDC) { + nRow = ((ColGroupDDC) b).getMapToData().size(); + } + else + throw new NotImplementedException(); + + return combineUC(combineColumns, a, b, nRow); + } + + private static AColGroup combineUC(IColIndex combinedColumns, AColGroup a, AColGroup b, int nRow) { + double sparsityCombined = (a.getSparsity() * a.getNumCols() + b.getSparsity() * b.getNumCols()) / + combinedColumns.size(); + + if(sparsityCombined < 0.4) + return combineUCSparse(combinedColumns, a, b, nRow); + else + return combineUCDense(combinedColumns, a, b, nRow); + } + + private static AColGroup combineUCSparse(IColIndex combinedColumns, AColGroup a, AColGroup b, int nRow) { + MatrixBlock target = new MatrixBlock(nRow, combinedColumns.size(), true); + target.allocateBlock(); + + SparseBlock db = target.getSparseBlock(); + + IColIndex aTempCols = ColIndexFactory.getColumnMapping(combinedColumns, a.getColIndices()); + a.copyAndSet(aTempCols).decompressToSparseBlock(db, 0, nRow, 0, 0); + IColIndex bTempCols = ColIndexFactory.getColumnMapping(combinedColumns, b.getColIndices()); + b.copyAndSet(bTempCols).decompressToSparseBlock(db, 0, nRow, 0, 0); + + target.recomputeNonZeros(); + + return ColGroupUncompressed.create(combinedColumns, target, false); + } + + private static AColGroup combineUCDense(IColIndex combinedColumns, AColGroup a, AColGroup b, int nRow) { MatrixBlock target = new MatrixBlock(nRow, combinedColumns.size(), false); target.allocateBlock(); DenseBlock db = target.getDenseBlock(); @@ -222,19 +313,37 @@ private static AColGroup combineUC(IColIndex combinedColumns, AColGroup a, AColG target.recomputeNonZeros(); return ColGroupUncompressed.create(combinedColumns, target, false); - } public static double[] constructDefaultTuple(AColGroupCompressed ac, AColGroupCompressed bc) { - double[] ret = new double[ac.getNumCols() + bc.getNumCols()]; - if(ac instanceof IContainDefaultTuple) { - double[] defa = ((IContainDefaultTuple) ac).getDefaultTuple(); - System.arraycopy(defa, 0, ret, 0, defa.length); + final double[] ret = new double[ac.getNumCols() + bc.getNumCols()]; + final IIterate ai = ac.getColIndices().iterator(); + final IIterate bi = bc.getColIndices().iterator(); + final double[] defa = ((IContainDefaultTuple) ac).getDefaultTuple(); + final double[] defb = ((IContainDefaultTuple) bc).getDefaultTuple(); + + int i = 0; + while(ai.hasNext() && bi.hasNext()) { + if(ai.v() < bi.v()) { + ret[i++] = defa[ai.i()]; + ai.next(); + } + else { + ret[i++] = defb[bi.i()]; + bi.next(); + } } - if(bc instanceof IContainDefaultTuple) { - double[] defb = ((IContainDefaultTuple) bc).getDefaultTuple(); - System.arraycopy(defb, 0, ret, ac.getNumCols(), defb.length); + + while(ai.hasNext()) { + ret[i++] = defa[ai.i()]; + ai.next(); + } + + while(bi.hasNext()) { + ret[i++] = defb[bi.i()]; + bi.next(); } + return ret; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java index 35656f2ea2c..f38c68dee13 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java @@ -254,19 +254,27 @@ private static void decompressDenseSingleThread(MatrixBlock ret, List } } - protected static void decompressDenseMultiThread(MatrixBlock ret, List groups, double[] constV, int k, - boolean overlapping) { - final int nRows = ret.getNumRows(); - final double eps = getEps(constV); - final int blklen = Math.max(nRows / k, 512); - decompressDenseMultiThread(ret, groups, nRows, blklen, constV, eps, k, overlapping); - } + // private static void decompressDenseMultiThread(MatrixBlock ret, List groups, double[] constV, int k, + // boolean overlapping) { + // final int nRows = ret.getNumRows(); + // final double eps = getEps(constV); + // final int blklen = Math.max(nRows / k, 512); + // decompressDenseMultiThread(ret, groups, nRows, blklen, constV, eps, k, overlapping); + // } protected static void decompressDenseMultiThread(MatrixBlock ret, List groups, double[] constV, double eps, int k, boolean overlapping) { + + Timing time = new Timing(true); final int nRows = ret.getNumRows(); final int blklen = Math.max(nRows / k, 512); decompressDenseMultiThread(ret, groups, nRows, blklen, constV, eps, k, overlapping); + if(DMLScript.STATISTICS) { + final double t = time.stop(); + DMLCompressionStatistics.addDecompressTime(t, k); + if(LOG.isTraceEnabled()) + LOG.trace("decompressed block w/ k=" + k + " in " + t + "ms."); + } } private static void decompressDenseMultiThread(MatrixBlock ret, List filteredGroups, int rlen, int blklen, diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java index d0983d4ae06..ee11483461d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java @@ -37,6 +37,7 @@ import org.apache.sysds.runtime.compress.colgroup.APreAgg; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.data.SparseBlockCSR; import org.apache.sysds.runtime.functionobjects.Plus; import org.apache.sysds.runtime.matrix.data.LibMatrixBincell; import org.apache.sysds.runtime.matrix.data.LibMatrixMult; @@ -111,11 +112,52 @@ public static MatrixBlock leftMultByMatrixTransposed(CompressedMatrixBlock right * @return The result of the matrix multiplication */ public static MatrixBlock leftMultByMatrix(CompressedMatrixBlock right, MatrixBlock left, MatrixBlock ret, int k) { - if(left.isEmpty() || right.isEmpty()) - return prepareEmptyReturnMatrix(right, left, ret, false); - ret = prepareReturnMatrix(right, left, ret, false); - ret = LMM(right.getColGroups(), left, ret, k, right.isOverlapping()); - return ret; + try { + if(left.isEmpty() || right.isEmpty()) + return prepareEmptyReturnMatrix(right, left, ret, false); + + if(isSelectionMatrix(left)) + return ret = CLALibSelectionMult.leftSelection(right, left, ret, k); + + ret = prepareReturnMatrix(right, left, ret, false); + ret = LMM(right.getColGroups(), left, ret, k, right.isOverlapping()); + + return ret; + } + catch(Exception e) { + e.printStackTrace(); + throw new DMLCompressionException("Failed CLA LLM", e); + } + } + + private static boolean isSelectionMatrix(MatrixBlock mb) { + if(mb.getNonZeros() <= mb.getNumRows() && mb.isInSparseFormat()) {// good start. + SparseBlock sb = mb.getSparseBlock(); + for(int i = 0; i < mb.getNumRows(); i++) { + if(sb.isEmpty(i)) + continue; + else if(sb.size(i) != 1) + return false; + else if(!(sb instanceof SparseBlockCSR)) { + double[] values = sb.values(i); + final int spos = sb.pos(i); + final int sEnd = spos + sb.size(i); + for(int j = spos; j < sEnd; j++) { + if(values[j] != 1) { + return false; + } + } + } + } + if(sb instanceof SparseBlockCSR) { + for(double d : sb.values(0)) + if(d != 1) + return false; + } + + return true; + } + return false; } private static MatrixBlock prepareEmptyReturnMatrix(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, @@ -188,22 +230,26 @@ private static MatrixBlock leftMultByCompressedTransposedMatrixParallel(Compress } try { - final double[] retV = ret.getDenseBlockValues(); if(containsLeft && containsRight) // if both -- multiply the left and right vectors scaling by number of shared dim - outerProductWithScaling(cL, cR, sd, retV); + outerProductWithScaling(cL, cR, sd, ret); if(containsLeft) // if left -- multiply left with right sum - outerProduct(cL, CLALibUtils.getColSum(fRight, cr, sd), retV); + for(Future f : outerProductParallelTasks(cL, CLALibUtils.getColSum(fRight, cr, sd), ret, ex)) + f.get(); + if(containsRight)// if right -- multiply right with left sum - outerProduct(CLALibUtils.getColSum(fLeft, rl, sd), cR, retV); + for(Future f : outerProductParallelTasks(CLALibUtils.getColSum(fLeft, rl, sd), cR, ret, ex)) + f.get(); for(Future f : t) { MatrixBlock mb = f.get(); if(!mb.isEmpty()) { if(mb.isInSparseFormat()) LibMatrixBincell.bincellOpInPlaceRight(ret, mb, new BinaryOperator(Plus.getPlusFnObject())); - else if(mb.getDenseBlock().isContiguous()) + else if(mb.getDenseBlock().isContiguous()) { + final double[] retV = ret.getDenseBlockValues(); LibMatrixMult.vectAdd(mb.getDenseBlockValues(), retV, 0, 0, retV.length); + } else LibMatrixBincell.bincellOpInPlaceRight(ret, mb, new BinaryOperator(Plus.getPlusFnObject())); } @@ -242,20 +288,20 @@ private static MatrixBlock leftMultByCompressedTransposedMatrixSingleThread(Comp for(int j = 0; j < fLeft.size(); j++) for(int i = 0; i < fRight.size(); i++) fRight.get(i).leftMultByAColGroup(fLeft.get(j), ret, sd); - final double[] retV = ret.getDenseBlockValues(); + if(containsLeft && containsRight) // if both -- multiply the left and right vectors scaling by number of shared dim - outerProductWithScaling(cL, cR, sd, retV); + outerProductWithScaling(cL, cR, sd, ret); if(containsLeft) // if left -- multiply left with right sum - outerProduct(cL, CLALibUtils.getColSum(fRight, cr, sd), retV); + outerProduct(cL, CLALibUtils.getColSum(fRight, cr, sd), ret); if(containsRight)// if right -- multiply right with left sum - outerProduct(CLALibUtils.getColSum(fLeft, rl, sd), cR, retV); + outerProduct(CLALibUtils.getColSum(fLeft, rl, sd), cR, ret); ret.recomputeNonZeros(); return ret; } private static MatrixBlock LMM(List colGroups, MatrixBlock that, MatrixBlock ret, int k, - boolean overlapping) { + boolean overlapping) throws Exception { final int numColumnsOut = ret.getNumColumns(); final int lr = that.getNumRows(); final boolean shouldFilter = CLALibUtils.shouldPreFilter(colGroups); @@ -286,7 +332,8 @@ private static MatrixBlock LMM(List colGroups, MatrixBlock that, Matr else ret.sparseToDense(); - outerProduct(rowSums, constV, ret.getDenseBlockValues()); + outerProductParallel(rowSums, constV, ret, k); + } } else { @@ -300,7 +347,7 @@ private static MatrixBlock LMM(List colGroups, MatrixBlock that, Matr } ret.recomputeNonZeros(k); - ret.examSparsity(); + ret.examSparsity(k); return ret; } @@ -331,10 +378,8 @@ private static void LMMParallel(List npa, List pa, MatrixBlo else tasks.add(new LMMPreAggTask(pa, that, ret, blo, end, off, s, null, 1)); } - if(pa.isEmpty() && rowSums != null) // row sums task tasks.add(new LMMRowSums(that, blo, end, rowSums)); - } for(Future future : pool.invokeAll(tasks)) @@ -379,7 +424,7 @@ private static void LMMParallel(List npa, List pa, MatrixBlo } private static void LMMTaskExec(List npa, List pa, MatrixBlock that, MatrixBlock ret, int rl, - int ru, double[] rowSums, int k) { + int ru, double[] rowSums, int k) throws Exception { if(npa.isEmpty() && pa.isEmpty()) { rowSum(that, rowSums, rl, ru, 0, that.getNumColumns()); return; @@ -395,17 +440,89 @@ private static void LMMTaskExec(List npa, List pa, MatrixBlo } } - private static void outerProduct(final double[] leftRowSum, final double[] rightColumnSum, final double[] result) { - for(int row = 0; row < leftRowSum.length; row++) { + private static void outerProductParallel(final double[] leftRowSum, final double[] rightColumnSum, + final MatrixBlock result, int k) { + ExecutorService pool = CommonThreadPool.get(k); + try { + for(Future t : outerProductParallelTasks(leftRowSum, rightColumnSum, result, pool)) { + t.get(); + } + } + catch(Exception e) { + throw new RuntimeException(); + } + finally { + pool.shutdown(); + } + } + + private static void outerProduct(final double[] leftRowSum, final double[] rightColumnSum, MatrixBlock result) { + outerProductRange(leftRowSum, rightColumnSum, result, 0, leftRowSum.length, 0, rightColumnSum.length); + } + + private static void outerProductRange(final double[] leftRowSum, final double[] rightColumnSum, + final MatrixBlock result, int rl, int ru, int cl, int cu) { + if(result.getDenseBlock().isContiguous()) + outerProductRangeContiguous(leftRowSum, rightColumnSum, result.getDenseBlockValues(), rl, ru, cl, cu); + else + outerProductRangeGeneric(leftRowSum, rightColumnSum, result.getDenseBlock(), rl, ru, cl, cu); + } + + private static void outerProductRangeContiguous(final double[] leftRowSum, final double[] rightColumnSum, + final double[] result, int rl, int ru, int cl, int cu) { + for(int row = rl; row < ru; row++) { final int offOut = rightColumnSum.length * row; final double vLeft = leftRowSum[row]; - for(int col = 0; col < rightColumnSum.length; col++) { - result[offOut + col] += vLeft * rightColumnSum[col]; + if(vLeft != 0) { + for(int col = cl; col < cu; col++) { + result[offOut + col] += vLeft * rightColumnSum[col]; + } + } + } + } + + private static void outerProductRangeGeneric(final double[] leftRowSum, final double[] rightColumnSum, + final DenseBlock res, int rl, int ru, int cl, int cu) { + for(int row = rl; row < ru; row++) { + final int offOut = res.pos(row); + final double[] result = res.values(row); + final double vLeft = leftRowSum[row]; + if(vLeft != 0) { + for(int col = cl; col < cu; col++) { + result[offOut + col] += vLeft * rightColumnSum[col]; + } + } + } + } + + private static List> outerProductParallelTasks(final double[] leftRowSum, final double[] rightColumnSum, + final MatrixBlock result, ExecutorService pool) { + // windows of 1024 each + final int blkz = 1024; + List> tasks = new ArrayList<>(); + for(int row = 0; row < leftRowSum.length; row += blkz) { + final int rl = row; + final int ru = Math.min(leftRowSum.length, row + blkz); + for(int col = 0; col < rightColumnSum.length; col += blkz) { + final int cl = col; + final int cu = Math.min(rightColumnSum.length, col + blkz); + tasks.add(pool.submit(() -> { + outerProductRange(leftRowSum, rightColumnSum, result, rl, ru, cl, cu); + })); } } + return tasks; } private static void outerProductWithScaling(final double[] leftRowSum, final double[] rightColumnSum, + final int scaling, final MatrixBlock result) { + if(result.getDenseBlock().isContiguous()) + outerProductWithScalingContiguous(leftRowSum, rightColumnSum, scaling, result.getDenseBlockValues()); + else + outerProductWithScalingGeneric(leftRowSum, rightColumnSum, scaling, result.getDenseBlock()); + } + + private static void outerProductWithScalingContiguous(final double[] leftRowSum, final double[] rightColumnSum, final int scaling, final double[] result) { for(int row = 0; row < leftRowSum.length; row++) { final int offOut = rightColumnSum.length * row; @@ -416,12 +533,24 @@ private static void outerProductWithScaling(final double[] leftRowSum, final dou } } + private static void outerProductWithScalingGeneric(final double[] leftRowSum, final double[] rightColumnSum, + final int scaling, final DenseBlock res) { + for(int row = 0; row < leftRowSum.length; row++) { + final int offOut = res.pos(row); + final double[] result = res.values(row); + final double vLeft = leftRowSum[row] * scaling; + for(int col = 0; col < rightColumnSum.length; col++) { + result[offOut + col] += vLeft * rightColumnSum[col]; + } + } + } + private static void LMMNoPreAgg(AColGroup g, MatrixBlock that, MatrixBlock ret, int rl, int ru) { g.leftMultByMatrixNoPreAgg(that, ret, rl, ru, 0, that.getNumColumns()); } private static void LMMWithPreAgg(List preAggCGs, MatrixBlock that, MatrixBlock ret, int rl, int ru, - int off, int skip, double[] rowSums, int k) { + int off, int skip, double[] rowSums, int k) throws Exception { if(!that.isInSparseFormat()) LMMWithPreAggDense(preAggCGs, that, ret, rl, ru, off, skip, rowSums); else @@ -429,31 +558,38 @@ private static void LMMWithPreAgg(List preAggCGs, MatrixBlock that, Mat } private static void LMMWithPreAggSparse(List preAggCGs, MatrixBlock that, MatrixBlock ret, int rl, int ru, - int off, int skip, double[] rowSum) { + int off, int skip, double[] rowSum) throws Exception { // row multiplication - final MatrixBlock tmpRes = new MatrixBlock(1, ret.getNumColumns(), false); - final int maxV = preAggCGs.get(off).getNumValues(); - final MatrixBlock preA = new MatrixBlock(1, maxV, false); - // final DenseBlock db = preA.getDenseBlock(); - preA.allocateDenseBlock(); - final double[] preAV = preA.getDenseBlockValues(); - tmpRes.allocateDenseBlock(); + // allocate the preAggregate on demand; + MatrixBlock preA = null; + MatrixBlock fTmp = null; final SparseBlock sb = that.getSparseBlock(); + final int nGroupsToMultiply = preAggCGs.size() / skip; - for(int j = off; j < preAggCGs.size(); j += skip) { + for(int j = off; j < preAggCGs.size(); j += skip) { // selected column groups for this thread. + final int nCol = preAggCGs.get(j).getNumCols(); + final int nVal = preAggCGs.get(j).getNumValues(); for(int r = rl; r < ru; r++) { if(sb.isEmpty(r)) continue; final int rcu = r + 1; - final int nCol = preAggCGs.get(j).getNumCols(); - final int nVal = preAggCGs.get(j).getNumValues(); - if(nCol == 1 || (sb.size(r) * nCol < sb.size(r) + nCol * nVal)) + if(nCol == 1 || (sb.size(r) * nCol < sb.size(r) + (long) nCol * nVal) || nGroupsToMultiply <= 1) LMMNoPreAgg(preAggCGs.get(j), that, ret, r, rcu); else { + if(preA == null) { + // only allocate if we use this path. + // also only allocate the size needed. + preA = new MatrixBlock(1, nVal, false); + fTmp = new MatrixBlock(1, ret.getNumColumns(), false); + fTmp.allocateDenseBlock(); + preA.allocateDenseBlock(); + } + + final double[] preAV = preA.getDenseBlockValues(); final APreAgg g = preAggCGs.get(j); preA.reset(1, g.getPreAggregateSize(), false); g.preAggregateSparse(sb, preAV, r, rcu); - g.mmWithDictionary(preA, tmpRes, ret, 1, r, rcu); + g.mmWithDictionary(preA, fTmp, ret, 1, r, rcu); } } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReorg.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReorg.java new file mode 100644 index 00000000000..b497ac9c474 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReorg.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.lib; + +import java.util.List; + +import org.apache.commons.lang.NotImplementedException; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.functionobjects.SwapIndex; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.ReorgOperator; + +public class CLALibReorg { + + protected static final Log LOG = LogFactory.getLog(CLALibReorg.class.getName()); + + public static MatrixBlock reorg(CompressedMatrixBlock cmb, ReorgOperator op, MatrixBlock ret, int startRow, + int startColumn, int length) { + // SwapIndex is transpose + if(op.fn instanceof SwapIndex && cmb.getNumColumns() == 1) { + MatrixBlock tmp = cmb.decompress(op.getNumThreads()); + long nz = tmp.setNonZeros(tmp.getNonZeros()); + tmp = new MatrixBlock(tmp.getNumColumns(), tmp.getNumRows(), tmp.getDenseBlockValues()); + tmp.setNonZeros(nz); + return tmp; + } + else if(op.fn instanceof SwapIndex) { + if(cmb.getCachedDecompressed() != null) + return cmb.getCachedDecompressed().reorgOperations(op, ret, startRow, startColumn, length); + + return transpose(cmb, ret, op.getNumThreads()); + } + else { + // Allow transpose to be compressed output. In general we need to have a transposed flag on + // the compressed matrix. https://issues.apache.org/jira/browse/SYSTEMDS-3025 + String message = op.getClass().getSimpleName() + " -- " + op.fn.getClass().getSimpleName(); + MatrixBlock tmp = cmb.getUncompressed(message, op.getNumThreads()); + return tmp.reorgOperations(op, ret, startRow, startColumn, length); + } + } + + private static MatrixBlock transpose(CompressedMatrixBlock cmb, MatrixBlock ret, int k) { + + final long nnz = cmb.getNonZeros(); + final int nRow = cmb.getNumRows(); + final int nCol = cmb.getNumColumns(); + final boolean sparseOut = MatrixBlock.evalSparseFormatInMemory(nRow, nCol, nnz); + if(sparseOut) + return transposeSparse(cmb, ret, k); + else + return transposeDense(cmb, ret, k, nRow, nCol, nnz); + } + + private static MatrixBlock transposeSparse(CompressedMatrixBlock cmb, MatrixBlock ret, int k) { + throw new NotImplementedException(); + } + + private static MatrixBlock transposeDense(CompressedMatrixBlock cmb, MatrixBlock ret, int k, int nRow, int nCol, + long nnz) { + if(ret == null) + ret = new MatrixBlock(nCol, nRow, false, nnz); + else + ret.reset(nCol, nRow, false, nnz); + + ret.allocateDenseBlock(); + + decompressToTransposedDense(ret, cmb.getColGroups(), nRow, 0, nRow); + return ret; + } + + private static void decompressToTransposedDense(MatrixBlock ret, List groups, int rlen, int rl, int ru) { + for(int i = 0; i < groups.size(); i++) { + AColGroup g = groups.get(i); + g.decompressToDenseBlockTransposed(ret.getDenseBlock(), rl, ru); + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java index 3dea7f577a9..5b9ce91d9fd 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java @@ -57,6 +57,7 @@ private CLALibScalar() { } public static MatrixBlock scalarOperations(ScalarOperator sop, CompressedMatrixBlock m1, MatrixValue result) { + // Timing time = new Timing(true); if(isInvalidForCompressedOutput(m1, sop)) { LOG.warn("scalar overlapping not supported for op: " + sop.fn.getClass().getSimpleName()); MatrixBlock m1d = m1.decompress(sop.getNumThreads()); @@ -78,7 +79,7 @@ public static MatrixBlock scalarOperations(ScalarOperator sop, CompressedMatrixB int threadsAvailable = (sop.getNumThreads() > 1) ? sop.getNumThreads() : OptimizerUtils .getConstrainedNumThreads(-1); if(threadsAvailable > 1) - parallelScalarOperations(sop, colGroups, ret, threadsAvailable); + parallelScalarOperations(sop, colGroups, ret, threadsAvailable ); else { // Apply the operation to each of the column groups. // Most implementations will only modify metadata. @@ -90,8 +91,15 @@ public static MatrixBlock scalarOperations(ScalarOperator sop, CompressedMatrixB ret.setOverlapping(m1.isOverlapping()); } - ret.recomputeNonZeros(); + if(sop.fn instanceof Divide){ + ret.setNonZeros(m1.getNonZeros()); + } + else{ + ret.recomputeNonZeros(); + } + // System.out.println("CLA Scalar: " + sop + " " + m1.getNumRows() + ", " + m1.getNumColumns() + ", " + m1.getColGroups().size() + // + " -- " + "\t\t" + time.stop()); return ret; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSelectionMult.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSelectionMult.java new file mode 100644 index 00000000000..94a585e993e --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSelectionMult.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.lib; + +import java.util.List; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.utils.IntArrayList; +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; + +/** + * This lib is responsible for selecting and extracting specific rows or columns from a compressed matrix. + * + * The operation performed is like a left matrix multiplication where the left side only have max 1 non zero per row. + * + */ +public class CLALibSelectionMult { + protected static final Log LOG = LogFactory.getLog(CLALibSelectionMult.class.getName()); + + /** + * Left selection where the left matrix is sparse with a max 1 non zero per row and that non zero is a 1. + * + * @param right Right hand side compressed matrix + * @param left Left hand side matrix + * @param ret Output matrix to put the result into. + * @param k The parallelization degree. + * @return The selected rows and columns of the input matrix + */ + public static MatrixBlock leftSelection(CompressedMatrixBlock right, MatrixBlock left, MatrixBlock ret, int k) { + if(right.getNonZeros() <= -1) + right.recomputeNonZeros(); + + boolean sparseOut = right.getSparsity() < 0.3; + ret.reset(left.getNumRows(), right.getNumColumns(), sparseOut); + ret.allocateBlock(); + final List preFilter = right.getColGroups(); + final boolean shouldFilter = CLALibUtils.shouldPreFilter(preFilter); + if(shouldFilter) { + final double[] constV = new double[ret.getNumColumns()]; + // final List noPreAggGroups = new ArrayList<>(); + // final List preAggGroups = new ArrayList<>(); + final List morphed = CLALibUtils.filterGroups(preFilter, constV); + + if(sparseOut) { + leftSparseSelection(morphed, left, ret, k); + double[] rowSums = left.rowSum(k).getDenseBlockValues(); + outerProductSparse(rowSums, constV, ret); + } + else { + leftDenseSelection(morphed, left, ret, k); + } + + } + else { + if(sparseOut) + leftSparseSelection(preFilter, left, ret, k); + else + leftDenseSelection(preFilter, left, ret, k); + } + + ret.recomputeNonZeros(k); + return ret; + } + + private static void leftSparseSelection(List right, MatrixBlock left, MatrixBlock ret, int k) { + for(AColGroup g : right) + g.sparseSelection(left, ret, 0, left.getNumRows()); + left.getSparseBlock().sort(); + } + + private static void leftDenseSelection(List right, MatrixBlock left, MatrixBlock ret, int k) { + throw new NotImplementedException(); + } + + private static void outerProductSparse(double[] rows, double[] cols, MatrixBlock ret) { + SparseBlock sb = ret.getSparseBlock(); + + IntArrayList skipCols = new IntArrayList(); + for(int c = 0; c < cols.length; c++) + if(cols[c] != 0) + skipCols.appendValue(c); + + final int skipSz = skipCols.size(); + if(skipSz == 0) + return; + + final int[] skipC = skipCols.extractValues(); + for(int r = 0; r < rows.length; r++) { + final double rv = rows[r]; + if(rv != 0) { + for(int ci = 0; ci < skipSz; ci++) { + final int c = skipC[ci]; + sb.add(r, c, rv * cols[c]); + } + } + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleIntListHashMap.java b/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleIntListHashMap.java index 1c9ef3082cb..00fa67f6b6e 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleIntListHashMap.java +++ b/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleIntListHashMap.java @@ -113,7 +113,7 @@ public void appendValue(double key, int value) { } else { for(DIListEntry e = _data[ix]; e != null; e = e.next) { - if(e.key == key) { + if(Util.eq(e.key , key)) { IntArrayList lstPtr = e.value; lstPtr.appendValue(value); break; diff --git a/src/main/java/org/apache/sysds/runtime/compress/utils/Util.java b/src/main/java/org/apache/sysds/runtime/compress/utils/Util.java index 024ed71a862..7fc76cbeacf 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/utils/Util.java +++ b/src/main/java/org/apache/sysds/runtime/compress/utils/Util.java @@ -103,6 +103,13 @@ public static MatrixBlock extractValues(double[] v, IColIndex colIndexes) { return rowVector; } + /** + * Nan Enabled equals operator returns true on Nan == Nan. + * + * @param a value 1 + * @param b value 2 + * @return if they are equal on the bit level. + */ public static boolean eq(double a, double b) { long al = Double.doubleToRawLongBits(a); long bl = Double.doubleToRawLongBits(b); diff --git a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java index a4c15b2b533..24b14c152fc 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java +++ b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java @@ -27,6 +27,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.Stack; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -35,6 +36,7 @@ import org.apache.sysds.common.Types.OpOp2; import org.apache.sysds.common.Types.OpOp3; import org.apache.sysds.common.Types.OpOpData; +import org.apache.sysds.common.Types.ParamBuiltinOp; import org.apache.sysds.common.Types.ReOrgOp; import org.apache.sysds.hops.AggBinaryOp; import org.apache.sysds.hops.AggUnaryOp; @@ -78,6 +80,7 @@ public class WorkloadAnalyzer { private final Set overlapping; private final DMLProgram prog; private final Map treeLookup; + private final Stack stack; public static Map getAllCandidateWorkloads(DMLProgram prog) { // extract all compression candidates from program (in program order) @@ -94,7 +97,7 @@ public static Map getAllCandidateWorkloads(DMLProgram prog) { // construct workload tree for candidate WorkloadAnalyzer wa = new WorkloadAnalyzer(prog); - WTreeRoot tree = wa.createWorkloadTree(cand); + WTreeRoot tree = wa.createWorkloadTreeRoot(cand); map.put(cand.getHopID(), tree); allWAs.add(wa); @@ -111,6 +114,7 @@ private WorkloadAnalyzer(DMLProgram prog) { this.transientCompressed = new HashMap<>(); this.overlapping = new HashSet<>(); this.treeLookup = new HashMap<>(); + this.stack = new Stack<>(); } private WorkloadAnalyzer(DMLProgram prog, Set compressed, HashMap transientCompressed, @@ -122,13 +126,20 @@ private WorkloadAnalyzer(DMLProgram prog, Set compressed, HashMap(); } - private WTreeRoot createWorkloadTree(Hop candidate) { + private WTreeRoot createWorkloadTreeRoot(Hop candidate) { WTreeRoot main = new WTreeRoot(candidate); compressed.add(candidate.getHopID()); + if(HopRewriteUtils.isTransformEncode(candidate)) { + Hop matrix = ((FunctionOp) candidate).getOutputs().get(0); + compressed.add(matrix.getHopID()); + transientCompressed.put(matrix.getName(), matrix.getHopID()); + } for(StatementBlock sb : prog.getStatementBlocks()) - createWorkloadTree(main, sb, prog, new HashSet<>()); + createWorkloadTreeNodes(main, sb, prog, new HashSet<>()); + pruneWorkloadTree(main); return main; } @@ -222,14 +233,14 @@ private static void getCandidates(Hop hop, DMLProgram prog, List cands, Set hop.setVisited(); } - private void createWorkloadTree(AWTreeNode n, StatementBlock sb, DMLProgram prog, Set fStack) { + private void createWorkloadTreeNodes(AWTreeNode n, StatementBlock sb, DMLProgram prog, Set fStack) { WTreeNode node; if(sb instanceof FunctionStatementBlock) { FunctionStatementBlock fsb = (FunctionStatementBlock) sb; FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0); node = new WTreeNode(WTNodeType.FCALL, 1); for(StatementBlock csb : fstmt.getBody()) - createWorkloadTree(node, csb, prog, fStack); + createWorkloadTreeNodes(node, csb, prog, fStack); } else if(sb instanceof WhileStatementBlock) { WhileStatementBlock wsb = (WhileStatementBlock) sb; @@ -238,7 +249,7 @@ else if(sb instanceof WhileStatementBlock) { createWorkloadTree(wsb.getPredicateHops(), prog, node, fStack); for(StatementBlock csb : wstmt.getBody()) - createWorkloadTree(node, csb, prog, fStack); + createWorkloadTreeNodes(node, csb, prog, fStack); } else if(sb instanceof IfStatementBlock) { IfStatementBlock isb = (IfStatementBlock) sb; @@ -247,9 +258,9 @@ else if(sb instanceof IfStatementBlock) { createWorkloadTree(isb.getPredicateHops(), prog, node, fStack); for(StatementBlock csb : istmt.getIfBody()) - createWorkloadTree(node, csb, prog, fStack); + createWorkloadTreeNodes(node, csb, prog, fStack); for(StatementBlock csb : istmt.getElseBody()) - createWorkloadTree(node, csb, prog, fStack); + createWorkloadTreeNodes(node, csb, prog, fStack); } else if(sb instanceof ForStatementBlock) { // incl parfor ForStatementBlock fsb = (ForStatementBlock) sb; @@ -260,7 +271,7 @@ else if(sb instanceof ForStatementBlock) { // incl parfor createWorkloadTree(fsb.getToHops(), prog, node, fStack); createWorkloadTree(fsb.getIncrementHops(), prog, node, fStack); for(StatementBlock csb : fstmt.getBody()) - createWorkloadTree(node, csb, prog, fStack); + createWorkloadTreeNodes(node, csb, prog, fStack); } else { // generic (last-level) @@ -269,14 +280,19 @@ else if(sb instanceof ForStatementBlock) { // incl parfor if(hops != null) { // process hop DAG to collect operations that are compressed. - for(Hop hop : hops) + for(Hop hop : hops) { createWorkloadTree(hop, prog, n, fStack); + // createStack(hop); + // processStack(prog, n, fStack); + } // maintain hop DAG outputs (compressed or not compressed) for(Hop hop : hops) { if(hop instanceof FunctionOp) { FunctionOp fop = (FunctionOp) hop; - if(!fStack.contains(fop.getFunctionKey())) { + if(HopRewriteUtils.isTransformEncode(fop)) + return; + else if(!fStack.contains(fop.getFunctionKey())) { fStack.add(fop.getFunctionKey()); FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fop.getFunctionKey()); if(fsb == null) @@ -295,7 +311,7 @@ else if(sb instanceof ForStatementBlock) { // incl parfor WorkloadAnalyzer fa = new WorkloadAnalyzer(prog, compressed, fCompressed, transposed, overlapping, treeLookup); - fa.createWorkloadTree(n, fsb, prog, fStack); + fa.createWorkloadTreeNodes(n, fsb, prog, fStack); String[] outs = fop.getOutputVariableNames(); for(int i = 0; i < outs.length; i++) { Long id = fCompressed.get(outs[i]); @@ -305,7 +321,6 @@ else if(sb instanceof ForStatementBlock) { // incl parfor fStack.remove(fop.getFunctionKey()); } } - } } return; @@ -313,27 +328,42 @@ else if(sb instanceof ForStatementBlock) { // incl parfor n.addChild(node); } - private void createWorkloadTree(Hop hop, DMLProgram prog, AWTreeNode parent, Set fStack) { + private void createStack(Hop hop) { if(hop == null || visited.contains(hop) || isNoOp(hop)) return; - - // DFS: recursively process children (inputs first for propagation of compression status) + stack.add(hop); for(Hop c : hop.getInput()) - createWorkloadTree(c, prog, parent, fStack); + createStack(c); - // map statement block propagation to hop propagation - if(HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD, OpOpData.TRANSIENTREAD) && - transientCompressed.containsKey(hop.getName())) { - compressed.add(hop.getHopID()); - treeLookup.put(hop.getHopID(), treeLookup.get(transientCompressed.get(hop.getName()))); - } + visited.add(hop); + } - // collect operations on compressed intermediates or inputs - // if any input is compressed we collect this hop as a compressed operation - if(hop.getInput().stream().anyMatch(h -> compressed.contains(h.getHopID()))) - createOp(hop, parent); + private void createWorkloadTree(Hop hop, DMLProgram prog, AWTreeNode parent, Set fStack) { + createStack(hop); + processStack(prog, parent, fStack); + } + + private void processStack(DMLProgram prog, AWTreeNode parent, Set fStack) { + + while(!stack.isEmpty()) { + Hop hop = stack.pop(); + + // map statement block propagation to hop propagation + if(HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD, OpOpData.TRANSIENTREAD) && + transientCompressed.containsKey(hop.getName())) { + compressed.add(hop.getHopID()); + treeLookup.put(hop.getHopID(), treeLookup.get(transientCompressed.get(hop.getName()))); + } + else { + + // collect operations on compressed intermediates or inputs + // if any input is compressed we collect this hop as a compressed operation + if(hop.getInput().stream().anyMatch(h -> compressed.contains(h.getHopID()))) + createOp(hop, parent); + + } + } - visited.add(hop); } private void createOp(Hop hop, AWTreeNode parent) { @@ -369,11 +399,16 @@ else if(hop instanceof AggUnaryOp) { o = new OpNormal(hop, false); } } - else if(hop instanceof UnaryOp && - !HopRewriteUtils.isUnary(hop, OpOp1.MULT2, OpOp1.MINUS1_MULT, OpOp1.MINUS_RIGHT, OpOp1.CAST_AS_MATRIX)) { - if(isOverlapping(hop.getInput(0))) { - treeLookup.get(hop.getInput(0).getHopID()).setDecompressing(); - return; + else if(hop instanceof UnaryOp) { + if(!HopRewriteUtils.isUnary(hop, OpOp1.MULT2, OpOp1.MINUS1_MULT, OpOp1.MINUS_RIGHT, OpOp1.CAST_AS_MATRIX)) { + if(isOverlapping(hop.getInput(0))) { + treeLookup.get(hop.getInput(0).getHopID()).setDecompressing(); + return; + } + + } + else if(HopRewriteUtils.isUnary(hop, OpOp1.DETECTSCHEMA)) { + o = new OpNormal(hop, false); } } else if(hop instanceof AggBinaryOp) { @@ -411,6 +446,9 @@ else if(HopRewriteUtils.isBinary(hop, OpOp2.RBIND)) { setDecompressionOnAllInputs(hop, parent); return; } + else if(HopRewriteUtils.isBinary(hop, OpOp2.APPLY_SCHEMA)) { + o = new OpNormal(hop, true); + } else { ArrayList in = hop.getInput(); final boolean ol0 = isOverlapping(in.get(0)); @@ -461,22 +499,11 @@ else if(ol0 || ol1) { } else if(hop instanceof IndexingOp) { - IndexingOp idx = (IndexingOp) hop; final boolean isOverlapping = isOverlapping(hop.getInput(0)); - final boolean fullColumn = HopRewriteUtils.isFullColumnIndexing(idx); - - if(fullColumn) { - o = new OpNormal(hop, RewriteCompressedReblock.satisfiesSizeConstraintsForCompression(hop)); - if(isOverlapping) { - overlapping.add(hop.getHopID()); - o.setOverlapping(); - } - } - else { - // This decompression is a little different, since it does not decompress the entire matrix - // but only a sub part. therefore create a new op node and set it to decompressing. - o = new OpNormal(hop, false); - o.setDecompressing(); + o = new OpNormal(hop, true); + if(isOverlapping) { + overlapping.add(hop.getHopID()); + o.setOverlapping(); } } else if(HopRewriteUtils.isTernary(hop, OpOp3.MINUS_MULT, OpOp3.PLUS_MULT, OpOp3.QUANTILE, OpOp3.CTABLE)) { @@ -505,7 +532,17 @@ else if(isCompressed(o2)) { setDecompressionOnAllInputs(hop, parent); } } - else if(hop instanceof ParameterizedBuiltinOp || hop instanceof NaryOp) { + else if(hop instanceof ParameterizedBuiltinOp) { + if(HopRewriteUtils.isParameterBuiltinOp(hop, ParamBuiltinOp.REPLACE, ParamBuiltinOp.TRANSFORMAPPLY)) { + o = new OpNormal(hop, true); + } + else { + LOG.warn("Unknown Hop:" + hop.getClass().getSimpleName() + "\n" + Explain.explain(hop)); + setDecompressionOnAllInputs(hop, parent); + return; + } + } + else if(hop instanceof NaryOp) { setDecompressionOnAllInputs(hop, parent); return; } @@ -522,7 +559,50 @@ else if(hop instanceof ParameterizedBuiltinOp || hop instanceof NaryOp) { if(o.isCompressedOutput()) compressed.add(hop.getHopID()); } + else if(hop.getDataType().isFrame()) { + Op o = null; + if(HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD, OpOpData.TRANSIENTREAD)) + return; + else if(HopRewriteUtils.isData(hop, OpOpData.TRANSIENTWRITE, OpOpData.PERSISTENTWRITE)) { + transientCompressed.put(hop.getName(), hop.getInput(0).getHopID()); + compressed.add(hop.getHopID()); + o = new OpMetadata(hop, hop.getInput(0)); + if(isOverlapping(hop.getInput(0))) + o.setOverlapping(); + } + else if(HopRewriteUtils.isUnary(hop, OpOp1.DETECTSCHEMA)) { + o = new OpNormal(hop, false); + } + else if(HopRewriteUtils.isBinary(hop, OpOp2.APPLY_SCHEMA)) { + o = new OpNormal(hop, true); + } + else if(hop instanceof AggUnaryOp) { + o = new OpNormal(hop, false); + } + else { + LOG.warn("Unknown Hop:" + hop.getClass().getSimpleName() + "\n" + Explain.explain(hop)); + setDecompressionOnAllInputs(hop, parent); + return; + } + + o = o != null ? o : new OpNormal(hop, RewriteCompressedReblock.satisfiesSizeConstraintsForCompression(hop)); + treeLookup.put(hop.getHopID(), o); + parent.addOp(o); + if(o.isCompressedOutput()) + compressed.add(hop.getHopID()); + } + else if(HopRewriteUtils.isTransformEncode(hop)) { + Hop matrix = ((FunctionOp) hop).getOutputs().get(0); + compressed.add(matrix.getHopID()); + transientCompressed.put(matrix.getName(), matrix.getHopID()); + parent.addOp(new OpNormal(hop, true)); + } + else if(hop instanceof FunctionOp && ((FunctionOp) hop).getFunctionNamespace().equals(".builtinNS")) { + parent.addOp(new OpNormal(hop, false)); + } else { + LOG.warn( + "Unknown Hop:" + hop.getClass().getSimpleName() + "\n" + hop.getDataType() + "\n" + Explain.explain(hop)); parent.addOp(new OpNormal(hop, false)); } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java index 06a548a7539..2702013c3ee 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java @@ -635,7 +635,7 @@ public void execute(ExecutionContext ec) if( _monitor ) StatisticMonitor.putPFStat(_ID, Stat.PARFOR_INIT_DATA_T, time.stop()); - // initialize iter var to form value + // initialize iter var to from value IntObject iterVar = new IntObject(from.getLongValue()); /////// diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java index 3efafbb30bc..312f88ca7db 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java @@ -106,6 +106,7 @@ public class FrameBlock implements CacheBlock, Externalizable { /** Locks on the columns not tied to the columns objects. */ private SoftReference _columnLocks = null; + /** Materialized number of rows in this FrameBlock */ private int _nRow = 0; /** Cached size in memory to avoid repeated scans of string columns */ @@ -756,7 +757,8 @@ else if(column != null && column.size() != _nRow) public void write(DataOutput out) throws IOException { final boolean isDefaultMeta = isColNamesDefault() && isColumnMetadataDefault(); // write header (rows, cols, default) - out.writeInt(getNumRows()); + final int nRow = getNumRows(); + out.writeInt(nRow); out.writeInt(getNumColumns()); out.writeBoolean(isDefaultMeta); // write columns (value type, data) @@ -767,7 +769,7 @@ public void write(DataOutput out) throws IOException { out.writeUTF(getColumnName(j)); _colmeta[j].write(out); } - if(type >= 0) // if allocated write column data + if(type >= 0 && nRow > 0) // if allocated write column data _coldata[j].write(out); } } @@ -796,6 +798,8 @@ public void readFields(DataInput in) throws IOException { isDefaultMeta ? null : new String[numCols]; // if meta is default allocate on demand _colmeta = (_colmeta != null && _colmeta.length == numCols) ? _colmeta : new ColumnMetadata[numCols]; _coldata = (_coldata != null && _coldata.length == numCols) ? _coldata : new Array[numCols]; + if(_nRow == 0) + _coldata = null; // read columns (value type, meta, data) for(int j = 0; j < numCols; j++) { byte type = in.readByte(); @@ -807,7 +811,7 @@ public void readFields(DataInput in) throws IOException { else _colmeta[j] = new ColumnMetadata(); // must be allocated. - if(type >= 0) // if in allocated column data then read it + if(type >= 0 && _nRow > 0) // if in allocated column data then read it _coldata[j] = ArrayFactory.read(in, _nRow); } _msize = -1; @@ -815,30 +819,12 @@ public void readFields(DataInput in) throws IOException { @Override public void writeExternal(ObjectOutput out) throws IOException { - - // if((out instanceof ObjectOutputStream)){ - // ObjectOutputStream oos = (ObjectOutputStream)out; - // FastBufferedDataOutputStream fos = new FastBufferedDataOutputStream(oos); - // write(fos); //note: cannot close fos as this would close oos - // fos.flush(); - // } - // else{ - write(out); - // } + write(out); } @Override public void readExternal(ObjectInput in) throws IOException { - // if(in instanceof ObjectInputStream) { - // // fast deserialize of dense/sparse blocks - // ObjectInputStream ois = (ObjectInputStream) in; - // FastBufferedDataInputStream fis = new FastBufferedDataInputStream(ois); - // readFields(fis); // note: cannot close fos as this would close oos - // } - // else { - // redirect deserialization to writable impl - readFields(in); - // } + readFields(in); } @Override @@ -878,7 +864,7 @@ private double arraysSizeInMemory() { for(int j = 0; j < clen; j++) size += ArrayFactory.getInMemorySize(_schema[j], rlen, true); else {// allocated - if(rlen > 1000 && clen > 10 && ConfigurationManager.isParallelIOEnabled()) { + if((rlen > 1000 || clen > 10 )&& ConfigurationManager.isParallelIOEnabled()) { final ExecutorService pool = CommonThreadPool.get(); try { List> f = new ArrayList<>(clen); @@ -893,6 +879,7 @@ private double arraysSizeInMemory() { } catch(InterruptedException | ExecutionException e) { LOG.error(e); + size = 0; for(Array aa : _coldata) size += aa.getInMemorySize(); } @@ -937,10 +924,10 @@ public boolean isShallowSerialize() { public boolean isShallowSerialize(boolean inclConvert) { // shallow serialize if non-string schema because a frame block // is always dense but strings have large array overhead per cell - boolean ret = true; - for(int j = 0; j < _schema.length && ret; j++) - ret &= _coldata[j].isShallowSerialize(); - return ret; + for(int j = 0; j < _schema.length; j++) + if(!_coldata[j].isShallowSerialize()) + return false; + return true; } @Override @@ -1217,6 +1204,22 @@ public void copy(FrameBlock src) { _msize = -1; } + public FrameBlock copyShallow(){ + FrameBlock ret = new FrameBlock(); + ret._nRow = _nRow; + ret._msize = _msize; + final int nCol = getNumColumns(); + if(_coldata != null) + ret._coldata = Arrays.copyOf(_coldata, nCol); + if(_colnames != null) + ret._colnames = Arrays.copyOf(_colnames, nCol); + if(_colmeta != null) + ret._colmeta = Arrays.copyOf(_colmeta, nCol); + if(_schema != null) + ret._schema = Arrays.copyOf(_schema, nCol); + return ret; + } + /** * Copy src matrix into the index range of the existing current matrix. * @@ -1358,7 +1361,7 @@ public FrameBlock getSchemaTypeOf() { } public final FrameBlock detectSchema(int k) { - return FrameLibDetectSchema.detectSchema(this, k); + return FrameLibDetectSchema.detectSchema(this, 0.01, k); } public final FrameBlock detectSchema(double sampleFraction, int k) { diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java index 6d2f28d3dd4..cb8867e9473 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java @@ -43,10 +43,34 @@ public ABooleanArray(int size) { public abstract ABooleanArray select(boolean[] select, int nTrue); @Override - public boolean possiblyContainsNaN(){ + public boolean possiblyContainsNaN() { return false; } + public void setNullsFromString(int rl, int ru, Array value) { + + final int remainder = rl % 64; + if(remainder == 0) { + final int ru64 = (ru / 64) * 64; + for(int i = rl; i < ru64; i++) { + unsafeSet(i, value.get(i) != null); + } + for(int i = ru64 ; i <= ru ; i++) { + set(i, value.get(i) != null); + } + } + else { + for(int i = rl; i <= ru; i++) { + set(i, value.get(i) != null); + } + } + + } + + protected void unsafeSet(int index, boolean value) { + set(index, value); + } + @Override protected Map createRecodeMap() { Map map = new HashMap<>(); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ACompressedArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ACompressedArray.java index a04fae7a2bc..ed774b275d6 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ACompressedArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ACompressedArray.java @@ -19,6 +19,7 @@ package org.apache.sysds.runtime.frame.data.columns; +import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.frame.data.compress.ArrayCompressionStatistics; @@ -121,4 +122,51 @@ public ArrayCompressionStatistics statistics(int nSamples) { return null; } + @Override + public abstract Array changeType(ValueType t); + + @Override + protected Array changeTypeBitSet(Array ret, int l, int u) { + throw new DMLCompressionException("Invalid to change sub compressed array"); + } + + @Override + protected Array changeTypeBoolean(Array retA, int l, int u) { + throw new DMLCompressionException("Invalid to change sub compressed array"); + } + + @Override + protected Array changeTypeDouble(Array retA, int l, int u) { + throw new DMLCompressionException("Invalid to change sub compressed array"); + } + + @Override + protected Array changeTypeFloat(Array retA, int l, int u) { + throw new DMLCompressionException("Invalid to change sub compressed array"); + } + + @Override + protected Array changeTypeInteger(Array retA, int l, int u) { + throw new DMLCompressionException("Invalid to change sub compressed array"); + } + + @Override + protected Array changeTypeLong(Array retA, int l, int u) { + throw new DMLCompressionException("Invalid to change sub compressed array"); + } + + @Override + protected Array changeTypeString(Array retA, int l, int u) { + throw new DMLCompressionException("Invalid to change sub compressed array"); + } + + @Override + protected Array changeTypeCharacter(Array retA, int l, int u) { + throw new DMLCompressionException("Invalid to change sub compressed array"); + } + + @Override + protected Array changeTypeHash64(Array retA, int l, int u) { + throw new DMLCompressionException("Invalid to change sub compressed array"); + } } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java index d2021872ba1..f600fc1a2d4 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java @@ -444,53 +444,50 @@ public boolean containsNull() { public abstract boolean possiblyContainsNaN(); - public Array safeChangeType(ValueType t, boolean containsNull){ - try{ - return changeType(t, containsNull); - } - catch(Exception e){ - Pair ct = analyzeValueType(); // full analysis - return changeType(ct.getKey(), ct.getValue()); - } - } + // public Array safeChangeType(ValueType t, boolean containsNull) { + // try { + // return changeType(t, containsNull); + // } + // catch(Exception e) { + // Pair ct = analyzeValueType(); // full analysis + // return changeType(ct.getKey(), ct.getValue()); + // } + // } public Array changeType(ValueType t, boolean containsNull) { return containsNull ? changeTypeWithNulls(t) : changeType(t); } public Array changeTypeWithNulls(ValueType t) { - + if(t == getValueType()) + return this; final ABooleanArray nulls = getNulls(); - if(nulls == null) + + if(nulls == null || t == ValueType.STRING) // String can contain null. return changeType(t); + return changeTypeWithNulls(ArrayFactory.allocateOptional(t, size())); + } - switch(t) { - case BOOLEAN: - if(size() > ArrayFactory.bitSetSwitchPoint) - return new OptionalArray<>(changeTypeBitSet(), nulls); - else - return new OptionalArray<>(changeTypeBoolean(), nulls); - case FP32: - return new OptionalArray<>(changeTypeFloat(), nulls); - case FP64: - return new OptionalArray<>(changeTypeDouble(), nulls); - case UINT4: - case UINT8: - throw new NotImplementedException(); - case HASH64: - return new OptionalArray<>(changeTypeHash64(), nulls); - case INT32: - return new OptionalArray<>(changeTypeInteger(), nulls); - case INT64: - return new OptionalArray<>(changeTypeLong(), nulls); - case CHARACTER: - return new OptionalArray<>(changeTypeCharacter(), nulls); - case STRING: - case UNKNOWN: - default: - return changeTypeString(); // String can contain null - } + public Array changeTypeWithNulls(Array ret) { + return changeTypeWithNulls((OptionalArray) ret, 0, ret.size()); + } + + public Array changeTypeWithNulls(Array ret, int l, int u) { + if(ret instanceof OptionalArray) + return changeTypeWithNulls((OptionalArray) ret, l, u); + else + return changeType(ret, l, u); + } + @SuppressWarnings("unchecked") + private OptionalArray changeTypeWithNulls(OptionalArray ret, int l, int u) { + if(this.getValueType() == ValueType.STRING) + ret._n.setNullsFromString(l, u - 1, (Array) this); + else + ret._n.set(l, u - 1, getNulls()); + + changeType(ret._a, l, u); + return ret; } /** @@ -499,98 +496,310 @@ public Array changeTypeWithNulls(ValueType t) { * @param t The type to change to * @return A new column array. */ - public final Array changeType(ValueType t) { - switch(t) { + public Array changeType(ValueType t) { + if(t == getValueType()) + return this; + else + return changeType(ArrayFactory.allocate(t, size())); + } + + public final Array changeType(Array ret) { + return changeType(ret, 0, ret.size()); + } + + @SuppressWarnings("unchecked") + public final Array changeType(Array ret, int rl, int ru) { + switch(ret.getValueType()) { case BOOLEAN: - if(size() > ArrayFactory.bitSetSwitchPoint) - return changeTypeBitSet(); + if(ret instanceof BitSetArray || // + (ret instanceof OptionalArray && ((OptionalArray) ret)._a instanceof BitSetArray)) + return changeTypeBitSet((Array) ret, rl, ru); else - return changeTypeBoolean(); + return changeTypeBoolean((Array) ret, rl, ru); case FP32: - return changeTypeFloat(); + return changeTypeFloat((Array) ret, rl, ru); case FP64: - return changeTypeDouble(); + return changeTypeDouble((Array) ret, rl, ru); case UINT4: case UINT8: throw new NotImplementedException(); case HASH64: - return changeTypeHash64(); + return changeTypeHash64((Array) ret, rl, ru); case INT32: - return changeTypeInteger(); + return changeTypeInteger((Array) ret, rl, ru); case INT64: - return changeTypeLong(); - case STRING: - return changeTypeString(); + return changeTypeLong((Array) ret, rl, ru); case CHARACTER: - return changeTypeCharacter(); + return changeTypeCharacter((Array) ret, rl, ru); case UNKNOWN: + case STRING: default: - return changeTypeString(); + return changeTypeString((Array) ret, rl, ru); } } /** - * Change type to a bitSet, of underlying longs to store the individual values + * Change type to a bitSet, of underlying longs to store the individual values. + * + * This method should be overwritten by subclasses if no change is needed. * * @return A Boolean type of array */ - protected abstract Array changeTypeBitSet(); + protected Array changeTypeBitSet() { + return changeTypeBitSet(new BitSetArray(size())); + } + + /** + * Change type to a bitSet, of underlying longs to store the individual values + * + * @param ret The array to insert the result into + * @return A Boolean type of array that is pointing the ret argument + */ + protected final Array changeTypeBitSet(Array ret) { + return changeTypeBitSet(ret, 0, size()); + } + + /** + * Change type to a bitSet, of underlying longs to store the individual values + * + * @param ret The array to insert the result into + * @param l lower index to convert from (inclusive) + * @param u upper index to convert to (exclusive) + * @return A Boolean type of array that is pointing the ret argument + */ + protected abstract Array changeTypeBitSet(Array ret, int l, int u); /** * Change type to a boolean array * * @returnA Boolean type of array */ - protected abstract Array changeTypeBoolean(); + protected Array changeTypeBoolean() { + return changeTypeBoolean(new BooleanArray(new boolean[size()])); + } + + /** + * Change type to a boolean array, of underlying longs to store the individual values + * + * @param ret The array to insert the result into + * @return A Boolean type of array that is pointing the ret argument + */ + protected Array changeTypeBoolean(Array ret) { + return changeTypeBoolean(ret, 0, size()); + } + + /** + * Change type to a boolean array, of underlying longs to store the individual values + * + * @param ret The array to insert the result into + * @param l lower index to convert from (inclusive) + * @param u upper index to convert to (exclusive) + * @return A Boolean type of array that is pointing the ret argument + */ + protected abstract Array changeTypeBoolean(Array ret, int l, int u); /** * Change type to a Double array type * * @return Double type of array */ - protected abstract Array changeTypeDouble(); + protected Array changeTypeDouble() { + return changeTypeDouble(new DoubleArray(new double[size()])); + } + + /** + * Change type to a Double array, of underlying longs to store the individual values + * + * @param ret The array to insert the result into + * @return A Double type of array that is pointing the ret argument + */ + protected Array changeTypeDouble(Array ret) { + return changeTypeDouble(ret, 0, size()); + } + + /** + * Change type to a Double array, of underlying longs to store the individual values + * + * @param ret The array to insert the result into + * @param l lower index to convert from (inclusive) + * @param u upper index to convert to (exclusive) + * @return A Double type of array that is pointing the ret argument + */ + protected abstract Array changeTypeDouble(Array ret, int l, int u); /** * Change type to a Float array type * * @return Float type of array */ - protected abstract Array changeTypeFloat(); + protected Array changeTypeFloat() { + return changeTypeFloat(new FloatArray(new float[size()])); + } + + /** + * Change type to a Float array, of underlying longs to store the individual values + * + * @param ret The array to insert the result into + * @return A Float type of array that is pointing the ret argument + */ + protected Array changeTypeFloat(Array ret) { + return changeTypeFloat(ret, 0, size()); + } + + /** + * Change type to a Float array, of underlying longs to store the individual values + * + * @param ret The array to insert the result into + * @param l lower index to convert from (inclusive) + * @param u upper index to convert to (exclusive) + * @return A Float type of array that is pointing the ret argument + */ + protected abstract Array changeTypeFloat(Array ret, int l, int u); /** * Change type to a Integer array type * * @return Integer type of array */ - protected abstract Array changeTypeInteger(); + protected Array changeTypeInteger() { + return changeTypeInteger(new IntegerArray(new int[size()])); + } + + /** + * Change type to a Integer array, of underlying longs to store the individual values + * + * @param ret The array to insert the result into + * @return A Integer type of array that is pointing the ret argument + */ + protected Array changeTypeInteger(Array ret) { + return changeTypeInteger(ret, 0, size()); + } + + /** + * Change type to a Integer array, of underlying longs to store the individual values + * + * @param ret The array to insert the result into + * @param l lower index to convert from (inclusive) + * @param u upper index to convert to (exclusive) + * @return A Integer type of array that is pointing the ret argument + */ + protected abstract Array changeTypeInteger(Array ret, int l, int u); /** * Change type to a Long array type * * @return Long type of array */ - protected abstract Array changeTypeLong(); + protected Array changeTypeLong() { + return changeTypeLong(new LongArray(new long[size()])); + } + + /** + * Change type to a Long array, of underlying longs to store the individual values + * + * @param ret The array to insert the result into + * @return A Long type of array that is pointing the ret argument + */ + protected Array changeTypeLong(Array ret) { + return changeTypeLong(ret, 0, size()); + } + + /** + * Change type to a Long array, of underlying longs to store the individual values + * + * @param ret The array to insert the result into + * @param l lower index to convert from (inclusive) + * @param u upper index to convert to (exclusive) + * @return A Long type of array that is pointing the ret argument + */ + protected abstract Array changeTypeLong(Array ret, int l, int u); /** * Change type to a Hash46 array type * * @return A Hash64 array */ - protected abstract Array changeTypeHash64(); + protected Array changeTypeHash64() { + return changeTypeHash64(new HashLongArray(new long[size()])); + } + + /** + * Change type to a Hash64 array, of underlying longs to store the individual values + * + * @param ret The array to insert the result into + * @return A Hash64 type of array that is pointing the ret argument + */ + protected Array changeTypeHash64(Array ret) { + return changeTypeHash64(ret, 0, size()); + } + + /** + * Change type to a Hash64 array, of underlying longs to store the individual values + * + * @param ret The array to insert the result into + * @param l lower index to convert from (inclusive) + * @param u upper index to convert to (exclusive) + * @return A Hash64 type of array that is pointing the ret argument + */ + protected abstract Array changeTypeHash64(Array ret, int l, int u); /** * Change type to a String array type * * @return String type of array */ - protected abstract Array changeTypeString(); + protected Array changeTypeString() { + return changeTypeString(new StringArray(new String[size()])); + } + + /** + * Change type to a String array, of underlying longs to store the individual values + * + * @param ret The array to insert the result into + * @return A String type of array that is pointing the ret argument + */ + protected Array changeTypeString(Array ret) { + return changeTypeString(ret, 0, size()); + } + + /** + * Change type to a String array, of underlying longs to store the individual values + * + * @param ret The array to insert the result into + * @param l lower index to convert from (inclusive) + * @param u upper index to convert to (exclusive) + * @return A String type of array that is pointing the ret argument + */ + protected abstract Array changeTypeString(Array ret, int l, int u); /** * Change type to a Character array type * * @return Character type of array */ - protected abstract Array changeTypeCharacter(); + protected Array changeTypeCharacter() { + return changeTypeCharacter(new CharArray(new char[size()])); + } + + /** + * Change type to a Character array, of underlying longs to store the individual values + * + * @param ret The array to insert the result into + * @return A Character type of array that is pointing the ret argument + */ + protected Array changeTypeCharacter(Array ret) { + return changeTypeCharacter(ret, 0, size()); + } + + /** + * Change type to a Character array, of underlying longs to store the individual values + * + * @param ret The array to insert the result into + * @param l lower index to convert from (inclusive) + * @param u upper index to convert to (exclusive) + * @return A Character type of array that is pointing the ret argument + */ + protected abstract Array changeTypeCharacter(Array ret, int l, int u); /** * Get the minimum and maximum length of the contained values as string type. @@ -774,7 +983,7 @@ public ArrayCompressionStatistics statistics(int nSamples) { if(ddcSize < memSize) return new ArrayCompressionStatistics(memSizePerElement, // estDistinct, true, vt.getKey(), vt.getValue(), FrameArrayType.DDC, getInMemorySize(), ddcSize); - else if(vt.getKey() != getValueType() ) + else if(vt.getKey() != getValueType()) return new ArrayCompressionStatistics(memSizePerElement, // estDistinct, true, vt.getKey(), vt.getValue(), null, getInMemorySize(), memSize); return null; diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java index 4ea341313fd..1dea2ac8179 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java @@ -37,26 +37,24 @@ public interface ArrayFactory { public final static int bitSetSwitchPoint = 64; public enum FrameArrayType { - STRING, BOOLEAN, BITSET, INT32, INT64, FP32, FP64, - CHARACTER, RAGGED, OPTIONAL, DDC, - HASH64; + STRING, BOOLEAN, BITSET, INT32, INT64, FP32, FP64, CHARACTER, RAGGED, OPTIONAL, DDC, HASH64; } public static StringArray create(String[] col) { return new StringArray(col); } - public static HashLongArray createHash64(String[] col){ + public static HashLongArray createHash64(String[] col) { return new HashLongArray(col); - } + } - public static OptionalArray createHash64Opt(String[] col){ + public static OptionalArray createHash64Opt(String[] col) { return new OptionalArray<>(col, ValueType.HASH64); - } + } - public static HashLongArray createHash64(long[] col){ + public static HashLongArray createHash64(long[] col) { return new HashLongArray(col); - } + } public static BooleanArray create(boolean[] col) { return new BooleanArray(col); @@ -157,24 +155,15 @@ public static Array allocate(ValueType v, int nRow, String val) { public static Array allocateOptional(ValueType v, int nRow) { switch(v) { case BOOLEAN: - if(nRow > bitSetSwitchPoint) - return new OptionalArray<>(new BitSetArray(nRow), true); - else - return new OptionalArray<>(new BooleanArray(new boolean[nRow]), true); case UINT4: case UINT8: case INT32: - return new OptionalArray<>(new IntegerArray(new int[nRow]), true); case INT64: - return new OptionalArray<>(new LongArray(new long[nRow]), true); case FP32: - return new OptionalArray<>(new FloatArray(new float[nRow]), true); case FP64: - return new OptionalArray<>(new DoubleArray(new double[nRow]), true); case CHARACTER: - return new OptionalArray<>(new CharArray(new char[nRow]), true); case HASH64: - return new OptionalArray<>(new HashLongArray(new long[nRow]), true); + return new OptionalArray<>(allocate(v, nRow), true); case UNKNOWN: case STRING: default: @@ -195,6 +184,7 @@ public static Array allocate(ValueType v, int nRow) { return allocateBoolean(nRow); case UINT4: case UINT8: + LOG.warn("Not supported allocation of UInt 4 or 8 array: defaulting to Int32"); case INT32: return new IntegerArray(new int[nRow]); case INT64: @@ -251,7 +241,7 @@ public static Array read(DataInput in, int nRow) throws IOException { case HASH64: arr = new HashLongArray(new long[nRow]); break; - default: + default: throw new NotImplementedException(v + ""); } arr.readFields(in); @@ -294,7 +284,7 @@ public static Array append(Array a, Array b) { */ @SuppressWarnings("unchecked") public static Array set(Array target, Array src, int rl, int ru, int rlen) { - + if(rlen <= ru) throw new DMLRuntimeException("Invalid range ru: " + ru + " should be less than rlen: " + rlen); else if(rl < 0) @@ -312,7 +302,7 @@ else if(target != null && target.size() < rlen) else if(src.getFrameArrayType() == FrameArrayType.DDC) { final DDCArray ddcA = ((DDCArray) src); final Array ddcDict = ddcA.getDict(); - if(ddcDict == null){ // read empty dict. + if(ddcDict == null) { // read empty dict. target = new DDCArray<>(null, MapToFactory.create(rlen, ddcA.getMap().getUnique())); } else if(ddcDict.getFrameArrayType() == FrameArrayType.OPTIONAL) { diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayWrapper.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayWrapper.java new file mode 100644 index 00000000000..bee4b9965b1 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayWrapper.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.frame.data.columns; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +import org.apache.hadoop.io.Writable; + +public class ArrayWrapper implements Writable { + + public Array _a; + + public ArrayWrapper(Array a){ + _a = a; + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(_a.size()); + _a.write(out); + } + + @Override + public void readFields(DataInput in) throws IOException { + int s = in.readInt(); + _a = ArrayFactory.read(in, s); + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java index cd23ce60b6b..fbe33a50f36 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java @@ -104,6 +104,15 @@ public synchronized void set(int index, boolean value) { _data[wIdx] &= ~(1L << index); } + @Override + public void unsafeSet(int index, boolean value){ + int wIdx = index >> 6; // same as divide by 64 bit faster + if(value) + _data[wIdx] |= (1L << index); + else + _data[wIdx] &= ~(1L << index); + } + @Override public void set(int index, double value) { set(index, Math.round(value) == 1.0); @@ -137,7 +146,7 @@ private static long[] toLongArrayPadded(BitSet data, int minLength) { @Override public void set(int rl, int ru, Array value, int rlSrc) { - if(useVectorizedKernel && value instanceof BitSetArray && (ru - rl >= 64)){ + if(useVectorizedKernel && value instanceof BitSetArray && (ru - rl >= 64)) { try { // try system array copy. // but if it does not work, default to get. @@ -149,7 +158,7 @@ public void set(int rl, int ru, Array value, int rlSrc) { } } else // default - super.set(rl,ru,value, rlSrc); + super.set(rl, ru, value, rlSrc); } private void setVectorized(int rl, int ru, BitSetArray value, int rlSrc) { @@ -163,7 +172,7 @@ private void setVectorizedLongs(int rl, int ru, long[] ov) { setVectorizedLongs(rl, ru, _data, ov); } - public static void setVectorizedLongs(int rl, int ru, long[] ret, long[] ov) { + public static void setVectorizedLongs(int rl, int ru, long[] ret, long[] ov) { final long remainder = rl % 64L; if(remainder == 0) @@ -326,12 +335,12 @@ private BitSetArray sliceSimple(int rl, int ru) { return new BitSetArray(ret); } - private BitSetArray sliceVectorized(int rl, int ru){ + private BitSetArray sliceVectorized(int rl, int ru) { return new BitSetArray(sliceVectorized(_data, rl, ru), ru - rl); } - public static long[] sliceVectorized(long[] _data,int rl, int ru) { + public static long[] sliceVectorized(long[] _data, int rl, int ru) { final long[] ret = new long[(ru - rl) / 64 + 1]; @@ -423,70 +432,83 @@ protected Array changeTypeBitSet() { } @Override - protected Array changeTypeBoolean() { - boolean[] ret = new boolean[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeBitSet(Array ret, int l, int u) { + for(int i = l; i < u; i++) + ret.set(i, get(i)); + return ret; + } + + @Override + protected Array changeTypeBoolean(Array retA, int l, int u) { + boolean[] ret = (boolean[]) retA.get(); + for(int i = l; i < u; i++) // if ever relevant use next set bit instead. // to increase speed, but it should not be the case in general ret[i] = get(i); - - return new BooleanArray(ret); + return retA; } @Override - protected Array changeTypeDouble() { - double[] ret = new double[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeDouble(Array retA, int l, int u) { + double[] ret = (double[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = get(i) ? 1.0 : 0.0; - return new DoubleArray(ret); + return retA; + } @Override - protected Array changeTypeFloat() { - float[] ret = new float[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeFloat(Array retA, int l, int u) { + float[] ret = (float[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = get(i) ? 1.0f : 0.0f; - return new FloatArray(ret); + return retA; + } @Override - protected Array changeTypeInteger() { - int[] ret = new int[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeInteger(Array retA, int l, int u) { + int[] ret = (int[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = get(i) ? 1 : 0; - return new IntegerArray(ret); + return retA; + } @Override - protected Array changeTypeLong() { - long[] ret = new long[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeLong(Array retA, int l, int u) { + long[] ret = (long[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = get(i) ? 1L : 0L; return new LongArray(ret); } @Override - protected Array changeTypeHash64(){ - long[] ret = new long[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeHash64(Array retA, int l, int u) { + long[] ret = ((HashLongArray) retA).getLongs(); + + for(int i = l; i < u; i++) ret[i] = get(i) ? 1L : 0L; - return new HashLongArray(ret); + return retA; + } @Override - protected Array changeTypeString() { - String[] ret = new String[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeString(Array retA, int l, int u) { + String[] ret = (String[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = get(i).toString(); - return new StringArray(ret); + return retA; + } @Override - public Array changeTypeCharacter() { - char[] ret = new char[size()]; - for(int i = 0; i < size(); i++) + public Array changeTypeCharacter(Array retA, int l, int u) { + char[] ret = (char[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = (char) (get(i) ? 1 : 0); - return new CharArray(ret); + return retA; + } @Override @@ -521,10 +543,10 @@ public boolean isEmpty() { @Override public boolean isAllTrue() { if(allTrue != -1) - return allTrue ==1; - + return allTrue == 1; + for(int i = 0; i < _data.length; i++) - if(_data[i] != -1L){ + if(_data[i] != -1L) { allTrue = 0; return false; } @@ -583,12 +605,11 @@ public ArrayCompressionStatistics statistics(int nSamples) { return null; } - @Override - public boolean equals(Array other){ + public boolean equals(Array other) { if(other instanceof BitSetArray) - return Arrays.equals(_data, ((BitSetArray)other)._data); - else + return Arrays.equals(_data, ((BitSetArray) other)._data); + else return false; } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java index ae0307ba41d..3b89db3b70e 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java @@ -79,7 +79,7 @@ public void setFromOtherType(int rl, int ru, Array value) { @Override public void set(int rl, int ru, Array value, int rlSrc) { - if(value instanceof BooleanArray){ + if(value instanceof BooleanArray) { try { // try system array copy. // but if it does not work, default to get. @@ -228,65 +228,80 @@ protected ABooleanArray changeTypeBitSet() { return new BitSetArray(_data); } + @Override + protected Array changeTypeBitSet(Array ret, int l, int u) { + for(int i = l; i < u; i++) + ret.set(i, get(i)); + return ret; + } + @Override protected ABooleanArray changeTypeBoolean() { return this; } @Override - protected Array changeTypeDouble() { - double[] ret = new double[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeBoolean(Array retA, int l, int u) { + boolean[] ret = (boolean[]) retA.get(); + for(int i = l; i < u; i++) + ret[i] = _data[i]; + return retA; + } + + @Override + protected Array changeTypeDouble(Array retA, int l, int u) { + double[] ret = (double[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = _data[i] ? 1.0 : 0.0; - return new DoubleArray(ret); + return retA; } @Override - protected Array changeTypeFloat() { - float[] ret = new float[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeFloat(Array retA, int l, int u) { + float[] ret = (float[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = _data[i] ? 1.0f : 0.0f; - return new FloatArray(ret); + return retA; } @Override - protected Array changeTypeInteger() { - int[] ret = new int[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeInteger(Array retA, int l, int u) { + int[] ret = (int[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = _data[i] ? 1 : 0; - return new IntegerArray(ret); + return retA; } @Override - protected Array changeTypeLong() { - long[] ret = new long[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeLong(Array retA, int l, int u) { + long[] ret = (long[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = _data[i] ? 1L : 0L; - return new LongArray(ret); + return retA; } @Override - protected Array changeTypeHash64(){ - long[] ret = new long[size()]; - for(int i = 0; i < size(); i++) - ret[i] = _data[i] ? 1L : 0L; - return new HashLongArray(ret); + protected Array changeTypeHash64(Array retA, int l, int u) { + long[] ret = ((HashLongArray) retA).getLongs(); + for(int i = l; i < u; i++) + ret[i] = _data[i] ? 1L : 0L; + return retA; } @Override - protected Array changeTypeString() { - String[] ret = new String[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeString(Array retA, int l, int u) { + String[] ret = (String[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = get(i).toString(); - return new StringArray(ret); + return retA; } @Override - public Array changeTypeCharacter() { - char[] ret = new char[size()]; - for(int i = 0; i < size(); i++) + public Array changeTypeCharacter(Array retA, int l, int u) { + char[] ret = (char[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = (char) (_data[i] ? 1 : 0); - return new CharArray(ret); + return retA; } @Override diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java index f597b8ec621..c0369cca68b 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java @@ -24,7 +24,6 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Arrays; -import java.util.BitSet; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; @@ -196,82 +195,84 @@ public long getExactSerializedSize() { } @Override - protected Array changeTypeBitSet() { - final BitSet ret = new BitSet(size()); - for(int i = 0; i < size(); i++) { + protected Array changeTypeBitSet(Array ret, int l, int u){ + for(int i = l; i < u; i++) { final int di = _data[i]; if(di != 0 && di != 1) throw new DMLRuntimeException("Unable to change to boolean from char array because of value:" // + _data[i] + " (as int: " + di + ")"); ret.set(i, di != 0); } - return new BitSetArray(ret, size()); + return ret; } @Override - protected Array changeTypeBoolean() { - final boolean[] ret = new boolean[size()]; - for(int i = 0; i < size(); i++) { + protected Array changeTypeBoolean(Array retA, int l, int u) { + boolean[] ret = (boolean[]) retA.get(); + for(int i = l; i < u; i++) { final int di = _data[i]; if(di != 0 && di != 1) throw new DMLRuntimeException("Unable to change to boolean from char array because of value:" // + _data[i] + " (as int: " + di + ")"); ret[i] = di != 0; } - return new BooleanArray(ret); + return retA; } @Override - protected Array changeTypeDouble() { - double[] ret = new double[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeDouble(Array retA, int l, int u) { + double[] ret = (double[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = _data[i]; - return new DoubleArray(ret); + return retA; } @Override - protected Array changeTypeFloat() { - float[] ret = new float[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeFloat(Array retA, int l, int u) { + float[] ret = (float[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = _data[i]; - return new FloatArray(ret); + return retA; } @Override - protected Array changeTypeInteger() { - int[] ret = new int[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeInteger(Array retA, int l, int u) { + int[] ret = (int[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = _data[i]; - return new IntegerArray(ret); + return retA; } @Override - protected Array changeTypeLong() { - long[] ret = new long[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeLong(Array retA, int l, int u) { + long[] ret = (long[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = _data[i]; - return new LongArray(ret); + return retA; } @Override - protected Array changeTypeHash64(){ - long[] ret = new long[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeHash64(Array retA, int l, int u) { + long[] ret = ((HashLongArray) retA).getLongs(); + for(int i = l; i < u; i++) ret[i] = _data[i]; - return new HashLongArray(ret); + return retA; } @Override - protected Array changeTypeString() { - String[] ret = new String[size()]; - for(int i = 0; i < size(); i++) - ret[i] = String.valueOf(_data[i]); - return new StringArray(ret); + protected Array changeTypeString(Array retA, int l, int u) { + String[] ret = (String[]) retA.get(); + for(int i = l; i < u; i++) + ret[i] = ""+_data[i]; + return retA; } @Override - public Array changeTypeCharacter() { - return this; + public Array changeTypeCharacter(Array retA, int l, int u) { + char[] ret = (char[]) retA.get(); + for(int i = l; i < u; i++) + ret[i] = _data[i]; + return retA; } @Override diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java index 3b7200c7beb..61af1174c3d 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java @@ -160,18 +160,8 @@ protected Map createRecodeMap() { * @param containsNull If the array contains null. * @return a compressed column group. */ - public static Array compressToDDC(Array arr, ValueType vt, boolean containsNull) { - Array arrT; - try { - arrT = containsNull ? arr.changeTypeWithNulls(vt) : arr.changeType(vt); - } - catch(Exception e) { - // fall back to full analysis. - Pair ct = arr.analyzeValueType(); - arrT = ct.getValue() ? arr.changeTypeWithNulls(ct.getKey()) : arr.changeType(ct.getKey()); - } - - return compressToDDC(arrT); + public static Array compressToDDC(Array arr, boolean containsNull) { + return compressToDDC(arr); } @Override @@ -293,6 +283,11 @@ public long getExactSerializedSize() { return 1L +1L+ map.getExactSizeOnDisk() + dict.getExactSerializedSize(); } + @Override + public Array changeType(ValueType t){ + return new DDCArray<>(dict.changeType(t), map); + } + @Override protected Array changeTypeBitSet() { return new DDCArray<>(dict.changeTypeBitSet(), map); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java index 68672c5d73a..0f7e3204d3e 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java @@ -25,7 +25,6 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Arrays; -import java.util.BitSet; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; @@ -33,9 +32,9 @@ import org.apache.sysds.runtime.frame.data.lib.FrameUtil; import org.apache.sysds.runtime.matrix.data.Pair; import org.apache.sysds.runtime.util.UtilFunctions; +import org.apache.sysds.utils.DoubleParser; import org.apache.sysds.utils.MemoryEstimates; -import ch.randelshofer.fastdoubleparser.JavaDoubleParser; public class DoubleArray extends Array { private double[] _data; @@ -254,27 +253,34 @@ public long getExactSerializedSize() { } @Override - protected Array changeTypeBitSet() { - BitSet ret = new BitSet(size()); - for(int i = 0; i < size(); i++) { + protected Array changeTypeBitSet(Array ret, int l, int u){ + for(int i = l; i < u; i++) { if(_data[i] != 0 && _data[i] != 1) throw new DMLRuntimeException( "Unable to change to Boolean from Integer array because of value:" + _data[i]); ret.set(i, _data[i] == 0 ? false : true); } - return new BitSetArray(ret, size()); + return ret; } @Override - protected Array changeTypeBoolean() { - boolean[] ret = new boolean[size()]; - for(int i = 0; i < size(); i++) { + protected Array changeTypeBoolean(Array retA, int l, int u) { + boolean[] ret = (boolean[]) retA.get(); + for(int i = l; i < u; i++) { if(_data[i] != 0 && _data[i] != 1) throw new DMLRuntimeException( "Unable to change to Boolean from Integer array because of value:" + _data[i]); ret[i] = _data[i] == 0 ? false : true; } - return new BooleanArray(ret); + return retA; + } + + @Override + protected Array changeTypeDouble(Array retA, int l, int u) { + double[] ret = (double[]) retA.get(); + for(int i = l; i < u; i++) + ret[i] = _data[i]; + return retA; } @Override @@ -283,60 +289,51 @@ protected Array changeTypeDouble() { } @Override - protected Array changeTypeFloat() { - float[] ret = new float[size()]; - for(int i = 0; i < size(); i++) - ret[i] = (float) _data[i]; - return new FloatArray(ret); + protected Array changeTypeFloat(Array retA, int l, int u) { + float[] ret = (float[]) retA.get(); + for(int i = l; i < u; i++) + ret[i] = (float)_data[i]; + return retA; } @Override - protected Array changeTypeInteger() { - int[] ret = new int[size()]; - for(int i = 0; i < size(); i++) { - if(_data[i] != (int) _data[i]) - throw new DMLRuntimeException("Unable to change to Integer from Double array because of value:" + _data[i]); - ret[i] = (int) _data[i]; - } - return new IntegerArray(ret); + protected Array changeTypeInteger(Array retA, int l, int u) { + int[] ret = (int[]) retA.get(); + for(int i = l; i < u; i++) + ret[i] = (int)_data[i]; + return retA; } @Override - protected Array changeTypeLong() { - long[] ret = new long[size()]; - for(int i = 0; i < size(); i++) { - if(_data[i] != (long) _data[i]) - throw new DMLRuntimeException("Unable to change to Long from Double array because of value:" + _data[i]); - ret[i] = (long) _data[i]; - } - return new LongArray(ret); + protected Array changeTypeLong(Array retA, int l, int u) { + long[] ret = (long[]) retA.get(); + for(int i = l; i < u; i++) + ret[i] = (long)_data[i]; + return retA; } @Override - protected Array changeTypeHash64() { - long[] ret = new long[size()]; - for(int i = 0; i < size(); i++) { - if(_data[i] != (long) _data[i]) - throw new DMLRuntimeException("Unable to change to Long from Double array because of value:" + _data[i]); - ret[i] = (long) _data[i]; - } - return new HashLongArray(ret); + protected Array changeTypeHash64(Array retA, int l, int u) { + long[] ret = ((HashLongArray) retA).getLongs(); + for(int i = l; i < u; i++) + ret[i] = (long)_data[i]; + return retA; } @Override - protected Array changeTypeString() { - String[] ret = new String[size()]; - for(int i = 0; i < size(); i++) - ret[i] = get(i).toString(); - return new StringArray(ret); + protected Array changeTypeString(Array retA, int l, int u) { + String[] ret = (String[]) retA.get(); + for(int i = l; i < u; i++) + ret[i] = Double.toString(_data[i]); + return retA; } @Override - public Array changeTypeCharacter() { - char[] ret = new char[size()]; - for(int i = 0; i < size(); i++) - ret[i] = CharArray.parseChar(get(i).toString()); - return new CharArray(ret); + public Array changeTypeCharacter(Array retA, int l, int u) { + char[] ret = (char[]) retA.get(); + for(int i = l; i < u; i++) + ret[i] = Double.toString(_data[i]).charAt(0); + return retA; } @Override @@ -359,7 +356,7 @@ public static double parseDouble(String value) { try { if(value == null || value.isEmpty()) return 0.0; - return JavaDoubleParser.parseDouble(value); + return DoubleParser.parseFloatingPointLiteral(value, 0, value.length()); } catch(NumberFormatException e) { final int len = value.length(); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java index 03709fd14ac..07555d095d5 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java @@ -25,7 +25,6 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Arrays; -import java.util.BitSet; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; @@ -205,83 +204,85 @@ public long getExactSerializedSize() { } @Override - protected Array changeTypeBitSet() { - BitSet ret = new BitSet(size()); - for(int i = 0; i < size(); i++) { + protected Array changeTypeBitSet(Array ret, int l, int u) { + for(int i = l; i < u; i++) { if(_data[i] != 0 && _data[i] != 1) - throw new DMLRuntimeException( - "Unable to change to Boolean from Integer array because of value:" + _data[i]); + throw new DMLRuntimeException("Unable to change to Boolean from Float array because of value:" + _data[i]); ret.set(i, _data[i] == 0 ? false : true); } - return new BitSetArray(ret, size()); + return ret; } @Override - protected Array changeTypeBoolean() { - boolean[] ret = new boolean[size()]; - for(int i = 0; i < size(); i++) { + protected Array changeTypeBoolean(Array retA, int l, int u) { + boolean[] ret = (boolean[]) retA.get(); + for(int i = l; i < u; i++) { if(_data[i] != 0 && _data[i] != 1) - throw new DMLRuntimeException( - "Unable to change to Boolean from Integer array because of value:" + _data[i]); + throw new DMLRuntimeException("Unable to change to Boolean from Float array because of value:" + _data[i]); ret[i] = _data[i] == 0 ? false : true; } - return new BooleanArray(ret); + return retA; } @Override - protected Array changeTypeDouble() { - double[] ret = new double[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeDouble(Array retA, int l, int u) { + double[] ret = (double[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = _data[i]; - return new DoubleArray(ret); + return retA; } @Override - protected Array changeTypeInteger() { - int[] ret = new int[size()]; - for(int i = 0; i < size(); i++) { - if(_data[i] != (int) _data[i]) - throw new DMLRuntimeException("Unable to change to integer from float array because of value:" + _data[i]); - ret[i] = (int) _data[i]; - } - return new IntegerArray(ret); + protected Array changeTypeFloat() { + return this; } @Override - protected Array changeTypeLong() { - long[] ret = new long[size()]; - for(int i = 0; i < size(); i++) - ret[i] = (int) _data[i]; - return new LongArray(ret); + protected Array changeTypeFloat(Array retA, int l, int u) { + float[] ret = (float[]) retA.get(); + for(int i = l; i < u; i++) + ret[i] = _data[i]; + return retA; } @Override - protected Array changeTypeHash64() { - long[] ret = new long[size()]; - for(int i = 0; i < size(); i++) - ret[i] = (int) _data[i]; - return new HashLongArray(ret); + protected Array changeTypeInteger(Array retA, int l, int u) { + int[] ret = (int[]) retA.get(); + for(int i = l; i < u; i++) + ret[i] = (int)_data[i]; + return retA; } @Override - protected Array changeTypeFloat() { - return this; + protected Array changeTypeLong(Array retA, int l, int u) { + long[] ret = (long[]) retA.get(); + for(int i = l; i < u; i++) + ret[i] = (long)_data[i]; + return retA; } @Override - protected Array changeTypeString() { - String[] ret = new String[size()]; - for(int i = 0; i < size(); i++) - ret[i] = get(i).toString(); - return new StringArray(ret); + protected Array changeTypeHash64(Array retA, int l, int u) { + long[] ret = ((HashLongArray) retA).getLongs(); + for(int i = l; i < u; i++) + ret[i] = (long)_data[i]; + return retA; } @Override - public Array changeTypeCharacter() { - char[] ret = new char[size()]; - for(int i = 0; i < size(); i++) - ret[i] = CharArray.parseChar(get(i).toString()); - return new CharArray(ret); + protected Array changeTypeString(Array retA, int l, int u) { + String[] ret = (String[]) retA.get(); + for(int i = l; i < u; i++) + ret[i] = Float.toString(_data[i]); + return retA; + } + + @Override + public Array changeTypeCharacter(Array retA, int l, int u) { + char[] ret = (char[]) retA.get(); + for(int i = l; i < u; i++) + ret[i] = Float.toString(_data[i]).charAt(0); + return retA; } @Override @@ -303,7 +304,7 @@ public double getAsDouble(int i) { public static float parseFloat(String value) { if(value == null) return 0.0f; - + final int len = value.length(); if(len == 0) return 0.0f; diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashLongArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashLongArray.java index 459164b21b0..a6cbb69d4aa 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashLongArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashLongArray.java @@ -23,7 +23,6 @@ import java.io.DataOutput; import java.io.IOException; import java.util.Arrays; -import java.util.BitSet; import java.util.HashMap; import java.util.Map; @@ -66,6 +65,10 @@ public long getLong(int index) { return _data[index]; } + protected long[] getLongs(){ + return _data; + } + @Override public void set(int index, Object value) { if(value instanceof String) @@ -269,59 +272,57 @@ public long getExactSerializedSize() { } @Override - protected Array changeTypeBitSet() { - BitSet ret = new BitSet(size()); - for(int i = 0; i < size(); i++) { + protected Array changeTypeBitSet(Array ret, int l, int u){ + for(int i = l; i < u; i++) { if(_data[i] != 0 && _data[i] != 1) throw new DMLRuntimeException( - "Unable to change to Boolean from Integer array because of value:" + _data[i]); + "Unable to change to Boolean from Hash array because of value:" + _data[i]); ret.set(i, _data[i] == 0 ? false : true); } - return new BitSetArray(ret, size()); + return ret; } @Override - protected Array changeTypeBoolean() { - boolean[] ret = new boolean[size()]; - for(int i = 0; i < size(); i++) { - if(_data[i] < 0 || _data[i] > 1) + protected Array changeTypeBoolean(Array retA, int l, int u) { + boolean[] ret = (boolean[]) retA.get(); + for(int i = l; i < u; i++) { + if(_data[i] != 0 && _data[i] != 1) throw new DMLRuntimeException( - "Unable to change to Boolean from Integer array because of value:" + _data[i]); + "Unable to change to Boolean from Hash array because of value:" + _data[i]); ret[i] = _data[i] == 0 ? false : true; } - return new BooleanArray(ret); + return retA; } - @Override - protected Array changeTypeDouble() { - double[] ret = new double[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeDouble(Array retA, int l, int u) { + double[] ret = (double[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = _data[i]; - return new DoubleArray(ret); + return retA; } @Override - protected Array changeTypeFloat() { - float[] ret = new float[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeFloat(Array retA, int l, int u) { + float[] ret = (float[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = _data[i]; - return new FloatArray(ret); + return retA; } @Override - protected Array changeTypeInteger() { - int[] ret = new int[size()]; - for(int i = 0; i < size(); i++) { - if(Math.abs(_data[i]) > Integer.MAX_VALUE) - throw new DMLRuntimeException("Unable to change to integer from long array because of value:" + _data[i]); - ret[i] = (int) _data[i]; - } - return new IntegerArray(ret); + protected Array changeTypeInteger(Array retA, int l, int u) { + int[] ret = (int[]) retA.get(); + for(int i = l; i < u; i++) + ret[i] = (int)_data[i]; + return retA; } @Override - protected Array changeTypeLong() { - return new LongArray(_data); + protected Array changeTypeLong(Array retA, int l, int u) { + long[] ret = (long[]) retA.get(); + for(int i = l; i < u; i++) + ret[i] = _data[i]; + return retA; } @Override @@ -330,11 +331,27 @@ protected Array changeTypeHash64() { } @Override - protected Array changeTypeString() { - String[] ret = new String[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeHash64(Array retA, int l, int u) { + long[] ret = ((HashLongArray) retA).getLongs(); + for(int i = l; i < u; i++) + ret[i] = _data[i]; + return retA; + } + + @Override + protected Array changeTypeString(Array retA, int l, int u) { + String[] ret = (String[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = get(i).toString(); - return new StringArray(ret); + return retA; + } + + @Override + public Array changeTypeCharacter(Array retA, int l, int u) { + char[] ret = (char[]) retA.get(); + for(int i = l; i < u; i++) + ret[i] = get(i).toString().charAt(0); + return retA; } @Override @@ -373,14 +390,6 @@ public static long parseHashLong(String s) { return Long.parseUnsignedLong(s, 16); } - @Override - public Array changeTypeCharacter() { - char[] ret = new char[size()]; - for(int i = 0; i < size(); i++) - ret[i] = get(i).toString().charAt(0); - return new CharArray(ret); - } - @Override public boolean isShallowSerialize() { return true; diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java index a07e499f9e9..a96be6f7480 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java @@ -25,7 +25,6 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Arrays; -import java.util.BitSet; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; @@ -203,43 +202,41 @@ public long getExactSerializedSize() { } @Override - protected Array changeTypeBitSet() { - BitSet ret = new BitSet(size()); - for(int i = 0; i < size(); i++) { + protected Array changeTypeBitSet(Array ret, int l, int u){ + for(int i = l; i < u; i++) { if(_data[i] != 0 && _data[i] != 1) throw new DMLRuntimeException( "Unable to change to Boolean from Integer array because of value:" + _data[i]); ret.set(i, _data[i] == 0 ? false : true); } - return new BitSetArray(ret, size()); + return ret; } - @Override - protected Array changeTypeBoolean() { - boolean[] ret = new boolean[size()]; - for(int i = 0; i < size(); i++) { + protected Array changeTypeBoolean(Array retA, int l, int u) { + boolean[] ret = (boolean[]) retA.get(); + for(int i = l; i < u; i++) { if(_data[i] < 0 || _data[i] > 1) throw new DMLRuntimeException( "Unable to change to Boolean from Integer array because of value:" + _data[i]); ret[i] = _data[i] == 0 ? false : true; } - return new BooleanArray(ret); + return retA; } @Override - protected Array changeTypeDouble() { - double[] ret = new double[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeDouble(Array retA, int l, int u) { + double[] ret = (double[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = _data[i]; - return new DoubleArray(ret); + return retA; } @Override - protected Array changeTypeFloat() { - float[] ret = new float[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeFloat(Array retA, int l, int u) { + float[] ret = (float[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = _data[i]; - return new FloatArray(ret); + return retA; } @Override @@ -248,35 +245,43 @@ protected Array changeTypeInteger() { } @Override - protected Array changeTypeLong() { - long[] ret = new long[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeInteger(Array retA, int l, int u) { + int[] ret = (int[]) retA.get(); + for(int i = l; i < u; i++) + ret[i] = _data[i]; + return retA; + } + + @Override + protected Array changeTypeLong(Array retA, int l, int u) { + long[] ret = (long[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = _data[i]; - return new LongArray(ret); + return retA; } @Override - protected Array changeTypeHash64() { - long[] ret = new long[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeHash64(Array retA, int l, int u) { + long[] ret = ((HashLongArray) retA).getLongs(); + for(int i = l; i < u; i++) ret[i] = _data[i]; - return new HashLongArray(ret); + return retA; } @Override - protected Array changeTypeString() { - String[] ret = new String[size()]; - for(int i = 0; i < size(); i++) - ret[i] = get(i).toString(); - return new StringArray(ret); + protected Array changeTypeString(Array retA, int l, int u) { + String[] ret = (String[]) retA.get(); + for(int i = l; i < u; i++) + ret[i] = Integer.toString(_data[i]); + return retA; } @Override - public Array changeTypeCharacter() { - char[] ret = new char[size()]; - for(int i = 0; i < size(); i++) - ret[i] = get(i).toString().charAt(0); - return new CharArray(ret); + public Array changeTypeCharacter(Array retA, int l, int u) { + char[] ret = (char[]) retA.get(); + for(int i = l; i < u; i++) + ret[i] = Integer.toString(_data[i]).charAt(0); + return retA; } @Override diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java index ddf724ecf85..ce8da11b77e 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java @@ -25,7 +25,6 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Arrays; -import java.util.BitSet; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; @@ -203,54 +202,52 @@ public long getExactSerializedSize() { } @Override - protected Array changeTypeBitSet() { - BitSet ret = new BitSet(size()); - for(int i = 0; i < size(); i++) { + protected Array changeTypeBitSet(Array ret, int l, int u) { + for(int i = l; i < u; i++) { if(_data[i] != 0 && _data[i] != 1) throw new DMLRuntimeException( "Unable to change to Boolean from Integer array because of value:" + _data[i]); ret.set(i, _data[i] == 0 ? false : true); } - return new BitSetArray(ret, size()); + return ret; } @Override - protected Array changeTypeBoolean() { - boolean[] ret = new boolean[size()]; - for(int i = 0; i < size(); i++) { + protected Array changeTypeBoolean(Array retA, int l, int u) { + boolean[] ret = (boolean[]) retA.get(); + for(int i = l; i < u; i++) { if(_data[i] < 0 || _data[i] > 1) - throw new DMLRuntimeException( - "Unable to change to Boolean from Integer array because of value:" + _data[i]); + throw new DMLRuntimeException("Unable to change to Boolean from Long array because of value:" + _data[i]); ret[i] = _data[i] == 0 ? false : true; } - return new BooleanArray(ret); + return retA; } @Override - protected Array changeTypeDouble() { - double[] ret = new double[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeDouble(Array retA, int l, int u) { + double[] ret = (double[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = _data[i]; - return new DoubleArray(ret); + return retA; } @Override - protected Array changeTypeFloat() { - float[] ret = new float[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeFloat(Array retA, int l, int u) { + float[] ret = (float[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = _data[i]; - return new FloatArray(ret); + return retA; } @Override - protected Array changeTypeInteger() { - int[] ret = new int[size()]; - for(int i = 0; i < size(); i++) { - if(Math.abs(_data[i]) > Integer.MAX_VALUE ) + protected Array changeTypeInteger(Array retA, int l, int u) { + int[] ret = (int[]) retA.get(); + for(int i = l; i < u; i++) { + if(Math.abs(_data[i]) > Integer.MAX_VALUE) throw new DMLRuntimeException("Unable to change to integer from long array because of value:" + _data[i]); ret[i] = (int) _data[i]; } - return new IntegerArray(ret); + return retA; } @Override @@ -258,17 +255,41 @@ protected Array changeTypeLong() { return this; } + @Override + protected Array changeTypeLong(Array retA, int l, int u) { + long[] ret = (long[]) retA.get(); + for(int i = l; i < u; i++) + ret[i] = _data[i]; + return retA; + } + @Override protected Array changeTypeHash64() { return new HashLongArray(_data); } @Override - protected Array changeTypeString() { - String[] ret = new String[size()]; - for(int i = 0; i < size(); i++) - ret[i] = get(i).toString(); - return new StringArray(ret); + protected Array changeTypeHash64(Array retA, int l, int u) { + long[] ret = ((HashLongArray) retA).getLongs(); + for(int i = l; i < u; i++) + ret[i] = _data[i]; + return retA; + } + + @Override + protected Array changeTypeString(Array retA, int l, int u) { + String[] ret = (String[]) retA.get(); + for(int i = l; i < u; i++) + ret[i] = Long.toString(_data[i]); + return retA; + } + + @Override + public Array changeTypeCharacter(Array retA, int l, int u) { + char[] ret = (char[]) retA.get(); + for(int i = l; i < u; i++) + ret[i] = Long.toString(_data[i]).charAt(0); + return retA; } @Override @@ -301,14 +322,6 @@ public static long parseLong(String s) { } } - @Override - public Array changeTypeCharacter() { - char[] ret = new char[size()]; - for(int i = 0; i < size(); i++) - ret[i] = get(i).toString().charAt(0); - return new CharArray(ret); - } - @Override public boolean isShallowSerialize() { return true; @@ -359,7 +372,7 @@ public boolean equals(Array other) { } @Override - public boolean possiblyContainsNaN(){ + public boolean possiblyContainsNaN() { return false; } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java index f653b7f321b..1c1af84a793 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java @@ -25,7 +25,6 @@ import java.util.HashMap; import java.util.Map; -import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; @@ -242,7 +241,6 @@ public void setNz(int rl, int ru, Array value) { T v = value.get(i); if(v != null) set(i, v); - } } @@ -252,7 +250,6 @@ public void setFromOtherTypeNz(int rl, int ru, Array value) { String v = UtilFunctions.objectToString(value.get(i)); if(v != null) set(i, v); - } } @@ -311,6 +308,13 @@ public ValueType getValueType() { return _a.getValueType(); } + @Override + public Array changeType(ValueType t) { + if (t == ValueType.STRING) // String can contain null. + return changeType(ArrayFactory.allocate(t, size())); + return changeTypeWithNulls(t); + } + @Override public Pair analyzeValueType(int maxCells) { return new Pair<>(getValueType(), true); @@ -322,61 +326,55 @@ public FrameArrayType getFrameArrayType() { } @Override - protected Array changeTypeBitSet() { - Array a = _a.changeTypeBitSet(); - return new OptionalArray<>(a, _n); + protected Array changeTypeBitSet(Array ret, int l, int u) { + return _a.changeTypeBitSet(ret, l, u); } @Override - protected Array changeTypeBoolean() { - Array a = _a.changeTypeBoolean(); - return new OptionalArray<>(a, _n); + protected Array changeTypeBoolean(Array retA, int l, int u) { + return _a.changeTypeBoolean(retA, l, u); } @Override - protected Array changeTypeDouble() { - Array a = _a.changeTypeDouble(); - return new OptionalArray<>(a, _n); + protected Array changeTypeDouble(Array retA, int l, int u) { + return _a.changeTypeDouble(retA, l, u); } @Override - protected Array changeTypeFloat() { - Array a = _a.changeTypeFloat(); - return new OptionalArray<>(a, _n); + protected Array changeTypeFloat(Array retA, int l, int u) { + return _a.changeTypeFloat(retA, l, u); } @Override - protected Array changeTypeInteger() { - Array a = _a.changeTypeInteger(); - return new OptionalArray<>(a, _n); + protected Array changeTypeInteger(Array retA, int l, int u) { + return _a.changeTypeInteger(retA, l, u); } @Override - protected Array changeTypeLong() { - Array a = _a.changeTypeLong(); - return new OptionalArray<>(a, _n); + protected Array changeTypeLong(Array retA, int l, int u) { + + return _a.changeTypeLong(retA, l, u); } @Override - protected Array changeTypeHash64() { - Array a = _a.changeTypeHash64(); - return new OptionalArray<>(a, _n); + protected Array changeTypeHash64(Array retA, int l, int u) { + return _a.changeTypeHash64(retA, l, u); } @Override - protected Array changeTypeCharacter() { - Array a = _a.changeTypeCharacter(); - return new OptionalArray<>(a, _n); + protected Array changeTypeCharacter(Array retA, int l, int u) { + return _a.changeTypeCharacter(retA, l, u); } @Override - protected Array changeTypeString() { - StringArray a = (StringArray) _a.changeTypeString(); - String[] d = a.get(); + protected Array changeTypeString(Array retA, int l, int u) { + String[] d = (String[]) retA.get(); for(int i = 0; i < _size; i++) - if(!_n.get(i)) + if(_n.get(i)) + d[i] = _a.get(i).toString(); + else d[i] = null; - return a; + return retA; } @Override @@ -426,34 +424,6 @@ public final boolean isNotEmpty(int i) { return _n.isNotEmpty(i) && _a.isNotEmpty(i); } - @Override - public Array changeTypeWithNulls(ValueType t) { - - switch(t) { - case BOOLEAN: - if(size() > ArrayFactory.bitSetSwitchPoint) - return changeTypeBitSet(); - else - return changeTypeBoolean(); - case FP32: - return changeTypeFloat(); - case FP64: - return changeTypeDouble(); - case UINT8: - throw new NotImplementedException(); - case INT32: - return changeTypeInteger(); - case INT64: - return changeTypeLong(); - case CHARACTER: - return changeTypeCharacter(); - case STRING: - case UNKNOWN: - default: - return changeTypeString(); // String can contain null - } - } - @Override public boolean containsNull() { return !_n.isAllTrue(); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java index b97ee68d550..35cb4ee5d68 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java @@ -263,46 +263,96 @@ protected Array changeTypeBitSet() { return _a.changeTypeBitSet(); } + @Override + protected Array changeTypeBitSet(Array ret, int l, int u) { + return _a.changeTypeBitSet(ret, l, u); + } + @Override protected Array changeTypeBoolean() { return _a.changeTypeBoolean(); } + @Override + protected Array changeTypeBoolean(Array retA, int l, int u) { + return _a.changeTypeBoolean(retA, l, u); + } + @Override protected Array changeTypeDouble() { return _a.changeTypeDouble(); } + @Override + protected Array changeTypeDouble(Array retA, int l, int u) { + return _a.changeTypeDouble(retA, l, u); + } + @Override protected Array changeTypeFloat() { return _a.changeTypeFloat(); } + @Override + protected Array changeTypeFloat(Array retA, int l, int u) { + return _a.changeTypeFloat(retA, l, u); + } + @Override protected Array changeTypeInteger() { return _a.changeTypeInteger(); } + @Override + protected Array changeTypeInteger(Array retA, int l, int u) { + return _a.changeTypeInteger(retA, l, u); + } + @Override protected Array changeTypeLong() { return _a.changeTypeLong(); } + @Override + protected Array changeTypeLong(Array retA, int l, int u) { + return _a.changeTypeLong(retA, l, u); + } + @Override protected Array changeTypeHash64() { return _a.changeTypeHash64(); } + @Override + protected Array changeTypeHash64(Array retA, int l, int u) { + return _a.changeTypeHash64(retA, l, u); + } + @Override protected Array changeTypeString() { return _a.changeTypeString(); } + @Override + protected Array changeTypeString(Array retA, int l, int u) { + return _a.changeTypeString(retA, l, u); + } + @Override protected Array changeTypeCharacter() { return _a.changeTypeCharacter(); } + @Override + protected Array changeTypeCharacter(Array retA, int l, int u) { + return _a.changeTypeCharacter(retA, l, u); + } + + @Override + public Array changeTypeWithNulls(ValueType t) { + throw new NotImplementedException("Not Implemented ragged array with nulls"); + } + @Override public void fill(String val) { _a.reset(super.size()); @@ -376,7 +426,7 @@ public boolean equals(Array other) { if(other._size == this._size && // other.getValueType() == this.getValueType() && // other instanceof RaggedArray) { - if(other == this){// same pointer + if(other == this) {// same pointer return true; } RaggedArray ot = (RaggedArray) other; diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java index 46a90505389..4e48f25b2cd 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java @@ -23,10 +23,8 @@ import java.io.DataOutput; import java.io.IOException; import java.util.Arrays; -import java.util.BitSet; import java.util.HashMap; import java.util.Map; -import java.util.regex.Pattern; import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.common.Types.ValueType; @@ -309,208 +307,221 @@ public long getExactSerializedSize() { } @Override - protected Array changeTypeBitSet() { - return changeTypeBoolean(); + public Array changeTypeWithNulls(ValueType t) { + if(t == getValueType()) + return this; + ABooleanArray nulls = getNulls(); + Array a = changeType(t); + if(a instanceof StringArray) + return a; + + return new OptionalArray<>(a, nulls); } @Override - protected Array changeTypeBoolean() { - String firstNN = _data[0]; - int i = 1; - while(firstNN == null && i < size()) { + protected Array changeTypeBitSet(Array ret, int rl, int ru) { + String firstNN = _data[rl]; + int i = rl + 1; + while(firstNN == null && i < ru) { firstNN = _data[i++]; } if(firstNN == null) - // this check is similar to saying i == size(); - // this means all values were null. therefore we have an easy time retuning an empty boolean array. - return ArrayFactory.allocateBoolean(size()); + return ret;// all values were null. therefore we have an easy time retuning an empty boolean array. else if(firstNN.toLowerCase().equals("true") || firstNN.toLowerCase().equals("false")) - return changeTypeBooleanStandard(); + return changeTypeBooleanStandardBitSet(ret, rl, ru); else if(firstNN.equals("0") || firstNN.equals("1") || firstNN.equals("1.0") || firstNN.equals("0.0")) - return changeTypeBooleanNumeric(); - // else if(firstNN.equals("0.0") || firstNN.equals("1.0")) - // return changeTypeBooleanFloat(); + return changeTypeBooleanNumericBitSet(ret, rl, ru); else if(firstNN.toLowerCase().equals("t") || firstNN.toLowerCase().equals("f")) - return changeTypeBooleanCharacter(); + return changeTypeBooleanCharacterBitSet(ret, rl, ru); else throw new DMLRuntimeException("Not supported type of Strings to change to Booleans value: " + firstNN); } - protected Array changeTypeBooleanStandard() { - if(size() > ArrayFactory.bitSetSwitchPoint) - return changeTypeBooleanStandardBitSet(); + @Override + protected Array changeTypeBoolean(Array ret, int rl, int ru) { + String firstNN = _data[rl]; + int i = rl + 1; + while(firstNN == null && i < ru) { + firstNN = _data[i++]; + } + + if(firstNN == null) + return ret;// all values were null. therefore we have an easy time retuning an empty boolean array. + else if(firstNN.toLowerCase().equals("true") || firstNN.toLowerCase().equals("false")) + return changeTypeBooleanStandardArray(ret, rl, ru); + else if(firstNN.equals("0") || firstNN.equals("1") || firstNN.equals("1.0") || firstNN.equals("0.0")) + return changeTypeBooleanNumericArray(ret, rl, ru); + else if(firstNN.toLowerCase().equals("t") || firstNN.toLowerCase().equals("f")) + return changeTypeBooleanCharacterArray(ret, rl, ru); else - return changeTypeBooleanStandardArray(); + throw new DMLRuntimeException("Not supported type of Strings to change to Booleans value: " + firstNN); } - protected Array changeTypeBooleanStandardBitSet() { - BitSet ret = new BitSet(size()); - for(int i = 0; i < size(); i++) { + protected Array changeTypeBooleanStandardBitSet(Array ret, int rl, int ru) { + for(int i = rl; i < ru; i++) { final String s = _data[i]; if(s != null) - ret.set(i, Boolean.parseBoolean(_data[i])); + ret.set(i, Boolean.parseBoolean(s)); } - - return new BitSetArray(ret, size()); + return ret; } - protected Array changeTypeBooleanStandardArray() { - boolean[] ret = new boolean[size()]; - for(int i = 0; i < size(); i++) { + protected Array changeTypeBooleanStandardArray(Array retA, int rl, int ru) { + boolean[] ret = (boolean[]) retA.get(); + for(int i = rl; i < ru; i++) { final String s = _data[i]; if(s != null) - ret[i] = Boolean.parseBoolean(_data[i]); + ret[i] = Boolean.parseBoolean(s); } - return new BooleanArray(ret); - } - - protected Array changeTypeBooleanCharacter() { - if(size() > ArrayFactory.bitSetSwitchPoint) - return changeTypeBooleanCharacterBitSet(); - else - return changeTypeBooleanCharacterArray(); + return retA; } - protected Array changeTypeBooleanCharacterBitSet() { - BitSet ret = new BitSet(size()); - for(int i = 0; i < size(); i++) { + protected Array changeTypeBooleanCharacterBitSet(Array ret, int rl, int ru) { + for(int i = rl; i < ru; i++) { final String s = _data[i]; if(s != null) - ret.set(i, isTrueCharacter(_data[i].charAt(0))); + ret.set(i, isTrueCharacter(s.charAt(0))); } - return new BitSetArray(ret, size()); + return ret; } - protected Array changeTypeBooleanCharacterArray() { - boolean[] ret = new boolean[size()]; - for(int i = 0; i < size(); i++) { + protected Array changeTypeBooleanCharacterArray(Array retA, int rl, int ru) { + boolean[] ret = (boolean[]) retA.get(); + for(int i = rl; i < ru; i++) { final String s = _data[i]; if(s != null) - ret[i] = isTrueCharacter(_data[i].charAt(0)); + ret[i] = isTrueCharacter(s.charAt(0)); } - return new BooleanArray(ret); + return retA; } private boolean isTrueCharacter(char a) { return a == 'T' || a == 't'; } - protected Array changeTypeBooleanNumeric() { - if(size() > ArrayFactory.bitSetSwitchPoint) - return changeTypeBooleanNumericBitSet(); - else - return changeTypeBooleanNumericArray(); - } - - protected Array changeTypeBooleanNumericBitSet() { - BitSet ret = new BitSet(size()); - for(int i = 0; i < size(); i++) { + protected Array changeTypeBooleanNumericBitSet(Array ret, int rl, int ru) { + for(int i = rl; i < ru; i++) { final String s = _data[i]; if(s != null) { if(s.length() > 1) { - final boolean zero = _data[i].equals("0.0"); - final boolean one = _data[i].equals("1.0"); + final boolean zero = s.equals("0.0"); + final boolean one = s.equals("1.0"); if(zero | one) ret.set(i, one); else - throw new DMLRuntimeException("Unable to change to Boolean from String array, value:" + _data[i]); + throw new DMLRuntimeException("Unable to change to Boolean from String array, value: " + s); } else { - final boolean zero = _data[i].charAt(0) == '0'; - final boolean one = _data[i].charAt(0) == '1'; + final boolean zero = s.charAt(0) == '0'; + final boolean one = s.charAt(0) == '1'; if(zero | one) ret.set(i, one); else - throw new DMLRuntimeException("Unable to change to Boolean from String array, value:" + _data[i]); + throw new DMLRuntimeException("Unable to change to Boolean from String array, value: " + s); } } } - return new BitSetArray(ret, size()); + return ret; } - protected Array changeTypeBooleanNumericArray() { - boolean[] ret = new boolean[size()]; - for(int i = 0; i < size(); i++) { + protected Array changeTypeBooleanNumericArray(Array retA, int rl, int ru) { + boolean[] ret = (boolean[]) retA.get(); + for(int i = rl; i < ru; i++) { final String s = _data[i]; if(s != null) { if(s.length() > 1) { - final boolean zero = _data[i].equals("0.0"); - final boolean one = _data[i].equals("1.0"); + final boolean zero = s.equals("0.0"); + final boolean one = s.equals("1.0"); if(zero | one) ret[i] = one; else - throw new DMLRuntimeException("Unable to change to Boolean from String array, value:" + _data[i]); + throw new DMLRuntimeException("Unable to change to Boolean from String array, value: " + s); } else { - final boolean zero = _data[i].charAt(0) == '0'; - final boolean one = _data[i].charAt(0) == '1'; + final boolean zero = s.charAt(0) == '0'; + final boolean one = s.charAt(0) == '1'; if(zero | one) ret[i] = one; else - throw new DMLRuntimeException("Unable to change to Boolean from String array, value:" + _data[i]); + throw new DMLRuntimeException("Unable to change to Boolean from String array, value: " + s); } } } - return new BooleanArray(ret); + return retA; } @Override - protected Array changeTypeDouble() { - double[] ret = new double[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeDouble(Array retA, int l, int u) { + final double[] ret = (double[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = DoubleArray.parseDouble(_data[i]); - return new DoubleArray(ret); + return retA; } @Override - protected Array changeTypeFloat() { - float[] ret = new float[size()]; - for(int i = 0; i < size(); i++) + protected Array changeTypeFloat(Array retA, int l, int u) { + final float[] ret = (float[]) retA.get(); + for(int i = l; i < u; i++) ret[i] = FloatArray.parseFloat(_data[i]); - return new FloatArray(ret); + return retA; } @Override - protected Array changeTypeInteger() { - String firstNN = _data[0]; - int i = 1; - while(firstNN == null && i < size()) { + protected Array changeTypeInteger(Array retA, int l, int u) { + String firstNN = _data[l]; + int i = l + 1; + while(firstNN == null && i < u) { firstNN = _data[i++]; } + if(firstNN == null) - throw new DMLRuntimeException("Invalid change to int on all null"); - else if(firstNN.contains(".")) - return changeTypeIntegerFloatString(); + return retA; // no non zero values. + final int[] ret = (int[]) retA.get(); + if(firstNN.contains(".")) + return changeTypeIntegerFloatString(ret, l, u); else - return changeTypeIntegerNormal(); + return changeTypeIntegerNormal(ret, l, u); } - protected Array changeTypeIntegerFloatString() { - int[] ret = new int[size()]; - Pattern p = Pattern.compile("\\."); - for(int i = 0; i < size(); i++) { + protected Array changeTypeIntegerFloatString(int[] ret, int l, int u) { + for(int i = l; i < u; i++) { final String s = _data[i]; - try { - if(s != null) - ret[i] = Integer.parseInt(p.split(s, 2)[0]); - } - catch(NumberFormatException e) { - - throw new DMLRuntimeException("Unable to change to Integer from String array", e); + // we do this to avoid allocating substrings. + ret[i] = parseSignificant(s); + } + return new IntegerArray(ret); + } + protected int parseSignificant(String s) { + final int len = s.length(); + int v = 0; + int c = 0; + for(; c < len; c++) { + char ch = s.charAt(c); + if(c == ',') + break; + else if((ch - '0') < 10) { + v = 10 * v + ch - '0'; } + else + throw new NumberFormatException(s); } - return new IntegerArray(ret); + c++; + for(; c < len; c++) { + char ch = s.charAt(c); + if((ch - '0') > 10) + throw new NumberFormatException(s); + } + return v; } - protected Array changeTypeIntegerNormal() { + protected Array changeTypeIntegerNormal(int[] ret, int l, int u) { try { - int[] ret = new int[size()]; - for(int i = 0; i < size(); i++) { + for(int i = l; i < u; i++) { final String s = _data[i]; if(s != null) ret[i] = Integer.parseInt(s); @@ -520,15 +531,14 @@ protected Array changeTypeIntegerNormal() { catch(NumberFormatException e) { if(e.getMessage().contains("For input string: \"\"")) { LOG.warn("inefficient safe cast"); - return changeTypeIntegerSafe(); + return changeTypeIntegerSafe(ret, l, u); } throw new DMLRuntimeException("Unable to change to Integer from String array", e); } } - protected Array changeTypeIntegerSafe() { - int[] ret = new int[size()]; - for(int i = 0; i < size(); i++) { + protected Array changeTypeIntegerSafe(int[] ret, int l, int u) { + for(int i = l; i < u; i++) { final String s = _data[i]; if(s != null && s.length() > 0) ret[i] = Integer.parseInt(s); @@ -537,60 +547,30 @@ protected Array changeTypeIntegerSafe() { } @Override - protected Array changeTypeLong() { - try { - long[] ret = new long[size()]; - for(int i = 0; i < size(); i++) { - final String s = _data[i]; - if(s != null) - ret[i] = Long.parseLong(s); - } - return new LongArray(ret); - } - catch(NumberFormatException e) { - throw new DMLRuntimeException("Unable to change to Long from String array", e); + protected Array changeTypeLong(Array retA, int l, int u) { + long[] ret = (long[]) retA.get(); + for(int i = l; i < u; i++) { + final String s = _data[i]; + if(s != null) + ret[i] = Long.parseLong(s); } + return new LongArray(ret); } @Override - protected Array changeTypeHash64() { - try { - long[] ret = new long[size()]; - for(int i = 0; i < size(); i++) { - final String s = _data[i]; - if(s != null) - ret[i] = HashLongArray.parseHashLong(s); - } - return new HashLongArray(ret); - } - catch(NumberFormatException e) { - if(e.getMessage().contains("For input string: \"\"")) { - LOG.warn("inefficient safe cast"); - return changeTypeHash64Safe(); - } - throw new DMLRuntimeException("Unable to change to Hash64 from String array", e); - } - } - - protected Array changeTypeHash64Safe() { - - long[] ret = new long[size()]; - for(int i = 0; i < size(); i++) { - final String s = _data[i]; - if(s != null && s.length() > 0) - ret[i] = HashLongArray.parseHashLong(s); - } - return new HashLongArray(ret); - + protected Array changeTypeHash64(Array retA, int l, int u) { + for(int i = l; i < u; i++) + retA.set(i, _data[i]); + return retA; } @Override - public Array changeTypeCharacter() { - char[] ret = new char[size()]; - for(int i = 0; i < size(); i++) { - if(_data[i] == null) - continue; - ret[i] = _data[i].charAt(0); + public Array changeTypeCharacter(Array retA, int l, int u) { + final char[] ret = (char[]) retA.get(); + for(int i = l; i < u; i++) { + final String s = _data[i]; + if(s != null) + ret[i] = s.charAt(0); } return new CharArray(ret); } @@ -600,6 +580,14 @@ public Array changeTypeString() { return this; } + @Override + public Array changeTypeString(Array retA, int l, int u) { + String[] ret = (String[]) retA.get(); + for(int i = l; i < u; i++) + ret[i] = _data[i]; + return retA; + } + @Override public Pair getMinMaxLength() { int minLength = Integer.MAX_VALUE; diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/compress/ArrayCompressionStatistics.java b/src/main/java/org/apache/sysds/runtime/frame/data/compress/ArrayCompressionStatistics.java index 8323060f810..c9d5dc71e87 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/compress/ArrayCompressionStatistics.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/compress/ArrayCompressionStatistics.java @@ -20,6 +20,8 @@ package org.apache.sysds.runtime.frame.data.compress; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.conf.DMLConfig; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType; public class ArrayCompressionStatistics { @@ -48,8 +50,12 @@ public ArrayCompressionStatistics(int bytePerValue, int nUnique, boolean shouldC @Override public String toString() { StringBuilder sb = new StringBuilder(); - sb.append(String.format("Compressed Stats: size:%8d->%8d, Use:%10s, Unique:%6d, ValueType:%7s", originalSize, - compressedSizeEstimate, bestType == null ? "None" : bestType.toString(), nUnique, valueType)); + if(ConfigurationManager.getDMLConfig().getDoubleValue(DMLConfig.COMPRESSED_SAMPLING_RATIO) < 1) + sb.append(String.format("Compressed Stats: size:%8d->%8d, Use:%10s, EstUnique:%6d, ValueType:%7s", + originalSize, compressedSizeEstimate, bestType == null ? "None" : bestType.toString(), nUnique, valueType)); + else + sb.append(String.format("Compressed Stats: size:%8d->%8d, Use:%10s, Unique:%6d, ValueType:%7s", originalSize, + compressedSizeEstimate, bestType == null ? "None" : bestType.toString(), nUnique, valueType)); return sb.toString(); } } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java b/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java index 869f97919a4..8a014f24454 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; @@ -29,6 +30,7 @@ import org.apache.sysds.runtime.compress.workload.WTreeRoot; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.columns.ArrayFactory; import org.apache.sysds.runtime.frame.data.columns.DDCArray; import org.apache.sysds.runtime.util.CommonThreadPool; @@ -48,7 +50,7 @@ private CompressedFrameBlockFactory(FrameBlock fb, FrameCompressionSettings cs) this.cs = cs; this.stats = new ArrayCompressionStatistics[in.getNumColumns()]; this.compressedColumns = new Array[in.getNumColumns()]; - this.nSamples = Math.min(in.getNumRows(), (int) Math.ceil(in.getNumRows() * cs.sampleRatio)); + this.nSamples = Math.min(in.getNumRows(), Math.max(1000, (int) Math.ceil(in.getNumRows() * cs.sampleRatio))); } public static FrameBlock compress(FrameBlock fb) { @@ -91,12 +93,15 @@ private void encodeSingleThread() { } private void encodeParallel() { - ExecutorService pool = CommonThreadPool.get(cs.k); + final ExecutorService pool = CommonThreadPool.get(cs.k); try { List> tasks = new ArrayList<>(); - for(int i = 0; i < compressedColumns.length; i++) { - final int l = i; - tasks.add(pool.submit(() -> compressCol(l))); + for(int j = 0; j < compressedColumns.length; j++) { + final int i = j; + final Future stats = pool.submit(() -> (getStatistics(i))); + final Future> tmp = pool.submit(() -> allocateCorrectedType(i, stats)); + final Future> tmp2 = changeTypeFuture(i, tmp, pool, cs.k); + tasks.add(pool.submit(() -> compressColFinally(i, tmp2))); } for(Future t : tasks) @@ -112,28 +117,115 @@ private void encodeParallel() { } private void compressCol(int i) { - stats[i] = in.getColumn(i).statistics(nSamples); - if(stats[i] != null) { - if(stats[i].bestType == null){ - // just cast to other value type. - compressedColumns[i] = in.getColumn(i).safeChangeType(stats[i].valueType, stats[i].containsNull); - } - else{ - // commented out because no other encodings are supported yet - switch(stats[i].bestType) { - case DDC: - compressedColumns[i] = DDCArray.compressToDDC(in.getColumn(i), stats[i].valueType, - stats[i].containsNull); - break; - default: - LOG.error("Unsupported encoding default to do nothing: " + stats[i].bestType); - compressedColumns[i] = in.getColumn(i); - break; + final ArrayCompressionStatistics s = getStatistics(i); + if(s != null) + compressCol(i, s); + else + compressedColumns[i] = in.getColumn(i); + } + + private ArrayCompressionStatistics getStatistics(int i) { + return stats[i] = in.getColumn(i).statistics(nSamples); + } + + private Array allocateCorrectedType(int i, Future f) { + try { + f.get(); + return allocateCorrectedType(i); + } + catch(InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + } + + private void compressColFinally(int i, Future> f) { + try { + final Array a = f.get(); + compressColFinally(i, a, stats[i]); + } + catch(InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + } + + private Array allocateCorrectedType(int i) { + final ArrayCompressionStatistics s = stats[i]; + final Array a = in.getColumn(i); + if(s.valueType != null && s.valueType != a.getValueType()) + return s.containsNull ? // + ArrayFactory.allocateOptional(s.valueType, a.size()) : // + ArrayFactory.allocate(s.valueType, a.size());// + else + return a; + } + + private Future> changeTypeFuture(int i, Future> f, ExecutorService pool, int k) { + try { + final Array tmp = f.get(); + final Array a = in.getColumn(i); + final ArrayCompressionStatistics s = stats[i]; + if(s.valueType != null && s.valueType != a.getValueType()) { + + final int nRow = in.getNumRows(); + final int block = Math.max(((nRow / k) / 64) * 64, 1024); + + final List> t = new ArrayList<>(); + for(int r = 0; r < nRow; r += block) { + + final int start = r; + final int end = Math.min(r + block, nRow); + t.add(pool.submit(() -> (a.changeTypeWithNulls(tmp, start, end)))); } + + return pool.submit(() -> { + try { + for(Future tt : t) + tt.get(); + return tmp; + } + catch(Exception e) { + + throw new RuntimeException(e); + } + }); } + else + return pool.submit(() -> tmp); + } + + catch(Exception e) { + e.printStackTrace(); + throw new RuntimeException(e); + } + } + + private void compressCol(int i, final ArrayCompressionStatistics s) { + final Array b = in.getColumn(i); + final Array a; + if(s.valueType != null && s.valueType != b.getValueType()) + a = b.changeType(s.valueType, s.containsNull); else - compressedColumns[i] = in.getColumn(i); + a = b; + + compressColFinally(i, a, s); + } + + private void compressColFinally(int i, final Array a, final ArrayCompressionStatistics s) { + + if(s.bestType != null) { + switch(s.bestType) { + case DDC: + compressedColumns[i] = DDCArray.compressToDDC(a, s.containsNull); + break; + default: + LOG.error("Unsupported encoding default to do nothing: " + s.bestType); + compressedColumns[i] = a; + break; + } + } + else + compressedColumns[i] = a; } private void logStatistics() { @@ -152,6 +244,8 @@ private void logRet(FrameBlock ret) { if(LOG.isDebugEnabled()) { final long before = in.getInMemorySize(); final long after = ret.getInMemorySize(); + LOG.debug(String.format("nRows %15d", in.getNumRows())); + LOG.debug(String.format("SampleSize %15d", nSamples)); LOG.debug(String.format("Uncompressed Size: %15d", before)); LOG.debug(String.format("compressed Size: %15d", after)); LOG.debug(String.format("ratio: %15.3f", (double) before / (double) after)); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/compress/FrameCompressionSettings.java b/src/main/java/org/apache/sysds/runtime/frame/data/compress/FrameCompressionSettings.java index 84a23bf6480..3c0c0072d07 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/compress/FrameCompressionSettings.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/compress/FrameCompressionSettings.java @@ -23,11 +23,11 @@ public class FrameCompressionSettings { - public final float sampleRatio; + public final double sampleRatio; public final int k; public final WTreeRoot wt; - protected FrameCompressionSettings(float sampleRatio, int k, WTreeRoot wt) { + protected FrameCompressionSettings(double sampleRatio, int k, WTreeRoot wt) { this.sampleRatio = sampleRatio; this.k = k; this.wt = wt; diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/compress/FrameCompressionSettingsBuilder.java b/src/main/java/org/apache/sysds/runtime/frame/data/compress/FrameCompressionSettingsBuilder.java index 936cd42898d..39c1fca07d6 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/compress/FrameCompressionSettingsBuilder.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/compress/FrameCompressionSettingsBuilder.java @@ -19,16 +19,18 @@ package org.apache.sysds.runtime.frame.data.compress; +import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.conf.DMLConfig; import org.apache.sysds.runtime.compress.workload.WTreeRoot; public class FrameCompressionSettingsBuilder { - private float sampleRatio; + private double sampleRatio; private int k; private WTreeRoot wt; public FrameCompressionSettingsBuilder() { - this.sampleRatio = 0.1f; + this.sampleRatio = ConfigurationManager.getDMLConfig().getDoubleValue(DMLConfig.COMPRESSED_SAMPLING_RATIO); this.k = 1; this.wt = null; } @@ -43,7 +45,7 @@ public FrameCompressionSettingsBuilder threads(int k) { return this; } - public FrameCompressionSettingsBuilder sampleRatio(float sampleRatio) { + public FrameCompressionSettingsBuilder sampleRatio(double sampleRatio) { this.sampleRatio = sampleRatio; return this; } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java index 0c8ceb9d874..4b703224a5b 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java @@ -31,6 +31,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.columns.ArrayFactory; import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; import org.apache.sysds.runtime.util.CommonThreadPool; @@ -38,10 +39,13 @@ public class FrameLibApplySchema { protected static final Log LOG = LogFactory.getLog(FrameLibApplySchema.class.getName()); + public static int PAR_ROW_THRESHOLD = 1024; + private final FrameBlock fb; private final ValueType[] schema; private final boolean[] nulls; private final int nCol; + private final int nRow; private final Array[] columnsIn; private final Array[] columnsOut; @@ -102,6 +106,7 @@ private FrameLibApplySchema(FrameBlock fb, ValueType[] schema, boolean[] nulls, this.k = k; verifySize(); nCol = fb.getNumColumns(); + nRow = fb.getNumRows(); columnsIn = fb.getColumns(); columnsOut = new Array[nCol]; } @@ -123,14 +128,14 @@ private FrameBlock apply() { final String[] colNames = fb.getColumnNames(false); final ColumnMetadata[] meta = fb.getColumnMetadata(); - FrameBlock out = new FrameBlock(schema, colNames, meta, columnsOut); - if(LOG.isDebugEnabled()){ + FrameBlock out = new FrameBlock(schema, colNames, meta, columnsOut); + if(LOG.isDebugEnabled()) { long inMem = fb.getInMemorySize(); long outMem = out.getInMemorySize(); - LOG.debug(String.format("Schema Apply Input Size: %16d" , inMem)); - LOG.debug(String.format(" Output Size: %16d" , outMem)); - LOG.debug(String.format(" Ratio : %4.3f" , ((double) inMem / outMem))); + LOG.debug(String.format("Schema Apply Input Size: %16d", inMem)); + LOG.debug(String.format(" Output Size: %16d", outMem)); + LOG.debug(String.format(" Ratio : %4.3f", ((double) inMem / outMem))); } return out; } @@ -152,19 +157,42 @@ private void apply(int i) { private void applyMultiThread() { final ExecutorService pool = CommonThreadPool.get(k); try { - List> f = new ArrayList<>(nCol ); - for(int i = 0; i < nCol ; i ++){ - final int j = i; - f.add(pool.submit(() -> apply(j))); + List> f = new ArrayList<>(nCol); + + final int rowThreads = Math.max(1, (k * 2) / nCol); + final int block = Math.max(((nRow / rowThreads) / 64) * 64, PAR_ROW_THRESHOLD); + for(int i = 0; i < nCol; i++) { + final int j = i; // final col variable for task + if(schema[i] == columnsIn[i].getValueType()) { + apply(i); + } + else { + + if(nulls != null && nulls[i]) { + columnsOut[j] = ArrayFactory.allocateOptional(schema[i], nRow); + for(int r = 0; r < nRow; r += block) { + final int start = r; + final int end = Math.min(nRow, r + block); + f.add(pool.submit(() -> columnsIn[j].changeTypeWithNulls(columnsOut[j], start, end))); + } + } + else { + columnsOut[j] = ArrayFactory.allocate(schema[i], nRow); + for(int r = 0; r < nRow; r += block) { + final int start = r; + final int end = Math.min(nRow, r + block); + f.add(pool.submit(() -> columnsIn[j].changeType(columnsOut[j], start, end))); + } + } // + } } - - for( Future e : f) + for(Future e : f) e.get(); } catch(InterruptedException | ExecutionException e) { throw new DMLRuntimeException("Failed to combine column groups", e); } - finally{ + finally { pool.shutdown(); } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java index 71e3788a1c2..d225a0ed6cc 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java @@ -30,6 +30,7 @@ import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.compress.estim.ComEstFactory; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.Array; import org.apache.sysds.runtime.matrix.data.Pair; @@ -39,7 +40,8 @@ public final class FrameLibDetectSchema { protected static final Log LOG = LogFactory.getLog(FrameLibDetectSchema.class.getName()); /** Default minium sample size */ - private static final int DEFAULT_MIN_CELLS = 100000; + private static final int DEFAULT_MIN_CELLS = 10000; + private static final int DEFAULT_MAX_CELLS = 1000000; /** Frame block to sample from */ private final FrameBlock in; @@ -52,7 +54,8 @@ private FrameLibDetectSchema(FrameBlock in, double sampleFraction, int k) { this.in = in; this.k = k; final int inRows = in.getNumRows(); - this.sampleSize = Math.min(inRows, Math.max((int) (inRows * sampleFraction), DEFAULT_MIN_CELLS)); + this.sampleSize = Math.min(Math.max((int) (inRows * sampleFraction), DEFAULT_MIN_CELLS), + ComEstFactory.getSampleSize(0.65, inRows, in.getNumColumns(), 1.0, DEFAULT_MIN_CELLS, DEFAULT_MAX_CELLS)); } public static FrameBlock detectSchema(FrameBlock in, int k) { @@ -85,6 +88,7 @@ private String[] parallelApply() { final ExecutorService pool = CommonThreadPool.get(k); try { final int cols = in.getNumColumns(); + LOG.error(sampleSize + " " + in.getNumRows()); final ArrayList tasks = new ArrayList<>(cols); for(int i = 0; i < cols; i++) tasks.add(new DetectValueTypeTask(in.getColumn(i), sampleSize)); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java index 06e68a63d51..f331406e1c9 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java @@ -39,6 +39,7 @@ import org.apache.sysds.lops.WeightedUnaryMM; import org.apache.sysds.lops.WeightedUnaryMMR; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType; import org.apache.sysds.runtime.instructions.cp.CPOperand; import org.apache.sysds.runtime.instructions.spark.AggregateTernarySPInstruction; import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction; @@ -194,6 +195,7 @@ public class SPInstructionParser extends InstructionParser String2SPInstructionType.put( "freplicate", SPType.Binary); String2SPInstructionType.put( "mapdropInvalidLength", SPType.Binary); String2SPInstructionType.put( "valueSwap", SPType.Binary); + String2SPInstructionType.put( "applySchema" , SPType.Binary); String2SPInstructionType.put( "_map", SPType.Ternary); // _map refers to the operation map // Relational Instruction Opcodes String2SPInstructionType.put( "==" , SPType.Binary); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java index cff0650235e..2ec23037385 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java @@ -93,8 +93,10 @@ public void processInstruction(ExecutionContext ec) { // Release the memory occupied by input matrices ec.releaseMatrixInput(input1.getName(), input2.getName()); // Ensure right dense/sparse output representation (guarded by released input memory) - if(checkGuardedRepresentationChange(inBlock1, inBlock2, retBlock)) - retBlock.examSparsity(); + if(checkGuardedRepresentationChange(inBlock1, inBlock2, retBlock)){ + int k = (_optr instanceof BinaryOperator) ? ((BinaryOperator) _optr).getNumThreads() : 1; + retBlock.examSparsity(k); + } } // Attach result matrix with MatrixObject associated with output_name diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixAppendCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixAppendCPInstruction.java index da7a4adec0e..4ca859a88e0 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixAppendCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixAppendCPInstruction.java @@ -22,7 +22,7 @@ import org.apache.commons.lang3.tuple.Pair; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; -import org.apache.sysds.runtime.compress.lib.CLALibAppend; +import org.apache.sysds.runtime.compress.lib.CLALibCBind; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysds.runtime.lineage.LineageItem; @@ -46,8 +46,8 @@ public void processInstruction(ExecutionContext ec) { validateInput(matBlock1, matBlock2); MatrixBlock ret; - if(matBlock1 instanceof CompressedMatrixBlock || matBlock2 instanceof CompressedMatrixBlock) - ret = CLALibAppend.append(matBlock1, matBlock2, InfrastructureAnalyzer.getLocalParallelism()); + if(_type == AppendType.CBIND && (matBlock1 instanceof CompressedMatrixBlock || matBlock2 instanceof CompressedMatrixBlock) ) + ret = CLALibCBind.cbind(matBlock1, matBlock2, InfrastructureAnalyzer.getLocalParallelism()); else ret = matBlock1.append(matBlock2, new MatrixBlock(), _type == AppendType.CBIND); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameFrameSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameFrameSPInstruction.java index 6f6232e71af..f283f9f7f41 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameFrameSPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameFrameSPInstruction.java @@ -19,6 +19,7 @@ package org.apache.sysds.runtime.instructions.spark; +import org.apache.commons.lang.NotImplementedException; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.broadcast.Broadcast; @@ -45,6 +46,7 @@ public void processInstruction(ExecutionContext ec) { JavaPairRDD in1 = sec.getFrameBinaryBlockRDDHandleForVariable(input1.getName()); JavaPairRDD out = null; + LOG.error(getOpcode()); if(getOpcode().equals("dropInvalidType")) { // get schema frame-block Broadcast fb = sec.getSparkContext().broadcast(sec.getFrameInput(input2.getName())); @@ -59,6 +61,11 @@ else if(getOpcode().equals("valueSwap")) { // Attach result frame with FrameBlock associated with output_name sec.releaseFrameInput(input2.getName()); } + else if(getOpcode().equals("applySchema")){ + Broadcast fb = sec.getSparkContext().broadcast(sec.getFrameInput(input2.getName())); + out = in1.mapValues(new applySchema(fb.getValue())); + sec.releaseFrameInput(input2.getName()); + } else { JavaPairRDD in2 = sec.getFrameBinaryBlockRDDHandleForVariable(input2.getName()); // create output frame @@ -70,7 +77,9 @@ else if(getOpcode().equals("valueSwap")) { //set output RDD and maintain dependencies sec.setRDDHandleForVariable(output.getName(), out); sec.addLineageRDD(output.getName(), input1.getName()); - if( !getOpcode().equals("dropInvalidType") && !getOpcode().equals("valueSwap")) + if(!getOpcode().equals("dropInvalidType") && // + !getOpcode().equals("valueSwap") && // + !getOpcode().equals("applySchema")) sec.addLineageRDD(output.getName(), input2.getName()); } @@ -116,4 +125,20 @@ public FrameBlock call(FrameBlock arg0) throws Exception { return arg0.valueSwap(schema_frame); } } + + + private static class applySchema implements Function{ + private static final long serialVersionUID = 58504021316402L; + + private FrameBlock schema; + + public applySchema(FrameBlock schema ) { + this.schema = schema; + } + + @Override + public FrameBlock call(FrameBlock arg0) throws Exception { + return arg0.applySchema(schema); + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/WriteSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/WriteSPInstruction.java index bb97b9a4ca2..06e972c91d5 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/WriteSPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/WriteSPInstruction.java @@ -23,6 +23,7 @@ import java.util.ArrayList; import java.util.Random; +import org.apache.commons.lang.NotImplementedException; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.tuple.Pair; import org.apache.hadoop.io.LongWritable; @@ -352,16 +353,13 @@ private static void customSaveTextFile(JavaRDD rdd, String fname, boolea } rdd.saveAsTextFile(randFName); - HDFSTool.mergeIntoSingleFile(randFName, fname); // Faster version :) - - // rdd.coalesce(1, true).saveAsTextFile(randFName); - // MapReduceTool.copyFileOnHDFS(randFName + "/part-00000", fname); + HDFSTool.mergeIntoSingleFile(randFName, fname); } catch (IOException e) { throw new DMLRuntimeException("Cannot merge the output into single file: " + e.getMessage()); } finally { try { - // This is to make sure that we donot create random files on HDFS + // This is to make sure that we do not create random files on HDFS HDFSTool.deleteFileIfExistOnHDFS( randFName ); } catch (IOException e) { throw new DMLRuntimeException("Cannot merge the output into single file: " + e.getMessage()); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/FrameRDDConverterUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/FrameRDDConverterUtils.java index 9371d43094c..a5974640cc5 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/FrameRDDConverterUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/FrameRDDConverterUtils.java @@ -90,10 +90,7 @@ public static JavaPairRDD csvToBinaryBlock(JavaSparkContext sc JavaRDD tmp = input.values() .map(new TextToStringFunction()); String tmpStr = tmp.first(); - boolean metaHeader = tmpStr.startsWith(TfUtils.TXMTD_MVPREFIX) - || tmpStr.startsWith(TfUtils.TXMTD_NDPREFIX); - tmpStr = (metaHeader) ? tmpStr.substring(tmpStr.indexOf(delim)+1) : tmpStr; - long rlen = tmp.count() - (hasHeader ? 1 : 0) - (metaHeader ? 2 : 0); + long rlen = tmp.count() ; long clen = IOUtilFunctions.splitCSV(tmpStr, delim).length; mc.set(rlen, clen, mc.getBlocksize(), -1); } @@ -582,14 +579,14 @@ public Iterator> call(Iterator> arg0) _colnames = row.split(_delim); continue; } - if( row.startsWith(TfUtils.TXMTD_MVPREFIX) ) { - _mvMeta = Arrays.asList(Arrays.copyOfRange(IOUtilFunctions.splitCSV(row, _delim), 1, (int)_clen+1)); - continue; - } - else if( row.startsWith(TfUtils.TXMTD_NDPREFIX) ) { - _ndMeta = Arrays.asList(Arrays.copyOfRange(IOUtilFunctions.splitCSV(row, _delim), 1, (int)_clen+1)); - continue; - } + // if( row.startsWith(TfUtils.TXMTD_MVPREFIX) ) { + // _mvMeta = Arrays.asList(Arrays.copyOfRange(IOUtilFunctions.splitCSV(row, _delim), 1, (int)_clen+1)); + // continue; + // } + // else if( row.startsWith(TfUtils.TXMTD_NDPREFIX) ) { + // _ndMeta = Arrays.asList(Arrays.copyOfRange(IOUtilFunctions.splitCSV(row, _delim), 1, (int)_clen+1)); + // continue; + // } //adjust row index for header and meta data rowix += (_hasHeader ? 0 : 1) - ((_mvMeta == null) ? 0 : 2); @@ -670,18 +667,18 @@ public Iterator call(Tuple2 arg0) ret.add(sb.toString()); sb.setLength(0); //reset } - if( !blk.isColumnMetadataDefault() ) { - sb.append(TfUtils.TXMTD_MVPREFIX + _props.getDelim()); - for( int j=0; j a = (DDCArray) ret.getColumn(colId); + ret.setColumn(colId, a.setDict(value._a)); + } + } + finally{ + IOUtilFunctions.closeSilently(reader); + } + } + } + catch(IOException e){ + throw new DMLRuntimeException("Failed to read Frame Dictionaries", e); + } + } + /** * Specific functionality of FrameReaderBinaryBlock, mostly used for testing. * @@ -143,4 +171,7 @@ public FrameBlock readFirstBlock(String fname) throws IOException { return value; } + + + } diff --git a/src/main/java/org/apache/sysds/runtime/io/FrameReaderTextCSV.java b/src/main/java/org/apache/sysds/runtime/io/FrameReaderTextCSV.java index cfe4a5e45ba..9b5faaec140 100644 --- a/src/main/java/org/apache/sysds/runtime/io/FrameReaderTextCSV.java +++ b/src/main/java/org/apache/sysds/runtime/io/FrameReaderTextCSV.java @@ -21,7 +21,11 @@ import java.io.IOException; import java.io.InputStream; +import java.util.ArrayList; +import java.util.List; import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; @@ -36,10 +40,12 @@ import org.apache.hadoop.mapred.TextInputFormat; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.Array; import org.apache.sysds.runtime.matrix.data.Pair; -import org.apache.sysds.runtime.transform.TfUtils; +import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.runtime.util.InputStreamInputFormat; /** @@ -131,7 +137,6 @@ protected final int readCSVFrameFromInputSplit(InputSplit split, InputFormat 1 ? CommonThreadPool.get(k) : null; + List> tasks = new ArrayList<>(); + // Read the data try { - String[] parts = null; // cache array for line reading. + Array[] destA = dest.getColumns(); while(reader.next(key, value)) // foreach line { - boolean emptyValuesFound = false; String cellStr = IOUtilFunctions.trim(value.toString()); - parts = IOUtilFunctions.splitCSV(cellStr, delim, parts); - // sanity checks for empty values and number of columns - - final boolean mtdP = parts[0].equals(TfUtils.TXMTD_MVPREFIX); - final boolean mtdx = parts[0].equals(TfUtils.TXMTD_NDPREFIX); - // parse frame meta data (missing values / num distinct) - if(mtdP || mtdx) { - if(parts.length != dest.getNumColumns() + 1){ - LOG.warn("Invalid metadata "); - parts = null; - continue; - } - else if(mtdP) - for(int j = 0; j < dest.getNumColumns(); j++) - dest.getColumnMetadata(j).setMvValue(parts[j + 1]); - else if(mtdx) - for(int j = 0; j < dest.getNumColumns(); j++) - dest.getColumnMetadata(j).setNumDistinct(Long.parseLong(parts[j + 1])); - parts = null; - continue; + if(pool != null){ + final int r = row; + tasks.add(pool.submit( () -> + parseLine(cellStr, delim, destA, r, (int) clen, dfillValue, sfillValue, isFill, naValues))); + } + else{ + parseLine(cellStr, delim, destA, row, (int) clen, dfillValue, sfillValue, isFill, naValues); } - assignColumns(row, nCol, dest, parts, naValues, isFill, dfillValue, sfillValue); - IOUtilFunctions.checkAndRaiseErrorCSVEmptyField(cellStr, isFill, emptyValuesFound); - IOUtilFunctions.checkAndRaiseErrorCSVNumColumns("", cellStr, parts, clen); row++; } } @@ -178,38 +173,55 @@ else if(mtdx) throw new DMLRuntimeException("Failed parsing string: \"" + value +"\"", e); } finally { + if(pool != null) + pool.shutdown(); IOUtilFunctions.closeSilently(reader); } return row; } - private boolean assignColumns(int row, int nCol, FrameBlock dest, String[] parts, Set naValues, + private void parseLine(String cellStr, String delim, Array[] destA , int row, + int clen, double dfillValue, String sfillValue, boolean isFill, + Set naValues) { + try{ + String[] parts = IOUtilFunctions.splitCSV(cellStr, delim, clen); + + assignColumns(row, (int)clen, destA, parts, naValues, isFill, dfillValue, sfillValue); + + IOUtilFunctions.checkAndRaiseErrorCSVNumColumns("", cellStr, parts, clen); + } + catch(Exception e){ + throw new RuntimeException(e); + } + } + + private boolean assignColumns(int row, int nCol, Array[] destA, String[] parts, Set naValues, boolean isFill, double dfillValue, String sfillValue) { if(!isFill && naValues == null) - return assignColumnsNoFillNoNan(row, nCol, dest, parts); + return assignColumnsNoFillNoNan(row, nCol, destA, parts); else - return assignColumnsGeneric(row, nCol, dest, parts, naValues, isFill, dfillValue, sfillValue); + return assignColumnsGeneric(row, nCol, destA, parts, naValues, isFill, dfillValue, sfillValue); } - private boolean assignColumnsGeneric(int row, int nCol, FrameBlock dest, String[] parts, Set naValues, + private boolean assignColumnsGeneric(int row, int nCol, Array[] destA, String[] parts, Set naValues, boolean isFill, double dfillValue, String sfillValue) { boolean emptyValuesFound = false; for(int col = 0; col < nCol; col++) { String part = IOUtilFunctions.trim(parts[col]); if(part.isEmpty() || (naValues != null && naValues.contains(part))) { if(isFill && dfillValue != 0) - dest.set(row, col, sfillValue); + destA[col].set(row, sfillValue); emptyValuesFound = true; } else - dest.set(row, col, part); + destA[col].set(row, part); } return emptyValuesFound; } - private boolean assignColumnsNoFillNoNan(int row, int nCol, FrameBlock dest, String[] parts){ + private boolean assignColumnsNoFillNoNan(int row, int nCol, Array[] destA, String[] parts){ boolean emptyValuesFound = false; for(int col = 0; col < nCol; col++) { @@ -217,7 +229,7 @@ private boolean assignColumnsNoFillNoNan(int row, int nCol, FrameBlock dest, Str if(part.isEmpty()) emptyValuesFound = true; else - dest.set(row, col, part); + destA[col].set(row, part); } return emptyValuesFound; @@ -255,32 +267,18 @@ protected static int countLinesInReader(InputSplit split, TextInputFormat inForm } } - protected static int countLinesInReader(RecordReader reader, long ncol, boolean header) + private static int countLinesInReader(RecordReader reader, long ncol, boolean header) throws IOException { final LongWritable key = new LongWritable(); final Text value = new Text(); int nrow = 0; - try { - // ignore header of first split - if(header) - reader.next(key, value); - while(reader.next(key, value)) { - // note the metadata can be located at any row when spark. - nrow += containsMetaTag(value) ? 0 : 1; - } - return nrow; + // ignore header of first split + if(header) + reader.next(key, value); + while(reader.next(key, value)) { + nrow ++; } - finally { - IOUtilFunctions.closeSilently(reader); - } - } - - private final static boolean containsMetaTag(Text val) { - if(val.charAt(0) == '#') - return val.find(TfUtils.TXMTD_MVPREFIX) > -1// - || val.find(TfUtils.TXMTD_NDPREFIX) > -1; - else - return false; + return nrow; } } diff --git a/src/main/java/org/apache/sysds/runtime/io/FrameReaderTextCSVParallel.java b/src/main/java/org/apache/sysds/runtime/io/FrameReaderTextCSVParallel.java index b2c9538f8de..d04ff1cc037 100644 --- a/src/main/java/org/apache/sysds/runtime/io/FrameReaderTextCSVParallel.java +++ b/src/main/java/org/apache/sysds/runtime/io/FrameReaderTextCSVParallel.java @@ -144,7 +144,6 @@ public CountRowsTask(InputSplit split, TextInputFormat informat, JobConf job, bo @Override public Integer call() throws Exception { return countLinesInReader(_split, _informat, _job, _nCol, _hasHeader); - } } diff --git a/src/main/java/org/apache/sysds/runtime/io/FrameWriterBinaryBlock.java b/src/main/java/org/apache/sysds/runtime/io/FrameWriterBinaryBlock.java index 859cbe028c2..b72661ba3ba 100644 --- a/src/main/java/org/apache/sysds/runtime/io/FrameWriterBinaryBlock.java +++ b/src/main/java/org/apache/sysds/runtime/io/FrameWriterBinaryBlock.java @@ -20,6 +20,8 @@ package org.apache.sysds.runtime.io; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; @@ -29,6 +31,10 @@ import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.columns.ArrayWrapper; +import org.apache.sysds.runtime.frame.data.columns.DDCArray; +import org.apache.sysds.runtime.matrix.data.Pair; import org.apache.sysds.runtime.util.HDFSTool; /** @@ -43,30 +49,67 @@ public final void writeFrameToHDFS(FrameBlock src, String fname, long rlen, long // prepare file access JobConf job = new JobConf(ConfigurationManager.getCachedJobConf()); Path path = new Path(fname); - + // if the file already exists on HDFS, remove it. HDFSTool.deleteFileIfExistOnHDFS(fname); - + HDFSTool.deleteFileIfExistOnHDFS(fname + ".dict"); + // bound check for src block if(src.getNumRows() > rlen || src.getNumColumns() > clen) { throw new IOException("Frame block [1:" + src.getNumRows() + ",1:" + src.getNumColumns() + "] " + "out of overall frame range [1:" + rlen + ",1:" + clen + "]."); } + Pair>>, FrameBlock> prep = extractDictionaries(src); + src = prep.getValue(); + // write binary block to hdfs (sequential/parallel) - writeBinaryBlockFrameToHDFS(path, job, src, rlen, clen); + writeBinaryBlockFrameToHDFS(path, job, prep.getValue(), rlen, clen); + + if(prep.getKey().size() > 0) + writeBinaryBlockDictsToSequenceFile(new Path(fname + ".dict"), job, prep.getKey()); + + } + + protected Pair>>, FrameBlock> extractDictionaries(FrameBlock src){ + List>> dicts = new ArrayList<>(); + int blen = ConfigurationManager.getBlocksize(); + if(src.getNumRows() < blen ) + return new Pair<>(dicts, src); + boolean modified = false; + for(int i = 0; i < src.getNumColumns(); i++){ + Array a = src.getColumn(i); + if(a instanceof DDCArray){ + DDCArray d = (DDCArray)a; + dicts.add(new Pair<>(i, d.getDict())); + if(modified == false){ + modified = true; + // make sure other users of this frame does not get effected + src = src.copyShallow(); + } + src.setColumn(i, d.nullDict()); + } + } + return new Pair<>(dicts, src); } protected void writeBinaryBlockFrameToHDFS(Path path, JobConf job, FrameBlock src, long rlen, long clen) throws IOException, DMLRuntimeException { FileSystem fs = IOUtilFunctions.getFileSystem(path); int blen = ConfigurationManager.getBlocksize(); - + // sequential write to single file writeBinaryBlockFrameToSequenceFile(path, job, fs, src, blen, 0, (int) rlen); IOUtilFunctions.deleteCrcFilesFromLocalFileSystem(fs, path); } + protected void writeBinaryBlockDictsToSequenceFile(Path path, JobConf job, List>> dicts) + throws IOException, DMLRuntimeException { + FileSystem fs = IOUtilFunctions.getFileSystem(path); + writeBinaryBlockDictsToSequenceFile(path, job, fs, dicts); + IOUtilFunctions.deleteCrcFilesFromLocalFileSystem(fs, path); + } + /** * Internal primitive to write a block-aligned row range of a frame to a single sequence file, which is used for both * single- and multi-threaded writers (for consistency). @@ -111,4 +154,20 @@ protected static void writeBinaryBlockFrameToSequenceFile(Path path, JobConf job IOUtilFunctions.closeSilently(writer); } } + + protected static void writeBinaryBlockDictsToSequenceFile(Path path, JobConf job, FileSystem fs, List>> dicts) throws IOException{ + final Writer writer = IOUtilFunctions.getSeqWriterArray(path, job, 1); + try{ + LongWritable index = new LongWritable(); + + for(int i = 0; i < dicts.size(); i++){ + Pair> p = dicts.get(i); + index.set(p.getKey()); + writer.append(index, new ArrayWrapper(p.getValue())); + } + } + finally { + IOUtilFunctions.closeSilently(writer); + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/io/FrameWriterCompressed.java b/src/main/java/org/apache/sysds/runtime/io/FrameWriterCompressed.java index 82c5a08e2c0..2e4c3d5ac3f 100644 --- a/src/main/java/org/apache/sysds/runtime/io/FrameWriterCompressed.java +++ b/src/main/java/org/apache/sysds/runtime/io/FrameWriterCompressed.java @@ -19,14 +19,13 @@ package org.apache.sysds.runtime.io; -import java.io.IOException; +import java.util.List; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.mapred.JobConf; import org.apache.sysds.hops.OptimizerUtils; -import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.Array; import org.apache.sysds.runtime.frame.data.lib.FrameLibCompress; +import org.apache.sysds.runtime.matrix.data.Pair; public class FrameWriterCompressed extends FrameWriterBinaryBlockParallel { @@ -37,11 +36,10 @@ public FrameWriterCompressed(boolean parallel) { } @Override - protected void writeBinaryBlockFrameToHDFS(Path path, JobConf job, FrameBlock src, long rlen, long clen) - throws IOException, DMLRuntimeException { + protected Pair>>, FrameBlock> extractDictionaries(FrameBlock src) { int k = parallel ? OptimizerUtils.getParallelBinaryWriteParallelism() : 1; FrameBlock compressed = FrameLibCompress.compress(src, k); - super.writeBinaryBlockFrameToHDFS(path, job, compressed, rlen, clen); + return super.extractDictionaries(compressed); } } diff --git a/src/main/java/org/apache/sysds/runtime/io/FrameWriterTextCSV.java b/src/main/java/org/apache/sysds/runtime/io/FrameWriterTextCSV.java index 5815ff231ea..f14cdf7ae28 100644 --- a/src/main/java/org/apache/sysds/runtime/io/FrameWriterTextCSV.java +++ b/src/main/java/org/apache/sysds/runtime/io/FrameWriterTextCSV.java @@ -31,7 +31,6 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.iterators.IteratorFactory; -import org.apache.sysds.runtime.transform.TfUtils; import org.apache.sysds.runtime.util.HDFSTool; /** @@ -107,17 +106,7 @@ protected static void writeCSVFrameToFile( Path path, JobConf job, FileSystem fs } sb.append('\n'); } - //append meta data - if( !src.isColumnMetadataDefault() ) { - sb.append(TfUtils.TXMTD_MVPREFIX + delim); - for( int j=0; j 0 ? replication : 1))); } + public static Writer getSeqWriterArray(Path path, Configuration job, int replication) throws IOException { + return SequenceFile.createWriter(job, Writer.file(path), Writer.bufferSize(4096), + Writer.keyClass(LongWritable.class), Writer.valueClass(ArrayWrapper.class), + Writer.compression(getCompressionEncodingType(), getCompressionCodec()), + Writer.replication((short) (replication > 0 ? replication : 1))); + } + public static Writer getSeqWriterTensor(Path path, Configuration job, int replication) throws IOException { return SequenceFile.createWriter(job, Writer.file(path), Writer.bufferSize(4096), Writer.replication((short) (replication > 0 ? replication : 1)), diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibAggregateUnarySpecialization.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibAggregateUnarySpecialization.java new file mode 100644 index 00000000000..92f1689690a --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibAggregateUnarySpecialization.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.matrix.data; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.CorrectionLocationType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.instructions.cp.KahanObject; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.runtime.matrix.operators.AggregateOperator; +import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; + +public class LibAggregateUnarySpecialization { + protected static final Log LOG = LogFactory.getLog(LibAggregateUnarySpecialization.class.getName()); + + public static void aggregateUnary(final MatrixBlock mb, AggregateUnaryOperator op, MatrixBlock result, int blen, + MatrixIndexes indexesIn) { + LOG.error(op); + if(op.sparseSafe) + sparseAggregateUnaryHelp(mb, op, result, blen, indexesIn); + else + denseAggregateUnaryHelp(mb, op, result, blen, indexesIn); + } + + private static void sparseAggregateUnaryHelp(final MatrixBlock mb, AggregateUnaryOperator op, MatrixBlock result, + int blen, MatrixIndexes indexesIn) { + // initialize result + if(op.aggOp.initialValue != 0) + result.reset(result.rlen, result.clen, op.aggOp.initialValue); + CellIndex tempCellIndex = new CellIndex(-1, -1); + KahanObject buffer = new KahanObject(0, 0); + + if(mb.sparse && mb.sparseBlock != null) { + SparseBlock a = mb.sparseBlock; + for(int r = 0; r < Math.min(mb.rlen, a.numRows()); r++) { + if(a.isEmpty(r)) + continue; + int apos = a.pos(r); + int alen = a.size(r); + int[] aix = a.indexes(r); + double[] aval = a.values(r); + for(int i = apos; i < apos + alen; i++) { + tempCellIndex.set(r, aix[i]); + op.indexFn.execute(tempCellIndex, tempCellIndex); + incrementalAggregateUnaryHelp(op.aggOp, result, tempCellIndex.row, tempCellIndex.column, aval[i], + buffer); + } + } + } + else if(!mb.sparse && mb.denseBlock != null) { + DenseBlock a = mb.getDenseBlock(); + for(int i = 0; i < mb.rlen; i++) + for(int j = 0; j < mb.clen; j++) { + tempCellIndex.set(i, j); + op.indexFn.execute(tempCellIndex, tempCellIndex); + incrementalAggregateUnaryHelp(op.aggOp, result, tempCellIndex.row, tempCellIndex.column, a.get(i, j), + buffer); + } + } + } + + private static void denseAggregateUnaryHelp(MatrixBlock mb, AggregateUnaryOperator op, MatrixBlock result, int blen, + MatrixIndexes indexesIn) { + if(op.aggOp.initialValue != 0) + result.reset(result.rlen, result.clen, op.aggOp.initialValue); + CellIndex tempCellIndex = new CellIndex(-1, -1); + KahanObject buffer = new KahanObject(0, 0); + for(int i = 0; i < mb.rlen; i++) + for(int j = 0; j < mb.clen; j++) { + tempCellIndex.set(i, j); + op.indexFn.execute(tempCellIndex, tempCellIndex); + incrementalAggregateUnaryHelp(op.aggOp, result, tempCellIndex.row, tempCellIndex.column, + mb.quickGetValue(i, j), buffer); + } + } + + private static void incrementalAggregateUnaryHelp(AggregateOperator aggOp, MatrixBlock result, int row, int column, + double newvalue, KahanObject buffer) { + if(aggOp.existsCorrection()) { + if(aggOp.correction == CorrectionLocationType.LASTROW || + aggOp.correction == CorrectionLocationType.LASTCOLUMN) { + int corRow = row, corCol = column; + if(aggOp.correction == CorrectionLocationType.LASTROW)// extra row + corRow++; + else if(aggOp.correction == CorrectionLocationType.LASTCOLUMN) + corCol++; + else + throw new DMLRuntimeException("unrecognized correctionLocation: " + aggOp.correction); + + buffer._sum = result.quickGetValue(row, column); + buffer._correction = result.quickGetValue(corRow, corCol); + buffer = (KahanObject) aggOp.increOp.fn.execute(buffer, newvalue); + result.quickSetValue(row, column, buffer._sum); + result.quickSetValue(corRow, corCol, buffer._correction); + } + else if(aggOp.correction == CorrectionLocationType.NONE) { + throw new DMLRuntimeException("unrecognized correctionLocation: " + aggOp.correction); + } + else// for mean + { + int corRow = row, corCol = column; + int countRow = row, countCol = column; + if(aggOp.correction == CorrectionLocationType.LASTTWOROWS) { + countRow++; + corRow += 2; + } + else if(aggOp.correction == CorrectionLocationType.LASTTWOCOLUMNS) { + countCol++; + corCol += 2; + } + else + throw new DMLRuntimeException("unrecognized correctionLocation: " + aggOp.correction); + buffer._sum = result.quickGetValue(row, column); + buffer._correction = result.quickGetValue(corRow, corCol); + double count = result.quickGetValue(countRow, countCol) + 1.0; + buffer = (KahanObject) aggOp.increOp.fn.execute(buffer, newvalue, count); + result.quickSetValue(row, column, buffer._sum); + result.quickSetValue(corRow, corCol, buffer._correction); + result.quickSetValue(countRow, countCol, count); + } + + } + else { + newvalue = aggOp.increOp.fn.execute(result.quickGetValue(row, column), newvalue); + result.quickSetValue(row, column, newvalue); + } + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java index e5ec7a00209..a2156f9001b 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java @@ -145,127 +145,131 @@ public static MatrixBlock uncellOp(MatrixBlock m1, MatrixBlock ret, UnaryOperato return ret; } - /** - * matrix-scalar, scalar-matrix binary operations. - * - * @param m1 input matrix - * @param ret result matrix - * @param op scalar operator - */ - public static void bincellOp(MatrixBlock m1, MatrixBlock ret, ScalarOperator op) { + public static MatrixBlock bincellOpScalar(MatrixBlock m1, MatrixBlock ret, ScalarOperator op, int k) { + // estimate the sparsity structure of result matrix + boolean sp = m1.sparse; // by default, we guess result.sparsity=input.sparsity + if (!op.sparseSafe) + sp = false; // if the operation is not sparse safe, then result will be in dense format + + //allocate the output matrix block + if( ret==null ) + ret = new MatrixBlock(m1.getNumRows(), m1.getNumColumns(), sp, m1.nonZeros); + else + ret.reset(m1.getNumRows(), m1.getNumColumns(), sp, m1.nonZeros); + //check internal assumptions if( (op.sparseSafe && m1.isInSparseFormat()!=ret.isInSparseFormat()) ||(!op.sparseSafe && ret.isInSparseFormat()) ) { - throw new DMLRuntimeException("Wrong output representation for safe="+op.sparseSafe+": "+m1.isInSparseFormat()+", "+ret.isInSparseFormat()); + throw new DMLRuntimeException("Wrong output representation for safe=" + op.sparseSafe + ": " + + m1.isInSparseFormat() + ", " + ret.isInSparseFormat()); } + + if((op.fn instanceof Multiply && op.getConstant() == 0.0)) + return ret; // no op + // fallback to singlet-threaded for special cases + if( k <= 1 || m1.isEmpty() || !op.sparseSafe + || ret.getLength() < PAR_NUMCELL_THRESHOLD2 ) { + bincellOpScalarSingleThread(m1, ret, op); + } + else{ + bincellOpScalarParallel(m1, ret, op, k); + } + // ensure empty results sparse representation + // (no additional memory requirements) + if(ret.isEmptyBlock(false)) + ret.examSparsity(k); + return ret; + } + + private static void bincellOpScalarSingleThread(MatrixBlock m1, MatrixBlock ret, ScalarOperator op) { //execute binary cell operations + long nnz = 0; if(op.sparseSafe) - safeBinaryScalar(m1, ret, op, 0, m1.rlen); + nnz = safeBinaryScalar(m1, ret, op, 0, m1.rlen); else - unsafeBinaryScalar(m1, ret, op); + nnz = unsafeBinaryScalar(m1, ret, op); + + ret.nonZeros = nnz; //ensure empty results sparse representation //(no additional memory requirements) if( ret.isEmptyBlock(false) ) ret.examSparsity(); + } - public static void bincellOp(MatrixBlock m1, MatrixBlock ret, ScalarOperator op, int k) { - //check internal assumptions - if( (op.sparseSafe && m1.isInSparseFormat()!=ret.isInSparseFormat()) - ||(!op.sparseSafe && ret.isInSparseFormat()) ) { - throw new DMLRuntimeException("Wrong output representation for safe="+op.sparseSafe+": "+m1.isInSparseFormat()+", "+ret.isInSparseFormat()); - } - - //fallback to singlet-threaded for special cases - if( m1.isEmpty() || !op.sparseSafe - || ret.getLength() < PAR_NUMCELL_THRESHOLD2 ) { - bincellOp(m1, ret, op); - return; - } - - //preallocate dense/sparse block for multi-threaded operations + private static void bincellOpScalarParallel(MatrixBlock m1, MatrixBlock ret, ScalarOperator op, int k) { + + // preallocate dense/sparse block for multi-threaded operations ret.allocateBlock(); - + + final ExecutorService pool = CommonThreadPool.get(k); try { - //execute binary cell operations - ExecutorService pool = CommonThreadPool.get(k); - ArrayList tasks = new ArrayList<>(); - ArrayList blklens = UtilFunctions.getBalancedBlockSizesDefault(ret.rlen, k, false); - for( int i=0, lb=0; i> taskret = pool.invokeAll(tasks); - - //aggregate non-zeros - ret.nonZeros = 0; //reset after execute - for( Future task : taskret ) - ret.nonZeros += task.get(); - pool.shutdown(); + // execute binary cell operations + final ArrayList tasks = new ArrayList<>(); + + // ArrayList blklens = UtilFunctions.getBalancedBlockSizesDefault(ret.rlen, k, false); + final int rMax = m1.getNumRows(); + final int blkLen = Math.max(Math.max(rMax / k, 1000 / ret.getNumColumns()), 1); + for(int i = 0; i < rMax; i += blkLen) + tasks.add(new BincellScalarTask(m1, ret, op, i, Math.min(rMax, i + blkLen))); + + // aggregate non-zeros + long nnz = 0; + for(Future task : pool.invokeAll(tasks)) + nnz += task.get(); + ret.nonZeros = nnz; } catch(InterruptedException | ExecutionException ex) { throw new DMLRuntimeException(ex); } - - //ensure empty results sparse representation - //(no additional memory requirements) - if( ret.isEmptyBlock(false) ) - ret.examSparsity(); - } - - /** - * matrix-matrix binary operations, MM, MV - * - * @param m1 input matrix 1 - * @param m2 input matrix 2 - * @param ret result matrix - * @param op binary operator - */ - public static void bincellOp(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op) { - BinaryAccessType atype = getBinaryAccessType(m1, m2); - - // preallocate for consistency (but be careful - // not to allocate if empty inputs might allow early abort) - if( atype == BinaryAccessType.MATRIX_MATRIX - && !(m1.isEmpty() || m2.isEmpty()) ) - { - ret.allocateBlock(); //chosen outside + finally { + pool.shutdown(); } - //execute binary cell operations - long nnz = 0; - if(op.sparseSafe || isSparseSafeDivide(op, m2)) - nnz = safeBinary(m1, m2, ret, op, atype, 0, m1.rlen); - else - nnz = unsafeBinary(m1, m2, ret, op, 0, m1.rlen); - ret.setNonZeros(nnz); - - //ensure empty results sparse representation - //(no additional memory requirements) - if( ret.isEmptyBlock(false) ) - ret.examSparsity(); } - + public static void bincellOp(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op, int k) { - BinaryAccessType atype = getBinaryAccessType(m1, m2); + // Timing time = new Timing(true); + final BinaryAccessType atype = getBinaryAccessType(m1, m2); + if(!ret.sparse && op.fn instanceof Divide){ + double s1 = m1.getSparsity(); + double s2 = m2.getSparsity(); + if(s1 < 0.4 && s2 > 0.99){ + ret.sparse = true; + } + } + // fallback to sequential computation for specialized operations - if( m1.isEmpty() || m2.isEmpty() + if(k <= 1 || m1.isEmpty() || m2.isEmpty() || ret.getLength() < PAR_NUMCELL_THRESHOLD2 || ((op.sparseSafe || isSparseSafeDivide(op, m2)) && !(atype == BinaryAccessType.MATRIX_MATRIX || atype.isMatrixVector() && isAllDense(m1, m2, ret)))) { - bincellOp(m1, m2, ret, op); - return; + bincellOpMatrixSingle(m1, m2, ret, op,atype); + } + else { + bincellOpMatrixParallel(m1, m2, ret, op, atype, k); } + if(ret.isEmptyBlock(false) ) + ret.examSparsity(k); + + // System.out.println("BinCell " + op + " " + m1.getNumRows() + ", " + m1.getNumColumns() + ", " + m1.getNonZeros() + // + " -- " + m2.getNumRows() + ", " + m2.getNumColumns() + " " + m2.getNonZeros() + "\t\t" + time.stop()); + } + + private static void bincellOpMatrixParallel(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op, BinaryAccessType atype, int k) { //preallocate dense/sparse block for multi-threaded operations + ret.allocateBlock(); //chosen outside + final ExecutorService pool = CommonThreadPool.get(k); try { //execute binary cell operations - ExecutorService pool = CommonThreadPool.get(k); ArrayList tasks = new ArrayList<>(); ArrayList blklens = UtilFunctions.getBalancedBlockSizesDefault(ret.rlen, k, false); for( int i=0, lb=0; i> taskret = pool.invokeAll(tasks); //aggregate non-zeros - ret.nonZeros = 0; //reset after execute + long nnz = 0; //reset after execute for( Future task : taskret ) - ret.nonZeros += task.get(); - pool.shutdown(); + nnz += task.get(); + + ret.nonZeros = nnz; } catch(InterruptedException | ExecutionException ex) { throw new DMLRuntimeException(ex); } - - //ensure empty results sparse representation - //(no additional memory requirements) - if( ret.isEmptyBlock(false) ) - ret.examSparsity(); + finally{ + pool.shutdown(); + } } + + private static void bincellOpMatrixSingle(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op, BinaryAccessType atype) { + // preallocate for consistency (but be careful + // not to allocate if empty inputs might allow early abort) + if(atype == BinaryAccessType.MATRIX_MATRIX && !(m1.isEmpty() || m2.isEmpty())) { + ret.allocateBlock(); // chosen outside + } + // execute binary cell operations + long nnz = 0; + if(op.sparseSafe || isSparseSafeDivide(op, m2)) + nnz = safeBinary(m1, m2, ret, op, atype, 0, m1.rlen); + else + nnz = unsafeBinary(m1, m2, ret, op, 0, m1.rlen); + ret.setNonZeros(nnz); + } + /** * NOTE: operations in place always require m1 and m2 to be of equal dimensions @@ -348,7 +367,7 @@ public static MatrixBlock bincellOpInPlaceLeft(MatrixBlock m1ret, MatrixBlock m2 MatrixBlock right = new MatrixBlock(nRows, nCols, true); right.copyShallow(m1ret); m1ret.cleanupBlock(true, true); - bincellOp(m2, right, m1ret, op); + bincellOp(m2, right, m1ret, op, 1); return m1ret; } @@ -596,7 +615,6 @@ private static long safeBinary(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryAccessType atype, int rl, int ru) { //NOTE: multi-threaded over rl-ru only applied for matrix-matrix, non-empty - boolean skipEmpty = (op.fn instanceof Multiply || isSparseSafeDivide(op, m2) ); boolean copyLeftRightEmpty = (op.fn instanceof Plus || op.fn instanceof Minus @@ -620,7 +638,7 @@ else if( m1.sparse && !m2.sparse && !ret.sparse && atype == BinaryAccessType.MATRIX_ROW_VECTOR) safeBinaryMVSparseDenseRow(m1, m2, ret, op); else if( m1.sparse ) //SPARSE m1 - safeBinaryMVSparse(m1, m2, ret, op); + safeBinaryMVSparseLeft(m1, m2, ret, op); else if( !m1.sparse && !m2.sparse && ret.sparse && op.fn instanceof Multiply && atype == BinaryAccessType.MATRIX_COL_VECTOR && (long)m1.rlen * m2.clen < Integer.MAX_VALUE ) @@ -668,11 +686,9 @@ else if( skipEmpty && (m1.sparse || m2.sparse) ) { private static long safeBinaryMVDense(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op, int rl, int ru) { - final boolean isMultiply = (op.fn instanceof Multiply); - final boolean skipEmpty = (isMultiply); - // early abort on skip and empy - if(skipEmpty && (m1.isEmptyBlock(false) || m2.isEmptyBlock(false))) + // early abort on skip and empty + if(op.fn instanceof Multiply && (m1.isEmptyBlock(false) || m2.isEmptyBlock(false))) return 0; // skip entire empty block // guard for postponed allocation in single-threaded exec @@ -689,73 +705,120 @@ private static long safeBinaryMVDense(MatrixBlock m1, MatrixBlock m2, MatrixBloc private static long safeBinaryMVDenseColVector(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op, int rl, int ru) { - final boolean multiply = (op.fn instanceof Multiply); final int clen = m1.clen; - final DenseBlock da = m1.getDenseBlock(); - if(da.values(0) == null) - throw new RuntimeException("Invalid input with empty input"); final DenseBlock dc = ret.getDenseBlock(); + final double[] b = m2.getDenseBlockValues(); + + if(op.fn instanceof Multiply) + return safeBinaryMVDenseColVectorMultiply(da, b, dc, clen, rl, ru); + else if(op.fn instanceof Divide) + return safeBinaryMVDenseColVectorDivide(da, b, dc, clen, rl, ru); + else + return safeBinaryMVDenseColVectorGeneric(da, b, dc, clen, op, rl, ru); + } + + private static long safeBinaryMVDenseColVectorGeneric(DenseBlock da, double[] b, DenseBlock dc, int clen, + BinaryOperator op, int rl, int ru) { + if(b == null) + return safeBinaryMVDenseColVectorGenericEmptyVector(da, dc, clen, op, rl, ru); + else + return safeBinaryMVDenseColVectorGenericDenseVector(da, b, dc, clen, op, rl, ru); + } + + private static long safeBinaryMVDenseColVectorGenericEmptyVector(DenseBlock da, DenseBlock dc, int clen, + BinaryOperator op, int rl, int ru) { long nnz = 0; - final double[] b = m2.getDenseBlockValues(); // always single block + for(int i = rl; i < ru; i++) { + final double[] a = da.values(i); + final double[] c = dc.values(i); + final int ix = da.pos(i); + for(int j = 0; j < clen; j++) { + double val = op.fn.execute(a[ix + j], 0); + nnz += ((c[ix + j] = val) != 0) ? 1 : 0; + } + } + return nnz; + } - if(b == null) { - if(multiply) - return 0; - else { - for(int i = rl; i < ru; i++) { - final double[] a = da.values(i); - final double[] c = dc.values(i); - final int ix = da.pos(i); - // GENERAL CASE - for(int j = 0; j < clen; j++) { - double val = op.fn.execute(a[ix + j], 0); - nnz += ((c[ix + j] = val) != 0) ? 1 : 0; - } - } + private static long safeBinaryMVDenseColVectorGenericDenseVector(DenseBlock da, double[] b, DenseBlock dc, int clen, + BinaryOperator op, int rl, int ru) { + long nnz = 0; + for(int i = rl; i < ru; i++) { + final double[] a = da.values(i); + final double[] c = dc.values(i); + final int ix = da.pos(i); + final double v2 = b[i]; + + for(int j = 0; j < clen; j++) { + double val = op.fn.execute(a[ix + j], v2); + nnz += ((c[ix + j] = val) != 0) ? 1 : 0; } } - else if(multiply){ + return nnz; + } + + private static long safeBinaryMVDenseColVectorMultiply(DenseBlock da, double[] b, DenseBlock dc, int clen, int rl, + int ru) { + if(b == null) + return 0; + else { + long nnz = 0; for(int i = rl; i < ru; i++) { final double[] a = da.values(i); final double[] c = dc.values(i); final int ix = da.pos(i); // replicate vector value - double v2 = b[i]; + final double v2 = b[i]; if(v2 == 0) // skip empty rows continue; else if(v2 == 1) { // ROW COPY - // a guaranteed to be non-null (see early abort) - System.arraycopy(a, ix, c, ix, clen); - nnz += m1.recomputeNonZeros(i, i, 0, clen - 1); + for(int j = ix; j < clen + ix; j++) + nnz += (c[j] != 0) ? 1 : 0; } - else { - // GENERAL CASE - for(int j = 0; j < clen; j++) { - double val = op.fn.execute(a[ix + j], v2); - nnz += ((c[ix + j] = val) != 0) ? 1 : 0; - } + else {// GENERAL CASE + for(int j = ix; j < clen + ix; j++) + nnz += ((c[j] = a[j] * v2) != 0) ? 1 : 0; } - } + return nnz; } - else{ + } + + private static long safeBinaryMVDenseColVectorDivide(DenseBlock da, double[] b, DenseBlock dc, int clen, int rl, + int ru) { + + if(b == null){ + dc.fill(Double.NaN); + return (long)dc.getDim(0) * dc.getDim(1); + } + else { + long nnz = 0; for(int i = rl; i < ru; i++) { final double[] a = da.values(i); final double[] c = dc.values(i); final int ix = da.pos(i); - - // replicate vector value - double v2 = b[i]; - - // GENERAL CASE - for(int j = 0; j < clen; j++) { - double val = op.fn.execute(a[ix + j], v2); - nnz += ((c[ix + j] = val) != 0) ? 1 : 0; - } - + final double v2 = b[i]; + processRowMVDenseDivide(a,c, ix, clen, v2); } + return nnz; + } + } + + private static long processRowMVDenseDivide(double[] a, double[] c, int ix, int clen, double v2) { + long nnz = 0; + if(v2 == 0) {// divide by zero. + Arrays.fill(c, ix, clen, Double.NaN); + nnz += clen; + } + else if(v2 == 1) { // ROW COPY + for(int j = ix; j < clen + ix; j++) + nnz += ((c[j] = a[j]) != 0) ? 1 : 0; + } + else { // GENERAL CASE + for(int j = ix; j < clen + ix; j++) + nnz += ((c[j] = a[j] / v2) != 0) ? 1 : 0; } return nnz; } @@ -814,6 +877,10 @@ private static void safeBinaryMVSparseDenseRow(MatrixBlock m1, MatrixBlock m2, M //early abort on skip and empty if( skipEmpty && (m1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) ) return; // skip entire empty block + else if( !skipEmpty && m2.isEmptyBlock(false) && (op.fn instanceof Minus || op.fn instanceof Plus)){ + ret.copy(m1); + return; + } //prepare op(0, m2) vector once for all rows double[] tmp = new double[clen]; @@ -848,7 +915,7 @@ private static void safeBinaryMVSparseDenseRow(MatrixBlock m1, MatrixBlock m2, M ret.nonZeros = nnz; } - private static void safeBinaryMVSparse(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op) { + private static void safeBinaryMVSparseLeft(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op) { boolean isMultiply = (op.fn instanceof Multiply); boolean skipEmpty = (isMultiply || isSparseSafeDivide(op, m2)); BinaryAccessType atype = getBinaryAccessType(m1, m2); @@ -858,34 +925,176 @@ private static void safeBinaryMVSparse(MatrixBlock m1, MatrixBlock m2, MatrixBlo return; // skip entire empty block // allocate once in order to prevent repeated reallocation - if(ret.sparse) - ret.allocateSparseRowsBlock(); + if(!ret.isAllocated()) + ret.allocateBlock(); + if(atype == BinaryAccessType.MATRIX_COL_VECTOR) - safeBinaryMVSparseColVector(m1, m2, ret, op); + safeBinaryMVSparseLeftColVector(m1, m2, ret, op); else if(atype == BinaryAccessType.MATRIX_ROW_VECTOR) - safeBinaryMVSparseRowVector(m1, m2, ret, op); + safeBinaryMVSparseLeftRowVector(m1, m2, ret, op); + + ret.recomputeNonZeros(); + } - private static void safeBinaryMVSparseColVector(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op) { + private static void safeBinaryMVSparseLeftColVector(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op) { + final boolean isMultiply = (op.fn instanceof Multiply); + final boolean skipEmpty = (isMultiply || isSparseSafeDivide(op, m2)); + + final int rlen = m1.rlen; + final int clen = m1.clen; + final SparseBlock a = m1.sparseBlock; + final boolean aNull = a == null; + if(skipEmpty && a == null) + return; + if(ret.isInSparseFormat()){ + final SparseBlockMCSR rb = (SparseBlockMCSR) ret.getSparseBlock(); + for(int i = 0; i < rlen; i++) { + final double v2 = m2.quickGetValue(i, 0); + final boolean emptyRow = !aNull ? a.isEmpty(i) : true; + if((skipEmpty && (emptyRow || v2 == 0)) // skip empty one side zero + || (emptyRow && v2 == 0)){ // both sides zero + continue; // skip empty rows + } + final double vz = op.fn.execute(0, v2); + final boolean fill = vz != 0; + + if(isMultiply && v2 == 1) // ROW COPY + ret.appendRow(i, a.get(i)); + else if(!fill) + safeBinaryMVSparseColVectorRowNoFill(a, i, rb, v2, emptyRow, op); + else // GENERAL CASE + safeBinaryMVSparseColVectorRowWithFill(a, i, rb, vz, v2, clen, emptyRow, op); + } + } + else{ + final DenseBlock db = ret.getDenseBlock(); + for(int i = 0; i < rlen; i++) { + final double v2 = m2.quickGetValue(i, 0); + + final boolean emptyRow = !aNull ? a.isEmpty(i) : true; + if((skipEmpty && (emptyRow || v2 == 0)) // skip empty one side zero + || (emptyRow && v2 == 0)){ // both sides zero + continue; // skip empty rows + } + final double vz = op.fn.execute(0, v2); + final boolean fill = vz != 0; + if(isMultiply && v2 == 1) // ROW COPY + ret.appendRow(i, a.get(i)); + else if(!fill) + safeBinaryMVSparseColVectorRowNoFill(a, i, db, v2, emptyRow, op); + else // GENERAL CASE + safeBinaryMVSparseColVectorRowWithFill(a, i, db, vz, v2, clen, emptyRow, op); + + } + } + } + + private static final void safeBinaryMVSparseColVectorRowNoFill(SparseBlock a, int i, SparseBlockMCSR rb, double v2, + boolean emptyRow, BinaryOperator op) { + if(!emptyRow) { + final int apos = a.pos(i); + final int alen = a.size(i); + final int[] aix = a.indexes(i); + final double[] avals = a.values(i); + rb.allocate(i, alen); // likely alen allocation + for(int j = apos; j < apos + alen; j++) { + double v = op.fn.execute(avals[j], v2); + rb.append(i, aix[j], v); + } + } + } + + private static final void safeBinaryMVSparseColVectorRowNoFill(SparseBlock a, int i, DenseBlock rb, double v2, + boolean emptyRow, BinaryOperator op) { + if(!emptyRow) { + final int apos = a.pos(i); + final int alen = a.size(i); + final int[] aix = a.indexes(i); + final double[] avals = a.values(i); + for(int j = apos; j < apos + alen; j++) { + double v = op.fn.execute(avals[j], v2); + rb.set(i, aix[j], v); + } + } + } + + private static final void safeBinaryMVSparseColVectorRowWithFill(SparseBlock a, int i, SparseBlockMCSR rb, double vz, + double v2, int clen, boolean emptyRow, BinaryOperator op) { + int lastIx = -1; + if(!emptyRow) { + final int apos = a.pos(i); + final int alen = a.size(i); + final int[] aix = a.indexes(i); + final double[] avals = a.values(i); + rb.allocate(i, clen); // likely clen allocation + for(int j = apos; j < apos + alen; j++) { + + fillZeroValuesScalar(vz, rb, i, lastIx + 1, aix[j]); + // actual value + double v = op.fn.execute(avals[j], v2); + rb.append(i, aix[j], v); + lastIx = aix[j]; + } + fillZeroValuesScalar(vz, rb, i, lastIx + 1, clen); + } + else{ + rb.allocate(i, clen); + fillZeroValuesScalar(vz, rb, i, lastIx + 1, clen); + } + } + + private static final void safeBinaryMVSparseColVectorRowWithFill(SparseBlock a, int i, DenseBlock rb, double vz, + double v2, int clen, boolean emptyRow, BinaryOperator op) { + int lastIx = -1; + if(!emptyRow) { + final int apos = a.pos(i); + final int alen = a.size(i); + final int[] aix = a.indexes(i); + final double[] avals = a.values(i); + for(int j = apos; j < apos + alen; j++) { + + fillZeroValuesScalar(vz, rb, i, lastIx + 1, aix[j]); + // actual value + double v = op.fn.execute(avals[j], v2); + rb.set(i, aix[j], v); + lastIx = aix[j]; + } + fillZeroValuesScalar(vz, rb, i, lastIx + 1, clen); + } + else{ + fillZeroValuesScalar(vz, rb, i, lastIx + 1, clen); + } + } + + private static final void fillZeroValuesScalar( double v, SparseBlock ret, + int rpos, int cpos, int len) { + + for(int k = cpos; k < len; k++) + ret.append(rpos, k, v); + + } + + + private static final void fillZeroValuesScalar( double v, DenseBlock ret, + int rpos, int cpos, int len) { + ret.set(rpos, rpos + 1, cpos, len, v); + } + + private static void safeBinaryMVSparseLeftRowVector(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op) { boolean isMultiply = (op.fn instanceof Multiply); boolean skipEmpty = (isMultiply || isSparseSafeDivide(op, m2)); int rlen = m1.rlen; int clen = m1.clen; SparseBlock a = m1.sparseBlock; - for(int i = 0; i < rlen; i++) { - double v2 = m2.quickGetValue(i, 0); - - if((skipEmpty && (a == null || a.isEmpty(i) || v2 == 0)) || ((a == null || a.isEmpty(i)) && v2 == 0)) { - continue; // skip empty rows - } - - if(isMultiply && v2 == 1) { // ROW COPY - if(a != null && !a.isEmpty(i)) - ret.appendRow(i, a.get(i)); - } - else { // GENERAL CASE + if(ret.isInSparseFormat()){ + for(int i = 0; i < rlen; i++) { + if(skipEmpty && (a == null || a.isEmpty(i))) + continue; // skip empty rows + if(skipEmpty && ret.sparse) + ret.sparseBlock.allocate(i, a.size(i)); int lastIx = -1; if(a != null && !a.isEmpty(i)) { int apos = a.pos(i); @@ -894,65 +1103,61 @@ private static void safeBinaryMVSparseColVector(MatrixBlock m1, MatrixBlock m2, double[] avals = a.values(i); for(int j = apos; j < apos + alen; j++) { // empty left - fillZeroValues(op, v2, ret, skipEmpty, i, lastIx + 1, aix[j]); + fillZeroValues(op, m2, ret, skipEmpty, i, lastIx + 1, aix[j]); // actual value + double v2 = m2.quickGetValue(0, aix[j]); double v = op.fn.execute(avals[j], v2); ret.appendValue(i, aix[j], v); lastIx = aix[j]; } } // empty left - fillZeroValues(op, v2, ret, skipEmpty, i, lastIx + 1, clen); + fillZeroValues(op, m2, ret, skipEmpty, i, lastIx + 1, clen); } } - } - - private static void safeBinaryMVSparseRowVector(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op) { - boolean isMultiply = (op.fn instanceof Multiply); - boolean skipEmpty = (isMultiply || isSparseSafeDivide(op, m2)); - - int rlen = m1.rlen; - int clen = m1.clen; - SparseBlock a = m1.sparseBlock; - for(int i = 0; i < rlen; i++) { - if(skipEmpty && (a == null || a.isEmpty(i))) - continue; // skip empty rows - if(skipEmpty && ret.sparse) - ret.sparseBlock.allocate(i, a.size(i)); - int lastIx = -1; - if(a != null && !a.isEmpty(i)) { - int apos = a.pos(i); - int alen = a.size(i); - int[] aix = a.indexes(i); - double[] avals = a.values(i); - for(int j = apos; j < apos + alen; j++) { - // empty left - fillZeroValues(op, m2, ret, skipEmpty, i, lastIx + 1, aix[j]); - // actual value - double v2 = m2.quickGetValue(0, aix[j]); - double v = op.fn.execute(avals[j], v2); - ret.appendValue(i, aix[j], v); - lastIx = aix[j]; + else{ + DenseBlock db = ret.getDenseBlock(); + for(int i = 0; i < rlen; i++){ + if(skipEmpty && (a == null || a.isEmpty(i))) + continue; // skip empty rows + if(skipEmpty && ret.sparse) + ret.sparseBlock.allocate(i, a.size(i)); + int lastIx = -1; + if(a != null && !a.isEmpty(i)) { + int apos = a.pos(i); + int alen = a.size(i); + int[] aix = a.indexes(i); + double[] avals = a.values(i); + for(int j = apos; j < apos + alen; j++) { + // empty left + fillZeroValues(op, m2, db, skipEmpty, i, lastIx + 1, aix[j]); + // actual value + double v2 = m2.quickGetValue(0, aix[j]); + double v = op.fn.execute(avals[j], v2); + db.set(i, aix[j], v); + lastIx = aix[j]; + } } + // empty left + fillZeroValues(op, m2, db, skipEmpty, i, lastIx + 1, clen); } - // empty left - fillZeroValues(op, m2, ret, skipEmpty, i, lastIx + 1, clen); } } - private static final void fillZeroValues(BinaryOperator op, double v2, MatrixBlock ret, boolean skipEmpty, int rpos, int cpos, int len) { + + private static void fillZeroValues(BinaryOperator op, MatrixBlock m2, MatrixBlock ret, boolean skipEmpty, int rpos, + int cpos, int len) { if(skipEmpty) return; - - final double v = op.fn.execute(0, v2); - if(v != 0){ - for( int k=cpos; k aix[apos]) { + apos++; + } + // for each point in the sparse range + for(; apos < alen && aix[apos] < len; apos++) { + if(!zeroIsZero) { + while(cpos < len && cpos < aix[apos]) { + ret.appendValue(rpos, cpos++, zero); + } + } + cpos = aix[apos]; + final double v = op.fn.execute(0, vals[apos]); + ret.appendValue(rpos, aix[apos], v); + // cpos++; + } + // process tail. + if(!zeroIsZero) { + while(cpos < len) { + ret.appendValue(rpos, cpos++, zero); + } } } } - private static void fillZeroValuesSparse(BinaryOperator op, MatrixBlock m2, MatrixBlock ret, boolean skipEmpty, + + private static void fillZeroValuesSparse(BinaryOperator op, MatrixBlock m2, DenseBlock ret, boolean skipEmpty, int rpos, int cpos, int len) { final double zero = op.fn.execute(0.0, 0.0); @@ -1005,7 +1273,7 @@ private static void fillZeroValuesSparse(BinaryOperator op, MatrixBlock m2, Matr if(sb.isEmpty(0)) { if(!zeroIsZero) { while(cpos < len) - ret.appendValue(rpos, cpos++, zero); + ret.set(rpos, cpos++, zero); } } else { @@ -1021,19 +1289,17 @@ private static void fillZeroValuesSparse(BinaryOperator op, MatrixBlock m2, Matr for(; apos < alen && aix[apos] < len; apos++) { if(!zeroIsZero) { while(cpos < len && cpos < aix[apos]) { - ret.appendValue(rpos, cpos++, zero); + ret.set(rpos, cpos++, zero); } } cpos = aix[apos]; final double v = op.fn.execute(0, vals[apos]); - ret.appendValue(rpos, aix[apos], v); - // cpos++; + ret.set(rpos, aix[apos], v); } // process tail. if(!zeroIsZero) { - while(cpos < len) { - ret.appendValue(rpos, cpos++, zero); - } + while(cpos < len) + ret.set(rpos, cpos++, zero); } } } @@ -1181,79 +1447,100 @@ private static void safeBinaryVVGeneric(MatrixBlock m1, MatrixBlock m2, MatrixBl //no need to recomputeNonZeros since maintained in append value } - private static long safeBinaryMMSparseSparse(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, - BinaryOperator op, int rl, int ru) - { - //guard for postponed allocation in single-threaded exec - if( ret.sparse && !ret.isAllocated() ) + private static long safeBinaryMMSparseSparse(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op, + int rl, int ru) { + // guard for postponed allocation in single-threaded exec + if(ret.sparse && !ret.isAllocated()) ret.allocateSparseRowsBlock(); - - //both sparse blocks existing - long lnnz = 0; - if(m1.sparseBlock!=null && m2.sparseBlock!=null) - { + + // both sparse blocks existing + if(m1.sparseBlock != null && m2.sparseBlock != null) { SparseBlock lsblock = m1.sparseBlock; SparseBlock rsblock = m2.sparseBlock; - - if( ret.sparse && lsblock.isAligned(rsblock) ) - { - SparseBlock c = ret.sparseBlock; - for(int r=rl; r= colPos2) ? 1 : 0; - } - result.nonZeros += sblock.size(resultRow); + // skip empty + SparseBlock sblock = result.getSparseBlock(); + while(p1 < size1 && p2 < size2) { + int colPos1 = cols1[pos1 + p1]; + int colPos2 = cols2[pos2 + p2]; + if(colPos1 == colPos2) + sblock.append(resultRow, colPos1, op.fn.execute(values1[pos1 + p1], values2[pos2 + p2])); + p1 += (colPos1 <= colPos2) ? 1 : 0; + p2 += (colPos1 >= colPos2) ? 1 : 0; + } + result.nonZeros += sblock.size(resultRow); + } + + private static void mergeForSparseBinaryGeneric(BinaryOperator op, double[] values1, int[] cols1, int pos1, + int size1, double[] values2, int[] cols2, int pos2, int size2, int resultRow, MatrixBlock result) { + SparseBlock c = result.getSparseBlock(); + // general case: merge-join (with outer join semantics) + while(pos1 < size1 && pos2 < size2) { + if(cols1[pos1] < cols2[pos2]) { + c.append(resultRow, cols1[pos1], op.fn.execute(values1[pos1], 0)); + pos1++; + } + else if(cols1[pos1 ] == cols2[pos2 ]) { + c.append(resultRow, cols1[pos1], op.fn.execute(values1[pos1], values2[pos2])); + pos1++; + pos2++; + } + else { + c.append(resultRow, cols2[pos2], op.fn.execute(0, values2[pos2])); + pos2++; + } } - else { - //general case: merge-join (with outer join semantics) - while( p1 < size1 && p2 < size2 ) { - if(cols1[pos1+p1]> tasks = new ArrayList<>(); for(int i = 0; i < m; i += blockSize) { final int start = i; diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java index bafdeba18b5..1696f7a9357 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java @@ -41,6 +41,7 @@ import org.apache.sysds.lops.WeightedUnaryMM.WUMMType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; +import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.DenseBlockFP64DEDUP; import org.apache.sysds.runtime.data.DenseBlockFactory; @@ -232,8 +233,8 @@ public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock else parallelMatrixMult(m1, m2, ret, k, ultraSparse, sparse, tm2, m1Perm); - //System.out.println("MM "+k+" ("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+")x" + - // "("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+") in "+time.stop()); + // System.out.println("MM "+k+" ("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+")x" + + // "("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+") in "+time.stop()); return ret; } @@ -247,10 +248,15 @@ private static void singleThreadedMatrixMult(MatrixBlock m1, MatrixBlock m2, Mat // core matrix mult computation if(ultraSparse && !fixedRet) matrixMultUltraSparse(m1, m2, ret, m1Perm, 0, ru2); + else if( ret.sparse ) //ultra-sparse + matrixMultUltraSparse(m1, m2, ret, m1Perm, 0, ru2); else if(!m1.sparse && !m2.sparse) - matrixMultDenseDense(m1, m2, ret, tm2, pm2, 0, ru2, 0, m2.clen); + if(m1.denseBlock instanceof DenseBlockFP64DEDUP && m2.denseBlock.isContiguous(0,m1.clen)) + matrixMultDenseDenseMMDedup(m1.denseBlock, m2.denseBlock, ret.denseBlock, m2.clen, m1.clen, 0, ru2, new ConcurrentHashMap<>()); + else + matrixMultDenseDense(m1, m2, ret, tm2, pm2, 0, ru2, 0, m2.clen); else if(m1.sparse && m2.sparse) - matrixMultSparseSparse(m1, m2, ret, pm2, sparse, 0, ru2); + matrixMultSparseSparse(m1, m2, ret, pm2, ret.sparse, 0, ru2); else if(m1.sparse) matrixMultSparseDense(m1, m2, ret, pm2, 0, ru2); else @@ -756,10 +762,10 @@ public static void matrixMultWSigmoid(MatrixBlock mW, MatrixBlock mU, MatrixBloc */ public static void matrixMultWDivMM(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock mX, MatrixBlock ret, WDivMMType wt) { //check for empty result - if( mW.isEmptyBlock(false) - || (wt.isLeft() && mU.isEmptyBlock(false)) - || (wt.isRight() && mV.isEmptyBlock(false)) - || (wt.isBasic() && mW.isEmptyBlock(false))) { + if( mW.isEmptyBlock(true) + || (wt.isLeft() && mU.isEmptyBlock(true)) + || (wt.isRight() && mV.isEmptyBlock(true)) + || (wt.isBasic() && mW.isEmptyBlock(true))) { ret.examSparsity(); //turn empty dense into sparse return; } @@ -804,10 +810,10 @@ else if( mW.sparse && !mU.sparse && !mV.sparse && (mX==null || mX.sparse || scal */ public static void matrixMultWDivMM(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock mX, MatrixBlock ret, WDivMMType wt, int k) { //check for empty result - if( mW.isEmptyBlock(false) - || (wt.isLeft() && mU.isEmptyBlock(false)) - || (wt.isRight() && mV.isEmptyBlock(false)) - || (wt.isBasic() && mW.isEmptyBlock(false))) { + if( mW.isEmptyBlock(true) + || (wt.isLeft() && mU.isEmptyBlock(true)) + || (wt.isRight() && mV.isEmptyBlock(true)) + || (wt.isBasic() && mW.isEmptyBlock(true))) { ret.examSparsity(); //turn empty dense into sparse return; } @@ -1001,7 +1007,7 @@ private static void matrixMultDenseDense(MatrixBlock m1, MatrixBlock m2, MatrixB final int m = m1.rlen; final int n = m2.clen; final int cd = m1.clen; - + if( LOW_LEVEL_OPTIMIZATION ) { if( m==1 && n==1 ) { //DOT PRODUCT double[] avals = a.valuesAt(0); @@ -1244,70 +1250,98 @@ public static void matrixMultDenseDenseMM(DenseBlock a, DenseBlock b, DenseBlock } private static void matrixMultDenseSparse(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean pm2, int rl, int ru) { + + if(ret.isInSparseFormat()){ + matrixMultDenseSparseOutSparse(m1, m2, ret, pm2, rl, ru); + } + else + matrixMultDenseSparseOutDense(m1, m2, ret, pm2, rl, ru); + } + + private static void matrixMultDenseSparseOutSparse(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean pm2, + int rl, int ru) { + final DenseBlock a = m1.getDenseBlock(); + final SparseBlock b = m2.getSparseBlock(); + final SparseBlock c = ret.getSparseBlock(); + final int m = m1.rlen; // rows left + final int cd = m1.clen; // common dim + + final int rl1 = pm2 ? 0 : rl; + final int ru1 = pm2 ? m : ru; + final int rl2 = pm2 ? rl : 0; + final int ru2 = pm2 ? ru : cd; + + final int blocksizeK = 32; + final int blocksizeI = 32; + + for(int bi = rl1; bi < ru1; bi += blocksizeI) { + for(int bk = rl2, bimin = Math.min(ru1, bi + blocksizeI); bk < ru2; bk += blocksizeK) { + final int bkmin = Math.min(ru2, bk + blocksizeK); + // core sub block matrix multiplication + for(int i = bi; i < bimin; i++) { // rows left + final double[] avals = a.values(i); + final int aix = a.pos(i); + for(int k = bk; k < bkmin; k++) { // common dimension + final double aval = avals[aix + k]; + if(aval == 0 || b.isEmpty(k)) + continue; + final int[] bIdx = b.indexes(k); + final double[] bVals = b.values(k); + final int bPos = b.pos(k); + final int bEnd = bPos + b.size(k); + for(int j = bPos; j < bEnd ; j++){ + c.add(i, bIdx[j], aval * bVals[j]); + } + } + } + } + } + } + + private static void matrixMultDenseSparseOutDense(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean pm2, int rl, + int ru) { DenseBlock a = m1.getDenseBlock(); DenseBlock c = ret.getDenseBlock(); int m = m1.rlen; int cd = m1.clen; - - // MATRIX-MATRIX (VV, MV not applicable here because V always dense) - if( LOW_LEVEL_OPTIMIZATION ) - { - SparseBlock b = m2.sparseBlock; - - if( pm2 && m==1 ) { //VECTOR-MATRIX - //parallelization over rows in rhs matrix - double[] avals = a.valuesAt(0); //vector - double[] cvals = c.valuesAt(0); //vector - for( int k=rl; k 1 && m2.clen > 1) ) return false; @@ -4687,9 +4752,8 @@ else if(!_m1.sparse && !_m2.sparse) matrixMultDenseDenseMMDedup(_m1.denseBlock, _m2.denseBlock, _ret.denseBlock, _m2.clen, _m1.clen, rl, ru, _cache); else matrixMultDenseDense(_m1, _m2, _ret, _tm2, _pm2r, rl, ru, cl, cu); - else if(_m1.sparse && _m2.sparse) - matrixMultSparseSparse(_m1, _m2, _ret, _pm2r, _sparse, rl, ru); + matrixMultSparseSparse(_m1, _m2, _ret, _pm2r, _ret.sparse , rl, ru); else if(_m1.sparse) matrixMultSparseDense(_m1, _m2, _ret, _pm2r, rl, ru); else diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java index 84ff9b7c526..b23b46d5db2 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java @@ -1231,7 +1231,7 @@ public boolean evalSparseFormatOnDisk() { public final void examSparsity() { examSparsity(true, 1); } - + /** * Evaluates if this matrix block should be in sparse format in * memory. Depending on the current representation, the state of the @@ -1286,7 +1286,7 @@ public void examSparsity(boolean allowCSR, int k) { else if( !sparse && sparseDst ) denseToSparse(allowCSR, k); } - + public static boolean evalSparseFormatInMemory(DataCharacteristics dc) { return evalSparseFormatInMemory(dc.getRows(), dc.getCols(), dc.getNonZeros()); } @@ -1358,12 +1358,13 @@ public void denseToSparse(boolean allowCSR, int k){ LibMatrixDenseToSparse.denseToSparse(this, allowCSR, k); } - public final void sparseToDense() { - sparseToDense(1); + public final MatrixBlock sparseToDense() { + return sparseToDense(1); } - public void sparseToDense(int k) { + public MatrixBlock sparseToDense(int k) { LibMatrixSparseToDense.sparseToDense(this, k); + return this; } /** @@ -1423,6 +1424,10 @@ else if(!sparse && denseBlock!=null){ nnz += e.get(); nonZeros = nnz; + if(nonZeros < 0) + throw new DMLRuntimeException("Invalid count of non zero values: " + nonZeros); + return nonZeros; + } catch(Exception e) { LOG.warn("Failed Parallel non zero count fallback to singlethread"); @@ -1438,6 +1443,12 @@ else if(!sparse && denseBlock!=null){ return nonZeros; } + /** + * Recompute the number of non-zero values + * @param rl row lower index, 0-based, inclusive + * @param ru row upper index, 0-based, inclusive + * @return the number of non-zero values + */ public long recomputeNonZeros(int rl, int ru) { return recomputeNonZeros(rl, ru, 0, clen-1); } @@ -2862,6 +2873,10 @@ private static SparsityEstimate estimateSparsityOnBinary(MatrixBlock m1, MatrixB est.sparse = false; return est; } + else if(op.fn instanceof Divide && m2.getSparsity() == 1.0){ + est.sparse = m1.sparse; + return est; + } BinaryAccessType atype = LibMatrixBincell.getBinaryAccessType(m1, m2); boolean outer = (atype == BinaryAccessType.OUTER_VECTOR_VECTOR); @@ -2870,6 +2885,11 @@ private static SparsityEstimate estimateSparsityOnBinary(MatrixBlock m1, MatrixB long nz1 = m1.getNonZeros(); long nz2 = m2.getNonZeros(); + if(nz1 <= 0) + nz1 = m1.recomputeNonZeros(); + if(nz2 <= 0) + nz2 = m2.recomputeNonZeros(); + //account for matrix vector and vector/vector long estnnz = 0; if( atype == BinaryAccessType.OUTER_VECTOR_VECTOR ) @@ -2957,13 +2977,14 @@ public boolean isShallowSerialize(boolean inclConvert) { boolean sparseDst = evalSparseFormatOnDisk(); return !sparse || !sparseDst || (sparse && sparseBlock instanceof SparseBlockCSR) - || (sparse && sparseBlock instanceof SparseBlockMCSR - && getInMemorySize() / MAX_SHALLOW_SERIALIZE_OVERHEAD - <= getExactSerializedSize()) - || (sparse && sparseBlock instanceof SparseBlockMCSR - && nonZeros < Integer.MAX_VALUE //CSR constraint - && inclConvert && CONVERT_MCSR_TO_CSR_ON_DEEP_SERIALIZE - && !isUltraSparseSerialize(sparseDst)); + || (sparse && sparseBlock instanceof SparseBlockMCSR); + // || (sparse && sparseBlock instanceof SparseBlockMCSR + // && getInMemorySize() / MAX_SHALLOW_SERIALIZE_OVERHEAD + // <= getExactSerializedSize()) + // || (sparse && sparseBlock instanceof SparseBlockMCSR + // && nonZeros < Integer.MAX_VALUE //CSR constraint + // && inclConvert && CONVERT_MCSR_TO_CSR_ON_DEEP_SERIALIZE + // && !isUltraSparseSerialize(sparseDst)); } @Override @@ -2985,26 +3006,7 @@ public void compactEmptyBlock() { @Override public MatrixBlock scalarOperations(ScalarOperator op, MatrixValue result) { - MatrixBlock ret = checkType(result); - - // estimate the sparsity structure of result matrix - boolean sp = this.sparse; // by default, we guess result.sparsity=input.sparsity - if (!op.sparseSafe) - sp = false; // if the operation is not sparse safe, then result will be in dense format - - //allocate the output matrix block - if( ret==null ) - ret = new MatrixBlock(rlen, clen, sp, this.nonZeros); - else - ret.reset(rlen, clen, sp, this.nonZeros); - - //core scalar operations - if( op.getNumThreads() > 1 ) - LibMatrixBincell.bincellOp(this, ret, op, op.getNumThreads()); - else - LibMatrixBincell.bincellOp(this, ret, op); - - return ret; + return LibMatrixBincell.bincellOpScalar(this, checkType(result), op, op.getNumThreads()); } public final MatrixBlock unaryOperations(UnaryOperator op){ @@ -3076,12 +3078,8 @@ public MatrixBlock binaryOperations(BinaryOperator op, MatrixValue thatValue, Ma else ret.reset(rows, cols, resultSparse.sparse, resultSparse.estimatedNonZeros); - //core binary cell operation - if( op.getNumThreads() > 1 ) - LibMatrixBincell.bincellOp( this, that, ret, op, op.getNumThreads() ); - else - LibMatrixBincell.bincellOp( this, that, ret, op ); - + LibMatrixBincell.bincellOp( this, that, ret, op, op.getNumThreads() ); + return ret; } @@ -3099,6 +3097,7 @@ else if(!resultSparse.sparse && this.sparse) //core binary cell operation LibMatrixBincell.bincellOpInPlace(this, that, op); + return this; } @@ -3171,10 +3170,7 @@ public MatrixBlock ternaryOperations(TernaryOperator op, MatrixBlock m2, MatrixB if (s2 != s3 && (op.fn instanceof PlusMultiply || op.fn instanceof MinusMultiply) ) { //SPECIAL CASE for sparse-dense combinations of common +* and -* BinaryOperator bop = ((ValueFunctionWithConstant)op.fn).setOp2Constant(s2 ? d2 : d3); - if( op.getNumThreads() > 1 ) - LibMatrixBincell.bincellOp(this, s2 ? m3 : m2, ret, bop, op.getNumThreads()); - else - LibMatrixBincell.bincellOp(this, s2 ? m3 : m2, ret, bop); + LibMatrixBincell.bincellOp(this, s2 ? m3 : m2, ret, bop, op.getNumThreads()); } else { //DEFAULT CASE @@ -3824,80 +3820,90 @@ public MatrixBlock append(MatrixBlock[] that, MatrixBlock result, boolean cbind) else result.reset(m, n, sp, nnz); - //core append operation - //copy left and right input into output - if( !result.sparse && nnz!=0 ) //DENSE - { - if( cbind ) { - DenseBlock resd = result.allocateBlock().getDenseBlock(); - MatrixBlock[] in = ArrayUtils.addAll(new MatrixBlock[]{this}, that); - - for( int i=0; i rlen && !shallowCopy && result.getSparseBlock() instanceof SparseBlockMCSR ) { - final SparseBlock sblock = result.getSparseBlock(); - // for each row calculate how many non zeros are pressent. - for( int i=0; i rlen && !shallowCopy && result.getSparseBlock() instanceof SparseBlockMCSR) { + final SparseBlock sblock = result.getSparseBlock(); + // for each row calculate how many non zeros are pressent. + for(int i = 0; i < result.rlen; i++) + sblock.allocate(i, computeNNzRow(that, i)); + + } + + // core append operation + // we can always append this directly to offset 0.0 in both cbind and rbind. + result.appendToSparse(this, 0, 0, !shallowCopy); + if(cbind) { + for(int i = 0, off = clen; i < that.length; i++) { + result.appendToSparse(that[i], 0, off); + off += that[i].clen; } - else { //rbind - for(int i=0, off=rlen; i _sparseRowsWZeros = null; + // protected ArrayList _sparseRowsWZeros = null; + + protected boolean containsZeroOut = false; protected long _estMetaSize = 0; protected int _estNumDistincts = 0; protected int _nBuildPartitions = 0; @@ -142,8 +142,7 @@ public MatrixBlock apply(CacheBlock in, MatrixBlock out, int outputCol, int r protected abstract double[] getCodeCol(CacheBlock in, int startInd, int rowEnd, double[] tmp); protected void applySparse(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk){ - boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK == SparseBlock.Type.MCSR; - mcsr = false; //force CSR for transformencode + boolean mcsr = out.getSparseBlock() instanceof SparseBlockMCSR; int index = _colID - 1; // Apply loop tiling to exploit CPU caches int rowEnd = getEndIndex(in.getNumRows(), rowStart, blk); @@ -411,20 +410,24 @@ public List> getApplyTasks(CacheBlock in, MatrixBlock out, return new ColumnApplyTask<>(this, in, out, outputCol, startRow, blk); } - public Set getSparseRowsWZeros(){ - if (_sparseRowsWZeros != null) { - return new HashSet<>(_sparseRowsWZeros); - } - else - return null; - } - - protected void addSparseRowsWZeros(ArrayList sparseRowsWZeros){ - synchronized (this){ - if(_sparseRowsWZeros == null) - _sparseRowsWZeros = new ArrayList<>(); - _sparseRowsWZeros.addAll(sparseRowsWZeros); - } + // public Set getSparseRowsWZeros(){ + // if (_sparseRowsWZeros != null) { + // return new HashSet<>(_sparseRowsWZeros); + // } + // else + // return null; + // } + + // protected void addSparseRowsWZeros(ArrayList sparseRowsWZeros){ + // synchronized (this){ + // if(_sparseRowsWZeros == null) + // _sparseRowsWZeros = new ArrayList<>(); + // _sparseRowsWZeros.addAll(sparseRowsWZeros); + // } + // } + + protected boolean containsZeroOut(){ + return containsZeroOut; } protected void setBuildRowBlocksPerColumn(int nPart) { diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java index 8e1055e41d2..3e500bd310f 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java @@ -19,6 +19,8 @@ package org.apache.sysds.runtime.transform.encode; +import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex; + import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; @@ -28,7 +30,6 @@ import java.util.Random; import java.util.concurrent.Callable; -import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex; import org.apache.commons.lang3.tuple.MutableTriple; import org.apache.sysds.api.DMLScript; import org.apache.sysds.lops.Lop; @@ -274,17 +275,12 @@ private static double[] prepareDataForEqualHeightBins(CacheBlock in, int colI private static double[] extractDoubleColumn(CacheBlock in, int colID, int startRow, int blockSize) { int endRow = getEndIndex(in.getNumRows(), startRow, blockSize); - double[] vals = new double[endRow - startRow]; final int cid = colID -1; + double[] vals = new double[endRow - startRow]; if(in instanceof FrameBlock) { // FrameBlock optimization Array a = ((FrameBlock) in).getColumn(cid); - for(int i = startRow; i < endRow; i++) { - double inVal = a.getAsNaNDouble(i); - if(Double.isNaN(inVal)) - continue; - vals[i - startRow] = inVal; - } + return a.extractDouble(vals, startRow, endRow); } else { for(int i = startRow; i < endRow; i++) { diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java index 6fda66113dd..6e7f4f8002d 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java @@ -27,10 +27,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.Objects; import java.util.concurrent.Callable; -import java.util.stream.Collectors; import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.DMLRuntimeException; @@ -408,13 +406,20 @@ public void shiftCol(int columnOffset) { _columnEncoders.forEach(e -> e.shiftCol(columnOffset)); } - @Override - public Set getSparseRowsWZeros(){ - return _columnEncoders.stream().map(ColumnEncoder::getSparseRowsWZeros).flatMap(l -> { - if(l == null) - return null; - return l.stream(); - }).collect(Collectors.toSet()); + // @Override + // public Set getSparseRowsWZeros(){ + // return _columnEncoders.stream().map(ColumnEncoder::getSparseRowsWZeros).flatMap(l -> { + // if(l == null) + // return null; + // return l.stream(); + // }).collect(Collectors.toSet()); + // } + + protected boolean containsZeroOut(){ + for(int i = 0; i < _columnEncoders.size(); i++) + if(_columnEncoders.get(i).containsZeroOut()) + return true; + return false; } @Override diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java index a9b00a0767a..395895ff458 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java @@ -24,15 +24,14 @@ import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; -import java.util.ArrayList; import java.util.List; import java.util.Objects; import org.apache.sysds.api.DMLScript; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.CacheBlock; -import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockCSR; +import org.apache.sysds.runtime.data.SparseBlockMCSR; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.DependencyTask; @@ -115,17 +114,14 @@ protected void applySparse(CacheBlock in, MatrixBlock out, int outputCol, int throw new DMLRuntimeException( "ColumnEncoderDummycode called with: " + in.getClass().getSimpleName() + " and not MatrixBlock"); } - boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK == SparseBlock.Type.MCSR; - mcsr = false; // force CSR for transformencode - ArrayList sparseRowsWZeros = null; + boolean mcsr = out.getSparseBlock() instanceof SparseBlockMCSR; + // ArrayList sparseRowsWZeros = null; int index = _colID - 1; for(int r = rowStart; r < getEndIndex(in.getNumRows(), rowStart, blk); r++) { if(mcsr) { double val = out.getSparseBlock().get(r).values()[index]; if(Double.isNaN(val)) { - if(sparseRowsWZeros == null) - sparseRowsWZeros = new ArrayList<>(); - sparseRowsWZeros.add(r); + containsZeroOut = true; out.getSparseBlock().get(r).values()[index] = 0; continue; } @@ -138,9 +134,7 @@ protected void applySparse(CacheBlock in, MatrixBlock out, int outputCol, int int rptr[] = csrblock.rowPointers(); double val = csrblock.values()[rptr[r] + index]; if(Double.isNaN(val)) { - if(sparseRowsWZeros == null) - sparseRowsWZeros = new ArrayList<>(); - sparseRowsWZeros.add(r); + containsZeroOut = true; csrblock.values()[rptr[r] + index] = 0; // test continue; } @@ -150,9 +144,6 @@ protected void applySparse(CacheBlock in, MatrixBlock out, int outputCol, int csrblock.values()[rptr[r] + index] = 1; } } - if(sparseRowsWZeros != null) { - addSparseRowsWZeros(sparseRowsWZeros); - } } protected void applyDense(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk){ diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java index c57c72f459d..eb505769f6d 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java @@ -25,6 +25,7 @@ import java.util.List; import org.apache.sysds.api.DMLScript; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.CacheBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.Array; @@ -67,7 +68,7 @@ protected TransformType getTransformType() { @Override protected double getCode(CacheBlock in, int row) { if(in instanceof FrameBlock){ - Array a = ((FrameBlock)in).getColumn(_colID -1); + Array a = ((FrameBlock)in).getColumn(_colID - 1); return getCode(a, row); } else{ // default @@ -80,16 +81,24 @@ protected double getCode(CacheBlock in, int row) { } protected double getCode(Array a, int row){ - return Math.abs(a.hashDouble(row) % _K + 1); + return Math.abs(a.hashDouble(row)) % _K + 1; + } + + protected static double getCode(Array a, int k , int row){ + return Math.abs(a.hashDouble(row)) % k ; } protected double[] getCodeCol(CacheBlock in, int startInd, int endInd, double[] tmp) { final int endLength = endInd - startInd; final double[] codes = tmp != null && tmp.length == endLength ? tmp : new double[endLength]; - if( in instanceof FrameBlock) { + if(in instanceof FrameBlock) { Array a = ((FrameBlock) in).getColumn(_colID-1); - for(int i = startInd; i < endInd; i++) - codes[i - startInd] = getCode(a, i); + for(int i = startInd; i < endInd; i++){ + double code = getCode(a, i); + if(code <= 0) + throw new DMLRuntimeException("Bad Code"); + codes[i - startInd] = code; + } } else {// default for(int i = startInd; i < endInd; i++) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java index 12077221c05..7406d366468 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java @@ -21,13 +21,13 @@ import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex; -import java.util.ArrayList; import java.util.List; import org.apache.sysds.api.DMLScript; import org.apache.sysds.runtime.controlprogram.caching.CacheBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockCSR; +import org.apache.sysds.runtime.data.SparseBlockMCSR; import org.apache.sysds.runtime.data.SparseRowVector; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -82,40 +82,48 @@ protected double[] getCodeCol(CacheBlock in, int startInd, int endInd, double } protected void applySparse(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk){ - //Set sparseRowsWZeros = null; - ArrayList sparseRowsWZeros = null; - boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK == SparseBlock.Type.MCSR; - mcsr = false; //force CSR for transformencode - int index = _colID - 1; - // Apply loop tiling to exploit CPU caches - int rowEnd = getEndIndex(in.getNumRows(), rowStart, blk); - double[] codes = getCodeCol(in, rowStart, rowEnd, null); - int B = 32; //tile size - for(int i = rowStart; i < rowEnd; i+=B) { - int lim = Math.min(i+B, rowEnd); - for (int ii=i; ii(); - sparseRowsWZeros.add(ii); - } - if (mcsr) { - SparseRowVector row = (SparseRowVector) out.getSparseBlock().get(ii); - row.values()[index] = v; - row.indexes()[index] = outputCol; - } - else { //csr - // Manually fill the column-indexes and values array - SparseBlockCSR csrblock = (SparseBlockCSR)out.getSparseBlock(); - int rptr[] = csrblock.rowPointers(); - csrblock.indexes()[rptr[ii]+index] = outputCol; - csrblock.values()[rptr[ii]+index] = codes[ii-rowStart]; - } - } + final SparseBlock sb = out.getSparseBlock(); + final boolean mcsr = sb instanceof SparseBlockMCSR; + final int index = _colID - 1; + final int rowEnd = getEndIndex(in.getNumRows(), rowStart, blk); + final int bs = 32; + double[] tmp = null; + for(int i = rowStart; i < rowEnd; i+= bs) { + int end = Math.min(i + bs , rowEnd); + tmp = getCodeCol(in, i, end,tmp); + if(mcsr) + applySparseBlockMCSR(in, (SparseBlockMCSR) sb, index, outputCol, i, end, tmp); + else + applySparseBlockCSR(in, (SparseBlockCSR) sb, index, outputCol, i, end, tmp); + + } + } + + private void applySparseBlockMCSR(CacheBlock in, final SparseBlockMCSR sb, final int index, + final int outputCol, int rl, int ru, double[] tmpCodes) { + for(int i = rl; i < ru; i ++) { + final double v = tmpCodes[i - rl]; + SparseRowVector row = (SparseRowVector) sb.get(i); + row.indexes()[index] = outputCol; + if(v == 0) + containsZeroOut = true; + else + row.values()[index] = v; } - if(sparseRowsWZeros != null){ - addSparseRowsWZeros(sparseRowsWZeros); + } + + private void applySparseBlockCSR(CacheBlock in, final SparseBlockCSR sb, final int index, final int outputCol, + int rl, int ru, double[] tmpCodes) { + final int[] rptr = sb.rowPointers(); + final int[] idx = sb.indexes(); + final double[] val = sb.values(); + for(int i = rl; i < ru; i++) { + final double v = tmpCodes[i - rl]; + idx[rptr[i] + index] = outputCol; + if(v == 0) + containsZeroOut = true; + else + val[rptr[i] + index] = v; } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java index 9569aa69d91..ed8700005e1 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java @@ -373,7 +373,12 @@ public String toString() { sb.append(": "); sb.append(_colID); sb.append(" --- map: "); - sb.append(_rcdMap); + if(_rcdMap.size() < 1000){ + sb.append(_rcdMap); + } + else{ + sb.append("Map to big to print but size is : " + _rcdMap.size()); + } return sb.toString(); } diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index 7fbdb1ea3c8..6c794950b34 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -158,6 +158,7 @@ private AColGroup recodeToDummy(ColumnEncoderComposite c) { Array a = in.getColumn(colId - 1); boolean containsNull = a.containsNull(); Map map = a.getRecodeMap(); + List r = c.getEncoders(); r.set(0, new ColumnEncoderRecode(colId, (HashMap) map)); int domain = map.size(); @@ -169,6 +170,7 @@ private AColGroup recodeToDummy(ColumnEncoderComposite c) { ADictionary d = new IdentityDictionary(colIndexes.size(), containsNull); AMapToData m = createMappingAMapToData(a, map, containsNull); + return ColGroupDDC.create(colIndexes, d, m, null); } @@ -291,7 +293,7 @@ private AColGroup passThrough(ColumnEncoderComposite c) { IColIndex colIndexes = ColIndexFactory.create(1); int colId = c._colID; Array a = in.getColumn(colId - 1); - if(a instanceof ACompressedArray){ + if(a instanceof ACompressedArray) { switch(a.getFrameArrayType()) { case DDC: DDCArray aDDC = (DDCArray) a; @@ -322,7 +324,7 @@ private AColGroup passThrough(ColumnEncoderComposite c) { if(containsNull) vals[map.size()] = Double.NaN; ValueType t = a.getValueType(); - map.forEach((k, v) -> vals[v.intValue()-1] = UtilFunctions.objectToDouble(t, k)); + map.forEach((k, v) -> vals[v.intValue() - 1] = UtilFunctions.objectToDouble(t, k)); ADictionary d = Dictionary.create(vals); AMapToData m = createMappingAMapToData(a, map, containsNull); return ColGroupDDC.create(colIndexes, d, m, null); @@ -337,30 +339,29 @@ private AMapToData createMappingAMapToData(Array a, Map map, boolean AMapToData m = MapToFactory.create(in.getNumRows(), si + (containsNull ? 1 : 0)); Array.ArrayIterator it = a.getIterator(); if(containsNull) { - while(it.hasNext()) { Object v = it.next(); - try{ + try { if(v != null) - m.set(it.getIndex(), map.get(v).intValue() -1); + m.set(it.getIndex(), map.get(v).intValue() - 1); else m.set(it.getIndex(), si); } - catch(Exception e){ - throw new RuntimeException("failed on " + v +" " + a.getValueType(), e); + catch(Exception e) { + throw new RuntimeException("failed on " + v + " " + a.getValueType(), e); } } } else { while(it.hasNext()) { Object v = it.next(); - m.set(it.getIndex(), map.get(v).intValue() -1); + m.set(it.getIndex(), map.get(v).intValue() - 1); } } return m; } catch(Exception e) { - throw new RuntimeException("failed constructing map: " + map, e); + throw new RuntimeException("failed constructing map: " + map, e); } } @@ -368,19 +369,17 @@ private AMapToData createHashMappingAMapToData(Array a, int k, boolean nulls) AMapToData m = MapToFactory.create(a.size(), k + (nulls ? 1 : 0)); if(nulls) { for(int i = 0; i < a.size(); i++) { - double h = Math.abs(a.hashDouble(i)); - if(Double.isNaN(h)) { + double h = Math.abs(a.hashDouble(i)) % k; + if(Double.isNaN(h)) m.set(i, k); - } - else { - m.set(i, (int) h % k); - } + else + m.set(i, (int)h); } } else { for(int i = 0; i < a.size(); i++) { - double h = Math.abs(a.hashDouble(i)); - m.set(i, (int) h % k); + double h = Math.abs(a.hashDouble(i)) % k; + m.set(i, (int)h); } } return m; diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java index be0680379f1..c60f03b7001 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java @@ -129,7 +129,7 @@ public static MultiColumnEncoder createEncoder(String spec, String[] colnames, i //rcIDs = unionDistinct(rcIDs, weIDs); // Error out if the first level encoders have overlaps if (intersect(rcIDs, binIDs, haIDs, weIDs)) - throw new DMLRuntimeException("More than one encoders (recode, binning, hashing, word_embedding) on one column is not allowed"); + throw new DMLRuntimeException("More than one encoders (recode, binning, hashing, word_embedding) on one column is not allowed:\n" + spec); List ptIDs = except(except(except(UtilFunctions.getSeqList(1, clen, 1), unionDistinct(rcIDs, haIDs)), binIDs), weIDs); diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java index 59c5f2c973c..4fb86ae8d03 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java @@ -28,8 +28,6 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Objects; -import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; @@ -53,6 +51,7 @@ import org.apache.sysds.runtime.data.DenseBlockFP64DEDUP; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockCSR; +import org.apache.sysds.runtime.data.SparseBlockMCSR; import org.apache.sysds.runtime.data.SparseRowVector; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -69,10 +68,10 @@ public class MultiColumnEncoder implements Encoder { // If true build and apply separately by placing a synchronization barrier public static boolean MULTI_THREADED_STAGES = ConfigurationManager.isStagedParallelTransform(); - // Only affects if MULTI_THREADED_STAGES is true + // Only affects if MULTI_THREADED_STAGES is true // if true apply tasks for each column will complete // before the next will start. - public static boolean APPLY_ENCODER_SEPARATE_STAGES = false; + public static boolean APPLY_ENCODER_SEPARATE_STAGES = false; private List _columnEncoders; // These encoders are deprecated and will be phased out soon. @@ -102,32 +101,36 @@ public MatrixBlock encode(CacheBlock in, boolean compressedOut) { return encode(in, 1, compressedOut); } - public MatrixBlock encode(CacheBlock in, int k, boolean compressedOut){ + public MatrixBlock encode(CacheBlock in, int k, boolean compressedOut) { try { if(isCompressedTransformEncode(in, compressedOut)) - return CompressedEncode.encode(this, (FrameBlock ) in, k); - + return CompressedEncode.encode(this, (FrameBlock) in, k); + deriveNumRowPartitions(in, k); + + MatrixBlock out; + if(k > 1 && !MULTI_THREADED_STAGES && !hasLegacyEncoder()) { - MatrixBlock out = new MatrixBlock(); + out = new MatrixBlock(); DependencyThreadPool pool = new DependencyThreadPool(k); LOG.debug("Encoding with full DAG on " + k + " Threads"); try { List> tasks = getEncodeTasks(in, out, pool); pool.submitAllAndWait(tasks); } - finally{ + finally { pool.shutdown(); } + out.setNonZeros((long)in.getNumRows() * in.getNumColumns()); outputMatrixPostProcessing(out, k); - return out; + } else { LOG.debug("Encoding with staged approach on: " + k + " Threads"); long t0 = System.nanoTime(); build(in, k); long t1 = System.nanoTime(); - LOG.debug("Elapsed time for build phase: "+ ((double) t1 - t0) / 1000000 + " ms"); + LOG.debug("Elapsed time for build phase: " + ((double) t1 - t0) / 1000000 + " ms"); if(_legacyMVImpute != null) { // These operations are redundant for every encoder excluding the legacyMVImpute, the workaround to // fix it for this encoder would be very dirty. This will only have a performance impact if there @@ -138,11 +141,29 @@ public MatrixBlock encode(CacheBlock in, int k, boolean compressedOut){ } // apply meta data t0 = System.nanoTime(); - MatrixBlock out = apply(in, k); + out = apply(in, k); + if(out.getNonZeros() < 0) + throw new DMLRuntimeException( + "Invalid assigned non zeros of transform encode output: " + out.getNonZeros()); + t1 = System.nanoTime(); - LOG.debug("Elapsed time for apply phase: "+ ((double) t1 - t0) / 1000000 + " ms"); - return out; + LOG.debug("Elapsed time for apply phase: " + ((double) t1 - t0) / 1000000 + " ms"); } + if(LOG.isDebugEnabled()) { + LOG.debug("Transform Encode output mem size: " + out.getInMemorySize()); + LOG.debug(String.format("Transform Encode output rows : %10d", out.getNumRows())); + LOG.debug(String.format("Transform Encode output cols : %10d", out.getNumColumns())); + LOG.debug(String.format("Transform Encode output sparsity : %10.5f", out.getSparsity())); + LOG.debug(String.format("Transform Encode output nnz : %10d", out.getNonZeros())); + LOG.error(out.slice(0, 10)); + + } + + if(out.getNonZeros() > (long)in.getNumRows() * in.getNumColumns()){ + throw new DMLRuntimeException("Invalid transform output, contains to many non zeros" + out.getNonZeros() + + " Max: " + ((long) in.getNumRows() * in.getNumColumns())); + } + return out; } catch(Exception ex) { throw new DMLRuntimeException("Failed transform-encode frame with encoder:\n" + this, ex); @@ -153,14 +174,11 @@ protected List getEncoders() { return _columnEncoders; } - /* TASK DETAILS: - * InitOutputMatrixTask: Allocate output matrix - * AllocMetaTask: Allocate metadata frame - * BuildTask: Build an encoder - * ColumnCompositeUpdateDCTask: Update domain size of a DC encoder based on #distincts, #bins, K - * ColumnMetaDataTask: Fill up metadata of an encoder - * ApplyTasksWrapperTask: Wrapper task for an Apply task - * UpdateOutputColTask: Set starting offsets of the DC columns + /* + * TASK DETAILS: InitOutputMatrixTask: Allocate output matrix AllocMetaTask: Allocate metadata frame BuildTask: Build + * an encoder ColumnCompositeUpdateDCTask: Update domain size of a DC encoder based on #distincts, #bins, K + * ColumnMetaDataTask: Fill up metadata of an encoder ApplyTasksWrapperTask: Wrapper task for an Apply task + * UpdateOutputColTask: Set starting offsets of the DC columns */ private List> getEncodeTasks(CacheBlock in, MatrixBlock out, DependencyThreadPool pool) { List> tasks = new ArrayList<>(); @@ -180,63 +198,62 @@ private List> getEncodeTasks(CacheBlock in, MatrixBlock out tasks.addAll(buildTasks); if(buildTasks.size() > 0) { // Check if any Build independent UpdateDC task (Bin+DC, FH+DC) - if (e.hasEncoder(ColumnEncoderDummycode.class) - && buildTasks.size() > 1 //filter out FH - && !buildTasks.get(buildTasks.size()-2).hasDependency(buildTasks.get(buildTasks.size()-1))) - independentUpdateDC = true; - + if(e.hasEncoder(ColumnEncoderDummycode.class) && buildTasks.size() > 1 // filter out FH + && !buildTasks.get(buildTasks.size() - 2).hasDependency(buildTasks.get(buildTasks.size() - 1))) + independentUpdateDC = true; + // Independent UpdateDC task - if (independentUpdateDC) { + if(independentUpdateDC) { // Apply Task depends on task prior to UpdateDC (Build/MergePartialBuild) - depMap.put(new Integer[] {tasks.size(), tasks.size() + 1}, //ApplyTask - new Integer[] {tasks.size() - 2, tasks.size() - 1}); //BuildTask - // getMetaDataTask depends on task prior to UpdateDC - depMap.put(new Integer[] {tasks.size() + 1, tasks.size() + 2}, //MetaDataTask - new Integer[] {tasks.size() - 2, tasks.size() - 1}); //BuildTask + depMap.put(new Integer[] {tasks.size(), tasks.size() + 1}, // ApplyTask + new Integer[] {tasks.size() - 2, tasks.size() - 1}); // BuildTask + // getMetaDataTask depends on task prior to UpdateDC + depMap.put(new Integer[] {tasks.size() + 1, tasks.size() + 2}, // MetaDataTask + new Integer[] {tasks.size() - 2, tasks.size() - 1}); // BuildTask } - else { + else { // Apply Task depends on the last task (Build/MergePartial/UpdateDC) - depMap.put(new Integer[] {tasks.size(), tasks.size() + 1}, //ApplyTask - new Integer[] {tasks.size() - 1, tasks.size()}); //Build/UpdateDC + depMap.put(new Integer[] {tasks.size(), tasks.size() + 1}, // ApplyTask + new Integer[] {tasks.size() - 1, tasks.size()}); // Build/UpdateDC // getMetaDataTask depends on build completion - depMap.put(new Integer[] {tasks.size() + 1, tasks.size() + 2}, //MetaDataTask - new Integer[] {tasks.size() - 1, tasks.size()}); //Build/UpdateDC + depMap.put(new Integer[] {tasks.size() + 1, tasks.size() + 2}, // MetaDataTask + new Integer[] {tasks.size() - 1, tasks.size()}); // Build/UpdateDC } // AllocMetaTask never depends on the UpdateDC task - if (e.hasEncoder(ColumnEncoderDummycode.class) && buildTasks.size() > 1) - depMap.put(new Integer[] {1, 2}, //AllocMetaTask (2nd task) - new Integer[] {tasks.size() - 2, tasks.size()-1}); //BuildTask + if(e.hasEncoder(ColumnEncoderDummycode.class) && buildTasks.size() > 1) + depMap.put(new Integer[] {1, 2}, // AllocMetaTask (2nd task) + new Integer[] {tasks.size() - 2, tasks.size() - 1}); // BuildTask else - depMap.put(new Integer[] {1, 2}, //AllocMetaTask (2nd task) - new Integer[] {tasks.size() - 1, tasks.size()}); //BuildTask + depMap.put(new Integer[] {1, 2}, // AllocMetaTask (2nd task) + new Integer[] {tasks.size() - 1, tasks.size()}); // BuildTask } // getMetaDataTask depends on AllocMeta task - depMap.put(new Integer[] {tasks.size() + 1, tasks.size() + 2}, //MetaDataTask - new Integer[] {1, 2}); //AllocMetaTask (2nd task) + depMap.put(new Integer[] {tasks.size() + 1, tasks.size() + 2}, // MetaDataTask + new Integer[] {1, 2}); // AllocMetaTask (2nd task) // Apply Task depends on InitOutputMatrixTask (output allocation) - depMap.put(new Integer[] {tasks.size(), tasks.size() + 1}, //ApplyTask - new Integer[] {0, 1}); //Allocation task (1st task) + depMap.put(new Integer[] {tasks.size(), tasks.size() + 1}, // ApplyTask + new Integer[] {0, 1}); // Allocation task (1st task) ApplyTasksWrapperTask applyTaskWrapper = new ApplyTasksWrapperTask(e, in, out, pool); if(e.hasEncoder(ColumnEncoderDummycode.class)) { // Allocation depends on build if DC is in the list. // Note, DC is the only encoder that changes dimensionality - depMap.put(new Integer[] {0, 1}, //Allocation task (1st task) - new Integer[] {tasks.size() - 1, tasks.size()}); //BuildTask + depMap.put(new Integer[] {0, 1}, // Allocation task (1st task) + new Integer[] {tasks.size() - 1, tasks.size()}); // BuildTask // UpdateOutputColTask, that sets the starting offsets of the DC columns, // depends on the Build completion tasks - depMap.put(new Integer[] {-2, -1}, //UpdateOutputColTask (last task) - new Integer[] {tasks.size() - 1, tasks.size()}); //BuildTask + depMap.put(new Integer[] {-2, -1}, // UpdateOutputColTask (last task) + new Integer[] {tasks.size() - 1, tasks.size()}); // BuildTask buildTasks.forEach(t -> t.setPriority(5)); applyOffsetDep = true; } if(hasDC && applyOffsetDep) { // Apply tasks depend on UpdateOutputColTask - depMap.put(new Integer[] {tasks.size(), tasks.size() + 1}, //ApplyTask - new Integer[] {-2, -1}); //UpdateOutputColTask (last task) + depMap.put(new Integer[] {tasks.size(), tasks.size() + 1}, // ApplyTask + new Integer[] {-2, -1}); // UpdateOutputColTask (last task) applyTAgg = applyTAgg == null ? new ArrayList<>() : applyTAgg; applyTAgg.add(applyTaskWrapper); @@ -269,7 +286,7 @@ public void build(CacheBlock in, int k) { public void build(CacheBlock in, int k, Map equiHeightBinMaxs) { if(hasLegacyEncoder() && !(in instanceof FrameBlock)) throw new DMLRuntimeException("LegacyEncoders do not support non FrameBlock Inputs"); - if(!_partitionDone) //happens if this method is directly called + if(!_partitionDone) // happens if this method is directly called deriveNumRowPartitions(in, k); if(k > 1) { buildMT(in, k); @@ -300,7 +317,7 @@ private void buildMT(CacheBlock in, int k) { catch(ExecutionException | InterruptedException e) { throw new RuntimeException(e); } - finally{ + finally { pool.shutdown(); } } @@ -312,7 +329,6 @@ public void legacyBuild(FrameBlock in) { _legacyMVImpute.build(in); } - public MatrixBlock apply(CacheBlock in) { return apply(in, 1); } @@ -331,7 +347,7 @@ public MatrixBlock apply(CacheBlock in, int k) { return apply(in, out, 0, k); } - public void updateAllDCEncoders(){ + public void updateAllDCEncoders() { for(ColumnEncoderComposite columnEncoder : _columnEncoders) columnEncoder.updateAllDCEncoders(); } @@ -366,7 +382,7 @@ public MatrixBlock apply(CacheBlock in, MatrixBlock out, int outputCol, int k } outputMatrixPreProcessing(out, in, hasDC, hasWE, distinctWE); if(k > 1) { - if(!_partitionDone) //happens if this method is directly called + if(!_partitionDone) // happens if this method is directly called deriveNumRowPartitions(in, k); applyMT(in, out, outputCol, k); } @@ -411,24 +427,25 @@ private void applyMT(CacheBlock in, MatrixBlock out, int outputCol, int k) { try { if(APPLY_ENCODER_SEPARATE_STAGES) { int offset = outputCol; - for (ColumnEncoderComposite e : _columnEncoders) { + for(ColumnEncoderComposite e : _columnEncoders) { pool.submitAllAndWait(e.getApplyTasks(in, out, e._colID - 1 + offset)); offset = getOffset(offset, e); } - } else + } + else pool.submitAllAndWait(getApplyTasks(in, out, outputCol)); } catch(ExecutionException | InterruptedException e) { throw new DMLRuntimeException(e); } - finally{ + finally { pool.shutdown(); } } private void deriveNumRowPartitions(CacheBlock in, int k) { int[] numBlocks = new int[2]; - if (k == 1) { //single-threaded + if(k == 1) { // single-threaded numBlocks[0] = 1; numBlocks[1] = 1; _columnEncoders.forEach(e -> e.setNumPartitions(1, 1)); @@ -436,47 +453,47 @@ private void deriveNumRowPartitions(CacheBlock in, int k) { return; } // Read from global flags. These are set by the unit tests - if (ColumnEncoder.BUILD_ROW_BLOCKS_PER_COLUMN > 0) + if(ColumnEncoder.BUILD_ROW_BLOCKS_PER_COLUMN > 0) numBlocks[0] = ColumnEncoder.BUILD_ROW_BLOCKS_PER_COLUMN; - if (ColumnEncoder.APPLY_ROW_BLOCKS_PER_COLUMN > 0) + if(ColumnEncoder.APPLY_ROW_BLOCKS_PER_COLUMN > 0) numBlocks[1] = ColumnEncoder.APPLY_ROW_BLOCKS_PER_COLUMN; // Read from the config file if set. These overwrite the derived values. - if (numBlocks[0] == 0 && ConfigurationManager.getParallelBuildBlocks() > 0) + if(numBlocks[0] == 0 && ConfigurationManager.getParallelBuildBlocks() > 0) numBlocks[0] = ConfigurationManager.getParallelBuildBlocks(); - if (numBlocks[1] == 0 && ConfigurationManager.getParallelApplyBlocks() > 0) + if(numBlocks[1] == 0 && ConfigurationManager.getParallelApplyBlocks() > 0) numBlocks[1] = ConfigurationManager.getParallelApplyBlocks(); // Else, derive the optimum number of partitions int nRow = in.getNumRows(); - int nThread = OptimizerUtils.getTransformNumThreads(); //VCores - int minNumRows = 16000; //min rows per partition + int nThread = OptimizerUtils.getTransformNumThreads(); // VCores + int minNumRows = 16000; // min rows per partition List recodeEncoders = new ArrayList<>(); // Count #Builds and #Applies (= #Col) int nBuild = 0; - for (ColumnEncoderComposite e : _columnEncoders) - if (e.hasBuild()) { + for(ColumnEncoderComposite e : _columnEncoders) + if(e.hasBuild()) { nBuild++; - if (e.hasEncoder(ColumnEncoderRecode.class)) + if(e.hasEncoder(ColumnEncoderRecode.class)) recodeEncoders.add(e); } int nApply = in.getNumColumns(); // #BuildBlocks = (2 * #PhysicalCores)/#build - if (numBlocks[0] == 0 && nBuild > 0 && nBuild < nThread) - numBlocks[0] = Math.round(((float)nThread)/nBuild); + if(numBlocks[0] == 0 && nBuild > 0 && nBuild < nThread) + numBlocks[0] = Math.round(((float) nThread) / nBuild); // #ApplyBlocks = (4 * #PhysicalCores)/#apply - if (numBlocks[1] == 0 && nApply > 0 && nApply < nThread*2) - numBlocks[1] = Math.round(((float)nThread*2)/nApply); + if(numBlocks[1] == 0 && nApply > 0 && nApply < nThread * 2) + numBlocks[1] = Math.round(((float) nThread * 2) / nApply); // Reduce #blocks if #rows per partition is too small - while (numBlocks[0] > 1 && nRow/numBlocks[0] < minNumRows) + while(numBlocks[0] > 1 && nRow / numBlocks[0] < minNumRows) numBlocks[0]--; - while (numBlocks[1] > 1 && nRow/numBlocks[1] < minNumRows) + while(numBlocks[1] > 1 && nRow / numBlocks[1] < minNumRows) numBlocks[1]--; // Reduce #build blocks for the recoders if all don't fit in memory int rcdNumBuildBlks = numBlocks[0]; - if (numBlocks[0] > 1 && recodeEncoders.size() > 0) { + if(numBlocks[0] > 1 && recodeEncoders.size() > 0) { // Estimate recode map sizes estimateRCMapSize(in, recodeEncoders); // Memory budget for maps = 70% of heap - sizeof(input) @@ -484,7 +501,7 @@ private void deriveNumRowPartitions(CacheBlock in, int k) { // Worst case scenario: all partial maps contain all distinct values (if < #rows) long totMemOverhead = getTotalMemOverhead(in, rcdNumBuildBlks, recodeEncoders); // Reduce recode build blocks count till they fit in the memory budget - while (rcdNumBuildBlks > 1 && totMemOverhead > memBudget) { + while(rcdNumBuildBlks > 1 && totMemOverhead > memBudget) { rcdNumBuildBlks--; totMemOverhead = getTotalMemOverhead(in, rcdNumBuildBlks, recodeEncoders); // TODO: Reduce only the ones with large maps @@ -493,18 +510,19 @@ private void deriveNumRowPartitions(CacheBlock in, int k) { // TODO: If still don't fit, serialize the column encoders // Set to 1 if not set by the above logics - for (int i=0; i<2; i++) - if (numBlocks[i] == 0) - numBlocks[i] = 1; //default 1 + for(int i = 0; i < 2; i++) + if(numBlocks[i] == 0) + numBlocks[i] = 1; // default 1 _partitionDone = true; // Materialize the partition counts in the encoders _columnEncoders.forEach(e -> e.setNumPartitions(numBlocks[0], numBlocks[1])); - if (rcdNumBuildBlks > 0 && rcdNumBuildBlks != numBlocks[0]) { + if(rcdNumBuildBlks > 0 && rcdNumBuildBlks != numBlocks[0]) { int rcdNumBlocks = rcdNumBuildBlks; recodeEncoders.forEach(e -> e.setNumPartitions(rcdNumBlocks, numBlocks[1])); } - //System.out.println("Block count = ["+numBlocks[0]+", "+numBlocks[1]+"], Recode block count = "+rcdNumBuildBlks); + // System.out.println("Block count = ["+numBlocks[0]+", "+numBlocks[1]+"], Recode block count = + // "+rcdNumBuildBlks); } private void estimateRCMapSize(CacheBlock in, List rcList) { @@ -537,17 +555,17 @@ private void estimateRCMapSize(CacheBlock in, List rc // Estimate total memory overhead of the partial recode maps of all recoders private long getTotalMemOverhead(CacheBlock in, int nBuildpart, List rcEncoders) { long totMemOverhead = 0; - if (nBuildpart == 1) { + if(nBuildpart == 1) { // Sum the estimated map sizes totMemOverhead = rcEncoders.stream().mapToLong(ColumnEncoderComposite::getEstMetaSize).sum(); return totMemOverhead; } // Estimate map size of each partition and sum - for (ColumnEncoderComposite rce : rcEncoders) { - long avgEntrySize = rce.getEstMetaSize()/ rce.getEstNumDistincts(); - int partSize = in.getNumRows()/nBuildpart; - int partNumDist = Math.min(partSize, rce.getEstNumDistincts()); //#distincts not more than #rows - long allMapsSize = partNumDist * avgEntrySize * nBuildpart; //worst-case scenario + for(ColumnEncoderComposite rce : rcEncoders) { + long avgEntrySize = rce.getEstMetaSize() / rce.getEstNumDistincts(); + int partSize = in.getNumRows() / nBuildpart; + int partNumDist = Math.min(partSize, rce.getEstNumDistincts()); // #distincts not more than #rows + long allMapsSize = partNumDist * avgEntrySize * nBuildpart; // worst-case scenario totMemOverhead += allMapsSize; } return totMemOverhead; @@ -556,21 +574,16 @@ private long getTotalMemOverhead(CacheBlock in, int nBuildpart, List input, boolean hasDC, boolean hasWE, int distinctWE) { long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; if(output.isInSparseFormat()) { - if (MatrixBlock.DEFAULT_SPARSEBLOCK != SparseBlock.Type.CSR - && MatrixBlock.DEFAULT_SPARSEBLOCK != SparseBlock.Type.MCSR) - throw new RuntimeException("Transformapply is only supported for MCSR and CSR output matrix"); - //boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK == SparseBlock.Type.MCSR; - boolean mcsr = false; //force CSR for transformencode - if (mcsr) { + long nnz = (long) output.getNumRows() * input.getNumColumns(); + if(nnz > Integer.MAX_VALUE) { output.allocateBlock(); SparseBlock block = output.getSparseBlock(); - if (hasDC && OptimizerUtils.getTransformNumThreads()>1) { + if(hasDC && OptimizerUtils.getTransformNumThreads() > 1) { // DC forces a single threaded allocation after the build phase and // before the apply starts. Below code parallelizes sparse allocation. - IntStream.range(0, output.getNumRows()) - .parallel().forEach(r -> { + IntStream.range(0, output.getNumRows()).parallel().forEach(r -> { block.allocate(r, input.getNumColumns()); - ((SparseRowVector)block.get(r)).setSize(input.getNumColumns()); + ((SparseRowVector) block.get(r)).setSize(input.getNumColumns()); }); } else { @@ -581,19 +594,19 @@ private static void outputMatrixPreProcessing(MatrixBlock output, CacheBlock // Setting the size here makes it possible to run all sparse apply tasks without any sync // could become problematic if the input is very sparse since we allocate the same size as the input // should be fine in theory ;) - ((SparseRowVector)block.get(r)).setSize(input.getNumColumns()); + ((SparseRowVector) block.get(r)).setSize(input.getNumColumns()); } } } - else { //csr - int size = output.getNumRows() * input.getNumColumns(); + else { // csr + final int size = (int) nnz; SparseBlock csrblock = new SparseBlockCSR(output.getNumRows(), size, size); // Manually fill the row pointers based on nnzs/row (= #cols in the input) - // Not using the set() methods to 1) avoid binary search and shifting, + // Not using the set() methods to 1) avoid binary search and shifting, // 2) reduce thread contentions on the arrays - int[] rptr = ((SparseBlockCSR)csrblock).rowPointers(); - for (int i=0; i } if(DMLScript.STATISTICS) { - LOG.debug("Elapsed time for allocation: "+ ((double) System.nanoTime() - t0) / 1000000 + " ms"); - TransformStatistics.incOutMatrixPreProcessingTime(System.nanoTime()-t0); + LOG.debug("Elapsed time for allocation: " + ((double) System.nanoTime() - t0) / 1000000 + " ms"); + TransformStatistics.incOutMatrixPreProcessingTime(System.nanoTime() - t0); } } - private void outputMatrixPostProcessing(MatrixBlock output, int k){ + private void outputMatrixPostProcessing(MatrixBlock output, int k) { long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; - if(output.isInSparseFormat()){ - if (k == 1) + if(output.isInSparseFormat() && containsZeroOut()) { + if(k == 1) outputMatrixPostProcessingSingleThread(output); - else - outputMatrixPostProcessingParallel(output, k); - } - else { - output.recomputeNonZeros(k); + else + outputMatrixPostProcessingParallel(output, k); } - - if(DMLScript.STATISTICS) - TransformStatistics.incOutMatrixPostProcessingTime(System.nanoTime()-t0); - } - + output.recomputeNonZeros(k); - private void outputMatrixPostProcessingSingleThread(MatrixBlock output){ - Set indexSet = _columnEncoders.stream() - .map(ColumnEncoderComposite::getSparseRowsWZeros).flatMap(l -> { - if(l == null) - return null; - return l.stream(); - }).collect(Collectors.toSet()); + if(output.getNonZeros() < 0) + throw new DMLRuntimeException( + "Invalid assigned non zeros of transform encode output: " + output.getNonZeros()); - if(!indexSet.stream().allMatch(Objects::isNull)) { - for(Integer row : indexSet) - output.getSparseBlock().get(row).compact(); - } + if(DMLScript.STATISTICS) + TransformStatistics.incOutMatrixPostProcessingTime(System.nanoTime() - t0); + } - output.recomputeNonZeros(); + private void outputMatrixPostProcessingSingleThread(MatrixBlock output) { + final SparseBlock sb = output.getSparseBlock(); + if(sb instanceof SparseBlockMCSR) { + IntStream.range(0, output.getNumRows()).forEach(row -> { + sb.compact(row); + }); + } + else { + ((SparseBlockCSR) sb).compact(); + } } + private boolean containsZeroOut() { + for(ColumnEncoder e : _columnEncoders) + if(e.containsZeroOut()) + return true; + return false; + } private void outputMatrixPostProcessingParallel(MatrixBlock output, int k) { - ExecutorService myPool = CommonThreadPool.get(k); + final ExecutorService myPool = CommonThreadPool.get(k); try { - // Collect the row indices that need compaction - Set indexSet = myPool.submit(() -> _columnEncoders.stream().parallel() - .map(ColumnEncoderComposite::getSparseRowsWZeros).flatMap(l -> { - if(l == null) - return null; - return l.stream(); - }).collect(Collectors.toSet())).get(); - - // Check if the set is empty - boolean emptySet = myPool.submit(() -> indexSet.stream().parallel().allMatch(Objects::isNull)).get(); - - // Concurrently compact the rows - if(emptySet) { - myPool.submit(() -> { - indexSet.stream().parallel().forEach(row -> { - output.getSparseBlock().get(row).compact(); - }); - }).get(); - } + final SparseBlock sb = output.getSparseBlock(); + if(sb instanceof SparseBlockMCSR) { + myPool.submit(() -> { + IntStream.range(0, output.getNumRows()).parallel().forEach(row -> { + sb.compact(row); + }); + }).get(); + } + else { + ((SparseBlockCSR) sb).compact(); + } } catch(Exception ex) { throw new DMLRuntimeException(ex); @@ -676,8 +684,6 @@ private void outputMatrixPostProcessingParallel(MatrixBlock output, int k) { finally { myPool.shutdown(); } - - output.recomputeNonZeros(); } @Override @@ -699,20 +705,20 @@ public FrameBlock getMetaData(FrameBlock meta, int k) { if(meta == null) meta = new FrameBlock(_columnEncoders.size(), ValueType.STRING); this.allocateMetaData(meta); - if (k > 1) { + if(k > 1) { ExecutorService pool = CommonThreadPool.get(k); try { ArrayList> tasks = new ArrayList<>(); for(ColumnEncoder columnEncoder : _columnEncoders) tasks.add(new ColumnMetaDataTask<>(columnEncoder, meta)); List> taskret = pool.invokeAll(tasks); - for (Future task : taskret) - task.get(); + for(Future task : taskret) + task.get(); } catch(Exception ex) { throw new DMLRuntimeException(ex); } - finally{ + finally { pool.shutdown(); } } @@ -721,13 +727,13 @@ public FrameBlock getMetaData(FrameBlock meta, int k) { columnEncoder.getMetaData(meta); } - //_columnEncoders.stream().parallel().forEach(columnEncoder -> - // columnEncoder.getMetaData(meta)); + // _columnEncoders.stream().parallel().forEach(columnEncoder -> + // columnEncoder.getMetaData(meta)); if(_legacyOmit != null) _legacyOmit.getMetaData(meta); if(_legacyMVImpute != null) _legacyMVImpute.getMetaData(meta); - LOG.debug("Time spent getting metadata "+((double) System.nanoTime() - t0) / 1000000 + " ms"); + LOG.debug("Time spent getting metadata " + ((double) System.nanoTime() - t0) / 1000000 + " ms"); return meta; } @@ -741,7 +747,7 @@ public void initMetaData(FrameBlock meta) { _legacyMVImpute.initMetaData(meta); } - //pass down init to composite encoders + // pass down init to composite encoders public void initEmbeddings(MatrixBlock embeddings) { for(ColumnEncoder columnEncoder : _columnEncoders) columnEncoder.initEmbeddings(embeddings); @@ -906,7 +912,7 @@ public List> getEncoderTypes() { return getEncoderTypes(-1); } - public int getEstNNzRow(){ + public int getEstNNzRow() { int nnz = 0; for(int i = 0; i < _columnEncoders.size(); i++) nnz += _columnEncoders.get(i).getDomainSize(); @@ -964,8 +970,7 @@ public MultiColumnEncoder subRangeEncoder(IndexRange i return new MultiColumnEncoder( encoders.stream().map(e -> ((ColumnEncoderComposite) e)).collect(Collectors.toList())); else - return new MultiColumnEncoder( - encoders.stream().map(ColumnEncoderComposite::new).collect(Collectors.toList())); + return new MultiColumnEncoder(encoders.stream().map(ColumnEncoderComposite::new).collect(Collectors.toList())); } public void mergeReplace(MultiColumnEncoder multiEncoder) { @@ -1053,7 +1058,7 @@ public boolean hasLegacyEncoder() { return hasLegacyEncoder(EncoderMVImpute.class) || hasLegacyEncoder(EncoderOmit.class); } - public boolean isCompressedTransformEncode(CacheBlock in, boolean enabled){ + public boolean isCompressedTransformEncode(CacheBlock in, boolean enabled) { return (enabled || ConfigurationManager.getDMLConfig().getBooleanValue(DMLConfig.COMPRESSED_TRANSFORMENCODE)) && in instanceof FrameBlock && _colOffset == 0; } @@ -1185,10 +1190,10 @@ private static class ApplyTasksWrapperTask extends DependencyWrapperTask private final MatrixBlock _out; private final CacheBlock _in; /** Offset because of dummmy coding such that the column id is correct. */ - private int _offset = -1; + private int _offset = -1; - private ApplyTasksWrapperTask(ColumnEncoder encoder, CacheBlock in, - MatrixBlock out, DependencyThreadPool pool) { + private ApplyTasksWrapperTask(ColumnEncoder encoder, CacheBlock in, MatrixBlock out, + DependencyThreadPool pool) { super(pool); _encoder = encoder; _out = out; @@ -1263,8 +1268,8 @@ public Object call() throws Exception { private static class AllocMetaTask implements Callable { private final MultiColumnEncoder _encoder; private final FrameBlock _meta; - - private AllocMetaTask (MultiColumnEncoder encoder, FrameBlock meta) { + + private AllocMetaTask(MultiColumnEncoder encoder, FrameBlock meta) { _encoder = encoder; _meta = meta; } @@ -1280,7 +1285,7 @@ public String toString() { return getClass().getSimpleName(); } } - + private static class ColumnMetaDataTask implements Callable { private final T _colEncoder; private final FrameBlock _out; diff --git a/src/main/java/org/apache/sysds/runtime/util/CollectionUtils.java b/src/main/java/org/apache/sysds/runtime/util/CollectionUtils.java index 6b57bc5a616..af5fd594f98 100644 --- a/src/main/java/org/apache/sysds/runtime/util/CollectionUtils.java +++ b/src/main/java/org/apache/sysds/runtime/util/CollectionUtils.java @@ -31,7 +31,11 @@ import java.util.stream.Stream; import java.util.stream.StreamSupport; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + public class CollectionUtils { + final static Log LOG = LogFactory.getLog(CollectionUtils.class.getName()); @SafeVarargs public static List asList(List... inputs) { @@ -122,7 +126,7 @@ public static boolean intersect(Collection... inputs) { //if the item is in the seen set, return true if (probe.contains(item)) return true; - probe.addAll(inputs[i]); + probe.addAll(nonEmpty[i]); } return false; } diff --git a/src/main/java/org/apache/sysds/runtime/util/DataConverter.java b/src/main/java/org/apache/sysds/runtime/util/DataConverter.java index 86e7bde452b..87ed1ad68f2 100644 --- a/src/main/java/org/apache/sysds/runtime/util/DataConverter.java +++ b/src/main/java/org/apache/sysds/runtime/util/DataConverter.java @@ -1120,27 +1120,27 @@ public static String toString(FrameBlock fb, boolean sparse, String separator, S colLength = colsToPrint < clen ? colsToPrint : clen; //print frame header - sb.append("# FRAME: "); - sb.append("nrow = " + fb.getNumRows() + ", "); - sb.append("ncol = " + fb.getNumColumns() + lineseparator); + // sb.append("# FRAME: "); + // sb.append("nrow = " + fb.getNumRows() + ", "); + // sb.append("ncol = " + fb.getNumColumns() + lineseparator); //print column names - sb.append("#"); sb.append(separator); - for( int j=0; j len){ + // block until buffer is free to use + _locks[_pos].get(); + System.arraycopy(b, off, b_pos, 0, len); + // submit write request + _locks[_pos] = _pool.submit(new WriteTask(b_pos, len)); + // copy for asynchronous write because b is reused higher up + _pos = (_pos+1) % _buff.length; + } + else{ + // we already have the byte array in hand, but we do not do it async + // since there would be no guarantee that the caller does not modify the array. + writeBuffer(b, off, len); + } } } catch(Exception ex) { diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java index b46792da029..53c0caec53e 100644 --- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java +++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java @@ -842,10 +842,10 @@ public static boolean isNonZero(Object obj) { } } - public static int computeNnz(double[] a, int ai, int len) { + public static final int computeNnz(final double[] a,final int ai,final int len) { int lnnz = 0; for( int i=ai; i + * Thus, the powers of two are not a concern since they can be represented exactly using the binary notation, only + * the powers of five affect the binary significand. + *

+ * The mantissas of powers of ten from -308 to 308, extended out to sixty-four bits. The array contains the powers of + * ten approximated as a 64-bit mantissa. It goes from 10^{@value #DOUBLE_MIN_EXPONENT_POWER_OF_TEN} to + * 10^{@value #DOUBLE_MAX_EXPONENT_POWER_OF_TEN} (inclusively). The mantissa is truncated, and never rounded up. Uses + * about 5 KB. + *

+ * + *

+	 * long getMantissaHigh(int q) {
+	 *  MANTISSA_64[q - SMALLEST_POWER_OF_TEN];
+	 * }
+	 * 
+ */ + static final long[] MANTISSA_64 = {0xa5ced43b7e3e9188L, 0xcf42894a5dce35eaL, 0x818995ce7aa0e1b2L, + 0xa1ebfb4219491a1fL, 0xca66fa129f9b60a6L, 0xfd00b897478238d0L, 0x9e20735e8cb16382L, 0xc5a890362fddbc62L, + 0xf712b443bbd52b7bL, 0x9a6bb0aa55653b2dL, 0xc1069cd4eabe89f8L, 0xf148440a256e2c76L, 0x96cd2a865764dbcaL, + 0xbc807527ed3e12bcL, 0xeba09271e88d976bL, 0x93445b8731587ea3L, 0xb8157268fdae9e4cL, 0xe61acf033d1a45dfL, + 0x8fd0c16206306babL, 0xb3c4f1ba87bc8696L, 0xe0b62e2929aba83cL, 0x8c71dcd9ba0b4925L, 0xaf8e5410288e1b6fL, + 0xdb71e91432b1a24aL, 0x892731ac9faf056eL, 0xab70fe17c79ac6caL, 0xd64d3d9db981787dL, 0x85f0468293f0eb4eL, + 0xa76c582338ed2621L, 0xd1476e2c07286faaL, 0x82cca4db847945caL, 0xa37fce126597973cL, 0xcc5fc196fefd7d0cL, + 0xff77b1fcbebcdc4fL, 0x9faacf3df73609b1L, 0xc795830d75038c1dL, 0xf97ae3d0d2446f25L, 0x9becce62836ac577L, + 0xc2e801fb244576d5L, 0xf3a20279ed56d48aL, 0x9845418c345644d6L, 0xbe5691ef416bd60cL, 0xedec366b11c6cb8fL, + 0x94b3a202eb1c3f39L, 0xb9e08a83a5e34f07L, 0xe858ad248f5c22c9L, 0x91376c36d99995beL, 0xb58547448ffffb2dL, + 0xe2e69915b3fff9f9L, 0x8dd01fad907ffc3bL, 0xb1442798f49ffb4aL, 0xdd95317f31c7fa1dL, 0x8a7d3eef7f1cfc52L, + 0xad1c8eab5ee43b66L, 0xd863b256369d4a40L, 0x873e4f75e2224e68L, 0xa90de3535aaae202L, 0xd3515c2831559a83L, + 0x8412d9991ed58091L, 0xa5178fff668ae0b6L, 0xce5d73ff402d98e3L, 0x80fa687f881c7f8eL, 0xa139029f6a239f72L, + 0xc987434744ac874eL, 0xfbe9141915d7a922L, 0x9d71ac8fada6c9b5L, 0xc4ce17b399107c22L, 0xf6019da07f549b2bL, + 0x99c102844f94e0fbL, 0xc0314325637a1939L, 0xf03d93eebc589f88L, 0x96267c7535b763b5L, 0xbbb01b9283253ca2L, + 0xea9c227723ee8bcbL, 0x92a1958a7675175fL, 0xb749faed14125d36L, 0xe51c79a85916f484L, 0x8f31cc0937ae58d2L, + 0xb2fe3f0b8599ef07L, 0xdfbdcece67006ac9L, 0x8bd6a141006042bdL, 0xaecc49914078536dL, 0xda7f5bf590966848L, + 0x888f99797a5e012dL, 0xaab37fd7d8f58178L, 0xd5605fcdcf32e1d6L, 0x855c3be0a17fcd26L, 0xa6b34ad8c9dfc06fL, + 0xd0601d8efc57b08bL, 0x823c12795db6ce57L, 0xa2cb1717b52481edL, 0xcb7ddcdda26da268L, 0xfe5d54150b090b02L, + 0x9efa548d26e5a6e1L, 0xc6b8e9b0709f109aL, 0xf867241c8cc6d4c0L, 0x9b407691d7fc44f8L, 0xc21094364dfb5636L, + 0xf294b943e17a2bc4L, 0x979cf3ca6cec5b5aL, 0xbd8430bd08277231L, 0xece53cec4a314ebdL, 0x940f4613ae5ed136L, + 0xb913179899f68584L, 0xe757dd7ec07426e5L, 0x9096ea6f3848984fL, 0xb4bca50b065abe63L, 0xe1ebce4dc7f16dfbL, + 0x8d3360f09cf6e4bdL, 0xb080392cc4349decL, 0xdca04777f541c567L, 0x89e42caaf9491b60L, 0xac5d37d5b79b6239L, + 0xd77485cb25823ac7L, 0x86a8d39ef77164bcL, 0xa8530886b54dbdebL, 0xd267caa862a12d66L, 0x8380dea93da4bc60L, + 0xa46116538d0deb78L, 0xcd795be870516656L, 0x806bd9714632dff6L, 0xa086cfcd97bf97f3L, 0xc8a883c0fdaf7df0L, + 0xfad2a4b13d1b5d6cL, 0x9cc3a6eec6311a63L, 0xc3f490aa77bd60fcL, 0xf4f1b4d515acb93bL, 0x991711052d8bf3c5L, + 0xbf5cd54678eef0b6L, 0xef340a98172aace4L, 0x9580869f0e7aac0eL, 0xbae0a846d2195712L, 0xe998d258869facd7L, + 0x91ff83775423cc06L, 0xb67f6455292cbf08L, 0xe41f3d6a7377eecaL, 0x8e938662882af53eL, 0xb23867fb2a35b28dL, + 0xdec681f9f4c31f31L, 0x8b3c113c38f9f37eL, 0xae0b158b4738705eL, 0xd98ddaee19068c76L, 0x87f8a8d4cfa417c9L, + 0xa9f6d30a038d1dbcL, 0xd47487cc8470652bL, 0x84c8d4dfd2c63f3bL, 0xa5fb0a17c777cf09L, 0xcf79cc9db955c2ccL, + 0x81ac1fe293d599bfL, 0xa21727db38cb002fL, 0xca9cf1d206fdc03bL, 0xfd442e4688bd304aL, 0x9e4a9cec15763e2eL, + 0xc5dd44271ad3cdbaL, 0xf7549530e188c128L, 0x9a94dd3e8cf578b9L, 0xc13a148e3032d6e7L, 0xf18899b1bc3f8ca1L, + 0x96f5600f15a7b7e5L, 0xbcb2b812db11a5deL, 0xebdf661791d60f56L, 0x936b9fcebb25c995L, 0xb84687c269ef3bfbL, + 0xe65829b3046b0afaL, 0x8ff71a0fe2c2e6dcL, 0xb3f4e093db73a093L, 0xe0f218b8d25088b8L, 0x8c974f7383725573L, + 0xafbd2350644eeacfL, 0xdbac6c247d62a583L, 0x894bc396ce5da772L, 0xab9eb47c81f5114fL, 0xd686619ba27255a2L, + 0x8613fd0145877585L, 0xa798fc4196e952e7L, 0xd17f3b51fca3a7a0L, 0x82ef85133de648c4L, 0xa3ab66580d5fdaf5L, + 0xcc963fee10b7d1b3L, 0xffbbcfe994e5c61fL, 0x9fd561f1fd0f9bd3L, 0xc7caba6e7c5382c8L, 0xf9bd690a1b68637bL, + 0x9c1661a651213e2dL, 0xc31bfa0fe5698db8L, 0xf3e2f893dec3f126L, 0x986ddb5c6b3a76b7L, 0xbe89523386091465L, + 0xee2ba6c0678b597fL, 0x94db483840b717efL, 0xba121a4650e4ddebL, 0xe896a0d7e51e1566L, 0x915e2486ef32cd60L, + 0xb5b5ada8aaff80b8L, 0xe3231912d5bf60e6L, 0x8df5efabc5979c8fL, 0xb1736b96b6fd83b3L, 0xddd0467c64bce4a0L, + 0x8aa22c0dbef60ee4L, 0xad4ab7112eb3929dL, 0xd89d64d57a607744L, 0x87625f056c7c4a8bL, 0xa93af6c6c79b5d2dL, + 0xd389b47879823479L, 0x843610cb4bf160cbL, 0xa54394fe1eedb8feL, 0xce947a3da6a9273eL, 0x811ccc668829b887L, + 0xa163ff802a3426a8L, 0xc9bcff6034c13052L, 0xfc2c3f3841f17c67L, 0x9d9ba7832936edc0L, 0xc5029163f384a931L, + 0xf64335bcf065d37dL, 0x99ea0196163fa42eL, 0xc06481fb9bcf8d39L, 0xf07da27a82c37088L, 0x964e858c91ba2655L, + 0xbbe226efb628afeaL, 0xeadab0aba3b2dbe5L, 0x92c8ae6b464fc96fL, 0xb77ada0617e3bbcbL, 0xe55990879ddcaabdL, + 0x8f57fa54c2a9eab6L, 0xb32df8e9f3546564L, 0xdff9772470297ebdL, 0x8bfbea76c619ef36L, 0xaefae51477a06b03L, + 0xdab99e59958885c4L, 0x88b402f7fd75539bL, 0xaae103b5fcd2a881L, 0xd59944a37c0752a2L, 0x857fcae62d8493a5L, + 0xa6dfbd9fb8e5b88eL, 0xd097ad07a71f26b2L, 0x825ecc24c873782fL, 0xa2f67f2dfa90563bL, 0xcbb41ef979346bcaL, + 0xfea126b7d78186bcL, 0x9f24b832e6b0f436L, 0xc6ede63fa05d3143L, 0xf8a95fcf88747d94L, 0x9b69dbe1b548ce7cL, + 0xc24452da229b021bL, 0xf2d56790ab41c2a2L, 0x97c560ba6b0919a5L, 0xbdb6b8e905cb600fL, 0xed246723473e3813L, + 0x9436c0760c86e30bL, 0xb94470938fa89bceL, 0xe7958cb87392c2c2L, 0x90bd77f3483bb9b9L, 0xb4ecd5f01a4aa828L, + 0xe2280b6c20dd5232L, 0x8d590723948a535fL, 0xb0af48ec79ace837L, 0xdcdb1b2798182244L, 0x8a08f0f8bf0f156bL, + 0xac8b2d36eed2dac5L, 0xd7adf884aa879177L, 0x86ccbb52ea94baeaL, 0xa87fea27a539e9a5L, 0xd29fe4b18e88640eL, + 0x83a3eeeef9153e89L, 0xa48ceaaab75a8e2bL, 0xcdb02555653131b6L, 0x808e17555f3ebf11L, 0xa0b19d2ab70e6ed6L, + 0xc8de047564d20a8bL, 0xfb158592be068d2eL, 0x9ced737bb6c4183dL, 0xc428d05aa4751e4cL, 0xf53304714d9265dfL, + 0x993fe2c6d07b7fabL, 0xbf8fdb78849a5f96L, 0xef73d256a5c0f77cL, 0x95a8637627989aadL, 0xbb127c53b17ec159L, + 0xe9d71b689dde71afL, 0x9226712162ab070dL, 0xb6b00d69bb55c8d1L, 0xe45c10c42a2b3b05L, 0x8eb98a7a9a5b04e3L, + 0xb267ed1940f1c61cL, 0xdf01e85f912e37a3L, 0x8b61313bbabce2c6L, 0xae397d8aa96c1b77L, 0xd9c7dced53c72255L, + 0x881cea14545c7575L, 0xaa242499697392d2L, 0xd4ad2dbfc3d07787L, 0x84ec3c97da624ab4L, 0xa6274bbdd0fadd61L, + 0xcfb11ead453994baL, 0x81ceb32c4b43fcf4L, 0xa2425ff75e14fc31L, 0xcad2f7f5359a3b3eL, 0xfd87b5f28300ca0dL, + 0x9e74d1b791e07e48L, 0xc612062576589ddaL, 0xf79687aed3eec551L, 0x9abe14cd44753b52L, 0xc16d9a0095928a27L, + 0xf1c90080baf72cb1L, 0x971da05074da7beeL, 0xbce5086492111aeaL, 0xec1e4a7db69561a5L, 0x9392ee8e921d5d07L, + 0xb877aa3236a4b449L, 0xe69594bec44de15bL, 0x901d7cf73ab0acd9L, 0xb424dc35095cd80fL, 0xe12e13424bb40e13L, + 0x8cbccc096f5088cbL, 0xafebff0bcb24aafeL, 0xdbe6fecebdedd5beL, 0x89705f4136b4a597L, 0xabcc77118461cefcL, + 0xd6bf94d5e57a42bcL, 0x8637bd05af6c69b5L, 0xa7c5ac471b478423L, 0xd1b71758e219652bL, 0x83126e978d4fdf3bL, + 0xa3d70a3d70a3d70aL, 0xccccccccccccccccL, 0x8000000000000000L, 0xa000000000000000L, 0xc800000000000000L, + 0xfa00000000000000L, 0x9c40000000000000L, 0xc350000000000000L, 0xf424000000000000L, 0x9896800000000000L, + 0xbebc200000000000L, 0xee6b280000000000L, 0x9502f90000000000L, 0xba43b74000000000L, 0xe8d4a51000000000L, + 0x9184e72a00000000L, 0xb5e620f480000000L, 0xe35fa931a0000000L, 0x8e1bc9bf04000000L, 0xb1a2bc2ec5000000L, + 0xde0b6b3a76400000L, 0x8ac7230489e80000L, 0xad78ebc5ac620000L, 0xd8d726b7177a8000L, 0x878678326eac9000L, + 0xa968163f0a57b400L, 0xd3c21bcecceda100L, 0x84595161401484a0L, 0xa56fa5b99019a5c8L, 0xcecb8f27f4200f3aL, + 0x813f3978f8940984L, 0xa18f07d736b90be5L, 0xc9f2c9cd04674edeL, 0xfc6f7c4045812296L, 0x9dc5ada82b70b59dL, + 0xc5371912364ce305L, 0xf684df56c3e01bc6L, 0x9a130b963a6c115cL, 0xc097ce7bc90715b3L, 0xf0bdc21abb48db20L, + 0x96769950b50d88f4L, 0xbc143fa4e250eb31L, 0xeb194f8e1ae525fdL, 0x92efd1b8d0cf37beL, 0xb7abc627050305adL, + 0xe596b7b0c643c719L, 0x8f7e32ce7bea5c6fL, 0xb35dbf821ae4f38bL, 0xe0352f62a19e306eL, 0x8c213d9da502de45L, + 0xaf298d050e4395d6L, 0xdaf3f04651d47b4cL, 0x88d8762bf324cd0fL, 0xab0e93b6efee0053L, 0xd5d238a4abe98068L, + 0x85a36366eb71f041L, 0xa70c3c40a64e6c51L, 0xd0cf4b50cfe20765L, 0x82818f1281ed449fL, 0xa321f2d7226895c7L, + 0xcbea6f8ceb02bb39L, 0xfee50b7025c36a08L, 0x9f4f2726179a2245L, 0xc722f0ef9d80aad6L, 0xf8ebad2b84e0d58bL, + 0x9b934c3b330c8577L, 0xc2781f49ffcfa6d5L, 0xf316271c7fc3908aL, 0x97edd871cfda3a56L, 0xbde94e8e43d0c8ecL, + 0xed63a231d4c4fb27L, 0x945e455f24fb1cf8L, 0xb975d6b6ee39e436L, 0xe7d34c64a9c85d44L, 0x90e40fbeea1d3a4aL, + 0xb51d13aea4a488ddL, 0xe264589a4dcdab14L, 0x8d7eb76070a08aecL, 0xb0de65388cc8ada8L, 0xdd15fe86affad912L, + 0x8a2dbf142dfcc7abL, 0xacb92ed9397bf996L, 0xd7e77a8f87daf7fbL, 0x86f0ac99b4e8dafdL, 0xa8acd7c0222311bcL, + 0xd2d80db02aabd62bL, 0x83c7088e1aab65dbL, 0xa4b8cab1a1563f52L, 0xcde6fd5e09abcf26L, 0x80b05e5ac60b6178L, + 0xa0dc75f1778e39d6L, 0xc913936dd571c84cL, 0xfb5878494ace3a5fL, 0x9d174b2dcec0e47bL, 0xc45d1df942711d9aL, + 0xf5746577930d6500L, 0x9968bf6abbe85f20L, 0xbfc2ef456ae276e8L, 0xefb3ab16c59b14a2L, 0x95d04aee3b80ece5L, + 0xbb445da9ca61281fL, 0xea1575143cf97226L, 0x924d692ca61be758L, 0xb6e0c377cfa2e12eL, 0xe498f455c38b997aL, + 0x8edf98b59a373fecL, 0xb2977ee300c50fe7L, 0xdf3d5e9bc0f653e1L, 0x8b865b215899f46cL, 0xae67f1e9aec07187L, + 0xda01ee641a708de9L, 0x884134fe908658b2L, 0xaa51823e34a7eedeL, 0xd4e5e2cdc1d1ea96L, 0x850fadc09923329eL, + 0xa6539930bf6bff45L, 0xcfe87f7cef46ff16L, 0x81f14fae158c5f6eL, 0xa26da3999aef7749L, 0xcb090c8001ab551cL, + 0xfdcb4fa002162a63L, 0x9e9f11c4014dda7eL, 0xc646d63501a1511dL, 0xf7d88bc24209a565L, 0x9ae757596946075fL, + 0xc1a12d2fc3978937L, 0xf209787bb47d6b84L, 0x9745eb4d50ce6332L, 0xbd176620a501fbffL, 0xec5d3fa8ce427affL, + 0x93ba47c980e98cdfL, 0xb8a8d9bbe123f017L, 0xe6d3102ad96cec1dL, 0x9043ea1ac7e41392L, 0xb454e4a179dd1877L, + 0xe16a1dc9d8545e94L, 0x8ce2529e2734bb1dL, 0xb01ae745b101e9e4L, 0xdc21a1171d42645dL, 0x899504ae72497ebaL, + 0xabfa45da0edbde69L, 0xd6f8d7509292d603L, 0x865b86925b9bc5c2L, 0xa7f26836f282b732L, 0xd1ef0244af2364ffL, + 0x8335616aed761f1fL, 0xa402b9c5a8d3a6e7L, 0xcd036837130890a1L, 0x802221226be55a64L, 0xa02aa96b06deb0fdL, + 0xc83553c5c8965d3dL, 0xfa42a8b73abbf48cL, 0x9c69a97284b578d7L, 0xc38413cf25e2d70dL, 0xf46518c2ef5b8cd1L, + 0x98bf2f79d5993802L, 0xbeeefb584aff8603L, 0xeeaaba2e5dbf6784L, 0x952ab45cfa97a0b2L, 0xba756174393d88dfL, + 0xe912b9d1478ceb17L, 0x91abb422ccb812eeL, 0xb616a12b7fe617aaL, 0xe39c49765fdf9d94L, 0x8e41ade9fbebc27dL, + 0xb1d219647ae6b31cL, 0xde469fbd99a05fe3L, 0x8aec23d680043beeL, 0xada72ccc20054ae9L, 0xd910f7ff28069da4L, + 0x87aa9aff79042286L, 0xa99541bf57452b28L, 0xd3fa922f2d1675f2L, 0x847c9b5d7c2e09b7L, 0xa59bc234db398c25L, + 0xcf02b2c21207ef2eL, 0x8161afb94b44f57dL, 0xa1ba1ba79e1632dcL, 0xca28a291859bbf93L, 0xfcb2cb35e702af78L, + 0x9defbf01b061adabL, 0xc56baec21c7a1916L, 0xf6c69a72a3989f5bL, 0x9a3c2087a63f6399L, 0xc0cb28a98fcf3c7fL, + 0xf0fdf2d3f3c30b9fL, 0x969eb7c47859e743L, 0xbc4665b596706114L, 0xeb57ff22fc0c7959L, 0x9316ff75dd87cbd8L, + 0xb7dcbf5354e9beceL, 0xe5d3ef282a242e81L, 0x8fa475791a569d10L, 0xb38d92d760ec4455L, 0xe070f78d3927556aL, + 0x8c469ab843b89562L, 0xaf58416654a6babbL, 0xdb2e51bfe9d0696aL, 0x88fcf317f22241e2L, 0xab3c2fddeeaad25aL, + 0xd60b3bd56a5586f1L, 0x85c7056562757456L, 0xa738c6bebb12d16cL, 0xd106f86e69d785c7L, 0x82a45b450226b39cL, + 0xa34d721642b06084L, 0xcc20ce9bd35c78a5L, 0xff290242c83396ceL, 0x9f79a169bd203e41L, 0xc75809c42c684dd1L, + 0xf92e0c3537826145L, 0x9bbcc7a142b17ccbL, 0xc2abf989935ddbfeL, 0xf356f7ebf83552feL, 0x98165af37b2153deL, + 0xbe1bf1b059e9a8d6L, 0xeda2ee1c7064130cL, 0x9485d4d1c63e8be7L, 0xb9a74a0637ce2ee1L, 0xe8111c87c5c1ba99L, + 0x910ab1d4db9914a0L, 0xb54d5e4a127f59c8L, 0xe2a0b5dc971f303aL, 0x8da471a9de737e24L, 0xb10d8e1456105dadL, + 0xdd50f1996b947518L, 0x8a5296ffe33cc92fL, 0xace73cbfdc0bfb7bL, 0xd8210befd30efa5aL, 0x8714a775e3e95c78L, + 0xa8d9d1535ce3b396L, 0xd31045a8341ca07cL, 0x83ea2b892091e44dL, 0xa4e4b66b68b65d60L, 0xce1de40642e3f4b9L, + 0x80d2ae83e9ce78f3L, 0xa1075a24e4421730L, 0xc94930ae1d529cfcL, 0xfb9b7cd9a4a7443cL, 0x9d412e0806e88aa5L, + 0xc491798a08a2ad4eL, 0xf5b5d7ec8acb58a2L, 0x9991a6f3d6bf1765L, 0xbff610b0cc6edd3fL, 0xeff394dcff8a948eL, + 0x95f83d0a1fb69cd9L, 0xbb764c4ca7a4440fL, 0xea53df5fd18d5513L, 0x92746b9be2f8552cL, 0xb7118682dbb66a77L, + 0xe4d5e82392a40515L, 0x8f05b1163ba6832dL, 0xb2c71d5bca9023f8L, 0xdf78e4b2bd342cf6L, 0x8bab8eefb6409c1aL, + 0xae9672aba3d0c320L, 0xda3c0f568cc4f3e8L, 0x8865899617fb1871L, 0xaa7eebfb9df9de8dL, 0xd51ea6fa85785631L, + 0x8533285c936b35deL, 0xa67ff273b8460356L, 0xd01fef10a657842cL, 0x8213f56a67f6b29bL, 0xa298f2c501f45f42L, + 0xcb3f2f7642717713L, 0xfe0efb53d30dd4d7L, 0x9ec95d1463e8a506L, 0xc67bb4597ce2ce48L, 0xf81aa16fdc1b81daL, + 0x9b10a4e5e9913128L, 0xc1d4ce1f63f57d72L, 0xf24a01a73cf2dccfL, 0x976e41088617ca01L, 0xbd49d14aa79dbc82L, + 0xec9c459d51852ba2L, 0x93e1ab8252f33b45L, 0xb8da1662e7b00a17L, 0xe7109bfba19c0c9dL, 0x906a617d450187e2L, + 0xb484f9dc9641e9daL, 0xe1a63853bbd26451L, 0x8d07e33455637eb2L, 0xb049dc016abc5e5fL, 0xdc5c5301c56b75f7L, + 0x89b9b3e11b6329baL, 0xac2820d9623bf429L, 0xd732290fbacaf133L, 0x867f59a9d4bed6c0L, 0xa81f301449ee8c70L, + 0xd226fc195c6a2f8cL, 0x83585d8fd9c25db7L, 0xa42e74f3d032f525L, 0xcd3a1230c43fb26fL, 0x80444b5e7aa7cf85L, + 0xa0555e361951c366L, 0xc86ab5c39fa63440L, 0xfa856334878fc150L, 0x9c935e00d4b9d8d2L, 0xc3b8358109e84f07L, + 0xf4a642e14c6262c8L, 0x98e7e9cccfbd7dbdL, 0xbf21e44003acdd2cL, 0xeeea5d5004981478L, 0x95527a5202df0ccbL, + 0xbaa718e68396cffdL, 0xe950df20247c83fdL, 0x91d28b7416cdd27eL, 0xb6472e511c81471dL, 0xe3d8f9e563a198e5L, + 0x8e679c2f5e44ff8fL}; + + public static double parseFloatingPointLiteral(String str, int offset, int endIndex) { + if(endIndex > 100) + return Double.parseDouble(str); + // Skip leading whitespace + int index = skipWhitespace(str, offset, endIndex); + char ch = str.charAt(index); + + // Parse optional sign + final boolean isNegative = ch == '-'; + if(isNegative || ch == '+') { + ch = charAt(str, ++index, endIndex); + } + + // Parse NaN or Infinity (this occurs rarely) + if(ch >= 'I') + return Double.parseDouble(str); + else if(str.charAt(endIndex - 1) >= 'a') + return Double.parseDouble(str); + + final double val = parseDecFloatLiteral(str, index, offset, endIndex); + if(Double.isNaN(val)) + return Double.parseDouble(str); + return isNegative ? -val : val; + } + + private static void illegal() { + throw new NumberFormatException("illegal syntax"); + } + + private static int inc(int significand, char ch) { + return 10 * significand + ch - '0'; + } + + private static long inc(long significand, char ch) { + return 10 * significand + ch - '0'; + } + + private static double parseDecFloatLiteral(String str, int index, int startIndex, int endIndex) { + + long significand = 0; + final int significandStartIndex = index; + int virtualIndexOfPoint = -1; + char ch = 0; + for(; index < endIndex; index++) { + ch = str.charAt(index); + if(isDigit(ch)) { + // This might overflow, we deal with it later. + significand = inc(significand, ch); + } + else if(ch == '.') { + if(virtualIndexOfPoint >= 0) + illegal(); + virtualIndexOfPoint = index; + } + else if((ch | 0x20) == 'e') { + break; // case of e + } + else + illegal(); + } + + final int digitCount; + final int significandEndIndex = index; + int exponent; + if(virtualIndexOfPoint < 0) { + digitCount = significandEndIndex - significandStartIndex; + virtualIndexOfPoint = significandEndIndex; + exponent = 0; + } + else { + digitCount = significandEndIndex - significandStartIndex - 1; + exponent = virtualIndexOfPoint - significandEndIndex + 1; + } + + if((ch | 0x20) == 'e') + return handleExponent(str, index, startIndex, endIndex, significandStartIndex, significandEndIndex, + virtualIndexOfPoint, significand, ch, exponent, digitCount); + else + return handleOverflow(str, index, startIndex, endIndex, digitCount, significand, significandStartIndex, + significandEndIndex, exponent, 0, virtualIndexOfPoint); + + } + + private static Double handleExponent(String str, int index, int startIndex, int endIndex, int sigStart, int sigEnd, + final int virtualIndexOfPoint, final long significand, char ch, int exponent, int digitCount) { + + // Parse exponent number + // --------------------- + int expNumber = 0; + ch = charAt(str, ++index, endIndex); + final boolean isExponentNegative = ch == '-'; + if(isExponentNegative || ch == '+') { + ch = charAt(str, ++index, endIndex); + } + if(!isDigit(ch)) + illegal(); + do { + // Guard against overflow + if(expNumber < MAX_EXPONENT_NUMBER) { + expNumber = inc(expNumber, ch); + } + ch = charAt(str, ++index, endIndex); + } + while(isDigit(ch)); + if(isExponentNegative) { + expNumber = -expNumber; + } + exponent += expNumber; + + return handleOverflow(str, index, startIndex, endIndex, digitCount, significand, sigStart, sigEnd, exponent, + expNumber, virtualIndexOfPoint); + } + + private static double handleOverflow(String str, int index, int startIndex, int endIndex, int digitCount, + long significand, int sigStart, int sigEnd, int exponent, int expNumber, int virtualIndexOfPoint) { + if(digitCount > 19) { + char ch; + int skipCountInTruncatedDigits = 0;// counts +1 if we skipped over the decimal point + significand = 0; + for(index = sigStart; index < sigEnd; index++) { + ch = str.charAt(index); + if(ch == '.') { + skipCountInTruncatedDigits++; + } + else { + if(Long.compareUnsigned(significand, MINIMAL_NINETEEN_DIGIT_INTEGER) < 0) { + significand = inc(significand, ch); + } + else { + break; + } + } + } + final boolean isSignificandTruncated = index < sigEnd; + final int exponentOfTruncatedSignificand = virtualIndexOfPoint - index + skipCountInTruncatedDigits + + expNumber; + return tryDecFloatToDoubleTruncated(significand, exponent, isSignificandTruncated, + exponentOfTruncatedSignificand); + } + else { + return tryDecFloatToDoubleTruncated(significand, exponent, false, 0); + } + } + + private static int skipWhitespace(String str, int index, int endIndex) { + while(index < endIndex && str.charAt(index) <= ' ') { + index++; + } + return index; + } + + static double tryDecFloatToDoubleTruncated(long significand, int exponent, boolean isSignificandTruncated, + final int exponentOfTruncatedSignificand) { + + final double result; + if(isSignificandTruncated) { + // We have too many digits. We may have to round up. + // To know whether rounding up is needed, we may have to examine up to 768 digits. + + // There are cases, in which rounding has no effect. + if(DOUBLE_MIN_EXPONENT_POWER_OF_TEN <= exponentOfTruncatedSignificand && + exponentOfTruncatedSignificand <= DOUBLE_MAX_EXPONENT_POWER_OF_TEN) { + double withoutRounding = tryDecToDoubleWithFastAlgorithm(significand, exponentOfTruncatedSignificand); + double roundedUp = tryDecToDoubleWithFastAlgorithm(significand + 1, exponentOfTruncatedSignificand); + if(!Double.isNaN(withoutRounding) && roundedUp == withoutRounding) { + return withoutRounding; + } + } + + // We have to take a slow path. + result = Double.NaN; + + } + else if(DOUBLE_MIN_EXPONENT_POWER_OF_TEN <= exponent && exponent <= DOUBLE_MAX_EXPONENT_POWER_OF_TEN) { + result = tryDecToDoubleWithFastAlgorithm(significand, exponent); + } + else { + result = Double.NaN; + } + return result; + } + + static double shortcut(long significand, int power) { + // convert the integer into a double. This is lossless since + // 0 <= i <= 2^53 - 1. + double d = (double) significand; + // + // The general idea is as follows. + // If 0 <= s < 2^53 and if 10^0 <= p <= 10^22 then + // 1) Both s and p can be represented exactly as 64-bit floating-point values + // 2) Because s and p can be represented exactly as floating-point values, + // then s * p and s / p will produce correctly rounded values. + // + if(power < 0) { + d = d / DOUBLE_POWERS_OF_TEN[-power]; + } + else { + d = d * DOUBLE_POWERS_OF_TEN[power]; + } + return d; + } + + static double tryDecToDoubleWithFastAlgorithm(long significand, int power) { + // we start with a fast path + // It was described in Clinger WD (1990). + if(-22 <= power && power <= 22 && Long.compareUnsigned(significand, (1L << DOUBLE_SIGNIFICAND_WIDTH) - 1) <= 0) { + return shortcut(significand, power); + } + + // The fast path has now failed, so we are falling back on the slower path. + + // We are going to need to do some 64-bit arithmetic to get a more precise product. + // We use a table lookup approach. + // It is safe because + // power >= DOUBLE_MIN_EXPONENT_POWER_OF_TEN + // and power <= DOUBLE_MAX_EXPONENT_POWER_OF_TEN + // We recover the mantissa of the power, it has a leading 1. It is always + // rounded down. + long factorMantissa = MANTISSA_64[power - DOUBLE_MIN_EXPONENT_POWER_OF_TEN]; + + // The exponent is 1023 + 64 + power + floor(log(5**power)/log(2)). + // + // 1023 is the exponent bias. + // The 64 comes from the fact that we use a 64-bit word. + // + // Computing floor(log(5**power)/log(2)) could be + // slow. Instead, we use a fast function. + // + // For power in (-400,350), we have that + // (((152170 + 65536) * power ) >> 16); + // is equal to + // floor(log(5**power)/log(2)) + power when power >= 0, + // and it is equal to + // ceil(log(5**-power)/log(2)) + power when power < 0 + // + // + // The 65536 is (1<<16) and corresponds to + // (65536 * power) >> 16 ---> power + // + // ((152170 * power ) >> 16) is equal to + // floor(log(5**power)/log(2)) + // + // Note that this is not magic: 152170/(1<<16) is + // approximately equal to log(5)/log(2). + // The 1<<16 value is a power of two; we could use a + // larger power of 2 if we wanted to. + // + long exponent = (((152170L + 65536L) * power) >> 16) + DOUBLE_EXPONENT_BIAS + 64; + // We want the most significant bit of digits to be 1. Shift if needed. + int lz = Long.numberOfLeadingZeros(significand); + long shiftedSignificand = significand << lz; + // We want the most significant 64 bits of the product. We know + // this will be non-zero because the most significant bit of digits is + // 1. + UInt128 product = fullMultiplication(shiftedSignificand, factorMantissa); + long upper = product.high; + + // The computed 'product' is always sufficient. + // Mathematical proof: + // Noble Mushtak and Daniel Lemire, Fast Number Parsing Without Fallback (to appear) + + // The final mantissa should be 53 bits with a leading 1. + // We shift it so that it occupies 54 bits with a leading 1. + long upperbit = upper >>> 63; + long mantissa = upper >>> (upperbit + 9); + lz += (int) (1 ^ upperbit); + // Here we have mantissa < (1<<54). + + // We have to round to even. The "to even" part + // is only a problem when we are right in between two floating-point values + // which we guard against. + // If we have lots of trailing zeros, we may fall right between two + // floating-point values. + if(((upper & 0x1ff) == 0x1ff) || ((upper & 0x1ff) == 0) && (mantissa & 3) == 1) { + // if mantissa & 1 == 1 we might need to round up. + // + // Scenarios: + // 1. We are not in the middle. Then we should round up. + // + // 2. We are right in the middle. Whether we round up depends + // on the last significant bit: if it is "one" then we round + // up (round to even) otherwise, we do not. + // + // So if the last significant bit is 1, we can safely round up. + // Hence, we only need to bail out if (mantissa & 3) == 1. + // Otherwise, we may need more accuracy or analysis to determine whether + // we are exactly between two floating-point numbers. + // It can be triggered with 1e23. + // Note: because the factor_mantissa and factor_mantissa_low are + // almost always rounded down (except for small positive powers), + // almost always should round up. + return Double.NaN; + } + + mantissa += 1; + mantissa >>>= 1; + + // Here we have mantissa < (1<<53), unless there was an overflow + if(mantissa >= (1L << DOUBLE_SIGNIFICAND_WIDTH)) { + // This will happen when parsing values such as 7.2057594037927933e+16 + mantissa = (1L << (DOUBLE_SIGNIFICAND_WIDTH - 1)); + lz--; // undo previous addition + } + + mantissa &= ~(1L << (DOUBLE_SIGNIFICAND_WIDTH - 1)); + + long realExponent = exponent - lz; + // we have to check that realExponent is in range, otherwise we bail out + if((realExponent < 1) || (realExponent > DOUBLE_MAX_EXPONENT_POWER_OF_TWO + DOUBLE_EXPONENT_BIAS)) { + return Double.NaN; + } + + long bits = mantissa | realExponent << (DOUBLE_SIGNIFICAND_WIDTH - 1); + return Double.longBitsToDouble(bits); + } + + private static boolean isDigit(char c) { + return (char) (c - '0') < 10; + } + + private static char charAt(String str, int i, int endIndex) { + return i < endIndex ? str.charAt(i) : 0; + } + + /** + * Computes {@code uint128 product = (uint64)x * (uint64)y}. + *

+ * References: + *

+ *
Getting the high part of 64 bit integer multiplication
+ *
+ * stackoverflow
+ *
+ * + * @param x uint64 factor x + * @param y uint64 factor y + * @return uint128 product of x and y + */ + static UInt128 fullMultiplication(long x, long y) { + return new UInt128(unsignedMultiplyHigh(x, y), x * y); + } + + public static long unsignedMultiplyHigh(long x, long y) { // if we update to jave 18 use Math internal. + // Compute via multiplyHigh() to leverage the intrinsic + long result = Math.multiplyHigh(x, y); + result += (y & (x >> 63)); // equivalent to `if (x < 0) result += y;` + result += (x & (y >> 63)); // equivalent to `if (y < 0) result += x;` + return result; + } + + static class UInt128 { + final long high, low; + + private UInt128(long high, long low) { + this.high = high; + this.low = low; + } + } +} diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java b/src/test/java/org/apache/sysds/test/TestUtils.java index 2c2bbc15e98..2be6636ee38 100644 --- a/src/test/java/org/apache/sysds/test/TestUtils.java +++ b/src/test/java/org/apache/sysds/test/TestUtils.java @@ -851,12 +851,37 @@ public static void compareFrames(FrameBlock expected, FrameBlock actual, boolean final int cols = expected.getNumColumns(); if(checkMeta) checkMetadata(expected, actual); + if(expected.getNumRows() == 0){ + if (expected.getColumns() != null) + fail(); + if (actual.getColumns() != null) + fail(); + } + else{ + for(int j = 0; j < cols; j++) { + Array ec = expected.getColumn(j); + Array ac = actual.getColumn(j); + if(ec.containsNull()) { + if(!ac.containsNull()) { + fail("Expected both columns to be containing null if one null:\n\nExpected containing null:\n" + + ec.toString().substring(0, 1000) + "\n\nActual:\n" + ac.toString().substring(0, 1000)); + } + } + else if(ac.containsNull()) { + fail("Expected both columns to be containing null if one null:\n\nExpected:\n" + + ec.toString().substring(0, 1000) + "\n\nActual containing null:\n" + ac.toString().substring(0, 1000)); + } + } + } for(int i = 0; i < rows; i++) { for(int j = 0; j < cols; j++) { final Object a = expected.get(i, j); final Object b = actual.get(i, j); - if(!(a == null && b == null)) { + if(a == null){ + assertTrue(a == b); + } + else if(!(a == null && b == null)) { try{ final String as = a.toString(); final String bs = b.toString(); @@ -2405,6 +2430,7 @@ public static FrameBlock generateRandomFrameBlock(int rows, int cols, long seed) } public static FrameBlock generateRandomFrameBlockWithSchemaOfStrings(int rows, int cols, long seed){ + FrameLibApplySchema.PAR_ROW_THRESHOLD = 10; ValueType[] schema = generateRandomSchema(cols, seed); FrameBlock f = generateRandomFrameBlock(rows, schema, seed); ValueType[] schemaString = UtilFunctions.nCopies(cols, ValueType.STRING); @@ -3302,6 +3328,12 @@ public static double[][] ceil(double[][] data) { return data; } + public static double[] ceil(double[] data) { + for(int i = 0; i < data.length; i++) + data[i] = Math.ceil(data[i]); + return data; + } + public static double[][] floor(double[][] data, int col) { for(int i=0; i= mb.getNonZeros())) { // guarantee that the nnz is at least the nnz + if(!(cmb.getNonZeros() >= mb.getNonZeros() || cmb.getNonZeros() == -1)) { // guarantee that the nnz is at least the nnz fail(bufferedToString + "\nIncorrect number of non Zeros should guarantee greater than or equals but are " + cmb.getNonZeros() + " and should be: " + mb.getNonZeros()); } diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java index 7b791180e7a..d3f191eb057 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java +++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java @@ -59,7 +59,7 @@ import org.apache.sysds.runtime.compress.estim.ComEstFactory; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; -import org.apache.sysds.runtime.compress.lib.CLALibAppend; +import org.apache.sysds.runtime.compress.lib.CLALibCBind; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode; import org.apache.sysds.runtime.functionobjects.Divide; @@ -256,7 +256,7 @@ else if(ct != null) { case C_BIND_SELF: if(cmb instanceof CompressedMatrixBlock) { CompressedMatrixBlock cmbc = (CompressedMatrixBlock) cmb; - cmb = CLALibAppend.append(cmbc, cmbc, _k); + cmb = CLALibCBind.cbind(cmbc, cmbc, _k); mb = mb.append(mb, new MatrixBlock()); cols *= 2; } diff --git a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java index dd5a65a0f77..53058b27a61 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java @@ -37,8 +37,8 @@ import org.apache.sysds.runtime.compress.colgroup.ColGroupRLE; import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCSingleZeros; import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCZeros; -import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; @@ -393,6 +393,23 @@ protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) { // TODO Auto-generated method stub throw new UnsupportedOperationException("Unimplemented method 'fixColIndexes'"); } + + @Override + public void sparseSelection(MatrixBlock selection, MatrixBlock ret, int rl, int ru) { + throw new UnsupportedOperationException("Unimplemented method 'sparseSelection'"); + } + + @Override + protected void decompressToDenseBlockTransposedSparseDictionary(DenseBlock db, int rl, int ru, SparseBlock dict) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'decompressToDenseBlockTransposedSparseDictionary'"); + } + + @Override + protected void decompressToDenseBlockTransposedDenseDictionary(DenseBlock db, int rl, int ru, double[] dict) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'decompressToDenseBlockTransposedDenseDictionary'"); + } } private class FakeDictBasedColGroup extends ADictBasedColGroup { @@ -643,5 +660,22 @@ protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) { // TODO Auto-generated method stub throw new UnsupportedOperationException("Unimplemented method 'fixColIndexes'"); } + + @Override + public void sparseSelection(MatrixBlock selection, MatrixBlock ret, int rl, int ru) { + throw new UnsupportedOperationException("Unimplemented method 'sparseSelection'"); + } + + @Override + protected void decompressToDenseBlockTransposedSparseDictionary(DenseBlock db, int rl, int ru, SparseBlock dict) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'decompressToDenseBlockTransposedSparseDictionary'"); + } + + @Override + protected void decompressToDenseBlockTransposedDenseDictionary(DenseBlock db, int rl, int ru, double[] dict) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'decompressToDenseBlockTransposedDenseDictionary'"); + } } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java index 54a543ad138..bb440d712ba 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java @@ -1166,26 +1166,17 @@ public void leftMultNoPreAggDenseColRange() { @Test public void leftMultNoPreAggDenseMultiRowColRange() { - try { - leftMultNoPreAgg(3, 0, 3, 5, nRow - 4); - } - catch(Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + leftMultNoPreAgg(3, 0, 3, 5, nRow - 4); } - @Test(expected = NotImplementedException.class) - // @Test + @Test public void leftMultNoPreAggSparseColRange() { leftMultNoPreAgg(3, 0, 1, 5, nRow - 4, 0.1); - throw new NotImplementedException("Make test parse since the check actually says it is correct"); } - @Test(expected = NotImplementedException.class) + @Test public void leftMultNoPreAggSparseMultiRowColRange() { leftMultNoPreAgg(3, 0, 3, 5, nRow - 4, 0.1); - throw new NotImplementedException("Make test parse since the check actually says it is correct"); } @Test @@ -1198,16 +1189,14 @@ public void leftMultNoPreAggDenseMultiRowColStartRange() { leftMultNoPreAgg(3, 0, 3, 5, 9); } - @Test(expected = NotImplementedException.class) + @Test public void leftMultNoPreAggSparseColStartRange() { leftMultNoPreAgg(3, 0, 1, 5, 9, 0.1); - throw new NotImplementedException("Make test parse since the check actually says it is correct"); } - @Test(expected = NotImplementedException.class) + @Test public void leftMultNoPreAggSparseMultiRowColStartRange() { leftMultNoPreAgg(3, 0, 3, 5, 9, 0.1); - throw new NotImplementedException("Make test parse since the check actually says it is correct"); } @Test @@ -1220,30 +1209,24 @@ public void leftMultNoPreAggDenseMultiRowColEndRange() { leftMultNoPreAgg(3, 0, 3, nRow - 10, nRow - 3); } - @Test(expected = NotImplementedException.class) + @Test public void leftMultNoPreAggSparseColEndRange() { leftMultNoPreAgg(3, 0, 1, nRow - 10, nRow - 3, 0.1); - throw new NotImplementedException("Make test parse since the check actually says it is correct"); } - @Test(expected = NotImplementedException.class) + @Test public void leftMultNoPreAggSparseMultiRowColEndRange() { leftMultNoPreAgg(3, 0, 3, nRow - 10, nRow - 3, 0.1); - throw new NotImplementedException("Make test parse since the check actually says it is correct"); } - @Test(expected = NotImplementedException.class) - // @Test + @Test public void leftMultNoPreAggSparseMultiRowColToEnd() { leftMultNoPreAgg(3, 0, 3, nRow - 10, nRow, 0.1); - throw new NotImplementedException("Make test parse since the check actually says it is correct"); } - @Test(expected = NotImplementedException.class) - // @Test + @Test public void leftMultNoPreAggSparseMultiRowColFromStart() { leftMultNoPreAgg(3, 0, 3, 0, 4, 0.1); - throw new NotImplementedException("Make test parse since the check actually says it is correct"); } public void leftMultNoPreAgg(int nRowLeft, int rl, int ru, int cl, int cu) { @@ -1286,7 +1269,7 @@ public void leftMultNoPreAgg(int nRowLeft, int rl, int ru, int cl, int cu, Matri compare(bt, ot); } catch(NotImplementedException e) { - throw e; + LOG.error("not implemented: " + base.getClass().getSimpleName() + " or: " + other.getClass().getSimpleName()); } catch(Exception e) { e.printStackTrace(); @@ -1520,7 +1503,7 @@ else if(base instanceof APreAgg) else if(base instanceof ColGroupConst) { double[] cb = new double[maxCol]; ((ColGroupConst) base).addToCommon(cb); - retO = mmRowSum(cb, rowSum, rl, ru, cl, cu); + retB = mmRowSum(cb, rowSum, rl, ru, cl, cu); } if(other instanceof AMorphingMMColGroup) { @@ -1547,22 +1530,36 @@ else if(other instanceof ColGroupConst) { retO.allocateDenseBlock(); other.leftMultByMatrixNoPreAgg(mb, retO, rl, ru, cl, cu); } - - compare(retB, retO); + try { + compare(retB, retO); + } + catch(Exception e) { + e.printStackTrace(); + // fail(e.getMessage()); + throw e; + } } private MatrixBlock mmPreAggDense(APreAgg g, MatrixBlock mb, double[] cv, double[] rowSum, int rl, int ru, int cl, int cu) { - final MatrixBlock retB = new MatrixBlock(ru, maxCol, false); - retB.allocateDenseBlock(); - double[] preB = new double[g.getPreAggregateSize() * (ru - rl)]; - g.preAggregateDense(mb, preB, rl, ru, cl, cu); - MatrixBlock preAggB = new MatrixBlock(ru - rl, g.getPreAggregateSize(), preB); - MatrixBlock tmpRes = new MatrixBlock(1, retB.getNumColumns(), false); - g.mmWithDictionary(preAggB, tmpRes, retB, 1, rl, ru); - mmRowSum(retB, cv, rowSum, rl, ru); - return retB; + try { + final MatrixBlock retB = new MatrixBlock(ru, maxCol, false); + retB.allocateDenseBlock(); + double[] preB = new double[g.getPreAggregateSize() * (ru - rl)]; + g.preAggregateDense(mb, preB, rl, ru, cl, cu); + MatrixBlock preAggB = new MatrixBlock(ru - rl, g.getPreAggregateSize(), preB); + MatrixBlock tmpRes = new MatrixBlock(1, retB.getNumColumns(), false); + tmpRes.allocateBlock(); + g.mmWithDictionary(preAggB, tmpRes, retB, 1, rl, ru); + mmRowSum(retB, cv, rowSum, rl, ru); + return retB; + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + return null; + } } private MatrixBlock mmRowSum(double[] cv, double[] rowSum, int rl, int ru, int cl, int cu) { @@ -2289,7 +2286,7 @@ private void appendSelfVerification(AColGroup g) { try { AColGroup g2 = g.append(g); - AColGroup g2n = AColGroup.appendN(new AColGroup[] {g, g}, nRow, nRow*2); + AColGroup g2n = AColGroup.appendN(new AColGroup[] {g, g}, nRow, nRow * 2); if(g2 != null && g2n != null) { double s2 = g2.getSum(nRow * 2); double s = g.getSum(nRow) * 2; @@ -2300,7 +2297,7 @@ private void appendSelfVerification(AColGroup g) { UA_ROW(InstructionUtils.parseBasicAggregateUnaryOperator("uar+", 1), 0, nRow * 2, g2, g2n, nRow * 2); } } - catch(NotImplementedException e){ + catch(NotImplementedException e) { // okay } catch(Exception e) { diff --git a/src/test/java/org/apache/sysds/test/component/compress/colgroup/CombineColGroups.java b/src/test/java/org/apache/sysds/test/component/compress/colgroup/CombineColGroups.java new file mode 100644 index 00000000000..84823e2776c --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/compress/colgroup/CombineColGroups.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.compress.colgroup; + +import static org.junit.Assert.fail; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.CompressionSettings; +import org.apache.sysds.runtime.compress.CompressionSettingsBuilder; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType; +import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory; +import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; +import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo; +import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; +import org.apache.sysds.runtime.compress.estim.EstimationFactors; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(value = Parameterized.class) +public class CombineColGroups { + protected static final Log LOG = LogFactory.getLog(CombineColGroups.class.getName()); + + /** Uncompressed ground truth */ + final MatrixBlock mb; + /** ColGroup 1 */ + final AColGroup a; + /** ColGroup 2 */ + final AColGroup b; + + @Parameters + public static Collection data() { + ArrayList tests = new ArrayList<>(); + + try { + addTwoCols(tests, 100, 3); + addTwoCols(tests, 1000, 3); + // addSingleVSMultiCol(tests, 100, 3, 1, 3); + // addSingleVSMultiCol(tests, 100, 3, 3, 4); + addSingleVSMultiCol(tests, 1000, 3, 1, 3, 1.0); + addSingleVSMultiCol(tests, 1000, 3, 3, 4, 1.0); + addSingleVSMultiCol(tests, 1000, 3, 3, 1, 1.0); + addSingleVSMultiCol(tests, 1000, 2, 1, 10, 0.05); + addSingleVSMultiCol(tests, 1000, 2, 10, 10, 0.05); + addSingleVSMultiCol(tests, 1000, 2, 10, 1, 0.05); + } + catch(Exception e) { + e.printStackTrace(); + fail("failed constructing tests"); + } + + return tests; + } + + public CombineColGroups(MatrixBlock mb, AColGroup a, AColGroup b) { + this.mb = mb; + this.a = a; + this.b = b; + + CompressedMatrixBlock.debug = true; + } + + @Test + public void combine() { + try { + AColGroup c = a.combine(b); + MatrixBlock ref = new MatrixBlock(mb.getNumRows(), mb.getNumColumns(), false); + ref.allocateDenseBlock(); + c.decompressToDenseBlock(ref.getDenseBlock(), 0, mb.getNumRows()); + ref.recomputeNonZeros(); + String errMessage = a.getClass().getSimpleName() + ": " + a.getColIndices() + " -- " + + b.getClass().getSimpleName() + ": " + b.getColIndices(); + + TestUtils.compareMatricesBitAvgDistance(mb, ref, 0, 0, errMessage); + } + catch(NotImplementedException e) { + // allowed + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + private static void addTwoCols(ArrayList tests, int nRow, int distinct) { + MatrixBlock mb = TestUtils.ceil(// + TestUtils.generateTestMatrixBlock(nRow, 2, 0, distinct, 1.0, 231)); + + List c1s = getGroups(mb, ColIndexFactory.createI(0)); + List c2s = getGroups(mb, ColIndexFactory.createI(1)); + + for(int i = 0; i < c1s.size(); i++) { + for(int j = 0; j < c2s.size(); j++) { + tests.add(new Object[] {mb, c1s.get(i), c2s.get(j)}); + } + } + } + + private static void addSingleVSMultiCol(ArrayList tests, int nRow, int distinct, int nColL, int nColR, + double sparsity) { + MatrixBlock mb = TestUtils.ceil(// + TestUtils.generateTestMatrixBlock(nRow, nColL + nColR, 0, distinct, sparsity, 231)); + + List c1s = getGroups(mb, ColIndexFactory.create(nColL)); + List c2s = getGroups(mb, ColIndexFactory.create(nColL, nColR + nColL)); + + for(int i = 0; i < c1s.size(); i++) { + for(int j = 0; j < c2s.size(); j++) { + tests.add(new Object[] {mb, c1s.get(0), c2s.get(0)}); + } + } + } + + private static List getGroups(MatrixBlock mb, IColIndex cols) { + final CompressionSettings cs = new CompressionSettingsBuilder().create(); + + final int nRow = mb.getNumColumns(); + final List es = new ArrayList<>(); + final EstimationFactors f = new EstimationFactors(nRow, nRow, mb.getSparsity()); + es.add(new CompressedSizeInfoColGroup(cols, f, 312152, CompressionType.DDC)); + es.add(new CompressedSizeInfoColGroup(cols, f, 321521, CompressionType.RLE)); + es.add(new CompressedSizeInfoColGroup(cols, f, 321452, CompressionType.SDC)); + es.add(new CompressedSizeInfoColGroup(cols, f, 325151, CompressionType.UNCOMPRESSED)); + final CompressedSizeInfo csi = new CompressedSizeInfo(es); + return ColGroupFactory.compressColGroups(mb, csi, cs); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/compress/colgroup/CustomColGroupTest.java b/src/test/java/org/apache/sysds/test/component/compress/colgroup/CustomColGroupTest.java index 572e96ac367..76be2caeee9 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/colgroup/CustomColGroupTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/colgroup/CustomColGroupTest.java @@ -57,12 +57,9 @@ public void appendEmptyToSDCZero2() { AColGroup e = new ColGroupEmpty(i); AColGroup s = ColGroupSDCSingleZeros.create(i, 10, new PlaceHolderDict(1), OffsetFactory.createOffset(new int[] {5, 10}), null); - AColGroup r = AColGroup.appendN(new AColGroup[] {e, s, e, e, s, s, e}, 20, 7 * 20); - LOG.error(r); assertTrue(r instanceof ColGroupSDCSingleZeros); assertEquals(r.getColIndices(), i); assertEquals(((ColGroupSDCSingleZeros) r).getNumRows(), 7 * 20); - } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/colgroup/scheme/SchemeTestBase.java b/src/test/java/org/apache/sysds/test/component/compress/colgroup/scheme/SchemeTestBase.java index a0b45351ba7..be9e6340d42 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/colgroup/scheme/SchemeTestBase.java +++ b/src/test/java/org/apache/sysds/test/component/compress/colgroup/scheme/SchemeTestBase.java @@ -59,27 +59,37 @@ public void testEncode() { TestUtils.compareMatricesBitAvgDistance(in, d, 0, 0); } catch(Exception e) { + if(e.getMessage().contains("Invalid SDC group that contains index with size == numRows")) + return;// all good e.printStackTrace(); - fail(e.getMessage()); + fail(e.getMessage() + " " + sh); } } @Test public void testEncodeT() { - MatrixBlock in = TestUtils - .round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 20, 0, distinct, 0.9, 7)); - AColGroup out = sh.encodeT(in); - MatrixBlock d = new MatrixBlock(in.getNumColumns(), src.getNumColumns(), false); - d.allocateBlock(); - out.decompressToDenseBlock(d.getDenseBlock(), 0, in.getNumColumns()); - d.recomputeNonZeros(); - TestUtils.compareMatricesBitAvgDistance(in, LibMatrixReorg.transpose(d), 0, 0); + try { + + MatrixBlock in = TestUtils + .round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 20, 0, distinct, 0.9, 7)); + AColGroup out = sh.encodeT(in); + MatrixBlock d = new MatrixBlock(in.getNumColumns(), src.getNumColumns(), false); + d.allocateBlock(); + out.decompressToDenseBlock(d.getDenseBlock(), 0, in.getNumColumns()); + d.recomputeNonZeros(); + TestUtils.compareMatricesBitAvgDistance(in, LibMatrixReorg.transpose(d), 0, 0); + } + catch(Exception e) { + if(e.getMessage().contains("Invalid SDC group that contains index with size == numRows")) + return;// all good + e.printStackTrace(); + fail(e.getMessage() + " " + sh); + } } @Test public void testEncode_sparse() { try { - MatrixBlock in = TestUtils.round(TestUtils.generateTestMatrixBlock(100, 100, 0, distinct, 0.05, 7)); AColGroup out = sh.encode(in); MatrixBlock d = new MatrixBlock(in.getNumRows(), src.getNumColumns(), false); @@ -90,8 +100,10 @@ public void testEncode_sparse() { TestUtils.compareMatricesBitAvgDistance(inSlice, d, 0, 0); } catch(Exception e) { + if(e.getMessage().contains("Invalid SDC group that contains index with size == numRows")) + return;// all good e.printStackTrace(); - fail(e.getMessage()); + fail(e.getMessage() + " " + sh); } } @@ -109,8 +121,10 @@ public void testEncode_sparseT() { TestUtils.compareMatricesBitAvgDistance(inSlice, LibMatrixReorg.transpose(d), 0, 0); } catch(Exception e) { + if(e.getMessage().contains("Invalid SDC group that contains index with size == numRows")) + return;// all good e.printStackTrace(); - fail(e.getMessage()); + fail(e.getMessage() + " " + sh); } } @@ -137,10 +151,11 @@ public void testUpdate() { d.recomputeNonZeros(); TestUtils.compareMatricesBitAvgDistance(inSlice, d, 0, 0); } - catch(Exception e) { + if(e.getMessage().contains("Invalid SDC group that contains index with size == numRows")) + return;// all good e.printStackTrace(); - fail(e.getMessage()); + fail(e.getMessage() + " " + sh); } } @@ -173,88 +188,116 @@ public void testUpdateT() { TestUtils.compareMatricesBitAvgDistance(inSlice, LibMatrixReorg.transpose(d), 0, 0); } catch(Exception e) { + if(e.getMessage().contains("Invalid SDC group that contains index with size == numRows")) + return;// all good e.printStackTrace(); - fail(e.getMessage()); + fail(e.getMessage() + " " + sh); } } @Test public void testUpdateSparse() { - MatrixBlock in = TestUtils - .round(TestUtils.generateTestMatrixBlock(130, src.getNumColumns() + 30, 0, distinct + 1, 0.1, 7)); - if(!in.isInSparseFormat()) - throw new RuntimeException(); try { - sh.encode(in); + + MatrixBlock in = TestUtils + .round(TestUtils.generateTestMatrixBlock(130, src.getNumColumns() + 30, 0, distinct + 1, 0.1, 7)); + if(!in.isInSparseFormat()) + throw new RuntimeException(); + try { + sh.encode(in); + } + catch(NullPointerException e) { + // all good expected + // we want to have an exception thrown if we try to encode something that is not possible to encode. + } + ICLAScheme shc = sh.clone(); + shc = shc.update(in); + AColGroup out = shc.encode(in); // should be possible now. + MatrixBlock d = new MatrixBlock(in.getNumRows(), src.getNumColumns(), false); + d.allocateBlock(); + out.decompressToDenseBlock(d.getDenseBlock(), 0, in.getNumRows()); + MatrixBlock inSlice = in.slice(0, in.getNumRows() - 1, 0, src.getNumColumns() - 1); + d.recomputeNonZeros(); + TestUtils.compareMatricesBitAvgDistance(inSlice, d, 0, 0); } - catch(NullPointerException e) { - // all good expected - // we want to have an exception thrown if we try to encode something that is not possible to encode. + catch(Exception e) { + if(e.getMessage().contains("Invalid SDC group that contains index with size == numRows")) + return;// all good + e.printStackTrace(); + fail(e.getMessage() + " " + sh); } - ICLAScheme shc = sh.clone(); - shc = shc.update(in); - AColGroup out = shc.encode(in); // should be possible now. - MatrixBlock d = new MatrixBlock(in.getNumRows(), src.getNumColumns(), false); - d.allocateBlock(); - out.decompressToDenseBlock(d.getDenseBlock(), 0, in.getNumRows()); - MatrixBlock inSlice = in.slice(0, in.getNumRows() - 1, 0, src.getNumColumns() - 1); - d.recomputeNonZeros(); - TestUtils.compareMatricesBitAvgDistance(inSlice, d, 0, 0); - } @Test public void testUpdateSparseT() { - MatrixBlock in = TestUtils - .round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 1000, 0, distinct + 1, 0.1, 7)); - if(!in.isInSparseFormat()) - throw new RuntimeException(); try { - sh.encodeT(in); + + MatrixBlock in = TestUtils + .round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 1000, 0, distinct + 1, 0.1, 7)); + if(!in.isInSparseFormat()) + throw new RuntimeException(); + try { + sh.encodeT(in); + } + catch(NullPointerException e) { + // all good expected + // we want to have an exception thrown if we try to encode something that is not possible to encode. + // but we can also not have an exception thrown... + } + ICLAScheme shc = sh.clone(); + shc = shc.updateT(in); + + AColGroup out = shc.encodeT(in); // should be possible now. + MatrixBlock d = new MatrixBlock(in.getNumColumns(), src.getNumColumns(), false); + d.allocateBlock(); + out.decompressToDenseBlock(d.getDenseBlock(), 0, in.getNumColumns()); + MatrixBlock inSlice = in.slice(0, src.getNumColumns() - 1, 0, in.getNumColumns() - 1); + d.recomputeNonZeros(); + TestUtils.compareMatricesBitAvgDistance(inSlice, LibMatrixReorg.transpose(d), 0, 0); } - catch(NullPointerException e) { - // all good expected - // we want to have an exception thrown if we try to encode something that is not possible to encode. - // but we can also not have an exception thrown... + catch(Exception e) { + if(e.getMessage().contains("Invalid SDC group that contains index")) + return; // all good + e.printStackTrace(); + fail(e.getMessage()); } - ICLAScheme shc = sh.clone(); - shc = shc.updateT(in); - - AColGroup out = shc.encodeT(in); // should be possible now. - MatrixBlock d = new MatrixBlock(in.getNumColumns(), src.getNumColumns(), false); - d.allocateBlock(); - out.decompressToDenseBlock(d.getDenseBlock(), 0, in.getNumColumns()); - MatrixBlock inSlice = in.slice(0, src.getNumColumns() - 1, 0, in.getNumColumns() - 1); - d.recomputeNonZeros(); - TestUtils.compareMatricesBitAvgDistance(inSlice, LibMatrixReorg.transpose(d), 0, 0); } @Test public void testUpdateSparseTEmptyColumn() { - MatrixBlock in = new MatrixBlock(src.getNumColumns(), 100, 0.0); - MatrixBlock b = new MatrixBlock(1, 100, 1.0); - in = in.append(b, false); - in.denseToSparse(true); - if(!in.isInSparseFormat()) - throw new RuntimeException(); try { - sh.encodeT(in); + + MatrixBlock in = new MatrixBlock(src.getNumColumns(), 100, 0.0); + MatrixBlock b = new MatrixBlock(1, 100, 1.0); + in = in.append(b, false); + in.denseToSparse(true); + if(!in.isInSparseFormat()) + throw new RuntimeException(); + try { + sh.encodeT(in); + } + catch(NullPointerException e) { + // all good expected + // we want to have an exception thrown if we try to encode something that is not possible to encode. + // but we can also not have an exception thrown... + } + ICLAScheme shc = sh.clone(); + shc = shc.updateT(in); + + AColGroup out = shc.encodeT(in); // should be possible now. + MatrixBlock d = new MatrixBlock(in.getNumColumns(), src.getNumColumns(), false); + d.allocateBlock(); + out.decompressToDenseBlock(d.getDenseBlock(), 0, in.getNumColumns()); + MatrixBlock inSlice = in.slice(0, src.getNumColumns() - 1, 0, in.getNumColumns() - 1); + d.recomputeNonZeros(); + TestUtils.compareMatricesBitAvgDistance(inSlice, LibMatrixReorg.transpose(d), 0, 0); } - catch(NullPointerException e) { - // all good expected - // we want to have an exception thrown if we try to encode something that is not possible to encode. - // but we can also not have an exception thrown... + catch(Exception e) { + if(e.getMessage().contains("Invalid SDC group that contains index with size == numRows")) + return; // all good expected exception + e.printStackTrace(); + fail(e.getMessage()); } - ICLAScheme shc = sh.clone(); - shc = shc.updateT(in); - - AColGroup out = shc.encodeT(in); // should be possible now. - MatrixBlock d = new MatrixBlock(in.getNumColumns(), src.getNumColumns(), false); - d.allocateBlock(); - out.decompressToDenseBlock(d.getDenseBlock(), 0, in.getNumColumns()); - MatrixBlock inSlice = in.slice(0, src.getNumColumns() - 1, 0, in.getNumColumns() - 1); - d.recomputeNonZeros(); - TestUtils.compareMatricesBitAvgDistance(inSlice, LibMatrixReorg.transpose(d), 0, 0); } @Test @@ -282,83 +325,111 @@ public void testUpdateLargeBlock() { TestUtils.compareMatricesBitAvgDistance(inSlice, d, 0, 0); } catch(Exception e) { + if(e.getMessage().contains("Invalid SDC group that contains index with size == numRows")) + return;// all good e.printStackTrace(); - fail(e.getMessage()); + fail(e.getMessage() + " " + sh); } } @Test public void testUpdateLargeBlockT() { - MatrixBlock in = TestUtils - .round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 130, 0, distinct + 5, 1.0, 7)); - in = ReadersTestCompareReaders.createMock(in); try { - sh.encodeT(in); - } - catch(NullPointerException e) { - // all good expected - // we want to have an exception thrown if we try to encode something that is not possible to encode. - // but we can also not have an exception thrown... - } - ICLAScheme shc = sh.clone(); - shc = shc.updateT(in); + MatrixBlock in = TestUtils + .round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 130, 0, distinct + 5, 1.0, 7)); + in = ReadersTestCompareReaders.createMock(in); + try { + sh.encodeT(in); + } + catch(NullPointerException e) { + // all good expected + // we want to have an exception thrown if we try to encode something that is not possible to encode. + // but we can also not have an exception thrown... + } + ICLAScheme shc = sh.clone(); + + shc = shc.updateT(in); - AColGroup out = shc.encodeT(in); // should be possible now. - MatrixBlock d = new MatrixBlock(in.getNumColumns(), src.getNumColumns(), false); - d.allocateBlock(); - out.decompressToDenseBlock(d.getDenseBlock(), 0, in.getNumColumns()); - MatrixBlock inSlice = in.slice(0, src.getNumColumns() - 1, 0, in.getNumColumns() - 1); - d.recomputeNonZeros(); - TestUtils.compareMatricesBitAvgDistance(inSlice, LibMatrixReorg.transpose(d), 0, 0); + AColGroup out = shc.encodeT(in); // should be possible now. + MatrixBlock d = new MatrixBlock(in.getNumColumns(), src.getNumColumns(), false); + d.allocateBlock(); + out.decompressToDenseBlock(d.getDenseBlock(), 0, in.getNumColumns()); + MatrixBlock inSlice = in.slice(0, src.getNumColumns() - 1, 0, in.getNumColumns() - 1); + d.recomputeNonZeros(); + TestUtils.compareMatricesBitAvgDistance(inSlice, LibMatrixReorg.transpose(d), 0, 0); + } + catch(Exception e) { + if(e.getMessage().contains("Invalid SDC group that contains index with size == numRows")) + return;// all good + e.printStackTrace(); + fail(e.getMessage() + " " + sh); + } } @Test public void testUpdateEmpty() { - MatrixBlock in = new MatrixBlock(5, src.getNumColumns(), 0.0); - try { - sh.encode(in); + + MatrixBlock in = new MatrixBlock(5, src.getNumColumns(), 0.0); + + try { + sh.encode(in); + } + catch(NullPointerException e) { + // all good expected + // we want to have an exception thrown if we try to encode something that is not possible to encode. + } + ICLAScheme shc = sh.clone(); + shc = shc.update(in); + AColGroup out = shc.encode(in); // should be possible now. + MatrixBlock d = new MatrixBlock(in.getNumRows(), src.getNumColumns(), false); + d.allocateBlock(); + out.decompressToDenseBlock(d.getDenseBlock(), 0, in.getNumRows()); + MatrixBlock inSlice = in.slice(0, in.getNumRows() - 1, 0, src.getNumColumns() - 1); + d.recomputeNonZeros(); + TestUtils.compareMatricesBitAvgDistance(inSlice, d, 0, 0); } - catch(NullPointerException e) { - // all good expected - // we want to have an exception thrown if we try to encode something that is not possible to encode. + catch(Exception e) { + if(e.getMessage().contains("Invalid SDC group that contains index with size == numRows")) + return;// all good + e.printStackTrace(); + fail(e.getMessage() + " " + sh); } - ICLAScheme shc = sh.clone(); - shc = shc.update(in); - AColGroup out = shc.encode(in); // should be possible now. - MatrixBlock d = new MatrixBlock(in.getNumRows(), src.getNumColumns(), false); - d.allocateBlock(); - out.decompressToDenseBlock(d.getDenseBlock(), 0, in.getNumRows()); - MatrixBlock inSlice = in.slice(0, in.getNumRows() - 1, 0, src.getNumColumns() - 1); - d.recomputeNonZeros(); - TestUtils.compareMatricesBitAvgDistance(inSlice, d, 0, 0); - } @Test public void testUpdateEmptyT() { - MatrixBlock in = new MatrixBlock(src.getNumColumns(), 5, 0.0); try { - sh.encodeT(in); - } - catch(NullPointerException e) { - // all good expected - // we want to have an exception thrown if we try to encode something that is not possible to encode. - // but we can also not have an exception thrown... - } - ICLAScheme shc = sh.clone(); - shc = shc.updateT(in); + MatrixBlock in = new MatrixBlock(src.getNumColumns(), 5, 0.0); + try { + sh.encodeT(in); + } + catch(NullPointerException e) { + // all good expected + // we want to have an exception thrown if we try to encode something that is not possible to encode. + // but we can also not have an exception thrown... + } + ICLAScheme shc = sh.clone(); - AColGroup out = shc.encodeT(in); // should be possible now. - MatrixBlock d = new MatrixBlock(in.getNumRows(), src.getNumColumns(), false); - d.allocateBlock(); - out.decompressToDenseBlock(d.getDenseBlock(), 0, in.getNumRows()); - MatrixBlock inSlice = in.slice(0, src.getNumColumns() - 1, 0, in.getNumColumns() - 1); - d.recomputeNonZeros(); - TestUtils.compareMatricesBitAvgDistance(inSlice, LibMatrixReorg.transpose(d), 0, 0); + shc = shc.updateT(in); + + AColGroup out = shc.encodeT(in); // should be possible now. + MatrixBlock d = new MatrixBlock(in.getNumRows(), src.getNumColumns(), false); + d.allocateBlock(); + out.decompressToDenseBlock(d.getDenseBlock(), 0, in.getNumRows()); + MatrixBlock inSlice = in.slice(0, src.getNumColumns() - 1, 0, in.getNumColumns() - 1); + d.recomputeNonZeros(); + TestUtils.compareMatricesBitAvgDistance(inSlice, LibMatrixReorg.transpose(d), 0, 0); + } + catch(Exception e) { + if(e.getMessage().contains("Invalid SDC group that contains index with size == numRows")) + return;// all good + e.printStackTrace(); + fail(e.getMessage() + " " + sh); + } } @Test @@ -386,50 +457,58 @@ public void testUpdateEmptyMyCols() { TestUtils.compareMatricesBitAvgDistance(inSlice, d, 0, 0); } catch(Exception e) { + if(e.getMessage().contains("Invalid SDC group that contains index with size == numRows")) + return;// all good e.printStackTrace(); - fail(e.getMessage()); + fail(e.getMessage() + " " + sh); } } @Test public void testUpdateEmptyMyColsT() { - MatrixBlock in = new MatrixBlock(src.getNumColumns(), 5, 0.0); - in = in.append(new MatrixBlock(1, 5, 1.0), false); try { - sh.encodeT(in); - } - catch(NullPointerException e) { - // all good expected - // we want to have an exception thrown if we try to encode something that is not possible to encode. - // but we can also not have an exception thrown... - } - ICLAScheme shc = sh.clone(); + MatrixBlock in = new MatrixBlock(src.getNumColumns(), 5, 0.0); + in = in.append(new MatrixBlock(1, 5, 1.0), false); + try { + sh.encodeT(in); + } + catch(NullPointerException e) { + // all good expected + // we want to have an exception thrown if we try to encode something that is not possible to encode. + // but we can also not have an exception thrown... + } + ICLAScheme shc = sh.clone(); - shc = shc.updateT(in); + shc = shc.updateT(in); - AColGroup out = shc.encodeT(in); // should be possible now. - MatrixBlock d = new MatrixBlock(in.getNumRows(), src.getNumColumns(), false); - d.allocateBlock(); - out.decompressToDenseBlock(d.getDenseBlock(), 0, in.getNumRows()); - MatrixBlock inSlice = in.slice(0, src.getNumColumns() - 1, 0, in.getNumColumns() - 1); - d.recomputeNonZeros(); - TestUtils.compareMatricesBitAvgDistance(inSlice, LibMatrixReorg.transpose(d), 0, 0); + AColGroup out = shc.encodeT(in); // should be possible now. + MatrixBlock d = new MatrixBlock(in.getNumRows(), src.getNumColumns(), false); + d.allocateBlock(); + out.decompressToDenseBlock(d.getDenseBlock(), 0, in.getNumRows()); + MatrixBlock inSlice = in.slice(0, src.getNumColumns() - 1, 0, in.getNumColumns() - 1); + d.recomputeNonZeros(); + TestUtils.compareMatricesBitAvgDistance(inSlice, LibMatrixReorg.transpose(d), 0, 0); + } + catch(Exception e) { + if(e.getMessage() != null && e.getMessage().contains("Invalid SDC group that contains index with size == numRows")) + return;// all good + e.printStackTrace(); + fail(e.getMessage() + " " + sh); + } } @Test public void testUpdateAndEncode() { double newVal = distinct + 4; - MatrixBlock in = TestUtils - .round(TestUtils.generateTestMatrixBlock(100, src.getNumColumns(), 0, newVal, 1.0, 7)); + MatrixBlock in = TestUtils.round(TestUtils.generateTestMatrixBlock(100, src.getNumColumns(), 0, newVal, 1.0, 7)); testUpdateAndEncode(in); } @Test public void testUpdateAndEncodeT() { double newVal = distinct + 4; - MatrixBlock in = TestUtils - .round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 100, 0, newVal, 1.0, 7)); + MatrixBlock in = TestUtils.round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 100, 0, newVal, 1.0, 7)); testUpdateAndEncodeT(in); } @@ -444,8 +523,7 @@ public void testUpdateAndEncodeSparse() { @Test public void testUpdateAndEncodeSparseT() { double newVal = distinct + 4; - MatrixBlock in = TestUtils - .round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 100, 0, newVal, 0.1, 7)); + MatrixBlock in = TestUtils.round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 100, 0, newVal, 0.1, 7)); testUpdateAndEncodeT(in); } @@ -461,8 +539,7 @@ public void testUpdateAndEncodeSparseTEmptyColumn() { @Test public void testUpdateAndEncodeLarge() { double newVal = distinct + 4; - MatrixBlock in = TestUtils - .round(TestUtils.generateTestMatrixBlock(100, src.getNumColumns(), 0, newVal, 1.0, 7)); + MatrixBlock in = TestUtils.round(TestUtils.generateTestMatrixBlock(100, src.getNumColumns(), 0, newVal, 1.0, 7)); in = ReadersTestCompareReaders.createMock(in); testUpdateAndEncode(in); @@ -471,8 +548,7 @@ public void testUpdateAndEncodeLarge() { @Test public void testUpdateAndEncodeLargeT() { double newVal = distinct + 4; - MatrixBlock in = TestUtils - .round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 100, 0, newVal, 1.0, 7)); + MatrixBlock in = TestUtils.round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 100, 0, newVal, 1.0, 7)); in = ReadersTestCompareReaders.createMock(in); testUpdateAndEncodeT(in); } @@ -480,16 +556,14 @@ public void testUpdateAndEncodeLargeT() { @Test public void testUpdateAndEncodeManyNew() { double newVal = distinct + 300; - MatrixBlock in = TestUtils - .round(TestUtils.generateTestMatrixBlock(100, src.getNumColumns(), 0, newVal, 1.0, 7)); + MatrixBlock in = TestUtils.round(TestUtils.generateTestMatrixBlock(100, src.getNumColumns(), 0, newVal, 1.0, 7)); testUpdateAndEncode(in); } @Test public void testUpdateAndEncodeTManyNew() { double newVal = distinct + 300; - MatrixBlock in = TestUtils - .round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 100, 0, newVal, 1.0, 7)); + MatrixBlock in = TestUtils.round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 100, 0, newVal, 1.0, 7)); testUpdateAndEncodeT(in); } @@ -504,16 +578,14 @@ public void testUpdateAndEncodeSparseManyNew() { @Test public void testUpdateAndEncodeSparseTManyNew() { double newVal = distinct + 300; - MatrixBlock in = TestUtils - .round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 100, 0, newVal, 0.1, 7)); + MatrixBlock in = TestUtils.round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 100, 0, newVal, 0.1, 7)); testUpdateAndEncodeT(in); } @Test public void testUpdateAndEncodeLargeManyNew() { double newVal = distinct + 300; - MatrixBlock in = TestUtils - .round(TestUtils.generateTestMatrixBlock(100, src.getNumColumns(), 0, newVal, 1.0, 7)); + MatrixBlock in = TestUtils.round(TestUtils.generateTestMatrixBlock(100, src.getNumColumns(), 0, newVal, 1.0, 7)); in = ReadersTestCompareReaders.createMock(in); testUpdateAndEncode(in); @@ -522,8 +594,7 @@ public void testUpdateAndEncodeLargeManyNew() { @Test public void testUpdateAndEncodeLargeTManyNew() { double newVal = distinct + 300; - MatrixBlock in = TestUtils - .round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 100, 0, newVal, 1.0, 7)); + MatrixBlock in = TestUtils.round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 100, 0, newVal, 1.0, 7)); in = ReadersTestCompareReaders.createMock(in); testUpdateAndEncodeT(in); } @@ -555,14 +626,23 @@ public void testUpdateAndEncodeEmptyInColsT() { } public void testUpdateAndEncode(MatrixBlock in) { - Pair r = sh.clone().updateAndEncode(in); - AColGroup out = r.getValue(); - MatrixBlock d = new MatrixBlock(in.getNumRows(), src.getNumColumns(), false); - d.allocateBlock(); - out.decompressToDenseBlock(d.getDenseBlock(), 0, in.getNumRows()); - MatrixBlock inSlice = in.slice(0, in.getNumRows() - 1, 0, src.getNumColumns() - 1); - d.recomputeNonZeros(); - TestUtils.compareMatricesBitAvgDistance(inSlice, d, 0, 0); + try { + + Pair r = sh.clone().updateAndEncode(in); + AColGroup out = r.getValue(); + MatrixBlock d = new MatrixBlock(in.getNumRows(), src.getNumColumns(), false); + d.allocateBlock(); + out.decompressToDenseBlock(d.getDenseBlock(), 0, in.getNumRows()); + MatrixBlock inSlice = in.slice(0, in.getNumRows() - 1, 0, src.getNumColumns() - 1); + d.recomputeNonZeros(); + TestUtils.compareMatricesBitAvgDistance(inSlice, d, 0, 0); + } + catch(Exception e) { + if(e.getMessage().contains("Invalid SDC group that contains index with size == numRows")) + return;// all good + e.printStackTrace(); + fail(e.getMessage() + " " + sh); + } } public void testUpdateAndEncodeT(MatrixBlock in) { @@ -577,6 +657,8 @@ public void testUpdateAndEncodeT(MatrixBlock in) { TestUtils.compareMatricesBitAvgDistance(inSlice, LibMatrixReorg.transpose(d), 0, 0); } catch(Exception e) { + if(e.getMessage().contains("Invalid SDC group that contains index with size == numRows")) + return;// all good e.printStackTrace(); fail(e.getMessage() + " " + sh); } diff --git a/src/test/java/org/apache/sysds/test/component/compress/colgroup/scheme/SchemeTestSDC.java b/src/test/java/org/apache/sysds/test/component/compress/colgroup/scheme/SchemeTestSDC.java index 1f7c872b0ff..064f10e9f34 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/colgroup/scheme/SchemeTestSDC.java +++ b/src/test/java/org/apache/sysds/test/component/compress/colgroup/scheme/SchemeTestSDC.java @@ -85,7 +85,6 @@ public SchemeTestSDC(MatrixBlock src, int distinct) { catch(Exception e) { e.printStackTrace(); fail(e.getMessage()); - throw new RuntimeException(); } } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/combine/CombineEncodings.java b/src/test/java/org/apache/sysds/test/component/compress/combine/CombineEncodings.java index 5a3f66ca374..17502298fe2 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/combine/CombineEncodings.java +++ b/src/test/java/org/apache/sysds/test/component/compress/combine/CombineEncodings.java @@ -76,8 +76,8 @@ public void combineCustom3() { @Test public void combineCustom4() { - IEncode ae = new DenseEncoding(MapToFactory.create(10, new int[] {0, 1, 2, 3, 4, 5, 6, 7, 7, 0}, 10)); - IEncode be = new DenseEncoding(MapToFactory.create(10, new int[] {0, 1, 2, 3, 4, 5, 6, 7, 7, 0}, 10)); + IEncode ae = new DenseEncoding(MapToFactory.create(10, new int[] {0, 1, 2, 3, 4, 5, 6, 7, 7, 0}, 8)); + IEncode be = new DenseEncoding(MapToFactory.create(10, new int[] {0, 1, 2, 3, 4, 5, 6, 7, 7, 0}, 8)); Pair> cec = ae.combineWithMap(be); IEncode ce = cec.getLeft(); Map cem = cec.getRight(); diff --git a/src/test/java/org/apache/sysds/test/component/compress/dictionary/CombineTest.java b/src/test/java/org/apache/sysds/test/component/compress/dictionary/CombineTest.java index f9e0ac1ee42..2632a19929d 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/dictionary/CombineTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/dictionary/CombineTest.java @@ -40,10 +40,11 @@ import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC; import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; import org.apache.sysds.runtime.compress.colgroup.ColGroupSDC; -import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; +import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -122,7 +123,7 @@ public void sparseSparse() { double[] ad = new double[] {0}; double[] bd = new double[] {0}; - IDictionary c = DictionaryFactory.combineSDC(a, ad, b, bd); + IDictionary c = DictionaryFactory.combineSDCNoFilter(a, ad, b, bd); MatrixBlock ret = c.getMBDict(2).getMatrixBlock(); MatrixBlock exp = new MatrixBlock(4, 2, new double[] {0, 0, 3, 0, 0, 4, 3, 4}); @@ -142,7 +143,7 @@ public void sparseSparse2() { double[] ad = new double[] {0}; double[] bd = new double[] {0, 0}; - IDictionary c = DictionaryFactory.combineSDC(a, ad, b, bd); + IDictionary c = DictionaryFactory.combineSDCNoFilter(a, ad, b, bd); MatrixBlock ret = c.getMBDict(2).getMatrixBlock(); MatrixBlock exp = new MatrixBlock(4, 3, new double[] {0, 0, 0, 3, 0, 0, 0, 4, 4, 3, 4, 4}); @@ -162,7 +163,7 @@ public void sparseSparse3() { double[] ad = new double[] {1}; double[] bd = new double[] {2}; - IDictionary c = DictionaryFactory.combineSDC(a, ad, b, bd); + IDictionary c = DictionaryFactory.combineSDCNoFilter(a, ad, b, bd); MatrixBlock ret = c.getMBDict(2).getMatrixBlock(); MatrixBlock exp = new MatrixBlock(4, 2, new double[] {// @@ -186,7 +187,7 @@ public void sparseSparse4() { double[] ad = new double[] {0, 1}; double[] bd = new double[] {0, 2}; - IDictionary c = DictionaryFactory.combineSDC(a, ad, b, bd); + IDictionary c = DictionaryFactory.combineSDCNoFilter(a, ad, b, bd); MatrixBlock ret = c.getMBDict(2).getMatrixBlock(); MatrixBlock exp = new MatrixBlock(4, 4, new double[] {// @@ -210,7 +211,7 @@ public void sparseSparse5() { double[] ad = new double[] {0, 1}; double[] bd = new double[] {0, 2}; - IDictionary c = DictionaryFactory.combineSDC(a, ad, b, bd); + IDictionary c = DictionaryFactory.combineSDCNoFilter(a, ad, b, bd); MatrixBlock ret = c.getMBDict(2).getMatrixBlock(); MatrixBlock exp = new MatrixBlock(6, 4, new double[] {// @@ -236,7 +237,7 @@ public void sparseSparse6() { double[] ad = new double[] {0, 1}; double[] bd = new double[] {0, 2}; - IDictionary c = DictionaryFactory.combineSDC(a, ad, b, bd); + IDictionary c = DictionaryFactory.combineSDCNoFilter(a, ad, b, bd); MatrixBlock ret = c.getMBDict(2).getMatrixBlock(); MatrixBlock exp = new MatrixBlock(9, 4, new double[] {// @@ -398,12 +399,15 @@ public void combineNotImplementedSparse6() { @Test public void sparseSparseConst1() { try { - IDictionary a = Dictionary.create(new double[] {3, 2, 7, 8}); + IDictionary ad = Dictionary.create(new double[] {3, 2, 7, 8}); // IDictionary b = Dictionary.create(new double[] {4, 4, 9, 5}); double[] bd = new double[] {0, 2}; - IDictionary c = DictionaryFactory.combineSparseConstSparseRet(a, 2, bd); + ColGroupDDC a = mockDDC(ad, ColIndexFactory.createI(0,1)); + AColGroupCompressed b = (AColGroupCompressed)ColGroupConst.create(ColIndexFactory.createI(2,3), bd); + + IDictionary c = DictionaryFactory.combineDictionaries(a,b); MatrixBlock ret = c.getMBDict(2).getMatrixBlock(); MatrixBlock exp = new MatrixBlock(2, 4, new double[] {// @@ -420,12 +424,14 @@ public void sparseSparseConst1() { @Test public void sparseSparseConst2() { try { - IDictionary a = Dictionary.create(new double[] {3, 2, 7, 8}); - // IDictionary b = Dictionary.create(new double[] {4, 4, 9, 5}); + IDictionary ad = Dictionary.create(new double[] {3, 2, 7, 8}); double[] bd = new double[] {0, 2}; - IDictionary c = DictionaryFactory.combineSparseConstSparseRet(a, 1, bd); + ColGroupDDC a = mockDDC(ad, ColIndexFactory.createI(0)); + AColGroupCompressed b = (AColGroupCompressed)ColGroupConst.create(ColIndexFactory.createI(2,3), bd); + + IDictionary c = DictionaryFactory.combineDictionaries(a,b); MatrixBlock ret = c.getMBDict(2).getMatrixBlock(); MatrixBlock exp = new MatrixBlock(2, 3, new double[] {// @@ -445,8 +451,8 @@ public void sparseSparseConst2() { public void testEmpty() { try { IDictionary d = Dictionary.create(new double[] {3, 2, 7, 8}); - AColGroup a = ColGroupDDC.create(ColIndexFactory.create(2), d, MapToFactory.create(10, 2), null); - ColGroupEmpty b = new ColGroupEmpty(ColIndexFactory.create(4)); + AColGroup a = ColGroupDDC.create(ColIndexFactory.createI(1, 2), d, MapToFactory.create(10, 2), null); + ColGroupEmpty b = new ColGroupEmpty(ColIndexFactory.createI(3, 4, 5, 6)); IDictionary c = DictionaryFactory.combineDictionaries((AColGroupCompressed) a, (AColGroupCompressed) b); MatrixBlock ret = c.getMBDict(2).getMatrixBlock(); @@ -466,9 +472,9 @@ public void testEmpty() { public void combineDictionariesSparse1() { try { IDictionary d = Dictionary.create(new double[] {3, 2, 7, 8}); - AColGroup a = ColGroupSDC.create(ColIndexFactory.create(2), 500, d, new double[] {1, 2}, + AColGroup a = ColGroupSDC.create(ColIndexFactory.createI(1, 2), 500, d, new double[] {1, 2}, OffsetFactory.createOffset(new int[] {3, 4}), MapToFactory.create(10, 2), null); - ColGroupEmpty b = new ColGroupEmpty(ColIndexFactory.create(4)); + ColGroupEmpty b = new ColGroupEmpty(ColIndexFactory.createI(3, 4, 5, 6)); IDictionary c = DictionaryFactory.combineDictionariesSparse((AColGroupCompressed) a, (AColGroupCompressed) b); MatrixBlock ret = c.getMBDict(2).getMatrixBlock(); @@ -487,17 +493,19 @@ public void combineDictionariesSparse1() { @Test public void combineDictionariesSparse2() { try { - IDictionary d = Dictionary.create(new double[] {3, 2, 7, 8}); - AColGroup b = ColGroupSDC.create(ColIndexFactory.create(2), 500, d, new double[] {1, 2}, + IDictionary d = Dictionary.create(new double[] {// + 3, 2, // + 7, 8}); + AColGroup a = ColGroupSDC.create(ColIndexFactory.createI(1, 2), 500, d, new double[] {1, 2}, OffsetFactory.createOffset(new int[] {3, 4}), MapToFactory.create(10, 2), null); - ColGroupEmpty a = new ColGroupEmpty(ColIndexFactory.create(4)); + ColGroupEmpty b = new ColGroupEmpty(ColIndexFactory.createI(3, 4, 5, 6)); IDictionary c = DictionaryFactory.combineDictionariesSparse((AColGroupCompressed) a, (AColGroupCompressed) b); MatrixBlock ret = c.getMBDict(2).getMatrixBlock(); MatrixBlock exp = new MatrixBlock(2, 6, new double[] {// - 0, 0, 0, 0, 3, 2, // - 0, 0, 0, 0, 7, 8,}); + 3, 2, 0, 0, 0, 0, // + 7, 8, 0, 0, 0, 0,}); TestUtils.compareMatricesBitAvgDistance(ret, exp, 0, 0); } catch(Exception e) { @@ -510,8 +518,8 @@ public void combineDictionariesSparse2() { public void combineMockingEmpty() { IDictionary ad = Dictionary.create(new double[] {1, 2, 3, 4}); double[] ade = new double[] {0}; - AColGroupCompressed a = mockSDC(ad, ade); - AColGroupCompressed b = mockSDC(ad, ade); + AColGroupCompressed a = mockSDC(ad, ade, ColIndexFactory.create(1)); + AColGroupCompressed b = mockSDC(ad, ade, ColIndexFactory.create(2)); Map m = new HashMap<>(); IDictionary red = DictionaryFactory.combineDictionaries(a, b, m); @@ -521,40 +529,45 @@ public void combineMockingEmpty() { @Test public void combineMockingDefault() { - IDictionary ad = Dictionary.create(new double[] {1, 2, 3, 4}); - double[] ade = new double[] {0}; - AColGroupCompressed a = mockSDC(ad, ade); - AColGroupCompressed b = mockSDC(ad, ade); - - Map m = new HashMap<>(); - m.put(0, 0); - IDictionary red = DictionaryFactory.combineDictionaries(a, b, m); - - assertEquals(red.getNumberOfValues(2), 1); - assertEquals(red, Dictionary.createNoCheck(new double[] {0, 0})); + try { + IDictionary ad = Dictionary.create(new double[] {1, 2, 3, 4}); + double[] ade = new double[] {0}; + AColGroupCompressed a = mockSDC(ad, ade, ColIndexFactory.create(1)); + AColGroupCompressed b = mockSDC(ad, ade, ColIndexFactory.create(2)); + Map m = new HashMap<>(); + m.put(0, 0); + IDictionary red = DictionaryFactory.combineDictionaries(a, b, m); + assertEquals(red.getNumberOfValues(2), 1); + assertEquals(Dictionary.createNoCheck(new double[] {0, 0}), red); + assertEquals(red, Dictionary.createNoCheck(new double[] {0, 0})); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } } @Test public void combineMockingFirstValue() { IDictionary ad = Dictionary.create(new double[] {1, 2, 3, 4}); double[] ade = new double[] {0}; - AColGroupCompressed a = mockSDC(ad, ade); - AColGroupCompressed b = mockSDC(ad, ade); + AColGroupCompressed a = mockSDC(ad, ade, ColIndexFactory.create(1)); + AColGroupCompressed b = mockSDC(ad, ade, ColIndexFactory.create(2)); Map m = new HashMap<>(); m.put(1, 0); IDictionary red = DictionaryFactory.combineDictionaries(a, b, m); assertEquals(red.getNumberOfValues(2), 1); - assertEquals(red, Dictionary.create(new double[] {1, 0})); + assertEquals(red, Dictionary.create(new double[] {0, 1})); } @Test public void combineMockingFirstAndDefault() { IDictionary ad = Dictionary.create(new double[] {1, 2, 3, 4}); double[] ade = new double[] {0}; - AColGroupCompressed a = mockSDC(ad, ade); - AColGroupCompressed b = mockSDC(ad, ade); + AColGroupCompressed a = mockSDC(ad, ade, ColIndexFactory.create(1)); + AColGroupCompressed b = mockSDC(ad, ade, ColIndexFactory.create(2)); Map m = new HashMap<>(); m.put(1, 0); @@ -562,15 +575,15 @@ public void combineMockingFirstAndDefault() { IDictionary red = DictionaryFactory.combineDictionaries(a, b, m); assertEquals(red.getNumberOfValues(2), 2); - assertEquals(red, Dictionary.create(new double[] {1, 0, 0, 0})); + assertEquals(red, Dictionary.create(new double[] {0, 1, 0, 0})); } @Test public void combineMockingMixed() { IDictionary ad = Dictionary.create(new double[] {1, 2, 3, 4}); double[] ade = new double[] {0}; - AColGroupCompressed a = mockSDC(ad, ade); - AColGroupCompressed b = mockSDC(ad, ade); + AColGroupCompressed a = mockSDC(ad, ade, ColIndexFactory.create(1)); + AColGroupCompressed b = mockSDC(ad, ade, ColIndexFactory.create(2)); Map m = new HashMap<>(); m.put(1, 0); @@ -579,15 +592,15 @@ public void combineMockingMixed() { IDictionary red = DictionaryFactory.combineDictionaries(a, b, m); assertEquals(red.getNumberOfValues(2), 3); - assertEquals(Dictionary.create(new double[] {1, 0, 0, 0, 0, 1}), red); + assertEquals(Dictionary.create(new double[] {0, 1, 0, 0, 1, 0}), red); } @Test public void combineMockingMixed2() { IDictionary ad = Dictionary.create(new double[] {1, 2, 3, 4}); double[] ade = new double[] {0}; - AColGroupCompressed a = mockSDC(ad, ade); - AColGroupCompressed b = mockSDC(ad, ade); + AColGroupCompressed a = mockSDC(ad, ade, ColIndexFactory.create(1)); + AColGroupCompressed b = mockSDC(ad, ade, ColIndexFactory.create(2)); Map m = new HashMap<>(); m.put(1, 0); @@ -596,7 +609,7 @@ public void combineMockingMixed2() { IDictionary red = DictionaryFactory.combineDictionaries(a, b, m); assertEquals(red.getNumberOfValues(2), 3); - assertEquals(Dictionary.create(new double[] {1, 0, 0, 0, 0, 2}), red); + assertEquals(Dictionary.create(new double[] {0, 1, 0, 0, 2, 0}), red); } @Test @@ -605,8 +618,8 @@ public void combineMockingSparseDenseEmpty() { IDictionary ad = Dictionary.create(new double[] {1, 2, 3, 4}); double[] ade = new double[] {0}; - AColGroupCompressed a = mockDDC(ad, 1); - AColGroupCompressed b = mockSDC(ad, ade); + AColGroupCompressed a = mockDDC(ad, ColIndexFactory.create(1)); + AColGroupCompressed b = mockSDC(ad, ade, ColIndexFactory.create(2)); Map m = new HashMap<>(); IDictionary red = DictionaryFactory.combineDictionaries(a, b, m); @@ -626,14 +639,14 @@ public void combineMockingSparseDenseOne() { IDictionary ad = Dictionary.create(new double[] {1, 2, 3, 4}); double[] ade = new double[] {0}; - AColGroupCompressed a = mockDDC(ad, 1); - AColGroupCompressed b = mockSDC(ad, ade); + AColGroupCompressed a = mockDDC(ad, ColIndexFactory.create(1)); + AColGroupCompressed b = mockSDC(ad, ade, ColIndexFactory.create(2)); Map m = new HashMap<>(); m.put(0, 0); IDictionary red = DictionaryFactory.combineDictionaries(a, b, m); assertEquals(1, red.getNumberOfValues(2)); - assertEquals(Dictionary.createNoCheck(new double[] {1, 0}), red); + assertEquals(Dictionary.createNoCheck(new double[] {0, 1}), red); } catch(Exception e) { e.printStackTrace(); @@ -647,8 +660,8 @@ public void combineMockingSparseDenseMixed1() { IDictionary ad = Dictionary.create(new double[] {1, 2, 3, 4}); double[] ade = new double[] {0}; - AColGroupCompressed a = mockDDC(ad, 1); - AColGroupCompressed b = mockSDC(ad, ade); + AColGroupCompressed a = mockDDC(ad, ColIndexFactory.create(1)); + AColGroupCompressed b = mockSDC(ad, ade, ColIndexFactory.create(2)); Map m = new HashMap<>(); m.put(0, 1); @@ -656,7 +669,7 @@ public void combineMockingSparseDenseMixed1() { IDictionary red = DictionaryFactory.combineDictionaries(a, b, m); assertEquals(2, red.getNumberOfValues(2)); - assertEquals(Dictionary.createNoCheck(new double[] {2, 0, 1, 0}), red); + assertEquals(Dictionary.createNoCheck(new double[] {0, 2, 0, 1}), red); } catch(Exception e) { e.printStackTrace(); @@ -670,8 +683,8 @@ public void combineMockingSparseDenseMixed2() { IDictionary ad = Dictionary.create(new double[] {1, 2, 3, 4}); double[] ade = new double[] {0}; - AColGroupCompressed a = mockDDC(ad, 1); - AColGroupCompressed b = mockSDC(ad, ade); + AColGroupCompressed a = mockDDC(ad, ColIndexFactory.create(1)); + AColGroupCompressed b = mockSDC(ad, ade, ColIndexFactory.create(2)); Map m = new HashMap<>(); m.put(0, 1); @@ -680,7 +693,7 @@ public void combineMockingSparseDenseMixed2() { IDictionary red = DictionaryFactory.combineDictionaries(a, b, m); assertEquals(3, red.getNumberOfValues(2)); - assertEquals(Dictionary.createNoCheck(new double[] {2, 0, 1, 0, 1, 1}), red); + assertEquals(Dictionary.createNoCheck(new double[] {0, 2, 0, 1, 1, 1}), red); } catch(Exception e) { e.printStackTrace(); @@ -694,8 +707,8 @@ public void combineMockingSparseDenseMixed3() { IDictionary ad = Dictionary.create(new double[] {1, 2, 3, 4}); double[] ade = new double[] {0}; - AColGroupCompressed a = mockDDC(ad, 1); - AColGroupCompressed b = mockSDC(ad, ade); + AColGroupCompressed a = mockDDC(ad, ColIndexFactory.create(1)); + AColGroupCompressed b = mockSDC(ad, ade, ColIndexFactory.create(2)); Map m = new HashMap<>(); m.put(0, 1); @@ -705,7 +718,87 @@ public void combineMockingSparseDenseMixed3() { IDictionary red = DictionaryFactory.combineDictionaries(a, b, m); assertEquals(4, red.getNumberOfValues(2)); - assertEquals(Dictionary.createNoCheck(new double[] {2, 0, 1, 0, 2, 1, 1, 1}), red); + assertEquals(Dictionary.createNoCheck(new double[] {0, 2, 0, 1, 1, 2, 1, 1}), red); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void combineFailCase1() { + try { + + IDictionary ad = Dictionary.create(new double[] {3, 1, 2}); + IDictionary ab = Dictionary.create(new double[] {2, 3}); + double[] ade = new double[] {1}; + AColGroupCompressed a = mockDDC(ad, ColIndexFactory.create(1)); + AColGroupCompressed b = mockSDC(ab, ade, ColIndexFactory.create(2)); + + Map m = new HashMap<>(); + // 0=8, 1=7, 2=5, 3=0, 4=6, 5=2, 6=4, 7=1, 8=3 + m.put(0, 8); + m.put(1, 7); + m.put(2, 5); + m.put(3, 0); + m.put(4, 6); + m.put(5, 2); + m.put(6, 4); + m.put(7, 1); + m.put(8, 3); + IDictionary red = DictionaryFactory.combineDictionaries(a, b, m); + + assertEquals(9, red.getNumberOfValues(2)); + assertEquals(Dictionary.createNoCheck(// + new double[] {// + 2, 3, // + 3, 1, // + 2, 2, // + 3, 2, // + 3, 3, // + 1, 2, // + 2, 1, // + 1, 1, // + 1, 3,// + }), red); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void combineFailCase2() { + try { + + IDictionary ad = Dictionary.create(new double[] {3, 1, 2}); + IDictionary ab = Dictionary.create(new double[] {2, 3}); + double[] ade = new double[] {1}; + AColGroupCompressed a = mockDDC(ad, ColIndexFactory.createI(1)); + AColGroupCompressed b = mockSDC(ab, ade, ColIndexFactory.createI(2)); + + Map m = new HashMap<>(); + for(int i = 0; i < 9; i++) { + m.put(i, i); + } + + IDictionary red = DictionaryFactory.combineDictionaries(a, b, m); + + assertEquals(9, red.getNumberOfValues(2)); + assertEquals(Dictionary.createNoCheck(// + new double[] {// + 3, 1, // + 1, 1, // + 2, 1, // + 3, 2, // + 1, 2, // + 2, 2, // + 3, 3, // + 1, 3, // + 2, 3,// + }), red); } catch(Exception e) { e.printStackTrace(); @@ -713,20 +806,106 @@ public void combineMockingSparseDenseMixed3() { } } - private ASDC mockSDC(IDictionary ad, double[] def) { + @Test + public void testCombineSDC() { + IDictionary ad = Dictionary.create(new double[] {2, 3}); + IDictionary ab = Dictionary.create(new double[] {1, 2}); + double[] ade = new double[] {1.0}; + double[] abe = new double[] {3.0}; + AColGroupCompressed a = mockSDC(ad, ade, ColIndexFactory.createI(1)); + AColGroupCompressed b = mockSDC(ab, abe, ColIndexFactory.createI(2)); + Map m = new HashMap<>(); + m.put(0, 8); + m.put(1, 0); + m.put(2, 4); + m.put(3, 7); + m.put(4, 6); + m.put(5, 1); + m.put(6, 5); + m.put(7, 2); + m.put(8, 3); + + IDictionary red = DictionaryFactory.combineDictionaries(a, b, m); + + assertEquals(9, red.getNumberOfValues(2)); + assertEquals(Dictionary.createNoCheck(// + new double[] {// + 2, 3, // + 3, 1, // + 2, 2, // + 3, 2, // + 3, 3, // + 1, 2, // + 2, 1, // + 1, 1, // + 1, 3,// + }), red); + } + + @Test + public void testCombineSDCRange() { + IDictionary ad = Dictionary.create(new double[] {2, 3}); + IDictionary ab = Dictionary.create(new double[] {1, 2}); + double[] ade = new double[] {1.0}; + double[] abe = new double[] {3.0}; + AColGroupCompressed a = mockSDC(ad, ade, ColIndexFactory.createI(1)); + AColGroupCompressed b = mockSDC(ab, abe, ColIndexFactory.createI(2)); + Map m = new HashMap<>(); + for(int i = 0; i < 9; i++) { + m.put(i, i); + } + IDictionary red = DictionaryFactory.combineDictionaries(a, b, m); + + assertEquals(9, red.getNumberOfValues(2)); + assertEquals(Dictionary.createNoCheck(// + new double[] {// + 1, 3, // + 2, 3, // + 3, 3, // + 1, 1, // + 2, 1, // + 3, 1, // + 1, 2, // + 2, 2, // + 3, 2,// + }), red); + } + + // private ASDC mockSDC(IDictionary ad, double[] def) { + // ASDC a = mock(ASDC.class); + // when(a.getCompType()).thenReturn(CompressionType.SDC); + // when(a.getDictionary()).thenReturn(ad); + // when(a.getDefaultTuple()).thenReturn(def); + // when(a.getNumCols()).thenReturn(def.length); + // when(a.getColIndices()).thenReturn(ColIndexFactory.create(def.length)); + // return a; + // } + + private ASDC mockSDC(IDictionary ad, double[] def, IColIndex c) { ASDC a = mock(ASDC.class); when(a.getCompType()).thenReturn(CompressionType.SDC); when(a.getDictionary()).thenReturn(ad); when(a.getDefaultTuple()).thenReturn(def); when(a.getNumCols()).thenReturn(def.length); + when(a.getColIndices()).thenReturn(c); return a; } - private ColGroupDDC mockDDC(IDictionary ad, int nCol) { + // private ColGroupDDC mockDDC(IDictionary ad, int nCol) { + // ColGroupDDC a = mock(ColGroupDDC.class); + // when(a.getCompType()).thenReturn(CompressionType.DDC); + // when(a.getDictionary()).thenReturn(ad); + // when(a.getNumCols()).thenReturn(nCol); + // when(a.getColIndices()).thenReturn(ColIndexFactory.create(nCol)); + // return a; + // } + + private ColGroupDDC mockDDC(IDictionary ad, IColIndex c) { ColGroupDDC a = mock(ColGroupDDC.class); when(a.getCompType()).thenReturn(CompressionType.DDC); when(a.getDictionary()).thenReturn(ad); - when(a.getNumCols()).thenReturn(nCol); + when(a.getNumCols()).thenReturn(c.size()); + when(a.getColIndices()).thenReturn(c); return a; } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/dictionary/CustomDictionaryTest.java b/src/test/java/org/apache/sysds/test/component/compress/dictionary/CustomDictionaryTest.java index 8dd8a1165bf..c051dfc5368 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/dictionary/CustomDictionaryTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/dictionary/CustomDictionaryTest.java @@ -28,6 +28,8 @@ import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.junit.Test; @@ -153,4 +155,46 @@ public void createZeroRowMatrixBlock() { MatrixBlockDictionary.create(new MatrixBlock(0, 10, 10.0)); } + @Test + public void IdentityDictionaryEquals() { + IDictionary a = new IdentityDictionary(10); + IDictionary b = new IdentityDictionary(10); + assertTrue(a.equals(b)); + } + + @Test + public void IdentityDictionaryNotEquals() { + IDictionary a = new IdentityDictionary(10); + IDictionary b = new IdentityDictionary(11); + assertFalse(a.equals(b)); + } + + @Test + public void IdentityDictionaryNotEquals2() { + IDictionary a = new IdentityDictionary(10); + IDictionary b = new IdentityDictionary(11, false); + assertFalse(a.equals(b)); + } + + @Test + public void IdentityDictionaryEquals2() { + IDictionary a = new IdentityDictionary(11, false); + IDictionary b = new IdentityDictionary(11, false); + assertTrue(a.equals(b)); + } + + @Test + public void IdentityDictionaryEquals2v() { + IDictionary a = new IdentityDictionary(11); + IDictionary b = new IdentityDictionary(11, false); + assertTrue(a.equals(b)); + } + + @Test + public void IdentityDictionaryNotEquals3() { + IDictionary a = new IdentityDictionary(11, true); + IDictionary b = new IdentityDictionary(11, false); + assertFalse(a.equals(b)); + } + } diff --git a/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java b/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java index 9307930f1d2..c2d1517a0f5 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java @@ -29,6 +29,7 @@ import java.util.Collection; import java.util.List; +import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.DMLCompressionException; @@ -36,8 +37,12 @@ import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary; +import org.apache.sysds.runtime.compress.colgroup.indexes.ArrayIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.colgroup.indexes.RangeIndex; +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.data.SparseBlockFactory; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode; import org.apache.sysds.runtime.functionobjects.Divide; @@ -104,6 +109,73 @@ public static Collection data() { 0, 0, 0, 0}), 5, 4}); + tests.add(new Object[] {new IdentityDictionary(4, true), // + Dictionary.create(new double[] {// + 1, 0, 0, 0, // + 0, 1, 0, 0, // + 0, 0, 1, 0, // + 0, 0, 0, 1, // + 0, 0, 0, 0}).getMBDict(4), + 5, 4}); + + tests.add(new Object[] {new IdentityDictionary(20, false), // + MatrixBlockDictionary.create(// + new MatrixBlock(20, 20, 20L, // + SparseBlockFactory.createIdentityMatrix(20)), + false), + 20, 20}); + + tests.add(new Object[] {new IdentityDictionary(20, false), // + Dictionary.create(new double[] {// + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, // + }), // + 20, 20}); + + tests.add(new Object[] {new IdentityDictionary(20, true), // + Dictionary.create(new double[] {// + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + }), // + 21, 20}); + create(tests, 30, 300, 0.2); } catch(Exception e) { @@ -149,6 +221,22 @@ public void sum() { assertEquals(as, bs, 0.0000001); } + @Test + public void sum2() { + int[] counts = getCounts(nRow, 124); + double as = a.sum(counts, nCol); + double bs = b.sum(counts, nCol); + assertEquals(as, bs, 0.0000001); + } + + @Test + public void sum3() { + int[] counts = getCounts(nRow, 124444); + double as = a.sum(counts, nCol); + double bs = b.sum(counts, nCol); + assertEquals(as, bs, 0.0000001); + } + @Test public void getValues() { try { @@ -442,6 +530,11 @@ public void equalsEl() { assertEquals(a, b); } + @Test + public void equalsElOp() { + assertEquals(b, a); + } + @Test public void opRightMinus() { BinaryOperator op = new BinaryOperator(Minus.getMinusFnObject()); @@ -479,9 +572,16 @@ public void opRightDiv() { } private void opRight(BinaryOperator op, double[] vals, IColIndex cols) { - IDictionary aa = a.binOpRight(op, vals, cols); - IDictionary bb = b.binOpRight(op, vals, cols); - compare(aa, bb, nRow, nCol); + try { + + IDictionary aa = a.binOpRight(op, vals, cols); + IDictionary bb = b.binOpRight(op, vals, cols); + compare(aa, bb, nRow, nCol); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } } private void opRight(BinaryOperator op, double[] vals) { @@ -529,6 +629,44 @@ public void testAddToEntry4() { } } + @Test + public void testAddToEntryRep1() { + double[] ret1 = new double[nCol]; + a.addToEntry(ret1, 0, 0, nCol, 16); + double[] ret2 = new double[nCol]; + b.addToEntry(ret2, 0, 0, nCol, 16); + assertTrue(Arrays.equals(ret1, ret2)); + } + + @Test + public void testAddToEntryRep2() { + double[] ret1 = new double[nCol * 2]; + a.addToEntry(ret1, 0, 1, nCol, 3214); + double[] ret2 = new double[nCol * 2]; + b.addToEntry(ret2, 0, 1, nCol, 3214); + assertTrue(Arrays.equals(ret1, ret2)); + } + + @Test + public void testAddToEntryRep3() { + double[] ret1 = new double[nCol * 3]; + a.addToEntry(ret1, 0, 2, nCol, 222); + double[] ret2 = new double[nCol * 3]; + b.addToEntry(ret2, 0, 2, nCol, 222); + assertTrue(Arrays.equals(ret1, ret2)); + } + + @Test + public void testAddToEntryRep4() { + if(a.getNumberOfValues(nCol) > 2) { + double[] ret1 = new double[nCol * 3]; + a.addToEntry(ret1, 2, 2, nCol, 321); + double[] ret2 = new double[nCol * 3]; + b.addToEntry(ret2, 2, 2, nCol, 321); + assertTrue(Arrays.equals(ret1, ret2)); + } + } + @Test public void testAddToEntryVectorized1() { try { @@ -544,6 +682,37 @@ public void testAddToEntryVectorized1() { } } + @Test + public void max() { + aggregate(Builtin.getBuiltinFnObject(BuiltinCode.MAX)); + } + + @Test + public void min() { + aggregate(Builtin.getBuiltinFnObject(BuiltinCode.MIN)); + } + + @Test(expected = NotImplementedException.class) + public void cMax() { + aggregate(Builtin.getBuiltinFnObject(BuiltinCode.CUMMAX)); + throw new NotImplementedException(); + } + + private void aggregate(Builtin fn) { + try { + double aa = a.aggregate(0, fn); + double bb = b.aggregate(0, fn); + assertEquals(aa, bb, 0.0); + } + catch(NotImplementedException ee) { + throw ee; + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + @Test public void testAddToEntryVectorized2() { try { @@ -607,13 +776,24 @@ public void containsValueWithReference(double value, double[] reference) { b.containsValueWithReference(value, reference)); } + private static void compare(IDictionary a, IDictionary b, int nCol) { + assertEquals(a.getNumberOfValues(nCol), b.getNumberOfValues(nCol)); + compare(a, b, a.getNumberOfValues(nCol), nCol); + } + private static void compare(IDictionary a, IDictionary b, int nRow, int nCol) { try { - String errorM = a.getClass().getSimpleName() + " " + b.getClass().getSimpleName(); for(int i = 0; i < nRow; i++) - for(int j = 0; j < nCol; j++) - assertEquals(errorM, a.getValue(i, j, nCol), b.getValue(i, j, nCol), 0.0001); + for(int j = 0; j < nCol; j++) { + double aa = a.getValue(i, j, nCol); + double bb = b.getValue(i, j, nCol); + boolean eq = Math.abs(aa - bb) < 0.0001; + if(!eq) { + assertEquals(errorM + " cell:<" + i + "," + j + ">", a.getValue(i, j, nCol), b.getValue(i, j, nCol), + 0.0001); + } + } } catch(Exception e) { e.printStackTrace(); @@ -641,6 +821,305 @@ public void preaggValuesFromDense() { } } + @Test + public void rightMMPreAggSparse() { + final int nColsOut = 30; + MatrixBlock sparse = TestUtils.generateTestMatrixBlock(1000, nColsOut, -10, 10, 0.1, 100); + sparse = TestUtils.ceil(sparse); + sparse.denseToSparse(true); + SparseBlock sb = sparse.getSparseBlock(); + if(sb == null) + throw new NotImplementedException(); + + IColIndex agCols = new RangeIndex(nColsOut); + IColIndex thisCols = new RangeIndex(0, nCol); + + int nVals = a.getNumberOfValues(nCol); + try { + + IDictionary aa = a.rightMMPreAggSparse(nVals, sb, thisCols, agCols, nColsOut); + IDictionary bb = b.rightMMPreAggSparse(nVals, sb, thisCols, agCols, nColsOut); + compare(aa, bb, nColsOut); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + + } + + + @Test + public void rightMMPreAggSparse2() { + final int nColsOut = 1000; + MatrixBlock sparse = TestUtils.generateTestMatrixBlock(1000, nColsOut, -10, 10, 0.01, 100); + sparse = TestUtils.ceil(sparse); + sparse.denseToSparse(true); + SparseBlock sb = sparse.getSparseBlock(); + if(sb == null) + throw new NotImplementedException(); + + IColIndex agCols = new RangeIndex(nColsOut); + IColIndex thisCols = new RangeIndex(0, nCol); + + int nVals = a.getNumberOfValues(nCol); + try { + + IDictionary aa = a.rightMMPreAggSparse(nVals, sb, thisCols, agCols, nColsOut); + IDictionary bb = b.rightMMPreAggSparse(nVals, sb, thisCols, agCols, nColsOut); + compare(aa, bb, nColsOut); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + + } + + @Test + public void rightMMPreAggSparseDifferentColumns() { + final int nColsOut = 3; + MatrixBlock sparse = TestUtils.generateTestMatrixBlock(1000, 50, -10, 10, 0.1, 100); + sparse = TestUtils.ceil(sparse); + sparse.denseToSparse(true); + SparseBlock sb = sparse.getSparseBlock(); + if(sb == null) + throw new NotImplementedException(); + + IColIndex agCols = new ArrayIndex(new int[] {4, 10, 38}); + IColIndex thisCols = new RangeIndex(0, nCol); + + int nVals = a.getNumberOfValues(nCol); + try { + + IDictionary aa = a.rightMMPreAggSparse(nVals, sb, thisCols, agCols, 50); + IDictionary bb = b.rightMMPreAggSparse(nVals, sb, thisCols, agCols, 50); + compare(aa, bb, nColsOut); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + + } + + @Test + public void MMDictScalingDense() { + double[] left = TestUtils.ceil(TestUtils.generateTestVector(a.getNumberOfValues(nCol) * 3, -10, 10, 1.0, 3214)); + IColIndex rowsLeft = ColIndexFactory.createI(1, 2, 3); + IColIndex colsRight = ColIndexFactory.create(0, nCol); + int[] scaling = new int[a.getNumberOfValues(nCol)]; + for(int i = 0; i < a.getNumberOfValues(nCol); i++) + scaling[i] = i + 1; + + try { + + MatrixBlock retA = new MatrixBlock(5, nCol, 0); + retA.allocateDenseBlock(); + a.MMDictScalingDense(left, rowsLeft, colsRight, retA, scaling); + + MatrixBlock retB = new MatrixBlock(5, nCol, 0); + retB.allocateDenseBlock(); + b.MMDictScalingDense(left, rowsLeft, colsRight, retB, scaling); + + TestUtils.compareMatricesBitAvgDistance(retA, retB, 10, 10); + } + catch(Exception e) { + + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void MMDictScalingDenseOffset() { + double[] left = TestUtils.generateTestVector(a.getNumberOfValues(nCol) * 3, -10, 10, 1.0, 3214); + IColIndex rowsLeft = ColIndexFactory.createI(1, 2, 3); + IColIndex colsRight = ColIndexFactory.create(3, nCol + 3); + int[] scaling = new int[a.getNumberOfValues(nCol)]; + for(int i = 0; i < a.getNumberOfValues(nCol); i++) + scaling[i] = i; + + try { + + MatrixBlock retA = new MatrixBlock(5, nCol + 3, 0); + retA.allocateDenseBlock(); + a.MMDictScalingDense(left, rowsLeft, colsRight, retA, scaling); + + MatrixBlock retB = new MatrixBlock(5, nCol + 3, 0); + retB.allocateDenseBlock(); + b.MMDictScalingDense(left, rowsLeft, colsRight, retB, scaling); + + TestUtils.compareMatricesBitAvgDistance(retA, retB, 10, 10); + } + catch(Exception e) { + + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void MMDictDense() { + double[] left = TestUtils.ceil(TestUtils.generateTestVector(a.getNumberOfValues(nCol) * 3, -10, 10, 1.0, 3214)); + IColIndex rowsLeft = ColIndexFactory.createI(1, 2, 3); + IColIndex colsRight = ColIndexFactory.create(0, nCol); + + try { + + MatrixBlock retA = new MatrixBlock(5, nCol, 0); + retA.allocateDenseBlock(); + a.MMDictDense(left, rowsLeft, colsRight, retA); + + MatrixBlock retB = new MatrixBlock(5, nCol, 0); + retB.allocateDenseBlock(); + b.MMDictDense(left, rowsLeft, colsRight, retB); + + TestUtils.compareMatricesBitAvgDistance(retA, retB, 10, 10); + } + catch(Exception e) { + + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void MMDictDenseOffset() { + double[] left = TestUtils.generateTestVector(a.getNumberOfValues(nCol) * 3, -10, 10, 1.0, 3214); + IColIndex rowsLeft = ColIndexFactory.createI(1, 2, 3); + IColIndex colsRight = ColIndexFactory.create(3, nCol + 3); + + try { + + MatrixBlock retA = new MatrixBlock(5, nCol + 3, 0); + retA.allocateDenseBlock(); + a.MMDictDense(left, rowsLeft, colsRight, retA); + + MatrixBlock retB = new MatrixBlock(5, nCol + 3, 0); + retB.allocateDenseBlock(); + b.MMDictDense(left, rowsLeft, colsRight, retB); + + TestUtils.compareMatricesBitAvgDistance(retA, retB, 10, 10); + } + catch(Exception e) { + + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void sumAllRowsToDouble() { + double[] aa = a.sumAllRowsToDouble(nCol); + double[] bb = b.sumAllRowsToDouble(nCol); + TestUtils.compareMatrices(aa, bb, 0.001); + } + + @Test + public void sumAllRowsToDoubleWithDefault() { + double[] def = TestUtils.generateTestVector(nCol, 1, 10, 1.0, 3215213); + double[] aa = a.sumAllRowsToDoubleWithDefault(def); + double[] bb = b.sumAllRowsToDoubleWithDefault(def); + TestUtils.compareMatrices(aa, bb, 0.001); + } + + @Test + public void sumAllRowsToDoubleWithReference() { + double[] def = TestUtils.generateTestVector(nCol, 1, 10, 1.0, 3215213); + double[] aa = a.sumAllRowsToDoubleWithReference(def); + double[] bb = b.sumAllRowsToDoubleWithReference(def); + TestUtils.compareMatrices(aa, bb, 0.001); + } + + @Test + public void sumAllRowsToDoubleSq() { + double[] aa = a.sumAllRowsToDoubleSq(nCol); + double[] bb = b.sumAllRowsToDoubleSq(nCol); + TestUtils.compareMatrices(aa, bb, 0.001); + } + + @Test + public void sumAllRowsToDoubleSqWithDefault() { + double[] def = TestUtils.generateTestVector(nCol, 1, 10, 1.0, 3215213); + double[] aa = a.sumAllRowsToDoubleSqWithDefault(def); + double[] bb = b.sumAllRowsToDoubleSqWithDefault(def); + TestUtils.compareMatrices(aa, bb, 0.001); + } + + @Test + public void sumAllRowsToDoubleSqWithReference() { + double[] def = TestUtils.generateTestVector(nCol, 1, 10, 1.0, 3215213); + double[] aa = a.sumAllRowsToDoubleSqWithReference(def); + double[] bb = b.sumAllRowsToDoubleSqWithReference(def); + TestUtils.compareMatrices(aa, bb, 0.001); + } + + @Test + public void aggColsMin() { + IColIndex cols = ColIndexFactory.create(2, nCol + 2); + Builtin m = Builtin.getBuiltinFnObject(BuiltinCode.MIN); + + double[] aa = new double[nCol + 3]; + a.aggregateCols(aa, m, cols); + double[] bb = new double[nCol + 3]; + b.aggregateCols(bb, m, cols); + + TestUtils.compareMatrices(aa, bb, 0.001); + } + + @Test + public void aggColsMax() { + IColIndex cols = ColIndexFactory.create(2, nCol + 2); + Builtin m = Builtin.getBuiltinFnObject(BuiltinCode.MAX); + + double[] aa = new double[nCol + 3]; + a.aggregateCols(aa, m, cols); + double[] bb = new double[nCol + 3]; + b.aggregateCols(bb, m, cols); + + TestUtils.compareMatrices(aa, bb, 0.001); + } + + @Test + public void getValue() { + int nCell = nCol * a.getNumberOfValues(nCol); + for(int i = 0; i < nCell; i++) + assertEquals(a.getValue(i), b.getValue(i), 0.0000); + } + + @Test + public void colSum() { + IColIndex cols = ColIndexFactory.create(2, nCol + 2); + int[] counts = new int[a.getNumberOfValues(nCol)]; + for(int i = 0; i < counts.length; i++) { + counts[i] = i + 1; + } + + double[] aa = new double[nCol + 3]; + a.colSum(aa, counts, cols); + double[] bb = new double[nCol + 3]; + b.colSum(bb, counts, cols); + + TestUtils.compareMatrices(aa, bb, 0.001); + } + + @Test + public void colProduct() { + IColIndex cols = ColIndexFactory.create(2, nCol + 2); + int[] counts = new int[a.getNumberOfValues(nCol)]; + for(int i = 0; i < counts.length; i++) { + counts[i] = i + 1; + } + + double[] aa = new double[nCol + 3]; + a.colProduct(aa, counts, cols); + double[] bb = new double[nCol + 3]; + b.colProduct(bb, counts, cols); + + TestUtils.compareMatrices(aa, bb, 0.001); + } + public void productWithDefault(double retV, double[] def) { // Shared final int[] counts = getCounts(nRow, 1324); diff --git a/src/test/java/org/apache/sysds/test/component/compress/indexes/CustomIndexTest.java b/src/test/java/org/apache/sysds/test/component/compress/indexes/CustomIndexTest.java index 3286a3eed61..9fa404ca77f 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/indexes/CustomIndexTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/indexes/CustomIndexTest.java @@ -19,6 +19,7 @@ package org.apache.sysds.test.component.compress.indexes; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotEquals; @@ -41,6 +42,7 @@ import org.apache.sysds.runtime.compress.colgroup.indexes.TwoIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.TwoRangesIndex; import org.apache.sysds.runtime.compress.utils.Util; +import org.apache.sysds.runtime.matrix.data.Pair; import org.junit.Test; import org.mockito.Mockito; @@ -1027,4 +1029,64 @@ public void containsAnyArray2() { IColIndex b = new RangeIndex(3, 11); assertTrue(a.containsAny(b)); } + + @Test + public void reordering1(){ + IColIndex a = ColIndexFactory.createI(1,3,5); + IColIndex b = ColIndexFactory.createI(2); + + assertFalse(IColIndex.inOrder(a, b)); + Pair r = IColIndex.reorderingIndexes(a, b); + + int[] ra = r.getKey(); + int[] rb = r.getValue(); + + assertArrayEquals(new int[]{0,2,3}, ra); + assertArrayEquals(new int[]{1}, rb); + } + + @Test + public void reordering2(){ + IColIndex a = ColIndexFactory.createI(1,3,5); + IColIndex b = ColIndexFactory.createI(2,4); + + assertFalse(IColIndex.inOrder(a, b)); + Pair r = IColIndex.reorderingIndexes(a, b); + + int[] ra = r.getKey(); + int[] rb = r.getValue(); + + assertArrayEquals(new int[]{0,2,4}, ra); + assertArrayEquals(new int[]{1,3}, rb); + } + + @Test + public void reordering3(){ + IColIndex a = ColIndexFactory.createI(1,3,5); + IColIndex b = ColIndexFactory.createI(0, 2,4); + + assertFalse(IColIndex.inOrder(a, b)); + Pair r = IColIndex.reorderingIndexes(a, b); + + int[] ra = r.getKey(); + int[] rb = r.getValue(); + + assertArrayEquals(new int[]{1,3,5}, ra); + assertArrayEquals(new int[]{0,2,4}, rb); + } + + @Test + public void reordering4(){ + IColIndex a = ColIndexFactory.createI(1,5); + IColIndex b = ColIndexFactory.createI(0,2,3,4); + + assertFalse(IColIndex.inOrder(a, b)); + Pair r = IColIndex.reorderingIndexes(a, b); + + int[] ra = r.getKey(); + int[] rb = r.getValue(); + + assertArrayEquals(new int[]{1,5}, ra); + assertArrayEquals(new int[]{0,2,3,4}, rb); + } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/indexes/IndexesTest.java b/src/test/java/org/apache/sysds/test/component/compress/indexes/IndexesTest.java index 871636ed477..1f5deccf779 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/indexes/IndexesTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/indexes/IndexesTest.java @@ -41,6 +41,7 @@ import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.indexes.ArrayIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; +import org.apache.sysds.runtime.compress.colgroup.indexes.CombinedIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex.SliceResult; import org.apache.sysds.runtime.compress.colgroup.indexes.IIterate; @@ -145,6 +146,7 @@ public static Collection data() { tests.add(createTwoRange(1, 10, 22, 30)); tests.add(createTwoRange(9, 11, 22, 30)); tests.add(createTwoRange(9, 11, 22, 60)); + tests.add(createCombined(9, 11, 22)); } catch(Exception e) { e.printStackTrace(); @@ -349,6 +351,19 @@ public void equalsSizeDiff_twoRanges2() { assertNotEquals(actual, c); } + @Test + public void equalsCombine(){ + RangeIndex a = new RangeIndex(9, 11); + SingleIndex b = new SingleIndex(22); + IColIndex c = a.combine(b); + if(eq(expected, c)){ + LOG.error(c.size()); + compare(expected, c); + compare(c, actual); + } + + } + @Test public void equalsItself() { assertEquals(actual, actual); @@ -395,10 +410,16 @@ public void combineTwoAbove() { @Test public void combineTwoAround() { - IColIndex b = new TwoIndex(expected[0] - 1, expected[expected.length - 1] + 1); - IColIndex c = actual.combine(b); - assertTrue(c.containsStrict(actual, b)); - assertTrue(c.containsStrict(b, actual)); + try { + IColIndex b = new TwoIndex(expected[0] - 1, expected[expected.length - 1] + 1); + IColIndex c = actual.combine(b); + assertTrue(c.containsStrict(actual, b)); + assertTrue(c.containsStrict(b, actual)); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } } @Test @@ -417,7 +438,10 @@ public void hashCodeEquals() { @Test public void estimateInMemorySizeIsNotToBig() { - assertTrue(MemoryEstimates.intArrayCost(expected.length) >= actual.estimateInMemorySize() - 16); + if(actual instanceof CombinedIndex) + assertTrue(MemoryEstimates.intArrayCost(expected.length) >= actual.estimateInMemorySize() - 64); + else + assertTrue(MemoryEstimates.intArrayCost(expected.length) >= actual.estimateInMemorySize() - 16); } @Test @@ -594,6 +618,17 @@ private void shift(int i) { compare(expected, actual.shift(i), i); } + private static boolean eq(int[] expected, IColIndex actual) { + if(expected.length == actual.size()) { + for(int i = 0; i < expected.length; i++) + if(expected[i] != actual.get(i)) + return false; + return true; + } + else + return false; + } + public static void compare(int[] expected, IColIndex actual) { assertEquals(expected.length, actual.size()); for(int i = 0; i < expected.length; i++) @@ -673,4 +708,19 @@ private static Object[] createTwoRange(int l1, int u1, int l2, int u2) { exp[j] = i; return new Object[] {exp, c}; } + + private static Object[] createCombined(int l1, int u1, int o) { + RangeIndex a = new RangeIndex(l1, u1); + SingleIndex b = new SingleIndex(o); + IColIndex c = a.combine(b); + int[] exp = new int[u1 - l1 + 1]; + + for(int i = l1, j = 0; i < u1; i++, j++) + exp[j] = i; + + exp[exp.length - 1] = o; + + return new Object[] {exp, c}; + + } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/io/IOTest.java b/src/test/java/org/apache/sysds/test/component/compress/io/IOTest.java index 3c18cf049bc..daa988b396f 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/io/IOTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/io/IOTest.java @@ -157,10 +157,11 @@ protected static void writeR(MatrixBlock src, String path, int rep) throws Excep } protected static void writeAndReadR(MatrixBlock mb, int blen, int rep) throws Exception { + String filename = getName(); try { - String filename = getName(); File f = new File(filename); - f.delete(); + if(f.isFile() || f.isDirectory()) + f.delete(); WriterCompressed.writeCompressedMatrixToHDFS(mb, filename, blen); File f2 = new File(filename); assertTrue(f2.isFile() || f2.isDirectory()); @@ -168,15 +169,21 @@ protected static void writeAndReadR(MatrixBlock mb, int blen, int rep) throws Ex IOCompressionTestUtils.verifyEquivalence(mb, mbr); } catch(Exception e) { - + File f = new File(filename); + if(f.isFile() || f.isDirectory()) + f.delete(); if(rep < 3) { Thread.sleep(1000); writeAndReadR(mb, blen, rep + 1); return; } - e.printStackTrace(); throw e; } + finally{ + File f = new File(filename); + if(f.isFile() || f.isDirectory()) + f.delete(); + } } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/offset/OffsetTests.java b/src/test/java/org/apache/sysds/test/component/compress/offset/OffsetTests.java index 73c26e1c26c..8b8cd49e35d 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/offset/OffsetTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/offset/OffsetTests.java @@ -568,11 +568,28 @@ public void slice100to10000() { slice(100, 10000); } + @Test + public void verify(){ + o.verify(o.getSize()); + } + @Test public void slice1to4() { slice(1, 4); } + @Test + public void slice() { + if(data.length > 1) { + int n = data[data.length - 1]; + for(int i = 0; i < n && i < 100; i++) { + for(int j = i; j < n + 1 && j < 100; j++) { + slice(i, j, false); + } + } + } + } + @Test public void sliceAllSpecific() { if(data.length > 1) @@ -580,13 +597,19 @@ public void sliceAllSpecific() { } private void slice(int l, int u) { + slice(l, u, false); + } + + private void slice(int l, int u, boolean str){ try { OffsetSliceInfo a = o.slice(l, u); - a.offsetSlice.toString(); + if(str) + a.offsetSlice.toString(); if(data.length > 0 && data[data.length - 1] > u) { AIterator it = a.offsetSlice.getIterator(); + a.offsetSlice.verify(a.uIndex - a.lIndex); int i = 0; while(i < data.length && data[i] < l) i++; @@ -601,7 +624,7 @@ private void slice(int l, int u) { } catch(Exception e) { e.printStackTrace(); - fail("Failed to slice first 100"); + fail("Failed to slice range: " + l + " -> " + u + " in:\n" + o); } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java b/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java index cf084e9ccc3..51845568172 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java @@ -86,7 +86,7 @@ public static Collection data() { // Simple tests no loops verifying basic behavior tests.add(new Object[] {0, 0, 0, 0, 0, 0, 1, 0, false, false, "sum.dml", args}); tests.add(new Object[] {0, 0, 0, 0, 0, 0, 1, 0, false, false, "mean.dml", args}); - tests.add(new Object[] {0, 0, 0, 0, 0, 0, 1, 1, false, false, "plus.dml", args}); + tests.add(new Object[] {0, 0, 0, 0, 0, 0, 2, 1, false, false, "plus.dml", args}); tests.add(new Object[] {0, 0, 0, 0, 0, 0, 2, 0, false, false, "sliceCols.dml", args}); tests.add(new Object[] {0, 0, 0, 0, 0, 0, 2, 0, false, false, "sliceIndex.dml", args}); // tests.add(new Object[] {0, 0, 0, 1, 0, 0, 0, 0, false, false, "leftMult.dml", args}); @@ -105,9 +105,9 @@ public static Collection data() { // Builtins: // nr 11: tests.add(new Object[] {0, 0, 0, 0, 0, 0, 6, 0, true, false, "functions/scale.dml", args}); - tests.add(new Object[] {0, 0, 0, 0, 0, 0, 5, 0, true, true, "functions/scale.dml", args}); + tests.add(new Object[] {0, 0, 0, 0, 0, 0, 4, 0, false, true, "functions/scale.dml", args}); tests.add(new Object[] {0, 0, 0, 0, 0, 0, 6, 0, true, false, "functions/scale_continued.dml", args}); - tests.add(new Object[] {0, 0, 0, 0, 0, 0, 6, 0, true, true, "functions/scale_continued.dml", args}); + tests.add(new Object[] {0, 0, 0, 0, 0, 0, 5, 0, true, true, "functions/scale_continued.dml", args}); tests.add(new Object[] {0, 0, 0, 0, 0, 0, 2, 0, false, true, "functions/scale_onlySide.dml", args}); tests.add(new Object[] {0, 0, 0, 0, 0, 0, 6, 0, true, false, "functions/scale_onlySide.dml", args}); diff --git a/src/test/java/org/apache/sysds/test/component/frame/FrameApplySchema.java b/src/test/java/org/apache/sysds/test/component/frame/FrameApplySchema.java index 388c41a820f..f805d27b089 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/FrameApplySchema.java +++ b/src/test/java/org/apache/sysds/test/component/frame/FrameApplySchema.java @@ -25,6 +25,8 @@ import java.util.Random; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; @@ -32,7 +34,11 @@ import org.junit.Test; public class FrameApplySchema { - + protected static final Log LOG = LogFactory.getLog(FrameApplySchema.class.getName()); + + static{ + FrameLibApplySchema.PAR_ROW_THRESHOLD = 10; + } @Test public void testApplySchemaStringToBoolean() { try { @@ -141,6 +147,20 @@ public void testUnkownColumnDefaultToString() { assertEquals(ValueType.UNKNOWN, fb.getSchema()[0]); } + @Test + public void testUnkownColumnDefaultToStringPar() { + try{ + FrameBlock fb = genStringContainingInteger(100, 3); + ValueType[] schema = new ValueType[] {ValueType.UNKNOWN, ValueType.INT32, ValueType.INT32}; + fb = FrameLibApplySchema.applySchema(fb, schema, 3); + assertEquals(ValueType.UNKNOWN, fb.getSchema()[0]); + } + catch(Exception e){ + e.printStackTrace(); + fail(e.getMessage()); + } + } + private FrameBlock genStringContainingInteger(int row, int col) { FrameBlock ret = new FrameBlock(); Random r = new Random(31); diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java index dc0f03c58ec..92a10ba101d 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java +++ b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java @@ -284,7 +284,7 @@ public void changeType(ValueType t) { Array r = a.changeType(t); assertTrue(r.getValueType() == t); } - catch(DMLRuntimeException e) { + catch(DMLRuntimeException | NumberFormatException e) { LOG.debug(e.getMessage()); // okay since we want exceptions // in cases where the the change fail. @@ -423,7 +423,7 @@ public void getFrameArrayType() { if(a.getFrameArrayType() == FrameArrayType.OPTIONAL) return; - assertEquals(a.toString(),t, a.getFrameArrayType()); + assertEquals(a.toString(), t, a.getFrameArrayType()); } @Test @@ -609,12 +609,9 @@ public void testSetRange(int start, int end, int otherSize, int seed) { compareSetSubRange(aa, other, start, end, 0, aa.getValueType()); } - catch(DMLCompressionException e) { + catch(DMLCompressionException | NumberFormatException | NotImplementedException e) { return;// valid } - catch(NumberFormatException e){ - return; // valid - } catch(Exception e) { e.printStackTrace(); fail(e.getMessage()); @@ -1146,11 +1143,11 @@ public void testSetNzString() { @SuppressWarnings("unchecked") public void testSetNzStringWithNull() { Array aa = a.clone(); - Array af = (Array) aa.changeTypeWithNulls(ValueType.STRING); try { + Array af = (Array) aa.changeTypeWithNulls(ValueType.STRING); aa.setFromOtherTypeNz(af); } - catch(DMLCompressionException e) { + catch(DMLCompressionException | NotImplementedException e) { return;// valid } catch(Exception e) { @@ -1184,19 +1181,18 @@ public void testSetFromString() { public void testSetFromStringWithNull() { Array aa = a.clone(); Array af; - if(aa.getFrameArrayType() == FrameArrayType.OPTIONAL // - && aa.getValueType() != ValueType.STRING // - && aa.getValueType() != ValueType.HASH64) { - af = aa.changeTypeWithNulls(ValueType.FP64); - } - else - af = aa.changeTypeWithNulls(ValueType.STRING); - try { + if(aa.getFrameArrayType() == FrameArrayType.OPTIONAL // + && aa.getValueType() != ValueType.STRING // + && aa.getValueType() != ValueType.HASH64) { + af = aa.changeTypeWithNulls(ValueType.FP64); + } + else + af = aa.changeTypeWithNulls(ValueType.STRING); aa.setFromOtherType(0, af.size() - 1, af); } - catch(DMLCompressionException e) { + catch(DMLCompressionException | NotImplementedException e) { return;// valid } catch(Exception e) { @@ -1246,7 +1242,7 @@ public void testSerializationSize() { a.write(fos); long s = fos.size(); long e = a.getExactSerializedSize(); - assertEquals(a.toString(),s, e); + assertEquals(a.toString(), s, e); } catch(IOException e) { throw new RuntimeException("Error in io", e); @@ -1923,7 +1919,7 @@ protected static void compare(Array a, Array b) { final Object av = a.get(i); final Object bv = b.get(i); if((av == null && bv != null) || (bv == null && av != null)) - fail("not both null"); + fail("not both null: " + err); else if(av != null && bv != null) assertTrue(err, av.toString().equals(bv.toString())); } diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/NegativeArrayTests.java b/src/test/java/org/apache/sysds/test/component/frame/array/NegativeArrayTests.java index 105785ebc99..beb8f677a89 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/array/NegativeArrayTests.java +++ b/src/test/java/org/apache/sysds/test/component/frame/array/NegativeArrayTests.java @@ -26,6 +26,8 @@ import java.io.IOException; import org.apache.commons.lang3.NotImplementedException; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.DMLCompressionException; @@ -46,6 +48,7 @@ import org.mockito.Mockito; public class NegativeArrayTests { + private static final Log LOG = LogFactory.getLog(NegativeArrayTests.class.getName()); @Test @SuppressWarnings("unchecked") @@ -72,24 +75,6 @@ public void testChangeTypeToInvalid() { s.toString(); } - @Test(expected = NotImplementedException.class) - public void testChangeTypeToUInt8() { - Array a = ArrayFactory.create(new int[] {1, 2, 3}); - a.changeType(ValueType.UINT8); - } - - @Test(expected = NotImplementedException.class) - public void testChangeTypeToUInt8WithNull_noNull() { - Array a = ArrayFactory.create(new int[] {1, 2, 3}); - a.changeTypeWithNulls(ValueType.UINT8); - } - - @Test(expected = NotImplementedException.class) - public void testChangeTypeToUInt8WithNull() { - Array a = ArrayFactory.create(new String[] {"1", "2", null}); - a.changeTypeWithNulls(ValueType.UINT8); - } - @Test(expected = DMLRuntimeException.class) public void getMinMax() { ArrayFactory.create(new int[] {1, 2, 3, 4}).getMinMaxLength(); @@ -220,7 +205,7 @@ public void readFieldsOpt() { @Test(expected = DMLRuntimeException.class) public void readFieldsRagged() { try { - new RaggedArray<>(ArrayFactory.create(new Integer[]{1,2,3}),10).readFields(null); + new RaggedArray<>(ArrayFactory.create(new Integer[] {1, 2, 3}), 10).readFields(null); } catch(IOException e) { fail("not correct exception"); @@ -262,11 +247,6 @@ public void parseInt() { IntegerArray.parseInt("notANumber"); } - @Test(expected = NotImplementedException.class) - public void optionalChangeToUInt8() { - new OptionalArray<>(new Double[3]).changeTypeWithNulls(ValueType.UINT8); - } - @Test(expected = NotImplementedException.class) public void byteArrayString() { new StringArray(new String[10]).getAsByteArray(); @@ -314,14 +294,14 @@ public void testInvalidBLength() { @Test(expected = DMLRuntimeException.class) public void testInvalidALength() { - Array a = ArrayFactory.allocate( ValueType.INT32, 10); + Array a = ArrayFactory.allocate(ValueType.INT32, 10); Array b = new OptionalArray<>(new Long[] {1L, 2L, 3L, 4L}); ArrayFactory.set(a, b, 10, 14, 20);// one to short } @Test(expected = DMLRuntimeException.class) public void testInvalidRL() { - Array a = ArrayFactory.allocate( ValueType.INT32, 10); + Array a = ArrayFactory.allocate(ValueType.INT32, 10); Array b = new OptionalArray<>(new Long[] {1L, 2L, 3L, 4L}); ArrayFactory.set(a, b, -1, 15, 20);// one to short } diff --git a/src/test/java/org/apache/sysds/test/component/matrix/BinaryOperationInPlaceTestParameterized.java b/src/test/java/org/apache/sysds/test/component/matrix/BinaryOperationInPlaceTestParameterized.java index 45dc59b0fb7..df38e1e59d5 100644 --- a/src/test/java/org/apache/sysds/test/component/matrix/BinaryOperationInPlaceTestParameterized.java +++ b/src/test/java/org/apache/sysds/test/component/matrix/BinaryOperationInPlaceTestParameterized.java @@ -161,6 +161,7 @@ public void testInplace() { assertEquals(lcb, left.getNumColumns()); assertEquals(rrb, right.getNumRows()); assertEquals(rcb, right.getNumColumns()); + TestUtils.compareMatricesBitAvgDistance(ret1, left, 0, 0, "Result is incorrect for inplace \n" + op + " " + lspb + " " + rspb + " (" + lrb + "," + lcb + ")" + " (" + rrb + "," + rcb + ")"); } diff --git a/src/test/java/org/apache/sysds/test/component/matrix/EigenDecompTest.java b/src/test/java/org/apache/sysds/test/component/matrix/EigenDecompTest.java index f79b44f5a1d..07ff5937a0a 100644 --- a/src/test/java/org/apache/sysds/test/component/matrix/EigenDecompTest.java +++ b/src/test/java/org/apache/sysds/test/component/matrix/EigenDecompTest.java @@ -23,6 +23,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.matrix.data.LibCommonsMath; import org.apache.sysds.runtime.matrix.data.LibMatrixMult; import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; @@ -120,6 +121,8 @@ private void testEigen(MatrixBlock in, double tol, int threads, type t) { case QR: m = LibCommonsMath.multiReturnOperations(in, "eigen_qr", threads, 1); break; + default: + throw new DMLRuntimeException("Fail"); } isValidDecomposition(in, m[1], m[0], tol, t.toString()); diff --git a/src/test/java/org/apache/sysds/test/component/matrix/EqualsTest.java b/src/test/java/org/apache/sysds/test/component/matrix/EqualsTest.java index f64035997d6..7ebb1e2adf0 100644 --- a/src/test/java/org/apache/sysds/test/component/matrix/EqualsTest.java +++ b/src/test/java/org/apache/sysds/test/component/matrix/EqualsTest.java @@ -199,12 +199,9 @@ public void unknownNNZEmptyBoth() { @Test public void unknownNNZEmptyOne() { - MatrixBlock m1 = new MatrixBlock(10, 10, 0.0); MatrixBlock m2 = new MatrixBlock(10, 10, 0.0); - m1.setNonZeros(-1); - assertTrue(m1.equals(m2)); assertTrue(m2.equals(m1)); } diff --git a/src/test/java/org/apache/sysds/test/component/matrix/MatrixBlockSerializationTest.java b/src/test/java/org/apache/sysds/test/component/matrix/MatrixBlockSerializationTest.java new file mode 100644 index 00000000000..40e7143fbb3 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/matrix/MatrixBlockSerializationTest.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.matrix; + +import static org.junit.Assert.fail; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(value = Parameterized.class) +public class MatrixBlockSerializationTest { + + private MatrixBlock mb; + + @Parameters + public static Collection data() { + List tests = new ArrayList<>(); + + try { + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(100, 100, 0, 10, 1.0, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(1000, 100, 0, 10, 1.0, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(100, 1, 0, 10, 1.0, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(1, 100, 0, 10, 1.0, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(100, 10, 0, 10, 1.0, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(100, 1000, 0, 10, 1.0, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(1000, 1000, 0, 10, 1.0, 3)}); + + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(100, 100, 0, 10, 0.1, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(1000, 100, 0, 10, 0.1, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(100, 1, 0, 10, 0.1, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(1, 100, 0, 10, 0.1, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(100, 10, 0, 10, 0.1, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(100, 1000, 0, 10, 0.1, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(1000, 1000, 0, 10, 0.1, 3)}); + + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(100, 100, 0, 10, 0.001, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(1000, 100, 0, 10, 0.001, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(100, 1, 0, 10, 0.001, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(1, 100, 0, 10, 0.001, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(100, 10, 0, 10, 0.001, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(100, 1000, 0, 10, 0.001, 3)}); + tests.add(new Object[] {TestUtils.generateTestMatrixBlock(1000, 1000, 0, 10, 0.001, 3)}); + tests.add(new Object[] {new MatrixBlock()}); + + } + catch(Exception e) { + e.printStackTrace(); + fail("failed constructing tests"); + } + + return tests; + } + + public MatrixBlockSerializationTest(MatrixBlock mb) { + this.mb = mb; + } + + @Test + public void testSerialization() { + try { + // serialize compressed matrix block + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + DataOutputStream fos = new DataOutputStream(bos); + mb.write(fos); + + // deserialize compressed matrix block + ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray()); + DataInputStream fis = new DataInputStream(bis); + MatrixBlock in = new MatrixBlock(); + in.readFields(fis); + TestUtils.compareMatrices(mb, in, 0); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinDifferenceStatistics.java b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinDifferenceStatistics.java index e488296dd12..705001079df 100644 --- a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinDifferenceStatistics.java +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinDifferenceStatistics.java @@ -87,7 +87,7 @@ private void run(ExecType instType, double error) { writeInputMatrixWithMTD("A", A, false); MatrixBlock C = TestUtils.generateTestMatrixBlock(1, 5, 1 - error, 1 + error, 1.0, 1342); MatrixBlock B = new MatrixBlock(100, 5, false); - LibMatrixBincell.bincellOp(A, C, B, new BinaryOperator(Multiply.getMultiplyFnObject())); + LibMatrixBincell.bincellOp(A, C, B, new BinaryOperator(Multiply.getMultiplyFnObject()), 1); writeInputMatrixWithMTD("B", B, true); String log = runTest(null).toString(); // LOG.error(log); diff --git a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressBase.java b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressBase.java index 2ae285cbf22..5c6bd15036f 100644 --- a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressBase.java +++ b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressBase.java @@ -74,7 +74,12 @@ public void compressTest(int rows, int cols, double sparsity, ExecType instType, DMLCompressionStatistics.reset(); Assert.assertEquals(out + "\ncompression count wrong : ", compressionCount, compressionCountsExpected); - Assert.assertEquals(out + "\nDecompression count wrong : ", decompressionCountExpected, decompressCount); + if(decompressionCountExpected < 0){ + assertTrue(out + "\nDecompression count wrong : " , decompressCount > 1); + } + else{ + Assert.assertEquals(out + "\nDecompression count wrong : ", decompressionCountExpected, decompressCount); + } } catch(Exception e) { diff --git a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java index 24290246995..51e481769ae 100644 --- a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java +++ b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java @@ -49,7 +49,7 @@ protected String getTestDir() { @Test public void testTranspose_CP() { - runTest(1500, 20, 1, 1, ExecType.CP, "transpose"); + runTest(1500, 20, 2, 1, ExecType.CP, "transpose"); } @Test @@ -79,12 +79,12 @@ public void testRowAggregate_SP() { @Test public void testSequence_CP() { - runTest(1500, 1, 0, 1, ExecType.CP, "plus_mm_ewbm_sum"); + runTest(1500, 1, -1, 1, ExecType.CP, "plus_mm_ewbm_sum"); } @Test public void testSequence_SP() { - runTest(1500, 1, 0, 1, ExecType.SPARK, "plus_mm_ewbm_sum"); + runTest(1500, 1, 2, 1, ExecType.SPARK, "plus_mm_ewbm_sum"); } @Test @@ -99,17 +99,17 @@ public void testPlus_MM_SP() { @Test public void test_ElementWiseBinaryMultiplyOp_right_CP() { - runTest(1500, 1, 0, 1, ExecType.CP, "ewbm_right"); + runTest(1500, 1, -1, 1, ExecType.CP, "ewbm_right"); } @Test public void test_ElementWiseBinaryMultiplyOp_right_SP() { - runTest(1500, 1, 0, 1, ExecType.SPARK, "ewbm_right"); + runTest(1500, 1, 2, 1, ExecType.SPARK, "ewbm_right"); } @Test public void test_ElementWiseBinaryMultiplyOp_left_CP() { - runTest(1500, 1, 0, 1, ExecType.CP, "ewbm_left"); + runTest(1500, 1, -1, 1, ExecType.CP, "ewbm_left"); } @Test @@ -119,7 +119,7 @@ public void test_ElementWiseBinaryMultiplyOp_left_SP() { @Test public void test_ElementWiseBinaryMultiplyOp_left_SP_larger() { - runTest(1500, 15, 0, 1, ExecType.SPARK, "ewbm_left"); + runTest(1500, 15, -1, 1, ExecType.SPARK, "ewbm_left"); } @Test diff --git a/src/test/java/org/apache/sysds/test/functions/compress/matrixByBin/CompressByBinTest.java b/src/test/java/org/apache/sysds/test/functions/compress/matrixByBin/CompressByBinTest.java index 1fe40002c29..14b6b5f787e 100644 --- a/src/test/java/org/apache/sysds/test/functions/compress/matrixByBin/CompressByBinTest.java +++ b/src/test/java/org/apache/sysds/test/functions/compress/matrixByBin/CompressByBinTest.java @@ -23,6 +23,8 @@ import java.util.Arrays; import java.util.Random; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types; import org.apache.sysds.runtime.controlprogram.caching.CacheBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; @@ -38,9 +40,9 @@ import org.junit.Assert; import org.junit.Test; - public class CompressByBinTest extends AutomatedTestBase { + protected static final Log LOG = LogFactory.getLog(CompressByBinTest.class.getName()); private final static String TEST_NAME = "compressByBins"; private final static String TEST_DIR = "functions/compress/matrixByBin/"; @@ -52,41 +54,48 @@ public class CompressByBinTest extends AutomatedTestBase { private final static int nbins = 10; - //private final static int[] dVector = new int[cols]; + // private final static int[] dVector = new int[cols]; @Override public void setUp() { - addTestConfiguration(TEST_NAME,new TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"X"})); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"X"})); } @Test - public void testCompressBinsMatrixWidthCP() { runCompress(Types.ExecType.CP, ColumnEncoderBin.BinMethod.EQUI_WIDTH); } + public void testCompressBinsMatrixWidthCP() { + runCompress(Types.ExecType.CP, ColumnEncoderBin.BinMethod.EQUI_WIDTH); + } @Test - public void testCompressBinsMatrixHeightCP() { runCompress(Types.ExecType.CP, ColumnEncoderBin.BinMethod.EQUI_HEIGHT); } + public void testCompressBinsMatrixHeightCP() { + runCompress(Types.ExecType.CP, ColumnEncoderBin.BinMethod.EQUI_HEIGHT); + } @Test - public void testCompressBinsFrameWidthCP() { runCompressFrame(Types.ExecType.CP, ColumnEncoderBin.BinMethod.EQUI_WIDTH); } + public void testCompressBinsFrameWidthCP() { + runCompressFrame(Types.ExecType.CP, ColumnEncoderBin.BinMethod.EQUI_WIDTH); + } @Test - public void testCompressBinsFrameHeightCP() { runCompressFrame(Types.ExecType.CP, ColumnEncoderBin.BinMethod.EQUI_HEIGHT); } + public void testCompressBinsFrameHeightCP() { + runCompressFrame(Types.ExecType.CP, ColumnEncoderBin.BinMethod.EQUI_HEIGHT); + } - private void runCompress(Types.ExecType instType, ColumnEncoderBin.BinMethod binMethod) - { + private void runCompress(Types.ExecType instType, ColumnEncoderBin.BinMethod binMethod) { Types.ExecMode platformOld = setExecMode(instType); - try - { + try { loadTestConfiguration(getTestConfiguration(TEST_NAME)); String HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = HOME + TEST_NAME + ".dml"; - programArgs = new String[]{"-args", input("X"), Boolean.toString(binMethod == ColumnEncoderBin.BinMethod.EQUI_WIDTH),output("meta"), output("res")}; + programArgs = new String[] {"-stats","-args", input("X"), + Boolean.toString(binMethod == ColumnEncoderBin.BinMethod.EQUI_WIDTH), output("meta"), output("res")}; double[][] X = generateMatrixData(binMethod); writeInputMatrixWithMTD("X", X, true); - runTest(true, false, null, -1); + runTest(null); checkMetaFile(DataConverter.convertToMatrixBlock(X), binMethod); @@ -99,24 +108,23 @@ private void runCompress(Types.ExecType instType, ColumnEncoderBin.BinMethod bin } } - private void runCompressFrame(Types.ExecType instType, ColumnEncoderBin.BinMethod binMethod) - { + private void runCompressFrame(Types.ExecType instType, ColumnEncoderBin.BinMethod binMethod) { Types.ExecMode platformOld = setExecMode(instType); - try - { + try { loadTestConfiguration(getTestConfiguration(TEST_NAME)); String HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = HOME + TEST_NAME + ".dml"; - programArgs = new String[]{"-explain", "-args", input("X"), Boolean.toString(binMethod == ColumnEncoderBin.BinMethod.EQUI_WIDTH) , output("meta"), output("res")}; + programArgs = new String[] {"-explain", "-args", input("X"), + Boolean.toString(binMethod == ColumnEncoderBin.BinMethod.EQUI_WIDTH), output("meta"), output("res")}; Types.ValueType[] schema = new Types.ValueType[cols]; Arrays.fill(schema, Types.ValueType.FP32); FrameBlock Xf = generateFrameData(binMethod, schema); writeInputFrameWithMTD("X", Xf, false, schema, Types.FileFormat.CSV); - runTest(true, false, null, -1); + runTest(null); checkMetaFile(Xf, binMethod); @@ -132,14 +140,15 @@ private void runCompressFrame(Types.ExecType instType, ColumnEncoderBin.BinMetho private double[][] generateMatrixData(ColumnEncoderBin.BinMethod binMethod) { double[][] X; if(binMethod == ColumnEncoderBin.BinMethod.EQUI_WIDTH) { - //generate actual dataset + // generate actual dataset X = getRandomMatrix(rows, cols, -100, 100, 1, 7); // make sure that bins in [-100, 100] for(int i = 0; i < cols; i++) { X[0][i] = -100; X[1][i] = 100; } - } else if(binMethod == ColumnEncoderBin.BinMethod.EQUI_HEIGHT) { + } + else if(binMethod == ColumnEncoderBin.BinMethod.EQUI_HEIGHT) { X = new double[rows][cols]; for(int c = 0; c < cols; c++) { double[] vals = new Random().doubles(nbins).toArray(); @@ -150,7 +159,8 @@ private double[][] generateMatrixData(ColumnEncoderBin.BinMethod binMethod) { j++; } } - } else + } + else throw new RuntimeException("Invalid binning method."); return X; @@ -164,9 +174,10 @@ private FrameBlock generateFrameData(ColumnEncoderBin.BinMethod binMethod, Types for(int i = 0; i < cols; i++) { Xf.set(0, i, -100); - Xf.set(rows-1, i, 100); + Xf.set(rows - 1, i, 100); } - } else if(binMethod == ColumnEncoderBin.BinMethod.EQUI_HEIGHT) { + } + else if(binMethod == ColumnEncoderBin.BinMethod.EQUI_HEIGHT) { Xf = new FrameBlock(); for(int c = 0; c < schema.length; c++) { double[] vals = new Random().doubles(nbins).toArray(); @@ -180,14 +191,16 @@ private FrameBlock generateFrameData(ColumnEncoderBin.BinMethod binMethod, Types Xf.appendColumn(f); } - } else + } + else throw new RuntimeException("Invalid binning method."); return Xf; } - private void checkMetaFile(CacheBlock X, ColumnEncoderBin.BinMethod binningType) throws IOException{ + private void checkMetaFile(CacheBlock X, ColumnEncoderBin.BinMethod binningType) throws IOException { FrameBlock outputMeta = readDMLFrameFromHDFS("meta", Types.FileFormat.CSV); + Assert.assertEquals(nbins, outputMeta.getNumRows()); double[] binStarts = new double[nbins]; @@ -201,9 +214,10 @@ private void checkMetaFile(CacheBlock X, ColumnEncoderBin.BinMethod binningTy Assert.assertEquals(i, binStart, 0.0); j++; } - } else { + } + else { binStarts[c] = Double.parseDouble(((String) outputMeta.getColumn(c).get(0)).split("·")[0]); - binEnds[c] = Double.parseDouble(((String) outputMeta.getColumn(c).get(nbins-1)).split("·")[1]); + binEnds[c] = Double.parseDouble(((String) outputMeta.getColumn(c).get(nbins - 1)).split("·")[1]); } } diff --git a/src/test/java/org/apache/sysds/test/functions/frame/FrameReadWriteTest.java b/src/test/java/org/apache/sysds/test/functions/frame/FrameReadWriteTest.java index bec24bae724..f2eeb6fa41d 100644 --- a/src/test/java/org/apache/sysds/test/functions/frame/FrameReadWriteTest.java +++ b/src/test/java/org/apache/sysds/test/functions/frame/FrameReadWriteTest.java @@ -43,6 +43,8 @@ import org.apache.sysds.test.TestUtils; import org.junit.Test; +import com.google.crypto.tink.subtle.Random; + public class FrameReadWriteTest extends AutomatedTestBase { protected static final Log LOG = LogFactory.getLog(FrameReadWriteTest.class.getName()); @@ -50,14 +52,18 @@ public class FrameReadWriteTest extends AutomatedTestBase { private final static String TEST_NAME = "FrameReadWrite"; private final static String TEST_CLASS_DIR = TEST_DIR + FrameReadWriteTest.class.getSimpleName() + "/"; - private static final AtomicInteger id = new AtomicInteger(0); + private static AtomicInteger id = new AtomicInteger(0); - private final static int rows = 1593; + private final static int rows = 1020; private final static ValueType[] schemaStrings = new ValueType[]{ValueType.STRING, ValueType.STRING, ValueType.STRING}; private final static ValueType[] schemaMixed = new ValueType[]{ValueType.STRING, ValueType.FP64, ValueType.INT64, ValueType.BOOLEAN}; private final static String DELIMITER = "::"; private final static boolean HEADER = true; + + static{ + id = new AtomicInteger(Random.randInt()); + } @Override public void setUp() { @@ -211,16 +217,14 @@ void initFrameData(FrameBlock frame, double[][] data, ValueType[] lschema) private void writeAndVerifyData(FileFormat fmt, FrameBlock frame1, FrameBlock frame2, FileFormatPropertiesCSV fprop) throws IOException { - writeAndVerifyData(fmt, frame1, fprop); writeAndVerifyData(fmt, frame2, fprop); } private void writeAndVerifyData(FileFormat fmt, FrameBlock fb, FileFormatPropertiesCSV fprop) throws IOException { + final String fname1 = SCRIPT_DIR + TEST_DIR + "/frameData" + id.incrementAndGet(); try{ - - final String fname1 = SCRIPT_DIR + TEST_DIR + "/frameData" + id.incrementAndGet(); final ValueType[] schema = fb.getSchema(); final int nCol = fb.getNumColumns(); @@ -235,14 +239,14 @@ private void writeAndVerifyData(FileFormat fmt, FrameBlock fb, FileFormatPropert //Read frame data from disk FrameBlock frame1Read = reader.readFrameFromHDFS(fname1, schema, nRow, nCol); - TestUtils.compareFrames(fb, frame1Read, true); - - HDFSTool.deleteFileIfExistOnHDFS(fname1); } catch(Exception e){ e.printStackTrace(); fail(e.getMessage()); + }finally{ + + HDFSTool.deleteFileIfExistOnHDFS(fname1); } } diff --git a/src/test/java/org/apache/sysds/test/functions/transform/TransformCSVFrameEncodeReadTest.java b/src/test/java/org/apache/sysds/test/functions/transform/TransformCSVFrameEncodeReadTest.java index f66fc1db3c2..783936df09f 100644 --- a/src/test/java/org/apache/sysds/test/functions/transform/TransformCSVFrameEncodeReadTest.java +++ b/src/test/java/org/apache/sysds/test/functions/transform/TransformCSVFrameEncodeReadTest.java @@ -21,6 +21,8 @@ import static org.junit.Assert.fail; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types.ExecMode; import org.apache.sysds.runtime.frame.data.FrameBlock; @@ -34,9 +36,9 @@ import org.apache.sysds.test.TestUtils; import org.junit.Test; +public class TransformCSVFrameEncodeReadTest extends AutomatedTestBase { + public static final Log LOG = LogFactory.getLog(TransformCSVFrameEncodeReadTest.class.getName()); -public class TransformCSVFrameEncodeReadTest extends AutomatedTestBase -{ private final static String TEST_NAME1 = "TransformCSVFrameEncodeRead"; private final static String TEST_DIR = "functions/transform/"; private final static String TEST_CLASS_DIR = TEST_DIR + TransformCSVFrameEncodeReadTest.class.getSimpleName() + "/"; @@ -134,9 +136,7 @@ private void runTransformTest( ExecMode rt, String ofmt, boolean subset, boolean fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; programArgs = new String[]{"-args", DATASET_DIR + DATASET, String.valueOf(nrows), output("R") }; - String stdOut = runTest(null).toString(); - //read input/output and compare FrameReader reader2 = parRead ? new FrameReaderTextCSVParallel( new FileFormatPropertiesCSV() ) : @@ -144,6 +144,7 @@ private void runTransformTest( ExecMode rt, String ofmt, boolean subset, boolean FrameBlock fb2 = reader2.readFrameFromHDFS(output("R"), -1L, -1L); String[] fromDisk = DataConverter.toString(fb2).split("\n"); String[] printed = stdOut.split("\n"); + boolean equal = true; String err = ""; for(int i = 0; i < fromDisk.length; i++){ @@ -155,7 +156,6 @@ private void runTransformTest( ExecMode rt, String ofmt, boolean subset, boolean } if(!equal) fail(err); - } catch(Exception ex) { throw new RuntimeException(ex);