Skip to content

Commit

Permalink
Test async dag
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajkoti committed Jan 24, 2025
1 parent 859f3ad commit 379d997
Showing 1 changed file with 41 additions and 6 deletions.
47 changes: 41 additions & 6 deletions cosmos/operators/airflow_async.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import inspect
from typing import Any, Sequence

from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator
Expand All @@ -8,10 +9,12 @@
from cosmos.config import ProfileConfig
from cosmos.constants import BIGQUERY_PROFILE_TYPE
from cosmos.exceptions import CosmosValueError
from cosmos.operators.base import AbstractDbtBaseOperator
from cosmos.operators.local import (
DbtBuildLocalOperator,
DbtCloneLocalOperator,
DbtCompileLocalOperator,
DbtLocalBaseOperator,
DbtLSLocalOperator,
DbtRunLocalOperator,
DbtRunOperationLocalOperator,
Expand Down Expand Up @@ -57,7 +60,13 @@ class DbtSourceAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSourceLocalO

class DbtRunAirflowAsyncOperator(BigQueryInsertJobOperator, DbtRunLocalOperator): # type: ignore

template_fields: Sequence[str] = DbtRunLocalOperator.template_fields + ("full_refresh", "project_dir", "location") # type: ignore[operator]
template_fields: Sequence[str] = DbtRunLocalOperator.template_fields + ( # type: ignore[operator]
"full_refresh",
"project_dir",
"gcp_project",
"dataset",
"location",
)

def __init__( # type: ignore
self,
Expand All @@ -84,15 +93,41 @@ def __init__( # type: ignore
self.location = location
self.configuration = configuration or {}
self.gcp_conn_id = self.profile_config.profile_mapping.conn_id # type: ignore

super().__init__(
project_dir=self.project_dir,
profile_config=self.profile_config,
profile = self.profile_config.profile_mapping.profile
self.gcp_project = profile["project"]
self.dataset = profile["dataset"]

# Cosmos attempts to pass many kwargs that BigQueryInsertJobOperator simply does not accept.
# We need to pop them.
async_op_kwargs = {}
cosmos_op_kwargs = {}
non_async_args = set(inspect.signature(AbstractDbtBaseOperator.__init__).parameters.keys())
non_async_args |= set(inspect.signature(DbtLocalBaseOperator.__init__).parameters.keys())

for arg_key, arg_value in kwargs.items():
if arg_key == "task_id":
async_op_kwargs[arg_key] = arg_value
cosmos_op_kwargs[arg_key] = arg_value
elif arg_key not in non_async_args:
async_op_kwargs[arg_key] = arg_value
else:
cosmos_op_kwargs[arg_key] = arg_value

# The following are the minimum required parameters to run BigQueryInsertJobOperator using the deferrable mode
BigQueryInsertJobOperator.__init__(
self,
gcp_conn_id=self.gcp_conn_id,
configuration=self.configuration,
location=self.location,
deferrable=True,
**kwargs,
**async_op_kwargs,
)

DbtRunLocalOperator.__init__(
self,
project_dir=self.project_dir,
profile_config=self.profile_config,
**cosmos_op_kwargs,
)
self.async_context = extra_context or {}
self.async_context["profile_type"] = self.profile_type
Expand Down

0 comments on commit 379d997

Please sign in to comment.