Skip to content

Commit

Permalink
[SPARK-50922][ML][PYTHON][CONNECT] Support OneVsRest on Connect
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Support OneVsRest on Connect

### Why are the changes needed?
feature parity

### Does this PR introduce _any_ user-facing change?
yes

### How was this patch tested?
new tests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#49704 from zhengruifeng/ml_connect_ovr_2.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Jan 28, 2025
1 parent 9a45019 commit b49ef2a
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 19 deletions.
2 changes: 2 additions & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,7 @@ def __hash__(self):
"pyspark.ml.tests.test_persistence",
"pyspark.ml.tests.test_pipeline",
"pyspark.ml.tests.test_tuning",
"pyspark.ml.tests.test_ovr",
"pyspark.ml.tests.test_stat",
"pyspark.ml.tests.test_training_summary",
"pyspark.ml.tests.tuning.test_tuning",
Expand Down Expand Up @@ -1129,6 +1130,7 @@ def __hash__(self):
"pyspark.ml.tests.connect.test_parity_feature",
"pyspark.ml.tests.connect.test_parity_pipeline",
"pyspark.ml.tests.connect.test_parity_tuning",
"pyspark.ml.tests.connect.test_parity_ovr",
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and
Expand Down
61 changes: 42 additions & 19 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
cast,
overload,
TYPE_CHECKING,
Tuple,
Callable,
)

from pyspark import keyword_only, since, inheritable_thread_target
Expand Down Expand Up @@ -85,6 +87,8 @@
MLWriter,
MLWritable,
HasTrainingSummary,
try_remote_read,
try_remote_write,
try_remote_attribute_relation,
)
from pyspark.ml.wrapper import JavaParams, JavaPredictor, JavaPredictionModel, JavaWrapper
Expand All @@ -94,6 +98,7 @@
from pyspark.sql.functions import udf, when
from pyspark.sql.types import ArrayType, DoubleType
from pyspark.storagelevel import StorageLevel
from pyspark.sql.utils import is_remote

if TYPE_CHECKING:
from pyspark.ml._typing import P, ParamMap
Expand Down Expand Up @@ -3572,31 +3577,45 @@ def _fit(self, dataset: DataFrame) -> "OneVsRestModel":
if handlePersistence:
multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)

def trainSingleClass(index: int) -> CM:
binaryLabelCol = "mc2b$" + str(index)
trainingDataset = multiclassLabeled.withColumn(
binaryLabelCol,
when(multiclassLabeled[labelCol] == float(index), 1.0).otherwise(0.0),
)
paramMap = dict(
[
(classifier.labelCol, binaryLabelCol),
(classifier.featuresCol, featuresCol),
(classifier.predictionCol, predictionCol),
]
)
if weightCol:
paramMap[cast(HasWeightCol, classifier).weightCol] = weightCol
return classifier.fit(trainingDataset, paramMap)
def _oneClassFitTasks(numClasses: int) -> List[Callable[[], Tuple[int, CM]]]:
indices = iter(range(numClasses))

def trainSingleClass() -> Tuple[int, CM]:
index = next(indices)

binaryLabelCol = "mc2b$" + str(index)
trainingDataset = multiclassLabeled.withColumn(
binaryLabelCol,
when(multiclassLabeled[labelCol] == float(index), 1.0).otherwise(0.0),
)
paramMap = dict(
[
(classifier.labelCol, binaryLabelCol),
(classifier.featuresCol, featuresCol),
(classifier.predictionCol, predictionCol),
]
)
if weightCol:
paramMap[cast(HasWeightCol, classifier).weightCol] = weightCol
return index, classifier.fit(trainingDataset, paramMap)

return [trainSingleClass] * numClasses

tasks = map(
inheritable_thread_target(dataset.sparkSession),
_oneClassFitTasks(numClasses),
)
pool = ThreadPool(processes=min(self.getParallelism(), numClasses))

models = pool.map(inheritable_thread_target(trainSingleClass), range(numClasses))
subModels = [None] * numClasses
for j, subModel in pool.imap_unordered(lambda f: f(), tasks):
assert subModels is not None
subModels[j] = subModel

if handlePersistence:
multiclassLabeled.unpersist()

return self._copyValues(OneVsRestModel(models=models))
return self._copyValues(OneVsRestModel(models=cast(List[ClassificationModel], subModels)))

def copy(self, extra: Optional["ParamMap"] = None) -> "OneVsRest":
"""
Expand Down Expand Up @@ -3671,9 +3690,11 @@ def _to_java(self) -> "JavaObject":
return _java_obj

@classmethod
@try_remote_read
def read(cls) -> "OneVsRestReader":
return OneVsRestReader(cls)

@try_remote_write
def write(self) -> MLWriter:
if isinstance(self.getClassifier(), JavaMLWritable):
return JavaMLWriter(self) # type: ignore[arg-type]
Expand Down Expand Up @@ -3787,7 +3808,7 @@ def __init__(self, models: List[ClassificationModel]):
from pyspark.core.context import SparkContext

self.models = models
if not isinstance(models[0], JavaMLWritable):
if is_remote() or not isinstance(models[0], JavaMLWritable):
return
# set java instance
java_models = [cast(_JavaClassificationModel, model)._to_java() for model in self.models]
Expand Down Expand Up @@ -3955,9 +3976,11 @@ def _to_java(self) -> "JavaObject":
return _java_obj

@classmethod
@try_remote_read
def read(cls) -> "OneVsRestModelReader":
return OneVsRestModelReader(cls)

@try_remote_write
def write(self) -> MLWriter:
if all(
map(
Expand Down
36 changes: 36 additions & 0 deletions python/pyspark/ml/connect/readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def saveInstance(
from pyspark.ml.evaluation import JavaEvaluator
from pyspark.ml.pipeline import Pipeline, PipelineModel
from pyspark.ml.tuning import CrossValidator, TrainValidationSplit
from pyspark.ml.classification import OneVsRest, OneVsRestModel

# Spark Connect ML is built on scala Spark.ML, that means we're only
# supporting JavaModel or JavaEstimator or JavaEvaluator
Expand Down Expand Up @@ -187,6 +188,26 @@ def saveInstance(
warnings.warn("Overwrite doesn't take effect for TrainValidationSplitModel")
tvsm_writer = RemoteTrainValidationSplitModelWriter(instance, optionMap, session)
tvsm_writer.save(path)
elif isinstance(instance, OneVsRest):
from pyspark.ml.classification import OneVsRestWriter

if shouldOverwrite:
# TODO(SPARK-50954): Support client side model path overwrite
warnings.warn("Overwrite doesn't take effect for OneVsRest")

ovr_writer = OneVsRestWriter(instance)
ovr_writer.session(session) # type: ignore[arg-type]
ovr_writer.save(path)
elif isinstance(instance, OneVsRestModel):
from pyspark.ml.classification import OneVsRestModelWriter

if shouldOverwrite:
# TODO(SPARK-50954): Support client side model path overwrite
warnings.warn("Overwrite doesn't take effect for OneVsRestModel")

ovrm_writer = OneVsRestModelWriter(instance)
ovrm_writer.session(session) # type: ignore[arg-type]
ovrm_writer.save(path)
else:
raise NotImplementedError(f"Unsupported write for {instance.__class__}")

Expand Down Expand Up @@ -215,6 +236,7 @@ def loadInstance(
from pyspark.ml.evaluation import JavaEvaluator
from pyspark.ml.pipeline import Pipeline, PipelineModel
from pyspark.ml.tuning import CrossValidator, TrainValidationSplit
from pyspark.ml.classification import OneVsRest, OneVsRestModel

if (
issubclass(clazz, JavaModel)
Expand Down Expand Up @@ -307,5 +329,19 @@ def _get_class() -> Type[RL]:
tvs_reader.session(session)
return tvs_reader.load(path)

elif issubclass(clazz, OneVsRest):
from pyspark.ml.classification import OneVsRestReader

ovr_reader = OneVsRestReader(OneVsRest)
ovr_reader.session(session)
return ovr_reader.load(path)

elif issubclass(clazz, OneVsRestModel):
from pyspark.ml.classification import OneVsRestModelReader

ovrm_reader = OneVsRestModelReader(OneVsRestModel)
ovrm_reader.session(session)
return ovrm_reader.load(path)

else:
raise RuntimeError(f"Unsupported read for {clazz}")
37 changes: 37 additions & 0 deletions python/pyspark/ml/tests/connect/test_parity_ovr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest

from pyspark.ml.tests.test_ovr import OneVsRestTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase


class OneVsRestParityTests(OneVsRestTestsMixin, ReusedConnectTestCase):
pass


if __name__ == "__main__":
from pyspark.ml.tests.connect.test_parity_ovr import * # noqa: F401

try:
import xmlrunner # type: ignore[import]

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
135 changes: 135 additions & 0 deletions python/pyspark/ml/tests/test_ovr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import tempfile
import unittest

import numpy as np

from pyspark.ml.linalg import Vectors
from pyspark.ml.classification import (
LinearSVC,
LinearSVCModel,
OneVsRest,
OneVsRestModel,
)
from pyspark.testing.sqlutils import ReusedSQLTestCase


class OneVsRestTestsMixin:
def test_one_vs_rest(self):
spark = self.spark
df = (
spark.createDataFrame(
[
(0, 1.0, Vectors.dense(0.0, 5.0)),
(1, 0.0, Vectors.dense(1.0, 2.0)),
(2, 1.0, Vectors.dense(2.0, 1.0)),
(3, 2.0, Vectors.dense(3.0, 3.0)),
],
["index", "label", "features"],
)
.coalesce(1)
.sortWithinPartitions("index")
.select("label", "features")
)

svc = LinearSVC(maxIter=1, regParam=1.0)
self.assertEqual(svc.getMaxIter(), 1)
self.assertEqual(svc.getRegParam(), 1.0)

ovr = OneVsRest(classifier=svc, parallelism=1)
self.assertEqual(ovr.getParallelism(), 1)

model = ovr.fit(df)
self.assertIsInstance(model, OneVsRestModel)
self.assertEqual(len(model.models), 3)
for submodel in model.models:
self.assertIsInstance(submodel, LinearSVCModel)

self.assertTrue(
np.allclose(model.models[0].intercept, 0.06279247869226989, atol=1e-4),
model.models[0].intercept,
)
self.assertTrue(
np.allclose(
model.models[0].coefficients.toArray(),
[-0.1198765502306968, -0.1027513287691687],
atol=1e-4,
),
model.models[0].coefficients,
)

self.assertTrue(
np.allclose(model.models[1].intercept, 0.025877458475338313, atol=1e-4),
model.models[1].intercept,
)
self.assertTrue(
np.allclose(
model.models[1].coefficients.toArray(),
[-0.0362284418654736, 0.010350983390135305],
atol=1e-4,
),
model.models[1].coefficients,
)

self.assertTrue(
np.allclose(model.models[2].intercept, -0.37024065419409624, atol=1e-4),
model.models[2].intercept,
)
self.assertTrue(
np.allclose(
model.models[2].coefficients.toArray(),
[0.12886829400126, 0.012273170857262873],
atol=1e-4,
),
model.models[2].coefficients,
)

output = model.transform(df)
expected_cols = ["label", "features", "rawPrediction", "prediction"]
self.assertEqual(output.columns, expected_cols)
self.assertEqual(output.count(), 4)

# Model save & load
with tempfile.TemporaryDirectory(prefix="linear_svc") as d:
path1 = os.path.join(d, "ovr")
ovr.write().overwrite().save(path1)
ovr2 = OneVsRest.load(path1)
self.assertEqual(str(ovr), str(ovr2))

path2 = os.path.join(d, "ovr_model")
model.write().overwrite().save(path2)
model2 = OneVsRestModel.load(path2)
self.assertEqual(str(model), str(model2))


class OneVsRestTests(OneVsRestTestsMixin, ReusedSQLTestCase):
pass


if __name__ == "__main__":
from pyspark.ml.tests.test_ovr import * # noqa: F401,F403

try:
import xmlrunner # type: ignore[import]

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)

0 comments on commit b49ef2a

Please sign in to comment.