Skip to content

Commit

Permalink
AIP-72: Port Registering of Asset Changes to Task SDK on task complet…
Browse files Browse the repository at this point in the history
…ion (apache#45924)
  • Loading branch information
amoghrajesh authored Jan 24, 2025
1 parent d460972 commit bb77ebf
Show file tree
Hide file tree
Showing 12 changed files with 416 additions and 52 deletions.
15 changes: 15 additions & 0 deletions airflow/api_fastapi/execution_api/datamodels/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,18 @@ class AssetAliasResponse(BaseModel):

name: str
group: str


class AssetProfile(BaseModel):
"""
Profile of an Asset.
Asset will have name, uri and asset_type defined.
AssetNameRef will have name and asset_type defined.
AssetUriRef will have uri and asset_type defined.
"""

name: str | None = None
uri: str | None = None
asset_type: str
48 changes: 44 additions & 4 deletions airflow/api_fastapi/execution_api/datamodels/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,19 @@
from datetime import timedelta
from typing import Annotated, Any, Literal, Union

from pydantic import AwareDatetime, Discriminator, Field, Tag, TypeAdapter, WithJsonSchema, field_validator
from pydantic import (
AwareDatetime,
Discriminator,
Field,
Tag,
TypeAdapter,
WithJsonSchema,
field_validator,
)

from airflow.api_fastapi.common.types import UtcDateTime
from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.execution_api.datamodels.asset import AssetProfile
from airflow.api_fastapi.execution_api.datamodels.connection import ConnectionResponse
from airflow.api_fastapi.execution_api.datamodels.variable import VariableResponse
from airflow.utils.state import IntermediateTIState, TaskInstanceState as TIState, TerminalTIState
Expand Down Expand Up @@ -52,14 +61,41 @@ class TIEnterRunningPayload(BaseModel):


class TITerminalStatePayload(BaseModel):
"""Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or FAILED)."""
"""Schema for updating TaskInstance to a terminal state except SUCCESS state."""

state: TerminalTIState
state: Literal[
TerminalTIState.FAILED,
TerminalTIState.SKIPPED,
TerminalTIState.REMOVED,
TerminalTIState.FAIL_WITHOUT_RETRY,
]

end_date: UtcDateTime
"""When the task completed executing"""


class TISuccessStatePayload(BaseModel):
"""Schema for updating TaskInstance to success state."""

state: Annotated[
Literal[TerminalTIState.SUCCESS],
# Specify a default in the schema, but not in code, so Pydantic marks it as required.
WithJsonSchema(
{
"type": "string",
"enum": [TerminalTIState.SUCCESS],
"default": TerminalTIState.SUCCESS,
}
),
]

end_date: UtcDateTime
"""When the task completed executing"""

task_outlets: Annotated[list[AssetProfile], Field(default_factory=list)]
outlet_events: Annotated[list[Any], Field(default_factory=list)]


class TITargetStatePayload(BaseModel):
"""Schema for updating TaskInstance to a target state, excluding terminal and running states."""

Expand Down Expand Up @@ -123,7 +159,10 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
state = v.get("state")
else:
state = getattr(v, "state", None)
if state in set(TerminalTIState):

if state == TIState.SUCCESS:
return "success"
elif state in set(TerminalTIState):
return "_terminal_"
elif state == TIState.DEFERRED:
return "deferred"
Expand All @@ -137,6 +176,7 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
TIStateUpdate = Annotated[
Union[
Annotated[TITerminalStatePayload, Tag("_terminal_")],
Annotated[TISuccessStatePayload, Tag("success")],
Annotated[TITargetStatePayload, Tag("_other_")],
Annotated[TIDeferredStatePayload, Tag("deferred")],
Annotated[TIRescheduleStatePayload, Tag("up_for_reschedule")],
Expand Down
14 changes: 13 additions & 1 deletion airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
TIRescheduleStatePayload,
TIRunContext,
TIStateUpdate,
TISuccessStatePayload,
TITerminalStatePayload,
)
from airflow.models.dagrun import DagRun as DR
Expand Down Expand Up @@ -226,7 +227,7 @@ def ti_update_state(
)

# We exclude_unset to avoid updating fields that are not set in the payload
data = ti_patch_payload.model_dump(exclude_unset=True)
data = ti_patch_payload.model_dump(exclude={"task_outlets", "outlet_events"}, exclude_unset=True)

query = update(TI).where(TI.id == ti_id_str).values(data)

Expand All @@ -243,6 +244,17 @@ def ti_update_state(
else:
updated_state = State.FAILED
query = query.values(state=updated_state)
elif isinstance(ti_patch_payload, TISuccessStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
updated_state = ti_patch_payload.state
task_instance = session.get(TI, ti_id_str)
TI.register_asset_changes_in_db(
task_instance,
ti_patch_payload.task_outlets, # type: ignore
ti_patch_payload.outlet_events,
session,
)
query = query.values(state=updated_state)
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
timeout = None
Expand Down
82 changes: 51 additions & 31 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.xcom import LazyXComSelectSequence, XCom
from airflow.plugins_manager import integrate_macros_plugins
from airflow.sdk.api.datamodels._generated import AssetProfile
from airflow.sdk.definitions._internal.templater import SandboxedEnvironment
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef
from airflow.sdk.definitions.taskgroup import MappedTaskGroup
Expand Down Expand Up @@ -160,7 +161,7 @@
from airflow.models.dagrun import DagRun
from airflow.sdk.definitions._internal.abstractoperator import Operator
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.types import OutletEventAccessorsProtocol, RuntimeTaskInstanceProtocol
from airflow.sdk.types import RuntimeTaskInstanceProtocol
from airflow.typing_compat import Literal, TypeGuard
from airflow.utils.task_group import TaskGroup

Expand Down Expand Up @@ -352,7 +353,29 @@ def _run_raw_task(
if not test_mode:
_add_log(event=ti.state, task_instance=ti, session=session)
if ti.state == TaskInstanceState.SUCCESS:
ti._register_asset_changes(events=context["outlet_events"], session=session)
added_alias_to_task_outlet = False
task_outlets = []
outlet_events = []
events = context["outlet_events"]
for obj in ti.task.outlets or []:
# Lineage can have other types of objects besides assets
asset_type = type(obj).__name__
if isinstance(obj, Asset):
task_outlets.append(AssetProfile(name=obj.name, uri=obj.uri, asset_type=asset_type))
outlet_events.append(attrs.asdict(events[obj])) # type: ignore
elif isinstance(obj, AssetNameRef):
task_outlets.append(AssetProfile(name=obj.name, asset_type=asset_type))
outlet_events.append(attrs.asdict(events)) # type: ignore
elif isinstance(obj, AssetUriRef):
task_outlets.append(AssetProfile(uri=obj.uri, asset_type=asset_type))
outlet_events.append(attrs.asdict(events)) # type: ignore
elif isinstance(obj, AssetAlias):
if not added_alias_to_task_outlet:
task_outlets.append(AssetProfile(asset_type=asset_type))
added_alias_to_task_outlet = True
for asset_alias_event in events[obj].asset_alias_events:
outlet_events.append(attrs.asdict(asset_alias_event))
TaskInstance.register_asset_changes_in_db(ti, task_outlets, outlet_events, session=session)

TaskInstance.save_to_db(ti=ti, session=session)
if ti.state == TaskInstanceState.SUCCESS:
Expand Down Expand Up @@ -2733,49 +2756,46 @@ def _run_raw_task(
session=session,
)

def _register_asset_changes(
self, *, events: OutletEventAccessorsProtocol, session: Session | None = None
) -> None:
if session:
TaskInstance._register_asset_changes_int(ti=self, events=events, session=session)
else:
TaskInstance._register_asset_changes_int(ti=self, events=events)

@staticmethod
@provide_session
def _register_asset_changes_int(
ti: TaskInstance, *, events: OutletEventAccessorsProtocol, session: Session = NEW_SESSION
def register_asset_changes_in_db(
ti: TaskInstance,
task_outlets: list[AssetProfile],
outlet_events: list[Any],
session: Session = NEW_SESSION,
) -> None:
if TYPE_CHECKING:
assert ti.task

# One task only triggers one asset event for each asset with the same extra.
# This tuple[asset uri, extra] to sets alias names mapping is used to find whether
# there're assets with same uri but different extra that we need to emit more than one asset events.
asset_alias_names: dict[tuple[AssetUniqueKey, frozenset], set[str]] = defaultdict(set)

asset_name_refs: set[str] = set()
asset_uri_refs: set[str] = set()

for obj in ti.task.outlets or []:
for obj in task_outlets:
ti.log.debug("outlet obj %s", obj)
# Lineage can have other types of objects besides assets
if isinstance(obj, Asset):
if obj.asset_type == Asset.__name__:
asset_manager.register_asset_change(
task_instance=ti,
asset=obj,
extra=events[obj].extra,
asset=Asset(name=obj.name, uri=obj.uri), # type: ignore
extra=outlet_events[0]["extra"],
session=session,
)
elif isinstance(obj, AssetNameRef):
asset_name_refs.add(obj.name)
elif isinstance(obj, AssetUriRef):
asset_uri_refs.add(obj.uri)
elif isinstance(obj, AssetAlias):
for asset_alias_event in events[obj].asset_alias_events:
asset_alias_name = asset_alias_event.source_alias_name
asset_unique_key = asset_alias_event.dest_asset_key
frozen_extra = frozenset(asset_alias_event.extra.items())
elif obj.asset_type == AssetNameRef.__name__:
asset_name_refs.add(obj.name) # type: ignore
elif obj.asset_type == AssetUriRef.__name__:
asset_uri_refs.add(obj.uri) # type: ignore
elif obj.asset_type == AssetAlias.__name__:
outlet_events = list(
map(
lambda event: {**event, "dest_asset_key": AssetUniqueKey(**event["dest_asset_key"])},
outlet_events,
)
)
for asset_alias_event in outlet_events:
asset_alias_name = asset_alias_event["source_alias_name"]
asset_unique_key = asset_alias_event["dest_asset_key"]
frozen_extra = frozenset(asset_alias_event["extra"].items())
asset_alias_names[(asset_unique_key, frozen_extra)].add(asset_alias_name)

asset_unique_keys = {key for key, _ in asset_alias_names}
Expand Down Expand Up @@ -2827,7 +2847,7 @@ def _register_asset_changes_int(
asset_manager.register_asset_change(
task_instance=ti,
asset=asset_model,
extra=events[asset_model].extra,
extra=outlet_events[asset_model].extra,
session=session,
)
asset_stmt = select(AssetModel).where(AssetModel.uri.in_(asset_uri_refs), AssetModel.active.has())
Expand All @@ -2836,7 +2856,7 @@ def _register_asset_changes_int(
asset_manager.register_asset_change(
task_instance=ti,
asset=asset_model,
extra=events[asset_model].extra,
extra=outlet_events[asset_model].extra,
session=session,
)

Expand Down
6 changes: 6 additions & 0 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
TIHeartbeatInfo,
TIRescheduleStatePayload,
TIRunContext,
TISuccessStatePayload,
TITerminalStatePayload,
ValidationError as RemoteValidationError,
VariablePostBody,
Expand Down Expand Up @@ -136,6 +137,11 @@ def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime):
body = TITerminalStatePayload(end_date=when, state=TerminalTIState(state))
self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def succeed(self, id: uuid.UUID, when: datetime, task_outlets, outlet_events):
"""Tell the API server that this TI has succeeded."""
body = TISuccessStatePayload(end_date=when, task_outlets=task_outlets, outlet_events=outlet_events)
self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def heartbeat(self, id: uuid.UUID, pid: int):
body = TIHeartbeatInfo(pid=pid, hostname=get_hostname())
self.client.put(f"task-instances/{id}/heartbeat", content=body.model_dump_json())
Expand Down
27 changes: 26 additions & 1 deletion task_sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@
from pydantic import BaseModel, ConfigDict, Field


class AssetProfile(BaseModel):
"""
Profile of an Asset.
Asset will have name, uri and asset_type defined.
AssetNameRef will have name and asset_type defined.
AssetUriRef will have uri and asset_type defined.
"""

name: Annotated[str | None, Field(title="Name")] = None
uri: Annotated[str | None, Field(title="Uri")] = None
asset_type: Annotated[str, Field(title="Asset Type")]


class AssetResponse(BaseModel):
"""
Asset schema for responses with fields that are needed for Runtime.
Expand Down Expand Up @@ -134,6 +148,17 @@ class TIRescheduleStatePayload(BaseModel):
end_date: Annotated[datetime, Field(title="End Date")]


class TISuccessStatePayload(BaseModel):
"""
Schema for updating TaskInstance to success state.
"""

state: Annotated[Literal["success"] | None, Field(title="State")] = "success"
end_date: Annotated[datetime, Field(title="End Date")]
task_outlets: Annotated[list[AssetProfile] | None, Field(title="Task Outlets")] = None
outlet_events: Annotated[list | None, Field(title="Outlet Events")] = None


class TITargetStatePayload(BaseModel):
"""
Schema for updating TaskInstance to a target state, excluding terminal and running states.
Expand Down Expand Up @@ -243,7 +268,7 @@ class TIRunContext(BaseModel):

class TITerminalStatePayload(BaseModel):
"""
Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or FAILED).
Schema for updating TaskInstance to a terminal state except SUCCESS state.
"""

state: TerminalTIState
Expand Down
15 changes: 14 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
TIDeferredStatePayload,
TIRescheduleStatePayload,
TIRunContext,
TISuccessStatePayload,
VariableResponse,
XComResponse,
)
Expand Down Expand Up @@ -191,11 +192,22 @@ class TaskState(BaseModel):
- anything else = FAILED
"""

state: TerminalTIState
state: Literal[
TerminalTIState.FAILED,
TerminalTIState.SKIPPED,
TerminalTIState.REMOVED,
TerminalTIState.FAIL_WITHOUT_RETRY,
]
end_date: datetime | None = None
type: Literal["TaskState"] = "TaskState"


class SucceedTask(TISuccessStatePayload):
"""Update a task's state to success. Includes task_outlets and outlet_events for registering asset events."""

type: Literal["SucceedTask"] = "SucceedTask"


class DeferTask(TIDeferredStatePayload):
"""Update a task instance state to deferred."""

Expand Down Expand Up @@ -292,6 +304,7 @@ class GetPrevSuccessfulDagRun(BaseModel):

ToSupervisor = Annotated[
Union[
SucceedTask,
DeferTask,
GetAssetByName,
GetAssetByUri,
Expand Down
Loading

0 comments on commit bb77ebf

Please sign in to comment.