From 8cc081ddea5936e3007041babc5d4e1295589397 Mon Sep 17 00:00:00 2001 From: machichima Date: Fri, 20 Dec 2024 23:09:43 +0800 Subject: [PATCH 1/7] feat: add driver/executor pod in Spark Signed-off-by: machichima --- .../flytekitplugins/spark/models.py | 27 +++ .../flytekitplugins/spark/task.py | 7 + .../flytekit-spark/tests/test_spark_task.py | 155 ++++++++++++++++-- 3 files changed, 179 insertions(+), 10 deletions(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/models.py b/plugins/flytekit-spark/flytekitplugins/spark/models.py index e74a9fbe3f..1f185609f4 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/models.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/models.py @@ -7,6 +7,7 @@ from flytekit.exceptions import user as _user_exceptions from flytekit.models import common as _common +from flytekit.models.task import K8sPod class SparkType(enum.Enum): @@ -27,6 +28,8 @@ def __init__( executor_path: str, databricks_conf: Optional[Dict[str, Dict[str, Dict]]] = None, databricks_instance: Optional[str] = None, + driver_pod: Optional[K8sPod] = None, + executor_pod: Optional[K8sPod] = None, ): """ This defines a SparkJob target. It will execute the appropriate SparkJob. @@ -47,6 +50,8 @@ def __init__( databricks_conf = {} self._databricks_conf = databricks_conf self._databricks_instance = databricks_instance + self._driver_pod = driver_pod + self._executor_pod = executor_pod def with_overrides( self, @@ -71,6 +76,8 @@ def with_overrides( hadoop_conf=new_hadoop_conf, databricks_conf=new_databricks_conf, databricks_instance=self.databricks_instance, + driver_pod=self.driver_pod, + executor_pod=self.executor_pod, executor_path=self.executor_path, ) @@ -139,6 +146,22 @@ def databricks_instance(self) -> str: """ return self._databricks_instance + @property + def driver_pod(self) -> K8sPod: + """ + Additional pod specs for driver pod. + :rtype: K8sPod + """ + return self._driver_pod + + @property + def executor_pod(self) -> K8sPod: + """ + Additional pod specs for the worker node pods. + :rtype: K8sPod + """ + return self._executor_pod + def to_flyte_idl(self): """ :rtype: flyteidl.plugins.spark_pb2.SparkJob @@ -167,6 +190,8 @@ def to_flyte_idl(self): hadoopConf=self.hadoop_conf, databricksConf=databricks_conf, databricksInstance=self.databricks_instance, + driverPod=self.driver_pod.to_flyte_idl() if self.driver_pod else None, + executorPod=self.executor_pod.to_flyte_idl() if self.executor_pod else None, ) @classmethod @@ -193,4 +218,6 @@ def from_flyte_idl(cls, pb2_object): executor_path=pb2_object.executorPath, databricks_conf=json_format.MessageToDict(pb2_object.databricksConf), databricks_instance=pb2_object.databricksInstance, + driver_pod=pb2_object.driverPod, + executor_pod=pb2_object.executorPod, ) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 7d2f718617..5d331bcbf8 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -13,6 +13,7 @@ from flytekit.extend import ExecutionState, TaskPlugins from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin from flytekit.image_spec import ImageSpec +from flytekit.models.task import K8sPod from .models import SparkJob, SparkType @@ -31,12 +32,16 @@ class Spark(object): hadoop_conf: Dictionary of hadoop conf. The variables should match a typical hadoop configuration for spark executor_path: Python binary executable to use for PySpark in driver and executor. applications_path: MainFile is the path to a bundled JAR, Python, or R file of the application to execute. + driver_pod: K8sPod for Spark driver pod + executor_pod: K8sPod for Spark executor pod """ spark_conf: Optional[Dict[str, str]] = None hadoop_conf: Optional[Dict[str, str]] = None executor_path: Optional[str] = None applications_path: Optional[str] = None + driver_pod: Optional[K8sPod] = None + executor_pod: Optional[K8sPod] = None def __post_init__(self): if self.spark_conf is None: @@ -168,6 +173,8 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: executor_path=self._default_executor_path or settings.python_interpreter, main_class="", spark_type=SparkType.PYTHON, + driver_pod=self.task_config.driver_pod, + executor_pod=self.task_config.executor_pod, ) if isinstance(self.task_config, (Databricks, DatabricksV2)): cfg = cast(DatabricksV2, self.task_config) diff --git a/plugins/flytekit-spark/tests/test_spark_task.py b/plugins/flytekit-spark/tests/test_spark_task.py index 7ce5f14ebf..77f823e363 100644 --- a/plugins/flytekit-spark/tests/test_spark_task.py +++ b/plugins/flytekit-spark/tests/test_spark_task.py @@ -5,15 +5,39 @@ import pyspark import pytest +from google.protobuf.json_format import MessageToDict +from flytekit import PodTemplate from flytekit.core import context_manager from flytekitplugins.spark import Spark from flytekitplugins.spark.task import Databricks, new_spark_session from pyspark.sql import SparkSession import flytekit -from flytekit import StructuredDataset, StructuredDatasetTransformerEngine, task, ImageSpec -from flytekit.configuration import Image, ImageConfig, SerializationSettings, FastSerializationSettings, DefaultImages -from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager, ExecutionState +from flytekit import ( + StructuredDataset, + StructuredDatasetTransformerEngine, + task, + ImageSpec, +) +from flytekit.configuration import ( + Image, + ImageConfig, + SerializationSettings, + FastSerializationSettings, + DefaultImages, +) +from flytekit.core.context_manager import ( + ExecutionParameters, + FlyteContextManager, + ExecutionState, +) +from flytekit.models.task import K8sPod +from kubernetes.client.models import ( + V1Container, + V1PodSpec, + V1Toleration, + V1EnvVar, +) @pytest.fixture(scope="function") @@ -68,7 +92,10 @@ def my_spark(a: str) -> int: retrieved_settings = my_spark.get_custom(settings) assert retrieved_settings["sparkConf"] == {"spark": "1"} assert retrieved_settings["executorPath"] == "/usr/bin/python3" - assert retrieved_settings["mainApplicationFile"] == "local:///usr/local/bin/entrypoint.py" + assert ( + retrieved_settings["mainApplicationFile"] + == "local:///usr/local/bin/entrypoint.py" + ) pb = ExecutionParameters.new_builder() pb.working_dir = "/tmp" @@ -121,11 +148,13 @@ def test_to_html(): df = spark.createDataFrame([("Bob", 10)], ["name", "age"]) sd = StructuredDataset(dataframe=df) tf = StructuredDatasetTransformerEngine() - output = tf.to_html(FlyteContextManager.current_context(), sd, pyspark.sql.DataFrame) + output = tf.to_html( + FlyteContextManager.current_context(), sd, pyspark.sql.DataFrame + ) assert pd.DataFrame(df.schema, columns=["StructField"]).to_html() == output -@mock.patch('pyspark.context.SparkContext.addPyFile') +@mock.patch("pyspark.context.SparkContext.addPyFile") def test_spark_addPyFile(mock_add_pyfile): @task( task_config=Spark( @@ -151,8 +180,11 @@ def my_spark(a: int) -> int: ctx = context_manager.FlyteContextManager.current_context() with context_manager.FlyteContextManager.with_context( - ctx.with_execution_state( - ctx.new_execution_state().with_params(mode=ExecutionState.Mode.TASK_EXECUTION)).with_serialization_settings(serialization_settings) + ctx.with_execution_state( + ctx.new_execution_state().with_params( + mode=ExecutionState.Mode.TASK_EXECUTION + ) + ).with_serialization_settings(serialization_settings) ) as new_ctx: my_spark.pre_execute(new_ctx.user_space_params) mock_add_pyfile.assert_called_once() @@ -173,7 +205,10 @@ def spark1(partitions: int) -> float: print("Starting Spark with Partitions: {}".format(partitions)) return 1.0 - assert spark1.container_image.base_image == f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}" + assert ( + spark1.container_image.base_image + == f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}" + ) assert spark1._default_executor_path == "/usr/bin/python3" assert spark1._default_applications_path == "local:///usr/local/bin/entrypoint.py" @@ -185,6 +220,106 @@ def spark2(partitions: int) -> float: print("Starting Spark with Partitions: {}".format(partitions)) return 1.0 - assert spark2.container_image.base_image == f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}" + assert ( + spark2.container_image.base_image + == f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}" + ) assert spark2._default_executor_path == "/usr/bin/python3" assert spark2._default_applications_path == "local:///usr/local/bin/entrypoint.py" + + +def test_spark_driver_executor_podSpec(): + custom_image = ImageSpec( + registry="ghcr.io/flyteorg", + packages=["flytekitplugins-spark"], + ) + + driver_pod_spec = V1PodSpec( + containers=[ + V1Container( + name="primary", + image="ghcr.io/flyteorg", + command=["echo"], + args=["wow"], + env=[V1EnvVar(name="x/custom-driver", value="driver")], + ), + ], + tolerations=[ + V1Toleration( + key="x/custom-driver", + operator="Equal", + value="foo-driver", + effect="NoSchedule", + ), + ], + ) + + executor_pod_spec = V1PodSpec( + containers=[ + V1Container( + name="primary", + image="ghcr.io/flyteorg", + command=["echo"], + args=["wow"], + env=[V1EnvVar(name="x/custom-executor", value="executor")], + ), + ], + tolerations=[ + V1Toleration( + key="x/custom-executor", + operator="Equal", + value="foo-executor", + effect="NoSchedule", + ), + ], + ) + + @task( + task_config=Spark( + spark_conf={"spark.driver.memory": "1000M"}, + driver_pod=K8sPod(pod_spec=driver_pod_spec.to_dict()), + executor_pod=K8sPod(pod_spec=executor_pod_spec.to_dict()), + ), + # limits=Resources(cpu="50m", mem="2000M"), + container_image=custom_image, + pod_template=PodTemplate(primary_container_name="primary"), + ) + def my_spark(a: str) -> int: + session = flytekit.current_context().spark_session + assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local" + return 10 + + assert my_spark.task_config is not None + assert my_spark.task_config.spark_conf == {"spark.driver.memory": "1000M"} + + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={"FOO": "baz"}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + + retrieved_settings = my_spark.get_custom(settings) + assert retrieved_settings["sparkConf"] == {"spark.driver.memory": "1000M"} + assert retrieved_settings["executorPath"] == "/usr/bin/python3" + assert ( + retrieved_settings["mainApplicationFile"] + == "local:///usr/local/bin/entrypoint.py" + ) + assert retrieved_settings["driverPod"] == MessageToDict(K8sPod(pod_spec=driver_pod_spec.to_dict()).to_flyte_idl()) + assert retrieved_settings["executorPod"] == MessageToDict(K8sPod(pod_spec=executor_pod_spec.to_dict()).to_flyte_idl()) + + pb = ExecutionParameters.new_builder() + pb.working_dir = "/tmp" + pb.execution_id = "ex:local:local:local" + p = pb.build() + new_p = my_spark.pre_execute(p) + assert new_p is not None + assert new_p.has_attr("SPARK_SESSION") + + assert my_spark.sess is not None + configs = my_spark.sess.sparkContext.getConf().getAll() + assert ("spark.driver.memory", "1000M") in configs + assert ("spark.app.name", "FlyteSpark: ex:local:local:local") in configs From 779339827145c4bf80f65c65be65ea1198c9ed2d Mon Sep 17 00:00:00 2001 From: machichima Date: Sat, 18 Jan 2025 21:03:54 +0800 Subject: [PATCH 2/7] feat: convert driver/exec podTemplate to k8spod Signed-off-by: machichima --- .../flytekitplugins/spark/task.py | 28 +++++++++++++++---- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 5d331bcbf8..3f1cbefb3b 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -10,10 +10,12 @@ from flytekit import FlyteContextManager, PythonFunctionTask, lazy_module, logger from flytekit.configuration import DefaultImages, SerializationSettings from flytekit.core.context_manager import ExecutionParameters +from flytekit.core.utils import _get_container_definition, _serialize_pod_spec, timeit +from flytekit.core.pod_template import PodTemplate from flytekit.extend import ExecutionState, TaskPlugins from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin from flytekit.image_spec import ImageSpec -from flytekit.models.task import K8sPod +from flytekit.models.task import K8sPod, K8sObjectMetadata from .models import SparkJob, SparkType @@ -40,8 +42,8 @@ class Spark(object): hadoop_conf: Optional[Dict[str, str]] = None executor_path: Optional[str] = None applications_path: Optional[str] = None - driver_pod: Optional[K8sPod] = None - executor_pod: Optional[K8sPod] = None + driver_pod: Optional[PodTemplate] = None + executor_pod: Optional[PodTemplate] = None def __post_init__(self): if self.spark_conf is None: @@ -173,8 +175,8 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: executor_path=self._default_executor_path or settings.python_interpreter, main_class="", spark_type=SparkType.PYTHON, - driver_pod=self.task_config.driver_pod, - executor_pod=self.task_config.executor_pod, + driver_pod=self.to_k8s_pod(self.task_config.driver_pod, settings), + executor_pod=self.to_k8s_pod(self.task_config.executor_pod, settings), ) if isinstance(self.task_config, (Databricks, DatabricksV2)): cfg = cast(DatabricksV2, self.task_config) @@ -183,6 +185,22 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: return MessageToDict(job.to_flyte_idl()) + def to_k8s_pod(self, pod_template: PodTemplate | None, settings: SerializationSettings) -> K8sPod | None: + """ + Convert the podTemplate to K8sPod + """ + if pod_template is None: + return None + + return K8sPod( + pod_spec=_serialize_pod_spec(pod_template, self._get_container(settings), settings), + metadata=K8sObjectMetadata( + labels=pod_template.labels, + annotations=pod_template.annotations, + ), + ) + + def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: import pyspark as _pyspark From b21d1e3a89249a5d6343fc3a61dd5c6b45946eba Mon Sep 17 00:00:00 2001 From: machichima Date: Sun, 19 Jan 2025 15:52:28 +0800 Subject: [PATCH 3/7] feat: pyspark take the primary container only Take the container with name set in driver/executor podTempalte primary_container_name Signed-off-by: machichima --- flytekit/core/utils.py | 6 ++++++ plugins/flytekit-spark/flytekitplugins/spark/task.py | 9 ++++----- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/flytekit/core/utils.py b/flytekit/core/utils.py index 9f1967d2f9..2d6ee73187 100644 --- a/flytekit/core/utils.py +++ b/flytekit/core/utils.py @@ -139,6 +139,7 @@ def _serialize_pod_spec( pod_template: "PodTemplate", primary_container: "task_models.Container", settings: SerializationSettings, + primary_only: bool = False, ) -> Dict[str, Any]: # import here to avoid circular import from kubernetes.client import ApiClient, V1PodSpec @@ -169,6 +170,7 @@ def _serialize_pod_spec( # with the values given to ContainerTask. # The attributes include: image, command, args, resource, and env (env is unioned) + is_primary = False if container.name == cast(PodTemplate, pod_template).primary_container_name: if container.image is None: # Copy the image from primary_container only if the image is not specified in the pod spec. @@ -192,9 +194,13 @@ def _serialize_pod_spec( container.env = [V1EnvVar(name=key, value=val) for key, val in primary_container.env.items()] + ( container.env or [] ) + is_primary = True else: container.image = get_registerable_container_image(container.image, settings.image_config) + if primary_only and not is_primary: + continue + final_containers.append(container) cast(V1PodSpec, pod_template.pod_spec).containers = final_containers diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 3f1cbefb3b..4eef297181 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -10,12 +10,12 @@ from flytekit import FlyteContextManager, PythonFunctionTask, lazy_module, logger from flytekit.configuration import DefaultImages, SerializationSettings from flytekit.core.context_manager import ExecutionParameters -from flytekit.core.utils import _get_container_definition, _serialize_pod_spec, timeit from flytekit.core.pod_template import PodTemplate +from flytekit.core.utils import _serialize_pod_spec from flytekit.extend import ExecutionState, TaskPlugins from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin from flytekit.image_spec import ImageSpec -from flytekit.models.task import K8sPod, K8sObjectMetadata +from flytekit.models.task import K8sObjectMetadata, K8sPod from .models import SparkJob, SparkType @@ -191,16 +191,15 @@ def to_k8s_pod(self, pod_template: PodTemplate | None, settings: SerializationSe """ if pod_template is None: return None - + return K8sPod( - pod_spec=_serialize_pod_spec(pod_template, self._get_container(settings), settings), + pod_spec=_serialize_pod_spec(pod_template, self._get_container(settings), settings, primary_only=True), metadata=K8sObjectMetadata( labels=pod_template.labels, annotations=pod_template.annotations, ), ) - def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: import pyspark as _pyspark From 00d0c3a70a11d6315e4f26d179f993b5aee1e15b Mon Sep 17 00:00:00 2001 From: machichima Date: Sun, 19 Jan 2025 18:33:54 +0800 Subject: [PATCH 4/7] feat: exclude cmd/args from task podTemplate Exclude those in the podTemplate of spark driver/executor pod Signed-off-by: machichima --- flytekit/core/utils.py | 11 +++++++---- plugins/flytekit-spark/flytekitplugins/spark/task.py | 4 +++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/flytekit/core/utils.py b/flytekit/core/utils.py index 2d6ee73187..f2e6e24bbb 100644 --- a/flytekit/core/utils.py +++ b/flytekit/core/utils.py @@ -139,7 +139,7 @@ def _serialize_pod_spec( pod_template: "PodTemplate", primary_container: "task_models.Container", settings: SerializationSettings, - primary_only: bool = False, + task_type: str = "", ) -> Dict[str, Any]: # import here to avoid circular import from kubernetes.client import ApiClient, V1PodSpec @@ -178,8 +178,10 @@ def _serialize_pod_spec( else: container.image = get_registerable_container_image(container.image, settings.image_config) - container.command = primary_container.command - container.args = primary_container.args + if task_type != "spark": + # for spark driver/executor, do not use the command and args from task podTemplate + container.command = primary_container.command + container.args = primary_container.args limits, requests = {}, {} for resource in primary_container.resources.limits: @@ -198,7 +200,8 @@ def _serialize_pod_spec( else: container.image = get_registerable_container_image(container.image, settings.image_config) - if primary_only and not is_primary: + if task_type == "spark" and not is_primary: + # for spark driver/executor, only take the primary container continue final_containers.append(container) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 4eef297181..7b11cec56b 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -193,7 +193,9 @@ def to_k8s_pod(self, pod_template: PodTemplate | None, settings: SerializationSe return None return K8sPod( - pod_spec=_serialize_pod_spec(pod_template, self._get_container(settings), settings, primary_only=True), + pod_spec=_serialize_pod_spec( + pod_template, self._get_container(settings), settings, task_type=self._SPARK_TASK_TYPE + ), metadata=K8sObjectMetadata( labels=pod_template.labels, annotations=pod_template.annotations, From 1b7c1c9cef0e39fde37660ad1663475f49fdcdf0 Mon Sep 17 00:00:00 2001 From: machichima Date: Sun, 19 Jan 2025 18:36:19 +0800 Subject: [PATCH 5/7] test: for custom driver/executor podTemplate Signed-off-by: machichima --- .../flytekit-spark/tests/test_spark_task.py | 127 ++++++++++++++++-- 1 file changed, 117 insertions(+), 10 deletions(-) diff --git a/plugins/flytekit-spark/tests/test_spark_task.py b/plugins/flytekit-spark/tests/test_spark_task.py index 77f823e363..a6dd77ee1f 100644 --- a/plugins/flytekit-spark/tests/test_spark_task.py +++ b/plugins/flytekit-spark/tests/test_spark_task.py @@ -6,7 +6,7 @@ import pytest from google.protobuf.json_format import MessageToDict -from flytekit import PodTemplate +from flytekit import PodTemplate from flytekit.core import context_manager from flytekitplugins.spark import Spark from flytekitplugins.spark.task import Databricks, new_spark_session @@ -31,7 +31,7 @@ FlyteContextManager, ExecutionState, ) -from flytekit.models.task import K8sPod +from flytekit.models.task import K8sObjectMetadata, K8sPod from kubernetes.client.models import ( V1Container, V1PodSpec, @@ -228,6 +228,18 @@ def spark2(partitions: int) -> float: assert spark2._default_applications_path == "local:///usr/local/bin/entrypoint.py" +def clean_dict(d): + """ + Recursively remove keys with None values from dictionaries and lists. + """ + if isinstance(d, dict): + return {k: clean_dict(v) for k, v in d.items() if v is not None} + elif isinstance(d, list): + return [clean_dict(item) for item in d if item is not None] + else: + return d + + def test_spark_driver_executor_podSpec(): custom_image = ImageSpec( registry="ghcr.io/flyteorg", @@ -237,12 +249,17 @@ def test_spark_driver_executor_podSpec(): driver_pod_spec = V1PodSpec( containers=[ V1Container( - name="primary", + name="driver-primary", image="ghcr.io/flyteorg", command=["echo"], args=["wow"], env=[V1EnvVar(name="x/custom-driver", value="driver")], ), + V1Container( + name="not-primary", + command=["echo"], + args=["not_primary"], + ), ], tolerations=[ V1Toleration( @@ -257,12 +274,77 @@ def test_spark_driver_executor_podSpec(): executor_pod_spec = V1PodSpec( containers=[ V1Container( - name="primary", + name="executor-primary", image="ghcr.io/flyteorg", command=["echo"], args=["wow"], env=[V1EnvVar(name="x/custom-executor", value="executor")], ), + V1Container( + name="not-primary", + command=["echo"], + args=["not_primary"], + ), + ], + tolerations=[ + V1Toleration( + key="x/custom-executor", + operator="Equal", + value="foo-executor", + effect="NoSchedule", + ), + ], + ) + + driver_pod = PodTemplate( + labels={"lKeyA_d": "lValA", "lKeyB_d": "lValB"}, + annotations={"aKeyA_d": "aValA", "aKeyB_d": "aValB"}, + primary_container_name="driver-primary", + pod_spec=driver_pod_spec, + ) + + executor_pod = PodTemplate( + labels={"lKeyA_e": "lValA", "lKeyB_e": "lValB"}, + annotations={"aKeyA_e": "aValA", "aKeyB_e": "aValB"}, + primary_container_name="executor-primary", + pod_spec=executor_pod_spec, + ) + + expect_driver_pod_spec = V1PodSpec( + containers=[ + V1Container( + name="driver-primary", + image="ghcr.io/flyteorg", + command=["echo"], + args=["wow"], + env=[ + V1EnvVar(name="FOO", value="baz"), + V1EnvVar(name="x/custom-driver", value="driver"), + ], + ), + ], + tolerations=[ + V1Toleration( + key="x/custom-driver", + operator="Equal", + value="foo-driver", + effect="NoSchedule", + ), + ], + ) + + expect_executor_pod_spec = V1PodSpec( + containers=[ + V1Container( + name="executor-primary", + image="ghcr.io/flyteorg", + command=["echo"], + args=["wow"], + env=[ + V1EnvVar(name="FOO", value="baz"), + V1EnvVar(name="x/custom-executor", value="executor"), + ], + ), ], tolerations=[ V1Toleration( @@ -274,13 +356,34 @@ def test_spark_driver_executor_podSpec(): ], ) + driver_pod_spec_dict_remove_None = expect_driver_pod_spec.to_dict() + executor_pod_spec_dict_remove_None = expect_executor_pod_spec.to_dict() + + driver_pod_spec_dict_remove_None = clean_dict(driver_pod_spec_dict_remove_None) + executor_pod_spec_dict_remove_None = clean_dict(executor_pod_spec_dict_remove_None) + + target_driver_k8sPod = K8sPod( + metadata=K8sObjectMetadata( + labels={"lKeyA_d": "lValA", "lKeyB_d": "lValB"}, + annotations={"aKeyA_d": "aValA", "aKeyB_d": "aValB"}, + ), + pod_spec=driver_pod_spec_dict_remove_None, # type: ignore + ) + + target_executor_k8sPod = K8sPod( + metadata=K8sObjectMetadata( + labels={"lKeyA_e": "lValA", "lKeyB_e": "lValB"}, + annotations={"aKeyA_e": "aValA", "aKeyB_e": "aValB"}, + ), + pod_spec=executor_pod_spec_dict_remove_None, # type: ignore + ) + @task( task_config=Spark( spark_conf={"spark.driver.memory": "1000M"}, - driver_pod=K8sPod(pod_spec=driver_pod_spec.to_dict()), - executor_pod=K8sPod(pod_spec=executor_pod_spec.to_dict()), + driver_pod=driver_pod, + executor_pod=executor_pod, ), - # limits=Resources(cpu="50m", mem="2000M"), container_image=custom_image, pod_template=PodTemplate(primary_container_name="primary"), ) @@ -291,8 +394,8 @@ def my_spark(a: str) -> int: assert my_spark.task_config is not None assert my_spark.task_config.spark_conf == {"spark.driver.memory": "1000M"} - default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( project="project", domain="domain", @@ -308,8 +411,12 @@ def my_spark(a: str) -> int: retrieved_settings["mainApplicationFile"] == "local:///usr/local/bin/entrypoint.py" ) - assert retrieved_settings["driverPod"] == MessageToDict(K8sPod(pod_spec=driver_pod_spec.to_dict()).to_flyte_idl()) - assert retrieved_settings["executorPod"] == MessageToDict(K8sPod(pod_spec=executor_pod_spec.to_dict()).to_flyte_idl()) + assert retrieved_settings["driverPod"] == MessageToDict( + target_driver_k8sPod.to_flyte_idl() + ) + assert retrieved_settings["executorPod"] == MessageToDict( + target_executor_k8sPod.to_flyte_idl() + ) pb = ExecutionParameters.new_builder() pb.working_dir = "/tmp" From 32f6aa97e7a7fbd854975715a8731580e43fadc4 Mon Sep 17 00:00:00 2001 From: machichima Date: Sun, 19 Jan 2025 21:14:12 +0800 Subject: [PATCH 6/7] fix: test and type Signed-off-by: machichima --- .../flytekitplugins/spark/task.py | 6 +++--- .../flytekit-spark/tests/test_spark_task.py | 21 +++++++++++++++---- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 7b11cec56b..942832ba8b 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -175,8 +175,8 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: executor_path=self._default_executor_path or settings.python_interpreter, main_class="", spark_type=SparkType.PYTHON, - driver_pod=self.to_k8s_pod(self.task_config.driver_pod, settings), - executor_pod=self.to_k8s_pod(self.task_config.executor_pod, settings), + driver_pod=self.to_k8s_pod(settings, self.task_config.driver_pod), + executor_pod=self.to_k8s_pod(settings, self.task_config.executor_pod), ) if isinstance(self.task_config, (Databricks, DatabricksV2)): cfg = cast(DatabricksV2, self.task_config) @@ -185,7 +185,7 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: return MessageToDict(job.to_flyte_idl()) - def to_k8s_pod(self, pod_template: PodTemplate | None, settings: SerializationSettings) -> K8sPod | None: + def to_k8s_pod(self, settings: SerializationSettings, pod_template: Optional[PodTemplate] = None) -> Optional[K8sPod]: """ Convert the podTemplate to K8sPod """ diff --git a/plugins/flytekit-spark/tests/test_spark_task.py b/plugins/flytekit-spark/tests/test_spark_task.py index a6dd77ee1f..ebcba98935 100644 --- a/plugins/flytekit-spark/tests/test_spark_task.py +++ b/plugins/flytekit-spark/tests/test_spark_task.py @@ -40,12 +40,23 @@ ) +# @pytest.fixture(scope="function") +# def reset_spark_session() -> None: +# pyspark.sql.SparkSession.builder.getOrCreate().stop() +# yield +# pyspark.sql.SparkSession.builder.getOrCreate().stop() + + + @pytest.fixture(scope="function") def reset_spark_session() -> None: - pyspark.sql.SparkSession.builder.getOrCreate().stop() + if SparkSession._instantiatedSession: + SparkSession.builder.getOrCreate().stop() + SparkSession._instantiatedSession = None yield - pyspark.sql.SparkSession.builder.getOrCreate().stop() - + if SparkSession._instantiatedSession: + SparkSession.builder.getOrCreate().stop() + SparkSession._instantiatedSession = None def test_spark_task(reset_spark_session): databricks_conf = { @@ -240,7 +251,7 @@ def clean_dict(d): return d -def test_spark_driver_executor_podSpec(): +def test_spark_driver_executor_podSpec(reset_spark_session): custom_image = ImageSpec( registry="ghcr.io/flyteorg", packages=["flytekitplugins-spark"], @@ -389,6 +400,8 @@ def test_spark_driver_executor_podSpec(): ) def my_spark(a: str) -> int: session = flytekit.current_context().spark_session + configs = session.sparkContext.getConf().getAll() + assert ("spark.driver.memory", "1000M") in configs assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local" return 10 From 167a390e1f99e9f96092134a58978781e2a4e25e Mon Sep 17 00:00:00 2001 From: machichima Date: Sun, 19 Jan 2025 21:42:43 +0800 Subject: [PATCH 7/7] fix: lint and docs Signed-off-by: machichima --- plugins/flytekit-spark/flytekitplugins/spark/task.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 942832ba8b..40ff840ac9 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -34,8 +34,8 @@ class Spark(object): hadoop_conf: Dictionary of hadoop conf. The variables should match a typical hadoop configuration for spark executor_path: Python binary executable to use for PySpark in driver and executor. applications_path: MainFile is the path to a bundled JAR, Python, or R file of the application to execute. - driver_pod: K8sPod for Spark driver pod - executor_pod: K8sPod for Spark executor pod + driver_pod: PodTemplate for Spark driver pod + executor_pod: PodTemplate for Spark executor pod """ spark_conf: Optional[Dict[str, str]] = None @@ -185,7 +185,9 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: return MessageToDict(job.to_flyte_idl()) - def to_k8s_pod(self, settings: SerializationSettings, pod_template: Optional[PodTemplate] = None) -> Optional[K8sPod]: + def to_k8s_pod( + self, settings: SerializationSettings, pod_template: Optional[PodTemplate] = None + ) -> Optional[K8sPod]: """ Convert the podTemplate to K8sPod """