Skip to content

Commit

Permalink
Merge branch 'apache:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
min-guk authored Jan 20, 2025
2 parents a6c5875 + 9484f11 commit 1128972
Show file tree
Hide file tree
Showing 161 changed files with 5,227 additions and 1,679 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/javaTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ jobs:
run: mvn jacoco:report

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5.0.2
uses: codecov/codecov-action@v5.1.2
if: github.repository_owner == 'apache'
with:
fail_ci_if_error: false
Expand Down
114 changes: 114 additions & 0 deletions scripts/builtin/sqrtMatrix.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

# Computes the matrix square root B of a matrix A, such that
# A = B %*% B.
#
# INPUT:
# ------------------------------------------------------------------------------
# A Input Matrix A
# S Strategy (COMMON .. java-based commons-math, DML)
# ------------------------------------------------------------------------------
#
# OUTPUT:
# ------------------------------------------------------------------------------
# B Output Matrix B
# ------------------------------------------------------------------------------


m_sqrtMatrix = function(Matrix[Double] A, String S)
return(Matrix[Double] B)
{
if (S == "COMMON") {
B = sqrtMatrixJava(A)
} else if (S == "DML") {
N = nrow(A);
D = ncol(A);

#check that matrix is square
if (D != N){
stop("matrixSqrt Input Error: matrix not square!")
}

# Any non singualar square matrix has a square root
isDiag = isDiagonal(A)
if(isDiag) {
B = sqrtDiagMatrix(A);
} else {
[eValues, eVectors] = eigen(A);

hasNonNegativeEigenValues = (sum(eValues >= 0) == length(eValues));

if(!hasNonNegativeEigenValues) {
stop("matrixSqrt exec Error: matrix has imaginary square root");
}

isSymmetric = sum(A == t(A)) == length(A);
allEigenValuesUnique = length(eValues) == length(unique(eValues));

if(allEigenValuesUnique | isSymmetric) {
# calculate X = VDV^(-1) -> S = sqrt(D) -> sqrt_x = VSV^(-1)
sqrtD = sqrtDiagMatrix(diag(eValues));
V_Inv = inv(eVectors);
B = eVectors %*% sqrtD %*% V_Inv;
} else {
#formular: (Denman–Beavers iteration)
Y = A
#identity matrix
Z = diag(matrix(1.0, rows=N, cols=1))

for (x in 1:100) {
Y_new = (1 / 2) * (Y + inv(Z))
Z_new = (1 / 2) * (Z + inv(Y))
Y = Y_new
Z = Z_new
}
B = Y
}
}
} else {
stop("Error: Unknown strategy for matrix square root.")
}
}

# assumes square and diagonal matrix
sqrtDiagMatrix = function(Matrix[Double] X)
return(Matrix[Double] sqrt_x)
{
N = nrow(X);

#check if identity matrix
is_identity = sum(diag(diag(X)) == X)==length(X)
& sum(diag(X) == matrix(1,nrow(X),1))==nrow(X);

if(is_identity)
sqrt_x = X;
else
sqrt_x = diag(sqrt(diag(X)));
}

isDiagonal = function (Matrix[Double] X)
return(boolean diagonal)
{
#all cells should be the same to be diagonal
diagonal = sum(diag(diag(X)) == X) == length(X);
}

2 changes: 2 additions & 0 deletions src/main/java/org/apache/sysds/common/Builtins.java
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,8 @@ public enum Builtins {
STEPLM("steplm",true, ReturnType.MULTI_RETURN),
STFT("stft", false, ReturnType.MULTI_RETURN),
SQRT("sqrt", false),
SQRT_MATRIX("sqrtMatrix", true),
SQRT_MATRIX_JAVA("sqrtMatrixJava", false, ReturnType.SINGLE_RETURN),
SUM("sum", false),
SVD("svd", false, ReturnType.MULTI_RETURN),
TABLE("table", "ctable", false),
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/apache/sysds/common/Types.java
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ public enum OpOp1 {
CUMSUMPROD, DETECTSCHEMA, COLNAMES, EIGEN, EXISTS, EXP, FLOOR, INVERSE,
IQM, ISNA, ISNAN, ISINF, LENGTH, LINEAGE, LOG, NCOL, NOT, NROW,
MEDIAN, PREFETCH, PRINT, ROUND, SIN, SINH, SIGN, SOFTMAX, SQRT, STOP, _EVICT,
SVD, TAN, TANH, TYPEOF, TRIGREMOTE,
SVD, TAN, TANH, TYPEOF, TRIGREMOTE, SQRT_MATRIX_JAVA,
//fused ML-specific operators for performance
SPROP, //sample proportion: P * (1 - P)
SIGMOID, //sigmoid function: 1 / (1 + exp(-X))
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/apache/sysds/conf/DMLConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ public class DMLConfig
_defaultVals.put(FLOATING_POINT_PRECISION, "double" );
_defaultVals.put(USE_SSL_FEDERATED_COMMUNICATION, "false");
_defaultVals.put(DEFAULT_FEDERATED_INITIALIZATION_TIMEOUT, "10");
_defaultVals.put(FEDERATED_TIMEOUT, "-1");
_defaultVals.put(FEDERATED_TIMEOUT, "86400"); // default 1 day compute timeout.
_defaultVals.put(FEDERATED_PLANNER, FederatedPlanner.RUNTIME.name());
_defaultVals.put(FEDERATED_PAR_CONN, "-1"); // vcores
_defaultVals.put(FEDERATED_PAR_INST, "-1"); // vcores
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/apache/sysds/hops/UnaryOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ && getInput().get(0).getParent().size()==1 ) //unary is only parent

//ensure cp exec type for single-node operations
if( _op == OpOp1.PRINT || _op == OpOp1.ASSERT || _op == OpOp1.STOP || _op == OpOp1.TYPEOF
|| _op == OpOp1.INVERSE || _op == OpOp1.EIGEN || _op == OpOp1.CHOLESKY || _op == OpOp1.SVD
|| _op == OpOp1.INVERSE || _op == OpOp1.EIGEN || _op == OpOp1.CHOLESKY || _op == OpOp1.SVD || _op == OpOp1.SQRT_MATRIX_JAVA
|| getInput().get(0).getDataType() == DataType.LIST || isMetadataOperation() )
{
_etype = ExecType.CP;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
if( LineageCacheConfig.getCompAssRW() )
_sbRuleSet.add( new MarkForLineageReuse() );
_sbRuleSet.add( new RewriteRemoveTransformEncodeMeta() );
}
_dagRuleSet.add( new RewriteNonScalarPrint() );
}

// DYNAMIC REWRITES (which do require size information)
if( dynamicRewrites )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,30 +381,28 @@ private static Hop removeUnnecessaryCumulativeOp(Hop parent, Hop hi, int pos)

return hi;
}

private static Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos)
{
if( hi instanceof ReorgOp )
{

private static Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos) {
if( hi instanceof ReorgOp ) {
ReorgOp rop = (ReorgOp) hi;
Hop input = hi.getInput(0);
Hop input = hi.getInput(0);
boolean apply = false;
//equal dims of reshape input and output -> no need for reshape because

//equal dims of reshape input and output -> no need for reshape because
//byrow always refers to both input/output and hence gives the same result
apply |= (rop.getOp()==ReOrgOp.RESHAPE && HopRewriteUtils.isEqualSize(hi, input));
//1x1 dimensions of transpose/reshape -> no need for reorg
apply |= ((rop.getOp()==ReOrgOp.TRANS || rop.getOp()==ReOrgOp.RESHAPE)
&& rop.getDim1()==1 && rop.getDim2()==1);

//1x1 dimensions of transpose/reshape/roll -> no need for reorg
apply |= ((rop.getOp()==ReOrgOp.TRANS || rop.getOp()==ReOrgOp.RESHAPE
|| rop.getOp()==ReOrgOp.ROLL) && rop.getDim1()==1 && rop.getDim2()==1);

if( apply ) {
HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
hi = input;
LOG.debug("Applied removeUnnecessaryReorg.");
}
}

return hi;
}

Expand Down Expand Up @@ -1356,44 +1354,78 @@ else if ( applyRight ) {
* @param pos position
* @return high-level operator
*/
private static Hop pushdownSumOnAdditiveBinary(Hop parent, Hop hi, int pos)
private static Hop pushdownSumOnAdditiveBinary(Hop parent, Hop hi, int pos)
{
//all patterns headed by full sum over binary operation
if( hi instanceof AggUnaryOp //full sum root over binaryop
&& ((AggUnaryOp)hi).getDirection()==Direction.RowCol
&& ((AggUnaryOp)hi).getOp() == AggOp.SUM
&& hi.getInput(0) instanceof BinaryOp
&& hi.getInput(0).getParent().size()==1 ) //single parent
&& ((AggUnaryOp)hi).getDirection()==Direction.RowCol
&& ((AggUnaryOp)hi).getOp() == AggOp.SUM
&& hi.getInput(0) instanceof BinaryOp
&& hi.getInput(0).getParent().size()==1 ) //single parent
{
BinaryOp bop = (BinaryOp) hi.getInput(0);
Hop left = bop.getInput(0);
Hop right = bop.getInput(1);

if( HopRewriteUtils.isEqualSize(left, right) //dims(A) == dims(B)
&& left.getDataType() == DataType.MATRIX
&& right.getDataType() == DataType.MATRIX )

if( left.getDataType() == DataType.MATRIX
&& right.getDataType() == DataType.MATRIX )
{
OpOp2 applyOp = ( bop.getOp() == OpOp2.PLUS //pattern a: sum(A+B)->sum(A)+sum(B)
|| bop.getOp() == OpOp2.MINUS ) //pattern b: sum(A-B)->sum(A)-sum(B)
? bop.getOp() : null;

if( applyOp != null ) {
//create new subdag sum(A) bop sum(B)
AggUnaryOp sum1 = HopRewriteUtils.createSum(left);
AggUnaryOp sum2 = HopRewriteUtils.createSum(right);
BinaryOp newBin = HopRewriteUtils.createBinary(sum1, sum2, applyOp);

//rewire new subdag
HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
HopRewriteUtils.cleanupUnreferenced(hi, bop);

hi = newBin;

LOG.debug("Applied pushdownSumOnAdditiveBinary (line "+hi.getBeginLine()+").");
if (HopRewriteUtils.isEqualSize(left, right)) {
//create new subdag sum(A) bop sum(B) for equal-sized matrices
AggUnaryOp sum1 = HopRewriteUtils.createSum(left);
AggUnaryOp sum2 = HopRewriteUtils.createSum(right);
BinaryOp newBin = HopRewriteUtils.createBinary(sum1, sum2, applyOp);
//rewire new subdag
HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
HopRewriteUtils.cleanupUnreferenced(hi, bop);

hi = newBin;

LOG.debug("Applied pushdownSumOnAdditiveBinary (line "+hi.getBeginLine()+").");
}
// Check if right operand is a vector (has dimension of 1 in either rows or columns)
else if (right.getDim1() == 1 || right.getDim2() == 1) {
AggUnaryOp sum1 = HopRewriteUtils.createSum(left);
AggUnaryOp sum2 = HopRewriteUtils.createSum(right);

// Row vector case (1 x n)
if (right.getDim1() == 1) {
// Create nrow(A) operation using dimensions
UnaryOp nRows = HopRewriteUtils.createUnary(left, OpOp1.NROW);
BinaryOp scaledSum = HopRewriteUtils.createBinary(nRows, sum2, OpOp2.MULT);
BinaryOp newBin = HopRewriteUtils.createBinary(sum1, scaledSum, applyOp);
//rewire new subdag
HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
HopRewriteUtils.cleanupUnreferenced(hi, bop);

hi = newBin;

LOG.debug("Applied pushdownSumOnAdditiveBinary with row vector (line "+hi.getBeginLine()+").");
}
// Column vector case (n x 1)
else if (right.getDim2() == 1) {
// Create ncol(A) operation using dimensions
UnaryOp nCols = HopRewriteUtils.createUnary(left, OpOp1.NCOL);
BinaryOp scaledSum = HopRewriteUtils.createBinary(nCols, sum2, OpOp2.MULT);
BinaryOp newBin = HopRewriteUtils.createBinary(sum1, scaledSum, applyOp);
//rewire new subdag
HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
HopRewriteUtils.cleanupUnreferenced(hi, bop);

hi = newBin;

LOG.debug("Applied pushdownSumOnAdditiveBinary with column vector (line "+hi.getBeginLine()+").");
}
}
}
}
}

return hi;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
hi = simplifyNotOverComparisons(hop, hi, i); //e.g., !(A>B) -> (A<=B)
//hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))

hi = fixNonScalarPrint(hop, hi, i); //e.g., print(m) -> print(toString(m))

//process childs recursively after rewrites (to investigate pattern newly created by rewrites)
if( !descendFirst )
Expand Down Expand Up @@ -2131,20 +2130,6 @@ else if(HopRewriteUtils.isBinary(binaryOperator, OpOp2.EQUAL)) {
return hi;
}

private static Hop fixNonScalarPrint(Hop parent, Hop hi, int pos) {
if(HopRewriteUtils.isUnary(parent, OpOp1.PRINT) && !hi.getDataType().isScalar()) {
LinkedHashMap<String, Hop> args = new LinkedHashMap<>();
args.put("target", hi);
Hop newHop = HopRewriteUtils.createParameterizedBuiltinOp(
hi, args, ParamBuiltinOp.TOSTRING);
HopRewriteUtils.replaceChildReference(parent, hi, newHop, pos);
hi = newHop;
LOG.debug("Applied fixNonScalarPrint (line " + hi.getBeginLine() + ")");
}

return hi;
}

/**
* NOTE: currently disabled since this rewrite is INVALID in the
* presence of NaNs (because (NaN!=NaN) is true).
Expand Down
Loading

0 comments on commit 1128972

Please sign in to comment.