Skip to content

Commit

Permalink
feat: add driver/executor pod in Spark
Browse files Browse the repository at this point in the history
Signed-off-by: machichima <[email protected]>
  • Loading branch information
machichima committed Dec 20, 2024
1 parent c8f98c5 commit af03383
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 10 deletions.
27 changes: 27 additions & 0 deletions plugins/flytekit-spark/flytekitplugins/spark/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from flytekit.exceptions import user as _user_exceptions
from flytekit.models import common as _common
from flytekit.models.task import K8sPod


class SparkType(enum.Enum):
Expand All @@ -27,6 +28,8 @@ def __init__(
executor_path: str,
databricks_conf: Optional[Dict[str, Dict[str, Dict]]] = None,
databricks_instance: Optional[str] = None,
driver_pod: Optional[K8sPod] = None,
executor_pod: Optional[K8sPod] = None,
):
"""
This defines a SparkJob target. It will execute the appropriate SparkJob.
Expand All @@ -47,6 +50,8 @@ def __init__(
databricks_conf = {}
self._databricks_conf = databricks_conf
self._databricks_instance = databricks_instance
self._driver_pod = driver_pod
self._executor_pod = executor_pod

def with_overrides(
self,
Expand All @@ -71,6 +76,8 @@ def with_overrides(
hadoop_conf=new_hadoop_conf,
databricks_conf=new_databricks_conf,
databricks_instance=self.databricks_instance,
driver_pod=self.driver_pod,
executor_pod=self.executor_pod,
executor_path=self.executor_path,
)

Expand Down Expand Up @@ -139,6 +146,22 @@ def databricks_instance(self) -> str:
"""
return self._databricks_instance

@property
def driver_pod(self) -> K8sPod:
"""
Additional pod specs for driver pod.
:rtype: K8sPod
"""
return self._driver_pod

@property
def executor_pod(self) -> K8sPod:
"""
Additional pod specs for the worker node pods.
:rtype: K8sPod
"""
return self._executor_pod

def to_flyte_idl(self):
"""
:rtype: flyteidl.plugins.spark_pb2.SparkJob
Expand Down Expand Up @@ -167,6 +190,8 @@ def to_flyte_idl(self):
hadoopConf=self.hadoop_conf,
databricksConf=databricks_conf,
databricksInstance=self.databricks_instance,
driverPod=self.driver_pod.to_flyte_idl() if self.driver_pod else None,
executorPod=self.executor_pod.to_flyte_idl() if self.executor_pod else None,
)

@classmethod
Expand All @@ -193,4 +218,6 @@ def from_flyte_idl(cls, pb2_object):
executor_path=pb2_object.executorPath,
databricks_conf=json_format.MessageToDict(pb2_object.databricksConf),
databricks_instance=pb2_object.databricksInstance,
driver_pod=pb2_object.driverPod,
executor_pod=pb2_object.executorPod,
)
7 changes: 7 additions & 0 deletions plugins/flytekit-spark/flytekitplugins/spark/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from flytekit.extend import ExecutionState, TaskPlugins
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
from flytekit.image_spec import ImageSpec
from flytekit.models.task import K8sPod

from .models import SparkJob, SparkType

Expand All @@ -31,12 +32,16 @@ class Spark(object):
hadoop_conf: Dictionary of hadoop conf. The variables should match a typical hadoop configuration for spark
executor_path: Python binary executable to use for PySpark in driver and executor.
applications_path: MainFile is the path to a bundled JAR, Python, or R file of the application to execute.
driver_pod: K8sPod for Spark driver pod
executor_pod: K8sPod for Spark executor pod
"""

spark_conf: Optional[Dict[str, str]] = None
hadoop_conf: Optional[Dict[str, str]] = None
executor_path: Optional[str] = None
applications_path: Optional[str] = None
driver_pod: Optional[K8sPod] = None
executor_pod: Optional[K8sPod] = None

def __post_init__(self):
if self.spark_conf is None:
Expand Down Expand Up @@ -168,6 +173,8 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
executor_path=self._default_executor_path or settings.python_interpreter,
main_class="",
spark_type=SparkType.PYTHON,
driver_pod=self.task_config.driver_pod,
executor_pod=self.task_config.executor_pod,
)
if isinstance(self.task_config, (Databricks, DatabricksV2)):
cfg = cast(DatabricksV2, self.task_config)
Expand Down
155 changes: 145 additions & 10 deletions plugins/flytekit-spark/tests/test_spark_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,39 @@
import pyspark
import pytest

from google.protobuf.json_format import MessageToDict
from flytekit import PodTemplate
from flytekit.core import context_manager
from flytekitplugins.spark import Spark
from flytekitplugins.spark.task import Databricks, new_spark_session
from pyspark.sql import SparkSession

import flytekit
from flytekit import StructuredDataset, StructuredDatasetTransformerEngine, task, ImageSpec
from flytekit.configuration import Image, ImageConfig, SerializationSettings, FastSerializationSettings, DefaultImages
from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager, ExecutionState
from flytekit import (
StructuredDataset,
StructuredDatasetTransformerEngine,
task,
ImageSpec,
)
from flytekit.configuration import (
Image,
ImageConfig,
SerializationSettings,
FastSerializationSettings,
DefaultImages,
)
from flytekit.core.context_manager import (
ExecutionParameters,
FlyteContextManager,
ExecutionState,
)
from flytekit.models.task import K8sPod
from kubernetes.client.models import (
V1Container,
V1PodSpec,
V1Toleration,
V1EnvVar,
)


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -68,7 +92,10 @@ def my_spark(a: str) -> int:
retrieved_settings = my_spark.get_custom(settings)
assert retrieved_settings["sparkConf"] == {"spark": "1"}
assert retrieved_settings["executorPath"] == "/usr/bin/python3"
assert retrieved_settings["mainApplicationFile"] == "local:///usr/local/bin/entrypoint.py"
assert (
retrieved_settings["mainApplicationFile"]
== "local:///usr/local/bin/entrypoint.py"
)

pb = ExecutionParameters.new_builder()
pb.working_dir = "/tmp"
Expand Down Expand Up @@ -121,11 +148,13 @@ def test_to_html():
df = spark.createDataFrame([("Bob", 10)], ["name", "age"])
sd = StructuredDataset(dataframe=df)
tf = StructuredDatasetTransformerEngine()
output = tf.to_html(FlyteContextManager.current_context(), sd, pyspark.sql.DataFrame)
output = tf.to_html(
FlyteContextManager.current_context(), sd, pyspark.sql.DataFrame
)
assert pd.DataFrame(df.schema, columns=["StructField"]).to_html() == output


@mock.patch('pyspark.context.SparkContext.addPyFile')
@mock.patch("pyspark.context.SparkContext.addPyFile")
def test_spark_addPyFile(mock_add_pyfile):
@task(
task_config=Spark(
Expand All @@ -151,8 +180,11 @@ def my_spark(a: int) -> int:

ctx = context_manager.FlyteContextManager.current_context()
with context_manager.FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.new_execution_state().with_params(mode=ExecutionState.Mode.TASK_EXECUTION)).with_serialization_settings(serialization_settings)
ctx.with_execution_state(
ctx.new_execution_state().with_params(
mode=ExecutionState.Mode.TASK_EXECUTION
)
).with_serialization_settings(serialization_settings)
) as new_ctx:
my_spark.pre_execute(new_ctx.user_space_params)
mock_add_pyfile.assert_called_once()
Expand All @@ -173,7 +205,10 @@ def spark1(partitions: int) -> float:
print("Starting Spark with Partitions: {}".format(partitions))
return 1.0

assert spark1.container_image.base_image == f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}"
assert (
spark1.container_image.base_image
== f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}"
)
assert spark1._default_executor_path == "/usr/bin/python3"
assert spark1._default_applications_path == "local:///usr/local/bin/entrypoint.py"

Expand All @@ -185,6 +220,106 @@ def spark2(partitions: int) -> float:
print("Starting Spark with Partitions: {}".format(partitions))
return 1.0

assert spark2.container_image.base_image == f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}"
assert (
spark2.container_image.base_image
== f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}"
)
assert spark2._default_executor_path == "/usr/bin/python3"
assert spark2._default_applications_path == "local:///usr/local/bin/entrypoint.py"


def test_spark_driver_executor_podSpec():
custom_image = ImageSpec(
registry="ghcr.io/flyteorg",
packages=["flytekitplugins-spark"],
)

driver_pod_spec = V1PodSpec(
containers=[
V1Container(
name="primary",
image="ghcr.io/flyteorg",
command=["echo"],
args=["wow"],
env=[V1EnvVar(name="x/custom-driver", value="driver")],
),
],
tolerations=[
V1Toleration(
key="x/custom-driver",
operator="Equal",
value="foo-driver",
effect="NoSchedule",
),
],
)

executor_pod_spec = V1PodSpec(
containers=[
V1Container(
name="primary",
image="ghcr.io/flyteorg",
command=["echo"],
args=["wow"],
env=[V1EnvVar(name="x/custom-executor", value="executor")],
),
],
tolerations=[
V1Toleration(
key="x/custom-executor",
operator="Equal",
value="foo-executor",
effect="NoSchedule",
),
],
)

@task(
task_config=Spark(
spark_conf={"spark.driver.memory": "1000M"},
driver_pod=K8sPod(pod_spec=driver_pod_spec.to_dict()),
executor_pod=K8sPod(pod_spec=executor_pod_spec.to_dict()),
),
# limits=Resources(cpu="50m", mem="2000M"),
container_image=custom_image,
pod_template=PodTemplate(primary_container_name="primary"),
)
def my_spark(a: str) -> int:
session = flytekit.current_context().spark_session
assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local"
return 10

assert my_spark.task_config is not None
assert my_spark.task_config.spark_conf == {"spark.driver.memory": "1000M"}

default_img = Image(name="default", fqn="test", tag="tag")
settings = SerializationSettings(
project="project",
domain="domain",
version="version",
env={"FOO": "baz"},
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)

retrieved_settings = my_spark.get_custom(settings)
assert retrieved_settings["sparkConf"] == {"spark.driver.memory": "1000M"}
assert retrieved_settings["executorPath"] == "/usr/bin/python3"
assert (
retrieved_settings["mainApplicationFile"]
== "local:///usr/local/bin/entrypoint.py"
)
assert retrieved_settings["driverPod"] == MessageToDict(K8sPod(pod_spec=driver_pod_spec.to_dict()).to_flyte_idl())
assert retrieved_settings["executorPod"] == MessageToDict(K8sPod(pod_spec=executor_pod_spec.to_dict()).to_flyte_idl())

pb = ExecutionParameters.new_builder()
pb.working_dir = "/tmp"
pb.execution_id = "ex:local:local:local"
p = pb.build()
new_p = my_spark.pre_execute(p)
assert new_p is not None
assert new_p.has_attr("SPARK_SESSION")

assert my_spark.sess is not None
configs = my_spark.sess.sparkContext.getConf().getAll()
assert ("spark.driver.memory", "1000M") in configs
assert ("spark.app.name", "FlyteSpark: ex:local:local:local") in configs

0 comments on commit af03383

Please sign in to comment.