From af03383cd166f943d527e7feb8aea950b566b570 Mon Sep 17 00:00:00 2001 From: machichima Date: Fri, 20 Dec 2024 23:09:43 +0800 Subject: [PATCH] feat: add driver/executor pod in Spark Signed-off-by: machichima --- .../flytekitplugins/spark/models.py | 27 +++ .../flytekitplugins/spark/task.py | 7 + .../flytekit-spark/tests/test_spark_task.py | 155 ++++++++++++++++-- 3 files changed, 179 insertions(+), 10 deletions(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/models.py b/plugins/flytekit-spark/flytekitplugins/spark/models.py index e74a9fbe3f..1f185609f4 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/models.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/models.py @@ -7,6 +7,7 @@ from flytekit.exceptions import user as _user_exceptions from flytekit.models import common as _common +from flytekit.models.task import K8sPod class SparkType(enum.Enum): @@ -27,6 +28,8 @@ def __init__( executor_path: str, databricks_conf: Optional[Dict[str, Dict[str, Dict]]] = None, databricks_instance: Optional[str] = None, + driver_pod: Optional[K8sPod] = None, + executor_pod: Optional[K8sPod] = None, ): """ This defines a SparkJob target. It will execute the appropriate SparkJob. @@ -47,6 +50,8 @@ def __init__( databricks_conf = {} self._databricks_conf = databricks_conf self._databricks_instance = databricks_instance + self._driver_pod = driver_pod + self._executor_pod = executor_pod def with_overrides( self, @@ -71,6 +76,8 @@ def with_overrides( hadoop_conf=new_hadoop_conf, databricks_conf=new_databricks_conf, databricks_instance=self.databricks_instance, + driver_pod=self.driver_pod, + executor_pod=self.executor_pod, executor_path=self.executor_path, ) @@ -139,6 +146,22 @@ def databricks_instance(self) -> str: """ return self._databricks_instance + @property + def driver_pod(self) -> K8sPod: + """ + Additional pod specs for driver pod. + :rtype: K8sPod + """ + return self._driver_pod + + @property + def executor_pod(self) -> K8sPod: + """ + Additional pod specs for the worker node pods. + :rtype: K8sPod + """ + return self._executor_pod + def to_flyte_idl(self): """ :rtype: flyteidl.plugins.spark_pb2.SparkJob @@ -167,6 +190,8 @@ def to_flyte_idl(self): hadoopConf=self.hadoop_conf, databricksConf=databricks_conf, databricksInstance=self.databricks_instance, + driverPod=self.driver_pod.to_flyte_idl() if self.driver_pod else None, + executorPod=self.executor_pod.to_flyte_idl() if self.executor_pod else None, ) @classmethod @@ -193,4 +218,6 @@ def from_flyte_idl(cls, pb2_object): executor_path=pb2_object.executorPath, databricks_conf=json_format.MessageToDict(pb2_object.databricksConf), databricks_instance=pb2_object.databricksInstance, + driver_pod=pb2_object.driverPod, + executor_pod=pb2_object.executorPod, ) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 7d2f718617..5d331bcbf8 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -13,6 +13,7 @@ from flytekit.extend import ExecutionState, TaskPlugins from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin from flytekit.image_spec import ImageSpec +from flytekit.models.task import K8sPod from .models import SparkJob, SparkType @@ -31,12 +32,16 @@ class Spark(object): hadoop_conf: Dictionary of hadoop conf. The variables should match a typical hadoop configuration for spark executor_path: Python binary executable to use for PySpark in driver and executor. applications_path: MainFile is the path to a bundled JAR, Python, or R file of the application to execute. + driver_pod: K8sPod for Spark driver pod + executor_pod: K8sPod for Spark executor pod """ spark_conf: Optional[Dict[str, str]] = None hadoop_conf: Optional[Dict[str, str]] = None executor_path: Optional[str] = None applications_path: Optional[str] = None + driver_pod: Optional[K8sPod] = None + executor_pod: Optional[K8sPod] = None def __post_init__(self): if self.spark_conf is None: @@ -168,6 +173,8 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: executor_path=self._default_executor_path or settings.python_interpreter, main_class="", spark_type=SparkType.PYTHON, + driver_pod=self.task_config.driver_pod, + executor_pod=self.task_config.executor_pod, ) if isinstance(self.task_config, (Databricks, DatabricksV2)): cfg = cast(DatabricksV2, self.task_config) diff --git a/plugins/flytekit-spark/tests/test_spark_task.py b/plugins/flytekit-spark/tests/test_spark_task.py index 7ce5f14ebf..77f823e363 100644 --- a/plugins/flytekit-spark/tests/test_spark_task.py +++ b/plugins/flytekit-spark/tests/test_spark_task.py @@ -5,15 +5,39 @@ import pyspark import pytest +from google.protobuf.json_format import MessageToDict +from flytekit import PodTemplate from flytekit.core import context_manager from flytekitplugins.spark import Spark from flytekitplugins.spark.task import Databricks, new_spark_session from pyspark.sql import SparkSession import flytekit -from flytekit import StructuredDataset, StructuredDatasetTransformerEngine, task, ImageSpec -from flytekit.configuration import Image, ImageConfig, SerializationSettings, FastSerializationSettings, DefaultImages -from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager, ExecutionState +from flytekit import ( + StructuredDataset, + StructuredDatasetTransformerEngine, + task, + ImageSpec, +) +from flytekit.configuration import ( + Image, + ImageConfig, + SerializationSettings, + FastSerializationSettings, + DefaultImages, +) +from flytekit.core.context_manager import ( + ExecutionParameters, + FlyteContextManager, + ExecutionState, +) +from flytekit.models.task import K8sPod +from kubernetes.client.models import ( + V1Container, + V1PodSpec, + V1Toleration, + V1EnvVar, +) @pytest.fixture(scope="function") @@ -68,7 +92,10 @@ def my_spark(a: str) -> int: retrieved_settings = my_spark.get_custom(settings) assert retrieved_settings["sparkConf"] == {"spark": "1"} assert retrieved_settings["executorPath"] == "/usr/bin/python3" - assert retrieved_settings["mainApplicationFile"] == "local:///usr/local/bin/entrypoint.py" + assert ( + retrieved_settings["mainApplicationFile"] + == "local:///usr/local/bin/entrypoint.py" + ) pb = ExecutionParameters.new_builder() pb.working_dir = "/tmp" @@ -121,11 +148,13 @@ def test_to_html(): df = spark.createDataFrame([("Bob", 10)], ["name", "age"]) sd = StructuredDataset(dataframe=df) tf = StructuredDatasetTransformerEngine() - output = tf.to_html(FlyteContextManager.current_context(), sd, pyspark.sql.DataFrame) + output = tf.to_html( + FlyteContextManager.current_context(), sd, pyspark.sql.DataFrame + ) assert pd.DataFrame(df.schema, columns=["StructField"]).to_html() == output -@mock.patch('pyspark.context.SparkContext.addPyFile') +@mock.patch("pyspark.context.SparkContext.addPyFile") def test_spark_addPyFile(mock_add_pyfile): @task( task_config=Spark( @@ -151,8 +180,11 @@ def my_spark(a: int) -> int: ctx = context_manager.FlyteContextManager.current_context() with context_manager.FlyteContextManager.with_context( - ctx.with_execution_state( - ctx.new_execution_state().with_params(mode=ExecutionState.Mode.TASK_EXECUTION)).with_serialization_settings(serialization_settings) + ctx.with_execution_state( + ctx.new_execution_state().with_params( + mode=ExecutionState.Mode.TASK_EXECUTION + ) + ).with_serialization_settings(serialization_settings) ) as new_ctx: my_spark.pre_execute(new_ctx.user_space_params) mock_add_pyfile.assert_called_once() @@ -173,7 +205,10 @@ def spark1(partitions: int) -> float: print("Starting Spark with Partitions: {}".format(partitions)) return 1.0 - assert spark1.container_image.base_image == f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}" + assert ( + spark1.container_image.base_image + == f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}" + ) assert spark1._default_executor_path == "/usr/bin/python3" assert spark1._default_applications_path == "local:///usr/local/bin/entrypoint.py" @@ -185,6 +220,106 @@ def spark2(partitions: int) -> float: print("Starting Spark with Partitions: {}".format(partitions)) return 1.0 - assert spark2.container_image.base_image == f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}" + assert ( + spark2.container_image.base_image + == f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}" + ) assert spark2._default_executor_path == "/usr/bin/python3" assert spark2._default_applications_path == "local:///usr/local/bin/entrypoint.py" + + +def test_spark_driver_executor_podSpec(): + custom_image = ImageSpec( + registry="ghcr.io/flyteorg", + packages=["flytekitplugins-spark"], + ) + + driver_pod_spec = V1PodSpec( + containers=[ + V1Container( + name="primary", + image="ghcr.io/flyteorg", + command=["echo"], + args=["wow"], + env=[V1EnvVar(name="x/custom-driver", value="driver")], + ), + ], + tolerations=[ + V1Toleration( + key="x/custom-driver", + operator="Equal", + value="foo-driver", + effect="NoSchedule", + ), + ], + ) + + executor_pod_spec = V1PodSpec( + containers=[ + V1Container( + name="primary", + image="ghcr.io/flyteorg", + command=["echo"], + args=["wow"], + env=[V1EnvVar(name="x/custom-executor", value="executor")], + ), + ], + tolerations=[ + V1Toleration( + key="x/custom-executor", + operator="Equal", + value="foo-executor", + effect="NoSchedule", + ), + ], + ) + + @task( + task_config=Spark( + spark_conf={"spark.driver.memory": "1000M"}, + driver_pod=K8sPod(pod_spec=driver_pod_spec.to_dict()), + executor_pod=K8sPod(pod_spec=executor_pod_spec.to_dict()), + ), + # limits=Resources(cpu="50m", mem="2000M"), + container_image=custom_image, + pod_template=PodTemplate(primary_container_name="primary"), + ) + def my_spark(a: str) -> int: + session = flytekit.current_context().spark_session + assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local" + return 10 + + assert my_spark.task_config is not None + assert my_spark.task_config.spark_conf == {"spark.driver.memory": "1000M"} + + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={"FOO": "baz"}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + + retrieved_settings = my_spark.get_custom(settings) + assert retrieved_settings["sparkConf"] == {"spark.driver.memory": "1000M"} + assert retrieved_settings["executorPath"] == "/usr/bin/python3" + assert ( + retrieved_settings["mainApplicationFile"] + == "local:///usr/local/bin/entrypoint.py" + ) + assert retrieved_settings["driverPod"] == MessageToDict(K8sPod(pod_spec=driver_pod_spec.to_dict()).to_flyte_idl()) + assert retrieved_settings["executorPod"] == MessageToDict(K8sPod(pod_spec=executor_pod_spec.to_dict()).to_flyte_idl()) + + pb = ExecutionParameters.new_builder() + pb.working_dir = "/tmp" + pb.execution_id = "ex:local:local:local" + p = pb.build() + new_p = my_spark.pre_execute(p) + assert new_p is not None + assert new_p.has_attr("SPARK_SESSION") + + assert my_spark.sess is not None + configs = my_spark.sess.sparkContext.getConf().getAll() + assert ("spark.driver.memory", "1000M") in configs + assert ("spark.app.name", "FlyteSpark: ex:local:local:local") in configs