From 5606584f52caa9388dd7563d5d33102a23ba304e Mon Sep 17 00:00:00 2001 From: 11iy <936630744@qq.com> Date: Tue, 9 Jul 2024 06:01:28 +0800 Subject: [PATCH] Add BuiltinImputeMARTest and ImputeMARTest.dml --- bin/README.md | 33 +- bin/systemds | 458 ------------------ .../builtin/part1/BuiltinImputeMARTest.java | 186 +++++++ .../functions/builtin/imputeMARTest.dml | 90 ++++ 4 files changed, 306 insertions(+), 461 deletions(-) delete mode 100755 bin/systemds create mode 100644 src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinImputeMARTest.java create mode 100644 src/test/scripts/functions/builtin/imputeMARTest.dml diff --git a/bin/README.md b/bin/README.md index 370efa2f093..5ba4e37b33c 100644 --- a/bin/README.md +++ b/bin/README.md @@ -17,7 +17,34 @@ limitations under the License. {% end comment %} --> -# Scripts to run SystemDS +# Apache SystemDS -This directory contains scripts to launch SystemDS. -For details look at: [RunSystemDS](/docs/site/run.md). +**Overview:** SystemDS is an open source ML system for the end-to-end data science lifecycle from data integration, cleaning, +and feature engineering, over efficient, local and distributed ML model training, to deployment and serving. To this +end, we aim to provide a stack of declarative languages with R-like syntax for (1) the different tasks of the data-science +lifecycle, and (2) users with different expertise. These high-level scripts are compiled into hybrid execution plans of +local, in-memory CPU and GPU operations, as well as distributed operations on Apache Spark. In contrast to existing +systems - that either provide homogeneous tensors or 2D Datasets - and in order to serve the entire data science lifecycle, +the underlying data model are DataTensors, i.e., tensors (multi-dimensional arrays) whose first dimension may have a +heterogeneous and nested schema. + + +Resource | Links +---------|------ +**Quick Start** | [Install, Quick Start and Hello World](https://apache.github.io/systemds/site/install.html) +**Documentation:** | [SystemDS Documentation](https://apache.github.io/systemds/) +**Python Documentation** | [Python SystemDS Documentation](https://apache.github.io/systemds/api/python/index.html) +**Issue Tracker** | [Jira Dashboard](https://issues.apache.org/jira/secure/Dashboard.jspa?selectPageId=12335852) + + +**Status and Build:** SystemDS is renamed from SystemML which is an **Apache Top Level Project**. +To build from source visit [SystemDS Install from source](https://apache.github.io/systemds/site/install.html) + +[![Build](https://github.com/apache/systemds/actions/workflows/build.yml/badge.svg?branch=main)](https://github.com/apache/systemds/actions/workflows/build.yml) +[![Documentation](https://github.com/apache/systemds/actions/workflows/documentation.yml/badge.svg?branch=main)](https://github.com/apache/systemds/actions/workflows/documentation.yml) +[![LicenseCheck](https://github.com/apache/systemds/actions/workflows/license.yml/badge.svg?branch=main)](https://github.com/apache/systemds/actions/workflows/license.yml) +[![Java Tests](https://github.com/apache/systemds/actions/workflows/javaTests.yml/badge.svg?branch=main)](https://github.com/apache/systemds/actions/workflows/javaTests.yml) +[![codecov](https://codecov.io/gh/apache/systemds/graph/badge.svg?token=4YfvX8s6Dz)](https://codecov.io/gh/apache/systemds) +[![Python Test](https://github.com/apache/systemds/actions/workflows/python.yml/badge.svg?branch=main)](https://github.com/apache/systemds/actions/workflows/python.yml) +[![Total PyPI downloads](https://static.pepy.tech/personalized-badge/systemds?units=abbreviation&period=total&left_color=grey&right_color=blue&left_text=Total%20PyPI%20Downloads)](https://pepy.tech/project/systemds) +[![Monthly PyPI downloads](https://static.pepy.tech/personalized-badge/systemds?units=abbreviation&left_color=grey&right_color=blue&left_text=Monthly%20PyPI%20Downloads)](https://pepy.tech/project/systemds) diff --git a/bin/systemds b/bin/systemds deleted file mode 100755 index 2e8e629495b..00000000000 --- a/bin/systemds +++ /dev/null @@ -1,458 +0,0 @@ -#!/usr/bin/env bash -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -# If not set by env, set to 1 to run spark-submit instead of local java -# This should be used to run with spark-submit instead of java -if [[ -z "$SYSDS_DISTRIBUTED" ]]; then - SYSDS_DISTRIBUTED=0 -fi - -# if not set by env, set to 1 to disable setup output of this script -if [ -z "$SYSDS_QUIET" ]; then - SYSDS_QUIET=0 -fi - -# if not set by env, set to default exec modes -if [[ -z "$SYSDS_EXEC_MODE" ]]; then - case "$SYSDS_DISTRIBUTED" in - 0) SYSDS_EXEC_MODE=singlenode ;; - *) SYSDS_EXEC_MODE=hybrid ;; - esac -fi - -# an echo toggle -print_out() -{ - if [ $SYSDS_QUIET == 0 ]; then - echo "$1" - fi -} - -if [[ -z $SYSTEMDS_ROOT ]] ; then - SYSTEMDS_ROOT=$(pwd) - print_out "SYSTEMDS_ROOT not set defaulting to current dir $(pwd)" -fi; - -# when using find, look in the directories in this order -DIR_SEARCH_ORDER="$SYSTEMDS_ROOT/target . $SYSTEMDS_ROOT $SYSTEMDS_ROOT/conf $SYSTEMDS_ROOT/lib $SYSTEMDS_ROOT/src" -ordered_find() { - result="" - for dir in $(echo "$DIR_SEARCH_ORDER" | tr ' ' '\n') ; do - if [[ $dir == "$SYSTEMDS_ROOT" ]] || [[ $dir == "." ]]; then - result=$(find "$dir" -maxdepth 1 -iname "$1" -print -quit) - if [[ $result != "" ]]; then break; fi - else - result=$(find "$dir" -iname "$1" -print -quit 2> /dev/null) - if [[ $result != "" ]]; then break; fi - fi - done - echo "$result" -} - -if [ -n "$SYSTEMDS_STANDALONE_OPTS" ]; then - print_out "Overriding SYSTEMDS_STANDALONE_OPTS with env var: $SYSTEMDS_STANDALONE_OPTS" -else - # specify parameters to java when running locally here - SYSTEMDS_STANDALONE_OPTS="-Xmx4g -Xms4g -Xmn400m " -fi - -if [ -n "$SYSTEMDS_REMOTE_DEBUGGING" ]; then - print_out "Overriding SYSTEMDS_REMOTE_DEBUGGING with env var: $SYSTEMDS_REMOTE_DEBUGGING" -else - SYSTEMDS_REMOTE_DEBUGGING=" -agentlib:jdwp=transport=dt_socket,suspend=y,address=8787,server=y " -fi - -# check if log4j config file exists, otherwise unset -# to run with a non fatal complaint by SystemDS -if [ -z "$LOG4JPROP" ] ; then - # before wild card search look obvious places. - if [ -f "$SYSTEMDS_ROOT/conf/log4j.properties" ]; then - LOG4JPROP="$SYSTEMDS_ROOT/conf/log4j.properties" - elif [ -f "$SYSTEMDS_ROOT/log4j.properties" ]; then - LOG4JPROP="$SYSTEMDS_ROOT/log4j.properties" - else # wildcard search - LOG4JPROP=$(ordered_find "log4j*properties") - fi -fi - -# If the LOG4J variable is declared or found. -if [ -f "${LOG4JPROP}" ]; then - LOG4JPROPFULL="-Dlog4j.configuration=file:$LOG4JPROP" -fi - -if [ -n "${SYSTEMDS_DISTRIBUTED_OPTS}" ]; then - print_out "Overriding SYSTEMDS_DISTRIBUTED_OPTS with env var $SYSTEMDS_DISTRIBUTED_OPTS" -else - # specify parameters to pass to spark-submit when running on spark here - SYSTEMDS_DISTRIBUTED_OPTS="\ - --master yarn \ - --deploy-mode client \ - --driver-memory 100g \ - --conf spark.driver.extraJavaOptions=\"-Xms100g -Xmn10g -Dlog4j.configuration=file:$LOG4JPROP\" \ - --conf spark.executor.extraJavaOptions=\"-Dlog4j.configuration=file:$LOG4JPROP\" \ - --conf spark.executor.heartbeatInterval=100s \ - --files $LOG4JPROP \ - --conf spark.network.timeout=512s \ - --num-executors 4 \ - --executor-memory 64g \ - --executor-cores 16 " -fi - - -# error help print -function printUsage { -cat << EOF - -Usage: $0 [-r] [SystemDS.jar] [-f] [arguments] [-help] - - SystemDS.jar : Specify a custom SystemDS.jar file (this will be prepended - to the classpath - or fed to spark-submit - -r : Spawn a debug server for remote debugging (standalone and - spark driver only atm). Default port is 8787 - change within - this script if necessary. See SystemDS documentation on how - to attach a remote debugger. - -f : Optional prefix to the dml-filename for consistency with - previous behavior dml-filename : The script file to run. - This is mandatory unless running as a federated worker - (see below). - arguments : The arguments specified after the DML script are passed to - SystemDS. Specify parameters that need to go to - java/spark-submit by editing this run script. - -help : Print this usage message and SystemDS parameter info - -Worker Usage: $0 [-r] WORKER [SystemDS.jar] [arguments] [-help] - - port : The port to open for the federated worker. - -Federated Monitoring Usage: $0 [-r] FEDMONITORING [SystemDS.jar] [arguments] [-help] - - port : The port to open for the federated monitoring tool. - -Set custom launch configuration by setting/editing SYSTEMDS_STANDALONE_OPTS -and/or SYSTEMDS_DISTRIBUTED_OPTS. - -Set the environment variable SYSDS_DISTRIBUTED=1 to run spark-submit instead of -local java Set SYSDS_QUIET=1 to omit extra information printed by this run -script. - -EOF -} - -# print an error if no argument is supplied. -if [ -z "$1" ] ; then - echo "Wrong Usage. Add -help for additional parameters."; - echo "" - printUsage; - exit -1 -fi - -#This loop handles the parameters to the run-script, not the ones passed to SystemDS. -#To not confuse getopts with SystemDS parameters, only the first two params are considered -#here. If more run-script params are needed, adjust the next line accordingly -PRINT_SYSDS_HELP=0 -while getopts ":hr:f:" options "$1$2"; do - case $options in - h ) echo "Help requested. Will exit after extended usage message!" - printUsage - PRINT_SYSDS_HELP=1 - break - ;; - \? ) echo "Unknown parameter -$OPTARG" - printUsage - exit - ;; - f ) - # silently remove -f (this variant is triggered if there's no - # jar file or WORKER as first parameter) - if echo "$OPTARG" | grep -qi "dml"; then - break - else - print_out "No DML Script found after -f option." - fi - ;; - r ) - print_out "Spawning server for remote debugging" - if [ $SYSDS_DISTRIBUTED == 0 ]; then - SYSTEMDS_STANDALONE_OPTS=${SYSTEMDS_STANDALONE_OPTS}${SYSTEMDS_REMOTE_DEBUGGING} - else - SYSTEMDS_DISTRIBUTED_OPTS=${SYSTEMDS_DISTRIBUTED_OPTS}${SYSTEMDS_REMOTE_DEBUGGING} - fi - shift # remove -r from positional arguments - ;; - * ) - print_out "Error: Unexpected error while processing options;" - printUsage - exit - esac -done - -# Peel off first and/or second argument so that $@ contains arguments to DML script -if echo "$1" | grep -q "jar"; then - SYSTEMDS_JAR_FILE=$1 - shift - # handle optional '-f' before DML file (for consistency) - if echo "$1" | grep -q "\-f"; then - shift - SCRIPT_FILE=$1 - shift - else - SCRIPT_FILE=$1 - shift - fi -elif echo "$1" | grep -q "WORKER"; then - WORKER=1 - shift - if echo "$1" | grep -q "jar"; then - SYSTEMDS_JAR_FILE=$1 - shift - fi - PORT=$1 - re='^[0-9]+$' - if ! [[ $PORT =~ $re ]] ; then - echo "error: Port is not a number" - printUsage - fi - shift -elif echo "$1" | grep -q "FEDMONITORING"; then - FEDMONITORING=1 - shift - if echo "$1" | grep -q "jar"; then - SYSTEMDS_JAR_FILE=$1 - shift - fi - PORT=$1 - re='^[0-9]+$' - if ! [[ $PORT =~ $re ]] ; then - echo "error: Port is not a number" - printUsage - fi - shift -else - # handle optional '-f' before DML file (for consistency) - if echo "$1" | grep -q "\-f"; then - shift - SCRIPT_FILE=$1 - shift - else - SCRIPT_FILE=$1 - shift - fi -fi - -if [ -z "$WORKER" ] ; then - WORKER=0 -fi - -if [ -z "$FEDMONITORING" ] ; then - FEDMONITORING=0 -fi - -# find a SystemDS jar file to run -if [ -z ${SYSTEMDS_JAR_FILE+x} ]; then # If it is not found yet. - if [ ! -z ${SYSTEMDS_ROOT+x} ]; then # Check currently set SYSETMDS_ROOT - # Current SYSTEMDS_ROOT is set and is a directory. - if [ -d "$SYSTEMDS_ROOT/target" ] && [ -d "$SYSTEMDS_ROOT/.git" ]; then - # Current path is most likely a build directory of systemds - SYSTEMDS_JAR_FILE=$(find "$SYSTEMDS_ROOT/target" -maxdepth 1 -iname ""systemds-?.?.?-SNAPSHOT.jar"" -print -quit) - elif [ -d "$SYSTEMDS_ROOT" ] && [ -d "$SYSTEMDS_ROOT/lib" ]; then - # Most likely a release directory. - SYSTEMDS_JAR_FILE=$(find "$SYSTEMDS_ROOT" -maxdepth 1 -iname ""systemds-?.?.?-SNAPSHOT.jar"" -print -quit) - fi - fi -fi - -# If no jar file is found, start searching --- expected + 70 ms execution time -if [ -z ${SYSTEMDS_JAR_FILE+x} ]; then - SYSTEMDS_JAR_FILE=$(ordered_find "systemds.jar") - if [ -z ${SYSTEMDS_JAR_FILE+x} ]; then - SYSTEMDS_JAR_FILE=$(ordered_find "systemds-?.?.?.jar") - if [ -z ${SYSTEMDS_JAR_FILE+x} ]; then - SYSTEMDS_JAR_FILE=$(ordered_find "systemds-?.?.?-SNAPSHOT.jar") - if [ -z ${SYSTEMDS_JAR_FILE+x} ]; then - echo "wARNING: Unable to find SystemDS jar file to launch" - exit -1 - fi - fi - fi -fi - -if [[ "$*" == *-config* ]]; then -# override config file from env var if given as parameter to SystemDS - read -r -d '' -a myArray < <( echo "$@" ) - INDEX=0 - for i in "${myArray[@]}"; do - if [[ ${myArray[INDEX]} == *-config* ]]; then - if [ -f "${myArray[((INDEX+1))]}" ]; then - CONFIG_FILE="${myArray[((INDEX+1))]}" - else - echo Warning! Passed config file "${myArray[((INDEX+1))]}" does not exist. - fi - # remove -config - unset 'myArray[INDEX]' - - # remove -config param if not starting with - - if [[ "${myArray[((INDEX+1))]:0:1}" != "-" ]]; then - unset 'myArray[((INDEX+1))]' - fi - # setting the script arguments without the passed -config for further processing - set -- "${myArray[@]}" - break; - fi - # debug print array item - #echo "${myArray[INDEX]}" - (( INDEX=INDEX+1 )) - done - - if [ -f "$CONFIG_FILE" ] ; then - CONFIG_FILE="-config $CONFIG_FILE" - else - CONFIG_FILE="" - fi -elif [ -z "$CONFIG_FILE" ] ; then - - # default search for config file - if [ -f "$SYSTEMDS_ROOT/conf/SystemDS-config-defaults.xml" ]; then - CONFIG_FILE="$SYSTEMDS_ROOT/conf/SystemDS-config-defaults.xml" - elif [ -f "$SYSTEMDS_ROOT/SystemDS-config-defaults.xml" ]; then - CONFIG_FILE="$SYSTEMDS_ROOT/conf/SystemDS-config-defaults.xml" - else # wildcard search - # same as above: set config file param if the file exists - CONFIG_FILE=$(ordered_find "SystemDS-config-defaults.xml") - fi - - if [ -z "$CONFIG_FILE" ]; then # Second search if still not found. - CONFIG_FILE=$(ordered_find "SystemDS-config.xml") - fi - - if [ -z "$CONFIG_FILE" ]; then - CONFIG_FILE="" - else - CONFIG_FILE="-config $CONFIG_FILE" - fi -else - # CONFIG_FILE was set by env var. Unset if that setting is wrong - if [ -f "${CONFIG_FILE}" ]; then - CONFIG_FILE="-config $CONFIG_FILE" - else - CONFIG_FILE="" - fi -fi - -# override exec mode if given as parameter to SystemDS (e.g. -exec singlenode) -read -r -d '' -a myArray < <( echo "$@" ) -INDEX=0 -for i in "${myArray[@]}"; do - if [[ "$i" == *-exec* ]]; then - SYSDS_EXEC_MODE="${myArray[((INDEX+1))]}" - break; - fi - (( INDEX=INDEX+1 )) -done - -# find absolute path to hadoop home in SYSTEMDS_ROOT -if [ -z "$HADOOP_HOME" ]; then - HADOOP_HOME="$(find "$SYSTEMDS_ROOT" -iname hadoop | tail -n 1 )" -fi - -# detect operating system to set correct path separator -if [ "$OSTYPE" == "win32" ] || [ "$OSTYPE" == "msys" ] || [ "$OSTYPE" == "cygwin" ]; then - PATH_SEP=\; - DIR_SEP=\\ -else - # default directory separator unix style - DIR_SEP=/ - PATH_SEP=: -fi - - -NATIVE_LIBS="$SYSTEMDS_ROOT${DIR_SEP}target${DIR_SEP}classes${DIR_SEP}lib" -PATH=${HADOOP_HOME}${DIR_SEP}bin${PATH_SEP}${PATH}${PATH_SEP}$NATIVE_LIBS -LD_LIBRARY_PATH=${HADOOP_HOME}${DIR_SEP}bin${PATH_SEP}${LD_LIBRARY_PATH} - - -if [ $PRINT_SYSDS_HELP == 1 ]; then - echo "----------------------------------------------------------------------" - echo "Further help on SystemDS arguments:" - java -jar $SYSTEMDS_JAR_FILE org.apache.sysds.api.DMLScript -help - exit 1 -fi - -if [ $SYSDS_QUIET == 0 ]; then - print_out "###############################################################################" - print_out "# SYSTEMDS_ROOT= $SYSTEMDS_ROOT" - print_out "# SYSTEMDS_JAR_FILE= $SYSTEMDS_JAR_FILE" - print_out "# SYSDS_EXEC_MODE= $SYSDS_EXEC_MODE" - print_out "# CONFIG_FILE= $CONFIG_FILE" - print_out "# LOG4JPROP= $LOG4JPROPFULL" - print_out "# HADOOP_HOME= $HADOOP_HOME" - print_out "#" -fi - -# Build the command to run -if [ $WORKER == 1 ]; then - print_out "# starting Federated worker on port $PORT" - CMD=" \ - java $SYSTEMDS_STANDALONE_OPTS \ - $LOG4JPROPFULL \ - -jar $SYSTEMDS_JAR_FILE \ - -w $PORT \ - $CONFIG_FILE \ - $*" -elif [ "$FEDMONITORING" == 1 ]; then - print_out "# starting Federated backend monitoring on port $PORT" - CMD=" \ - java $SYSTEMDS_STANDALONE_OPTS \ - $LOG4JPROPFULL \ - -jar $SYSTEMDS_JAR_FILE \ - -fedMonitoring $PORT \ - $CONFIG_FILE \ - $*" -elif [ $SYSDS_DISTRIBUTED == 0 ]; then - print_out "# Running script $SCRIPT_FILE locally with opts: $*" - - CMD=" \ - java $SYSTEMDS_STANDALONE_OPTS \ - $LOG4JPROPFULL \ - -jar $SYSTEMDS_JAR_FILE \ - -f $SCRIPT_FILE \ - -exec $SYSDS_EXEC_MODE \ - $CONFIG_FILE \ - $*" -else - print_out "# Running script $SCRIPT_FILE distributed with opts: $*" - CMD=" \ - spark-submit $SYSTEMDS_DISTRIBUTED_OPTS \ - $SYSTEMDS_JAR_FILE \ - -f $SCRIPT_FILE \ - -exec $SYSDS_EXEC_MODE \ - $CONFIG_FILE \ - $*" -fi - -if [ $SYSDS_QUIET == 0 ]; then - print_out "# Executing command: $CMD" - print_out "###############################################################################" -fi - -# run -eval "$CMD" diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinImputeMARTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinImputeMARTest.java new file mode 100644 index 00000000000..e5c3587d445 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinImputeMARTest.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.builtin.part1; + +import org.apache.sysds.common.Types; +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; +import java.util.HashMap; + +public class BuiltinImputeMARTest extends AutomatedTestBase { + private final static String TEST_NAME = "imputeMARTest"; + private final static String TEST_DIR = "functions/builtin/"; + private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinImputeMARTest.class.getSimpleName() + "/"; + private String DATASET = DATASET_DIR + "ChickWeight.csv"; + private final static double eps = 0.16; + private final static int iter = 3; + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"LogReg"})); + } + + @Test + public void testMiceImputation() { + runImputationTest("MICE", ExecType.CP); + } + + @Test + public void testMeanImputation() { + runImputationTest("MEAN", ExecType.CP); + } + + @Test + public void compareImputationMethods() { + runImputationTest("MICE", ExecType.CP); + runImputationTest("MEAN", ExecType.CP); + compareResults(); + } + + private void runImputationTest(String method, ExecType instType) { + Types.ExecMode platformOld = setExecMode(instType); + try { + System.out.println("Dataset " + DATASET); + loadTestConfiguration(getTestConfiguration(TEST_NAME)); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + + double[][] mask = {{0.0, 0.0, 1.0, 1.0, 0.0}}; + writeInputMatrixWithMTD("M", mask, true); + + programArgs = new String[]{"-nvargs", + "X=" + DATASET, + "Mask=" + input("M"), + "method=" + method, + "iteration=" + iter, + "LogReg=" + output("LogReg_" + method), + "targetCol=4"}; + + runTest(true, false, null, -1); + testLogisticRegression("LogReg_" + method); + } finally { + rtplatform = platformOld; + } + } + + private void testLogisticRegression(String outputName) { + HashMap logRegResults = readDMLMatrixFromOutputDir(outputName); + + Assert.assertEquals("Incorrect number of logistic regression coefficients", 6, logRegResults.size() - 1); + + for (int i = 1; i <= logRegResults.size() - 1; i++) { + Double coefficient = logRegResults.get(new MatrixValue.CellIndex(i, 1)); + Assert.assertTrue("Logistic regression coefficient is out of reasonable range: " + coefficient, + Math.abs(coefficient) < 10); + } + + Double auc = logRegResults.get(new MatrixValue.CellIndex(logRegResults.size(), 1)); + Assert.assertTrue("Logistic regression model performance (AUC) is below threshold: " + auc, auc > 0.7); + + Double rSquared = logRegResults.get(new MatrixValue.CellIndex(logRegResults.size(), 1)); + Assert.assertTrue("R-squared value is out of range: " + rSquared, rSquared >= 0 && rSquared <= 1); + Assert.assertTrue("Model fit (R-squared) is below acceptable threshold: " + rSquared, rSquared > 0.3); + + int n = 1000; + for (int i = 1; i <= logRegResults.size() - 2; i++) { + Double coefficient = logRegResults.get(new MatrixValue.CellIndex(i, 1)); + Double standardError = Math.abs(coefficient / Math.sqrt(n)); + Double zScore = coefficient / standardError; + Double pValue = 2 * (1 - cdf(Math.abs(zScore))); + Assert.assertTrue("P-value for coefficient " + i + " is not significant: " + pValue, pValue < 0.05); + } + System.out.println("Model fit (R-squared): " + rSquared); + } + + private void compareResults() { + HashMap miceResults = readDMLMatrixFromOutputDir("LogReg_MICE"); + HashMap meanResults = readDMLMatrixFromOutputDir("LogReg_MEAN"); + + System.out.println("Comparison of MICE and Mean Imputation:"); + System.out.println("Coefficient\tMICE\t\tMean\t\tMICE p-value\tMean p-value"); + + int n = 1000; + for (int i = 1; i <= miceResults.size() - 1; i++) { + Double miceCoef = miceResults.get(new MatrixValue.CellIndex(i, 1)); + Double meanCoef = meanResults.get(new MatrixValue.CellIndex(i, 1)); + + Double miceSE = Math.abs(miceCoef / Math.sqrt(n)); + Double meanSE = Math.abs(meanCoef / Math.sqrt(n)); + + Double micePValue = 2 * (1 - cdf(Math.abs(miceCoef / miceSE))); + Double meanPValue = 2 * (1 - cdf(Math.abs(meanCoef / meanSE))); + + System.out.printf("%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f%n", + i, miceCoef, meanCoef, micePValue, meanPValue); + + totalDifference += Math.abs(miceCoef - meanCoef); + } + + double averageDifference = totalDifference / coeffCount; + double threshold = 0.1; + + if (averageDifference < threshold) { + System.out.println("Data type: MCAR (Missing Completely At Random)"); + } else { + Double miceRSquared = miceResults.get(new MatrixValue.CellIndex(miceResults.size(), 1)); + Double meanRSquared = meanResults.get(new MatrixValue.CellIndex(meanResults.size(), 1)); + + System.out.println("MICE R-squared: " + miceRSquared); + System.out.println("Mean Imputation R-squared: " + meanRSquared); + + double rSquaredDifference = miceRSquared - meanRSquared; + double rSquaredThreshold = 0.05; + + if (rSquaredDifference > rSquaredThreshold) { + System.out.println("Data type: MAR (Missing At Random)"); + } else { + System.out.println("Data type: NMAR (Not Missing At Random)"); + } + } + + } + + private double cdf(double x) { + return 0.5 * (1 + erf(x / Math.sqrt(2))); + } + + private double erf(double z) { + double t = 1.0 / (1.0 + 0.5 * Math.abs(z)); + double ans = 1 - t * Math.exp(-z * z - 1.26551223 + + t * (1.00002368 + + t * (0.37409196 + + t * (0.09678418 + + t * (-0.18628806 + + t * (0.27886807 + + t * (-1.13520398 + + t * (1.48851587 + + t * (-0.82215223 + + t * 0.17087277))))))))); + return z >= 0 ? ans : -ans; + } +} diff --git a/src/test/scripts/functions/builtin/imputeMARTest.dml b/src/test/scripts/functions/builtin/imputeMARTest.dml new file mode 100644 index 00000000000..5f233c1d9b3 --- /dev/null +++ b/src/test/scripts/functions/builtin/imputeMARTest.dml @@ -0,0 +1,90 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Prepare data +X = read($X); +M = read($Mask); +imputationMethod = $method; +iter = $iteration; +targetCol = $targetCol; + +# MICE imputation +mice_imputation = function(Matrix[Double] X, Matrix[Double] M, Integer iter) return (Matrix[Double] X_imputed) { + X_imputed = X; + for(i in 1:iter) { + for(j in 1:ncol(X)) { + if(sum(M[,j]) > 0) { + X_temp = X_imputed; + X_temp[, j] = X_temp[, j] * (1-M[,j]); + model = lm(X_imputed[,j], X_temp); + X_imputed[M[,j]==1, j] = predict(model, X_temp[M[,j]==1,]); + } + } + } +} + +# Mean imputation +mean_imputation = function(Matrix[Double] X, Matrix[Double] M) return (Matrix[Double] X_imputed) { + X_imputed = X; + for(j in 1:ncol(X)) { + if(sum(M[,j]) > 0) { + col_mean = mean(X[M[,j]==0, j]); + X_imputed[M[,j]==1, j] = col_mean; + } + } +} + + +if (imputationMethod == "MICE") { + X_imputed = mice_imputation(X, M, iter); +} else if (imputationMethod == "MEAN") { + X_imputed = mean_imputation(X, M); +} else { + stop("Invalid imputation method specified"); +} + +# Prepare data for logistic regression +y = X_imputed[, targetCol]; +X_lr = X_imputed[, -targetCol]; + +# Logistic regression +model = glm(y, X_lr, family="binomial", verbose=FALSE); +beta = model$coefficients; +predictions = model$fitted.values; + +# Get AUC +pos_label = rowIndexMax(matrix(1,1,2)); +auc = as.scalar(auroc(y, predictions, pos_label)); + +# Calculate R-squared +y_mean = mean(y); +ss_tot = sum((y - y_mean)^2); +ss_res = sum((y - predictions)^2); +r_squared = 1 - (ss_res / ss_tot); + +# Output results +write(beta, $LogReg); +write(matrix(auc, 1, 1), $LogReg, format="csv", append=TRUE); +write(matrix(r_squared, 1, 1), $LogReg, format="csv", append=TRUE); + +print("Imputation method: " + imputationMethod); +print("AUC: " + auc); +print("R-squared: " + r_squared); \ No newline at end of file