Skip to content

Commit

Permalink
[SYSTEMDS-3823] Compression test case for bultin kmeans
Browse files Browse the repository at this point in the history
Closes #2194
  • Loading branch information
e-strauss committed Jan 29, 2025
1 parent 41c21bf commit 615cd9a
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1730,6 +1730,9 @@ public void putInto(MatrixBlock target, int rowOffset, int colOffset, boolean sp
* (the invoker is responsible to recompute nnz after all copies are done)
*/
public void copy(int rl, int ru, int cl, int cu, MatrixBlock src, boolean awareDestNZ ) {
if (src instanceof CompressedMatrixBlock){
src = ((CompressedMatrixBlock) src).getUncompressed("In-place matrix copy into indexed matrix");
}
if(sparse && src.sparse)
copySparseToSparse(rl, ru, cl, cu, src, awareDestNZ);
else if(sparse && !src.sparse)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import static org.junit.Assert.fail;

import java.io.ByteArrayOutputStream;
import java.io.File;

import org.apache.commons.logging.Log;
Expand All @@ -46,6 +47,7 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
private final static String TEST_NAME5 = "WorkloadAnalysisSliceFinder";
private final static String TEST_NAME6 = "WorkloadAnalysisLmCG";
private final static String TEST_NAME7 = "WorkloadAnalysisL2SVM";
private final static String TEST_NAME8 = "WorkloadAnalysisKmeans";
private final static String TEST_DIR = "functions/compress/workload/";
private final static String TEST_CLASS_DIR = TEST_DIR + WorkloadAnalysisTest.class.getSimpleName() + "/";

Expand Down Expand Up @@ -73,6 +75,7 @@ public void setUp() {
addTestConfiguration(TEST_NAME5, new TestConfiguration(dir, TEST_NAME5, new String[] {"B"}));
addTestConfiguration(TEST_NAME6, new TestConfiguration(dir, TEST_NAME6, new String[] {"B"}));
addTestConfiguration(TEST_NAME7, new TestConfiguration(dir, TEST_NAME7, new String[] {"B"}));
addTestConfiguration(TEST_NAME8, new TestConfiguration(dir, TEST_NAME8, new String[] {"B"}));
}

@Test
Expand Down Expand Up @@ -143,8 +146,23 @@ public void testL2SVMCP() {
runWorkloadAnalysisTest(TEST_NAME7, ExecMode.SINGLE_NODE, 2, false);
}

@Test
public void testKmeansSuccessfulCP() {
runWorkloadAnalysisTest(TEST_NAME8, ExecMode.SINGLE_NODE, 1, false, 30);
}

@Test
public void testKmeansUnsuccessfulCP() {
runWorkloadAnalysisTest(TEST_NAME8, ExecMode.SINGLE_NODE, 1, false, 10);
}

private void runWorkloadAnalysisTest(String testname, ExecMode mode, int compressionCount, boolean intermediates){
runWorkloadAnalysisTest(testname, mode, compressionCount, intermediates, -1);
}

// private void runWorkloadAnalysisTest(String testname, ExecMode mode, int compressionCount) {
private void runWorkloadAnalysisTest(String testname, ExecMode mode, int compressionCount, boolean intermediates) {
private void runWorkloadAnalysisTest(String testname, ExecMode mode, int compressionCount, boolean intermediates,
int maxIter) {
ExecMode oldPlatform = setExecMode(mode);
boolean oldIntermediates = WorkloadAnalyzer.ALLOW_INTERMEDIATE_CANDIDATES;

Expand All @@ -154,19 +172,20 @@ private void runWorkloadAnalysisTest(String testname, ExecMode mode, int compres

String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
programArgs = new String[] {"-stats", "20", "-args", input("X"), input("y"), output("B")};
programArgs = new String[] {"-stats", "20", "-args", input("X"), input("y"), output("B"),
String.valueOf(maxIter)};

writeInputMatrixWithMTD("X", X, false);
writeInputMatrixWithMTD("y", y, false);

String ret = runTest(null).toString();
ByteArrayOutputStream out = runTest(null);
String ret = out != null ? out.toString() : "";
LOG.debug(ret);

// check various additional expectations
long actualCompressionCount = (mode == ExecMode.HYBRID || mode == ExecMode.SINGLE_NODE) ? Statistics
.getCPHeavyHitterCount("compress") : Statistics.getCPHeavyHitterCount("sp_compress");

Assert.assertEquals("Assert that the compression counts expeted matches actual: " + compressionCount + " vs "
Assert.assertEquals("Assert that the compression counts expected matches actual: " + compressionCount + " vs "
+ actualCompressionCount, compressionCount, actualCompressionCount);
if(compressionCount > 0)
Assert.assertTrue(mode == ExecMode.SINGLE_NODE || mode == ExecMode.HYBRID ? heavyHittersContainsString(
Expand All @@ -176,6 +195,7 @@ private void runWorkloadAnalysisTest(String testname, ExecMode mode, int compres

}
catch(Exception e) {
e.printStackTrace();
resetExecMode(oldPlatform);
fail("Failed workload test");
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

X = read($1);


print("")
print("kmeans")

[data, Centering, ScaleFactor] = scale(X, TRUE, TRUE)
# terminates with result
[Y_n, C_n] = kmeans(X=data, k=16, runs= 1, max_iter=as.integer($4), eps= 1e-17, seed= 13, is_verbose=TRUE)
print(sum(Y_n))

0 comments on commit 615cd9a

Please sign in to comment.