Skip to content

Commit

Permalink
feat: convert driver/exec podTemplate to k8spod
Browse files Browse the repository at this point in the history
Signed-off-by: machichima <[email protected]>
  • Loading branch information
machichima committed Jan 18, 2025
1 parent 8cc081d commit 7793398
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions plugins/flytekit-spark/flytekitplugins/spark/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
from flytekit import FlyteContextManager, PythonFunctionTask, lazy_module, logger
from flytekit.configuration import DefaultImages, SerializationSettings
from flytekit.core.context_manager import ExecutionParameters
from flytekit.core.utils import _get_container_definition, _serialize_pod_spec, timeit
from flytekit.core.pod_template import PodTemplate
from flytekit.extend import ExecutionState, TaskPlugins
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
from flytekit.image_spec import ImageSpec
from flytekit.models.task import K8sPod
from flytekit.models.task import K8sPod, K8sObjectMetadata

from .models import SparkJob, SparkType

Expand All @@ -40,8 +42,8 @@ class Spark(object):
hadoop_conf: Optional[Dict[str, str]] = None
executor_path: Optional[str] = None
applications_path: Optional[str] = None
driver_pod: Optional[K8sPod] = None
executor_pod: Optional[K8sPod] = None
driver_pod: Optional[PodTemplate] = None
executor_pod: Optional[PodTemplate] = None

def __post_init__(self):
if self.spark_conf is None:
Expand Down Expand Up @@ -173,8 +175,8 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
executor_path=self._default_executor_path or settings.python_interpreter,
main_class="",
spark_type=SparkType.PYTHON,
driver_pod=self.task_config.driver_pod,
executor_pod=self.task_config.executor_pod,
driver_pod=self.to_k8s_pod(self.task_config.driver_pod, settings),
executor_pod=self.to_k8s_pod(self.task_config.executor_pod, settings),
)
if isinstance(self.task_config, (Databricks, DatabricksV2)):
cfg = cast(DatabricksV2, self.task_config)
Expand All @@ -183,6 +185,22 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:

return MessageToDict(job.to_flyte_idl())

def to_k8s_pod(self, pod_template: PodTemplate | None, settings: SerializationSettings) -> K8sPod | None:
"""
Convert the podTemplate to K8sPod
"""
if pod_template is None:
return None

return K8sPod(
pod_spec=_serialize_pod_spec(pod_template, self._get_container(settings), settings),
metadata=K8sObjectMetadata(
labels=pod_template.labels,
annotations=pod_template.annotations,
),
)


def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
import pyspark as _pyspark

Expand Down

0 comments on commit 7793398

Please sign in to comment.