Skip to content

Commit

Permalink
Merge branch 'master' into rename-ephemeral-storage-to-disk
Browse files Browse the repository at this point in the history
  • Loading branch information
JiangJiaWei1103 authored Feb 9, 2025
2 parents 2add341 + 1eb6743 commit ed30790
Show file tree
Hide file tree
Showing 46 changed files with 958 additions and 170 deletions.
4 changes: 2 additions & 2 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ filelock==3.14.0
# via
# snowflake-connector-python
# virtualenv
flyteidl==1.14.1
flyteidl==1.14.3
# via flytekit
frozenlist==1.4.1
# via
Expand Down Expand Up @@ -491,7 +491,7 @@ six==1.16.0
# isodate
# kubernetes
# python-dateutil
snowflake-connector-python==3.12.3
snowflake-connector-python==3.13.1
# via -r dev-requirements.in
sortedcontainers==2.4.0
# via
Expand Down
2 changes: 1 addition & 1 deletion flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def _dispatch_execute(
logger.info(f"Engine folder written successfully to the output prefix {output_prefix}")

if task_def is not None and not getattr(task_def, "disable_deck", True):
_output_deck(task_def.name.split(".")[-1], ctx.user_space_params)
_output_deck(task_name=task_def.name.split(".")[-1], new_user_params=ctx.user_space_params)

logger.debug("Finished _dispatch_execute")

Expand Down
14 changes: 10 additions & 4 deletions flytekit/clients/auth_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,12 @@ def upgrade_channel_to_proxy_authenticated(cfg: PlatformConfig, in_channel: grpc
:param in_channel: grpc.Channel Precreated channel
:return: grpc.Channel. New composite channel
"""

def authenticator_factory():
return get_proxy_authenticator(cfg)

if cfg.proxy_command:
proxy_authenticator = get_proxy_authenticator(cfg)
return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(proxy_authenticator))
return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(authenticator_factory))
else:
return in_channel

Expand All @@ -137,8 +140,11 @@ def upgrade_channel_to_authenticated(cfg: PlatformConfig, in_channel: grpc.Chann
:param in_channel: grpc.Channel Precreated channel
:return: grpc.Channel. New composite channel
"""
authenticator = get_authenticator(cfg, RemoteClientConfigStore(in_channel))
return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(authenticator))

def authenticator_factory():
return get_authenticator(cfg, RemoteClientConfigStore(in_channel))

return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(authenticator_factory))


def get_authenticated_channel(cfg: PlatformConfig) -> grpc.Channel:
Expand Down
16 changes: 16 additions & 0 deletions flytekit/clients/friendly.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from flytekit.clients.raw import RawSynchronousFlyteClient as _RawSynchronousFlyteClient
from flytekit.models import common as _common
from flytekit.models import domain as _domain
from flytekit.models import execution as _execution
from flytekit.models import filters as _filters
from flytekit.models import launch_plan as _launch_plan
Expand Down Expand Up @@ -896,6 +897,21 @@ def list_projects_paginated(self, limit=100, token=None, filters=None, sort_by=N
str(projects.token),
)

####################################################################################################################
#
# Domain Endpoints
#
####################################################################################################################

def get_domains(self):
"""
This returns a list of domains.
:rtype: list[flytekit.models.Domain]
"""
domains = super(SynchronousFlyteClient, self).get_domains()
return [_domain.Domain.from_flyte_idl(domain) for domain in domains.domains]

####################################################################################################################
#
# Matching Attributes Endpoints
Expand Down
17 changes: 12 additions & 5 deletions flytekit/clients/grpc_utils/auth_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,22 @@ class AuthUnaryInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamCli
is needed.
"""

def __init__(self, authenticator: Authenticator):
self._authenticator = authenticator
def __init__(self, get_authenticator: typing.Callable[[], Authenticator]):
self._get_authenticator = get_authenticator
self._authenticator = None

@property
def authenticator(self) -> Authenticator:
if self._authenticator is None:
self._authenticator = self._get_authenticator()
return self._authenticator

def _call_details_with_auth_metadata(self, client_call_details: grpc.ClientCallDetails) -> grpc.ClientCallDetails:
"""
Returns new ClientCallDetails with metadata added.
"""
metadata = client_call_details.metadata
auth_metadata = self._authenticator.fetch_grpc_call_auth_metadata()
auth_metadata = self.authenticator.fetch_grpc_call_auth_metadata()
if auth_metadata:
metadata = []
if client_call_details.metadata:
Expand Down Expand Up @@ -64,7 +71,7 @@ def intercept_unary_unary(
if not hasattr(e, "code"):
raise e
if e.code() == grpc.StatusCode.UNAUTHENTICATED or e.code() == grpc.StatusCode.UNKNOWN:
self._authenticator.refresh_credentials()
self.authenticator.refresh_credentials()
updated_call_details = self._call_details_with_auth_metadata(client_call_details)
return continuation(updated_call_details, request)
return fut
Expand All @@ -76,7 +83,7 @@ def intercept_unary_stream(self, continuation, client_call_details, request):
updated_call_details = self._call_details_with_auth_metadata(client_call_details)
c: grpc.Call = continuation(updated_call_details, request)
if c.code() == grpc.StatusCode.UNAUTHENTICATED:
self._authenticator.refresh_credentials()
self.authenticator.refresh_credentials()
updated_call_details = self._call_details_with_auth_metadata(client_call_details)
return continuation(updated_call_details, request)
return c
17 changes: 16 additions & 1 deletion flytekit/clients/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import typing

import grpc
from flyteidl.admin.project_pb2 import ProjectListRequest
from flyteidl.admin.project_pb2 import GetDomainRequest, ProjectListRequest
from flyteidl.admin.signal_pb2 import SignalList, SignalListRequest, SignalSetRequest, SignalSetResponse
from flyteidl.service import admin_pb2_grpc as _admin_service
from flyteidl.service import dataproxy_pb2 as _dataproxy_pb2
Expand Down Expand Up @@ -520,6 +520,21 @@ def update_project(self, project):
"""
return self._stub.UpdateProject(project, metadata=self._metadata)

####################################################################################################################
#
# Domain Endpoints
#
####################################################################################################################

def get_domains(self):
"""
This will return a list of domains registered with the Flyte Admin Service
:param flyteidl.admin.project_pb2.GetDomainRequest get_domain_request:
:rtype: flyteidl.admin.project_pb2.GetDomainsResponse
"""
get_domain_request = GetDomainRequest()
return self._stub.GetDomains(get_domain_request, metadata=self._metadata)

####################################################################################################################
#
# Matching Attributes Endpoints
Expand Down
46 changes: 30 additions & 16 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,21 +115,18 @@ class TaskMetadata(object):
See the :std:ref:`IDL <idl:protos/docs/core/core:taskmetadata>` for the protobuf definition.
Args:
cache (bool): Indicates if caching should be enabled. See :std:ref:`Caching <cookbook:caching>`
cache_serialize (bool): Indicates if identical (ie. same inputs) instances of this task should be executed in serial when caching is enabled. See :std:ref:`Caching <cookbook:caching>`
cache_version (str): Version to be used for the cached value
cache_ignore_input_vars (Tuple[str, ...]): Input variables that should not be included when calculating hash for cache
interruptible (Optional[bool]): Indicates that this task can be interrupted and/or scheduled on nodes with
lower QoS guarantees that can include pre-emption. This can reduce the monetary cost executions incur at the
cost of performance penalties due to potential interruptions
deprecated (str): Can be used to provide a warning message for deprecated task. Absence or empty str indicates
that the task is active and not deprecated
Attributes:
cache (bool): Indicates if caching should be enabled. See :std:ref:`Caching <cookbook:caching>`.
cache_serialize (bool): Indicates if identical (i.e. same inputs) instances of this task should be executed in serial when caching is enabled. See :std:ref:`Caching <cookbook:caching>`.
cache_version (str): Version to be used for the cached value.
cache_ignore_input_vars (Tuple[str, ...]): Input variables that should not be included when calculating hash for cache.
interruptible (Optional[bool]): Indicates that this task can be interrupted and/or scheduled on nodes with lower QoS guarantees that can include pre-emption.
deprecated (str): Can be used to provide a warning message for a deprecated task. An absence or empty string indicates that the task is active and not deprecated.
retries (int): for retries=n; n > 0, on failures of this task, the task will be retried at-least n number of times.
timeout (Optional[Union[datetime.timedelta, int]]): the max amount of time for which one execution of this task
should be executed for. The execution will be terminated if the runtime exceeds the given timeout
(approximately)
pod_template_name (Optional[str]): the name of existing PodTemplate resource in the cluster which will be used in this task.
timeout (Optional[Union[datetime.timedelta, int]]): The maximum duration for which one execution of this task should run. The execution will be terminated if the runtime exceeds this timeout.
pod_template_name (Optional[str]): The name of an existing PodTemplate resource in the cluster which will be used for this task.
generates_deck (bool): Indicates whether the task will generate a Deck URI.
is_eager (bool): Indicates whether the task should be treated as eager.
"""

cache: bool = False
Expand All @@ -141,6 +138,7 @@ class TaskMetadata(object):
retries: int = 0
timeout: Optional[Union[datetime.timedelta, int]] = None
pod_template_name: Optional[str] = None
generates_deck: bool = False
is_eager: bool = False

def __post_init__(self):
Expand Down Expand Up @@ -179,6 +177,7 @@ def to_taskmetadata_model(self) -> _task_model.TaskMetadata:
discovery_version=self.cache_version,
deprecated_error_message=self.deprecated,
cache_serializable=self.cache_serialize,
generates_deck=self.generates_deck,
pod_template_name=self.pod_template_name,
cache_ignore_input_vars=self.cache_ignore_input_vars,
is_eager=self.is_eager,
Expand Down Expand Up @@ -720,11 +719,15 @@ def dispatch_execute(
may be none
* ``DynamicJobSpec`` is returned when a dynamic workflow is executed
"""
if DeckField.TIMELINE.value in self.deck_fields and ctx.user_space_params is not None:
ctx.user_space_params.decks.append(ctx.user_space_params.timeline_deck)

# Invoked before the task is executed
new_user_params = self.pre_execute(ctx.user_space_params)

if self.enable_deck and ctx.user_space_params is not None:
if DeckField.TIMELINE.value in self.deck_fields:
ctx.user_space_params.decks.append(ctx.user_space_params.timeline_deck)
new_user_params = ctx.user_space_params.with_enable_deck(enable_deck=True).build()

# Create another execution context with the new user params, but let's keep the same working dir
with FlyteContextManager.with_context(
ctx.with_execution_state(
Expand Down Expand Up @@ -827,8 +830,19 @@ def disable_deck(self) -> bool:
"""
If true, this task will not output deck html file
"""
warnings.warn(
"`disable_deck` is deprecated and will be removed in the future.\n" "Please use `enable_deck` instead.",
DeprecationWarning,
)
return self._disable_deck

@property
def enable_deck(self) -> bool:
"""
If true, this task will output deck html file
"""
return not self._disable_deck

@property
def deck_fields(self) -> List[DeckField]:
"""
Expand Down
17 changes: 17 additions & 0 deletions flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class Builder(object):
logging: Optional[_logging.Logger] = None
task_id: typing.Optional[_identifier.Identifier] = None
output_metadata_prefix: Optional[str] = None
enable_deck: bool = False

def __init__(self, current: typing.Optional[ExecutionParameters] = None):
self.stats = current.stats if current else None
Expand All @@ -107,6 +108,7 @@ def __init__(self, current: typing.Optional[ExecutionParameters] = None):
self.raw_output_prefix = current.raw_output_prefix if current else None
self.task_id = current.task_id if current else None
self.output_metadata_prefix = current.output_metadata_prefix if current else None
self.enable_deck = current.enable_deck if current else False

def add_attr(self, key: str, v: typing.Any) -> ExecutionParameters.Builder:
self.attrs[key] = v
Expand All @@ -126,6 +128,7 @@ def build(self) -> ExecutionParameters:
raw_output_prefix=self.raw_output_prefix,
task_id=self.task_id,
output_metadata_prefix=self.output_metadata_prefix,
enable_deck=self.enable_deck,
**self.attrs,
)

Expand All @@ -147,6 +150,11 @@ def with_task_sandbox(self) -> Builder:
b.working_dir = task_sandbox_dir
return b

def with_enable_deck(self, enable_deck: bool) -> Builder:
b = self.new_builder(self)
b.enable_deck = enable_deck
return b

def builder(self) -> Builder:
return ExecutionParameters.Builder(current=self)

Expand All @@ -162,6 +170,7 @@ def __init__(
checkpoint=None,
decks=None,
task_id: typing.Optional[_identifier.Identifier] = None,
enable_deck: bool = False,
**kwargs,
):
"""
Expand Down Expand Up @@ -190,6 +199,7 @@ def __init__(
self._decks = decks
self._task_id = task_id
self._timeline_deck = None
self._enable_deck = enable_deck

@property
def stats(self) -> taggable.TaggableStats:
Expand Down Expand Up @@ -298,6 +308,13 @@ def timeline_deck(self) -> "TimeLineDeck": # type: ignore
self._timeline_deck = time_line_deck
return time_line_deck

@property
def enable_deck(self) -> bool:
"""
Returns whether deck is enabled or not
"""
return self._enable_deck

def __getattr__(self, attr_name: str) -> typing.Any:
"""
This houses certain task specific context. For example in Spark, it houses the SparkSession, etc
Expand Down
7 changes: 7 additions & 0 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from flyteidl.core import tasks_pb2

from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import Resources, convert_resources_to_resource_model
from flytekit.core.utils import _dnsify
from flytekit.extras.accelerators import BaseAccelerator
Expand Down Expand Up @@ -67,6 +68,7 @@ def __init__(
self._resources: typing.Optional[_resources_model] = None
self._extended_resources: typing.Optional[tasks_pb2.ExtendedResources] = None
self._container_image: typing.Optional[str] = None
self._pod_template: typing.Optional[PodTemplate] = None

def runs_before(self, other: Node):
"""
Expand Down Expand Up @@ -191,6 +193,7 @@ def with_overrides(
cache: Optional[bool] = None,
cache_version: Optional[str] = None,
cache_serialize: Optional[bool] = None,
pod_template: Optional[PodTemplate] = None,
*args,
**kwargs,
):
Expand Down Expand Up @@ -241,6 +244,10 @@ def with_overrides(

self._override_node_metadata(name, timeout, retries, interruptible, cache, cache_version, cache_serialize)

if pod_template is not None:
assert_not_promise(pod_template, "podtemplate")
self._pod_template = pod_template

return self


Expand Down
10 changes: 5 additions & 5 deletions flytekit/core/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,11 @@ def convert_resources_to_resource_model(


def pod_spec_from_resources(
k8s_pod_name: str,
primary_container_name: Optional[str] = None,
requests: Optional[Resources] = None,
limits: Optional[Resources] = None,
k8s_gpu_resource_key: str = "nvidia.com/gpu",
) -> dict[str, Any]:
) -> V1PodSpec:
def _construct_k8s_pods_resources(resources: Optional[Resources], k8s_gpu_resource_key: str):
if resources is None:
return None
Expand All @@ -157,10 +157,10 @@ def _construct_k8s_pods_resources(resources: Optional[Resources], k8s_gpu_resour
requests = requests or limits
limits = limits or requests

k8s_pod = V1PodSpec(
pod_spec = V1PodSpec(
containers=[
V1Container(
name=k8s_pod_name,
name=primary_container_name,
resources=V1ResourceRequirements(
requests=requests,
limits=limits,
Expand All @@ -169,4 +169,4 @@ def _construct_k8s_pods_resources(resources: Optional[Resources], k8s_gpu_resour
]
)

return k8s_pod.to_dict()
return pod_spec
Loading

0 comments on commit ed30790

Please sign in to comment.