Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] add driver/executor pod in Spark #3016

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions plugins/flytekit-spark/flytekitplugins/spark/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from flytekit.exceptions import user as _user_exceptions
from flytekit.models import common as _common
from flytekit.models.task import K8sPod


class SparkType(enum.Enum):
Expand All @@ -27,6 +28,8 @@ def __init__(
executor_path: str,
databricks_conf: Optional[Dict[str, Dict[str, Dict]]] = None,
databricks_instance: Optional[str] = None,
driver_pod: Optional[K8sPod] = None,
executor_pod: Optional[K8sPod] = None,
):
"""
This defines a SparkJob target. It will execute the appropriate SparkJob.
Expand All @@ -47,6 +50,8 @@ def __init__(
databricks_conf = {}
self._databricks_conf = databricks_conf
self._databricks_instance = databricks_instance
self._driver_pod = driver_pod
self._executor_pod = executor_pod

def with_overrides(
self,
Expand All @@ -71,6 +76,8 @@ def with_overrides(
hadoop_conf=new_hadoop_conf,
databricks_conf=new_databricks_conf,
databricks_instance=self.databricks_instance,
driver_pod=self.driver_pod,
executor_pod=self.executor_pod,
executor_path=self.executor_path,
)

Expand Down Expand Up @@ -139,6 +146,22 @@ def databricks_instance(self) -> str:
"""
return self._databricks_instance

@property
def driver_pod(self) -> K8sPod:
"""
Additional pod specs for driver pod.
:rtype: K8sPod
"""
return self._driver_pod

@property
def executor_pod(self) -> K8sPod:
"""
Additional pod specs for the worker node pods.
:rtype: K8sPod
"""
return self._executor_pod

def to_flyte_idl(self):
"""
:rtype: flyteidl.plugins.spark_pb2.SparkJob
Expand Down Expand Up @@ -167,6 +190,8 @@ def to_flyte_idl(self):
hadoopConf=self.hadoop_conf,
databricksConf=databricks_conf,
databricksInstance=self.databricks_instance,
driverPod=self.driver_pod.to_flyte_idl() if self.driver_pod else None,
executorPod=self.executor_pod.to_flyte_idl() if self.executor_pod else None,
)

@classmethod
Expand All @@ -193,4 +218,6 @@ def from_flyte_idl(cls, pb2_object):
executor_path=pb2_object.executorPath,
databricks_conf=json_format.MessageToDict(pb2_object.databricksConf),
databricks_instance=pb2_object.databricksInstance,
driver_pod=pb2_object.driverPod,
executor_pod=pb2_object.executorPod,
)
7 changes: 7 additions & 0 deletions plugins/flytekit-spark/flytekitplugins/spark/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from flytekit.extend import ExecutionState, TaskPlugins
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
from flytekit.image_spec import ImageSpec
from flytekit.models.task import K8sPod

from .models import SparkJob, SparkType

Expand All @@ -31,12 +32,16 @@ class Spark(object):
hadoop_conf: Dictionary of hadoop conf. The variables should match a typical hadoop configuration for spark
executor_path: Python binary executable to use for PySpark in driver and executor.
applications_path: MainFile is the path to a bundled JAR, Python, or R file of the application to execute.
driver_pod: K8sPod for Spark driver pod
executor_pod: K8sPod for Spark executor pod
"""

spark_conf: Optional[Dict[str, str]] = None
hadoop_conf: Optional[Dict[str, str]] = None
executor_path: Optional[str] = None
applications_path: Optional[str] = None
driver_pod: Optional[K8sPod] = None
executor_pod: Optional[K8sPod] = None

def __post_init__(self):
if self.spark_conf is None:
Expand Down Expand Up @@ -168,6 +173,8 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
executor_path=self._default_executor_path or settings.python_interpreter,
main_class="",
spark_type=SparkType.PYTHON,
driver_pod=self.task_config.driver_pod,
executor_pod=self.task_config.executor_pod,
)
if isinstance(self.task_config, (Databricks, DatabricksV2)):
cfg = cast(DatabricksV2, self.task_config)
Expand Down
155 changes: 145 additions & 10 deletions plugins/flytekit-spark/tests/test_spark_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,39 @@
import pyspark
import pytest

from google.protobuf.json_format import MessageToDict
from flytekit import PodTemplate
from flytekit.core import context_manager
from flytekitplugins.spark import Spark
from flytekitplugins.spark.task import Databricks, new_spark_session
from pyspark.sql import SparkSession

import flytekit
from flytekit import StructuredDataset, StructuredDatasetTransformerEngine, task, ImageSpec
from flytekit.configuration import Image, ImageConfig, SerializationSettings, FastSerializationSettings, DefaultImages
from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager, ExecutionState
from flytekit import (
StructuredDataset,
StructuredDatasetTransformerEngine,
task,
ImageSpec,
)
from flytekit.configuration import (
Image,
ImageConfig,
SerializationSettings,
FastSerializationSettings,
DefaultImages,
)
from flytekit.core.context_manager import (
ExecutionParameters,
FlyteContextManager,
ExecutionState,
)
from flytekit.models.task import K8sPod
from kubernetes.client.models import (
V1Container,
V1PodSpec,
V1Toleration,
V1EnvVar,
)


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -68,7 +92,10 @@ def my_spark(a: str) -> int:
retrieved_settings = my_spark.get_custom(settings)
assert retrieved_settings["sparkConf"] == {"spark": "1"}
assert retrieved_settings["executorPath"] == "/usr/bin/python3"
assert retrieved_settings["mainApplicationFile"] == "local:///usr/local/bin/entrypoint.py"
assert (
retrieved_settings["mainApplicationFile"]
== "local:///usr/local/bin/entrypoint.py"
)

pb = ExecutionParameters.new_builder()
pb.working_dir = "/tmp"
Expand Down Expand Up @@ -121,11 +148,13 @@ def test_to_html():
df = spark.createDataFrame([("Bob", 10)], ["name", "age"])
sd = StructuredDataset(dataframe=df)
tf = StructuredDatasetTransformerEngine()
output = tf.to_html(FlyteContextManager.current_context(), sd, pyspark.sql.DataFrame)
output = tf.to_html(
FlyteContextManager.current_context(), sd, pyspark.sql.DataFrame
)
assert pd.DataFrame(df.schema, columns=["StructField"]).to_html() == output


@mock.patch('pyspark.context.SparkContext.addPyFile')
@mock.patch("pyspark.context.SparkContext.addPyFile")
def test_spark_addPyFile(mock_add_pyfile):
@task(
task_config=Spark(
Expand All @@ -151,8 +180,11 @@ def my_spark(a: int) -> int:

ctx = context_manager.FlyteContextManager.current_context()
with context_manager.FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.new_execution_state().with_params(mode=ExecutionState.Mode.TASK_EXECUTION)).with_serialization_settings(serialization_settings)
ctx.with_execution_state(
ctx.new_execution_state().with_params(
mode=ExecutionState.Mode.TASK_EXECUTION
)
).with_serialization_settings(serialization_settings)
) as new_ctx:
my_spark.pre_execute(new_ctx.user_space_params)
mock_add_pyfile.assert_called_once()
Expand All @@ -173,7 +205,10 @@ def spark1(partitions: int) -> float:
print("Starting Spark with Partitions: {}".format(partitions))
return 1.0

assert spark1.container_image.base_image == f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}"
assert (
spark1.container_image.base_image
== f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}"
)
assert spark1._default_executor_path == "/usr/bin/python3"
assert spark1._default_applications_path == "local:///usr/local/bin/entrypoint.py"

Expand All @@ -185,6 +220,106 @@ def spark2(partitions: int) -> float:
print("Starting Spark with Partitions: {}".format(partitions))
return 1.0

assert spark2.container_image.base_image == f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}"
assert (
spark2.container_image.base_image
== f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}"
)
assert spark2._default_executor_path == "/usr/bin/python3"
assert spark2._default_applications_path == "local:///usr/local/bin/entrypoint.py"


def test_spark_driver_executor_podSpec():
custom_image = ImageSpec(
registry="ghcr.io/flyteorg",
packages=["flytekitplugins-spark"],
)

driver_pod_spec = V1PodSpec(
containers=[
V1Container(
name="primary",
image="ghcr.io/flyteorg",
command=["echo"],
args=["wow"],
env=[V1EnvVar(name="x/custom-driver", value="driver")],
),
],
tolerations=[
V1Toleration(
key="x/custom-driver",
operator="Equal",
value="foo-driver",
effect="NoSchedule",
),
],
)

executor_pod_spec = V1PodSpec(
containers=[
V1Container(
name="primary",
image="ghcr.io/flyteorg",
command=["echo"],
args=["wow"],
env=[V1EnvVar(name="x/custom-executor", value="executor")],
),
],
tolerations=[
V1Toleration(
key="x/custom-executor",
operator="Equal",
value="foo-executor",
effect="NoSchedule",
),
],
)

@task(
task_config=Spark(
spark_conf={"spark.driver.memory": "1000M"},
driver_pod=K8sPod(pod_spec=driver_pod_spec.to_dict()),
executor_pod=K8sPod(pod_spec=executor_pod_spec.to_dict()),
),
# limits=Resources(cpu="50m", mem="2000M"),
container_image=custom_image,
pod_template=PodTemplate(primary_container_name="primary"),
)
def my_spark(a: str) -> int:
session = flytekit.current_context().spark_session
assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local"
return 10

assert my_spark.task_config is not None
assert my_spark.task_config.spark_conf == {"spark.driver.memory": "1000M"}

default_img = Image(name="default", fqn="test", tag="tag")
settings = SerializationSettings(
project="project",
domain="domain",
version="version",
env={"FOO": "baz"},
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)

retrieved_settings = my_spark.get_custom(settings)
assert retrieved_settings["sparkConf"] == {"spark.driver.memory": "1000M"}
assert retrieved_settings["executorPath"] == "/usr/bin/python3"
assert (
retrieved_settings["mainApplicationFile"]
== "local:///usr/local/bin/entrypoint.py"
)
assert retrieved_settings["driverPod"] == MessageToDict(K8sPod(pod_spec=driver_pod_spec.to_dict()).to_flyte_idl())
assert retrieved_settings["executorPod"] == MessageToDict(K8sPod(pod_spec=executor_pod_spec.to_dict()).to_flyte_idl())

pb = ExecutionParameters.new_builder()
pb.working_dir = "/tmp"
pb.execution_id = "ex:local:local:local"
p = pb.build()
new_p = my_spark.pre_execute(p)
assert new_p is not None
assert new_p.has_attr("SPARK_SESSION")

assert my_spark.sess is not None
configs = my_spark.sess.sparkContext.getConf().getAll()
assert ("spark.driver.memory", "1000M") in configs
assert ("spark.app.name", "FlyteSpark: ex:local:local:local") in configs
Loading