diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 5d331bcbf8..3f1cbefb3b 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -10,10 +10,12 @@ from flytekit import FlyteContextManager, PythonFunctionTask, lazy_module, logger from flytekit.configuration import DefaultImages, SerializationSettings from flytekit.core.context_manager import ExecutionParameters +from flytekit.core.utils import _get_container_definition, _serialize_pod_spec, timeit +from flytekit.core.pod_template import PodTemplate from flytekit.extend import ExecutionState, TaskPlugins from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin from flytekit.image_spec import ImageSpec -from flytekit.models.task import K8sPod +from flytekit.models.task import K8sPod, K8sObjectMetadata from .models import SparkJob, SparkType @@ -40,8 +42,8 @@ class Spark(object): hadoop_conf: Optional[Dict[str, str]] = None executor_path: Optional[str] = None applications_path: Optional[str] = None - driver_pod: Optional[K8sPod] = None - executor_pod: Optional[K8sPod] = None + driver_pod: Optional[PodTemplate] = None + executor_pod: Optional[PodTemplate] = None def __post_init__(self): if self.spark_conf is None: @@ -173,8 +175,8 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: executor_path=self._default_executor_path or settings.python_interpreter, main_class="", spark_type=SparkType.PYTHON, - driver_pod=self.task_config.driver_pod, - executor_pod=self.task_config.executor_pod, + driver_pod=self.to_k8s_pod(self.task_config.driver_pod, settings), + executor_pod=self.to_k8s_pod(self.task_config.executor_pod, settings), ) if isinstance(self.task_config, (Databricks, DatabricksV2)): cfg = cast(DatabricksV2, self.task_config) @@ -183,6 +185,22 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: return MessageToDict(job.to_flyte_idl()) + def to_k8s_pod(self, pod_template: PodTemplate | None, settings: SerializationSettings) -> K8sPod | None: + """ + Convert the podTemplate to K8sPod + """ + if pod_template is None: + return None + + return K8sPod( + pod_spec=_serialize_pod_spec(pod_template, self._get_container(settings), settings), + metadata=K8sObjectMetadata( + labels=pod_template.labels, + annotations=pod_template.annotations, + ), + ) + + def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: import pyspark as _pyspark