diff --git a/cosmos/mocked_dbt_adapters.py b/cosmos/mocked_dbt_adapters.py index b8d495885..2e6e9bd78 100644 --- a/cosmos/mocked_dbt_adapters.py +++ b/cosmos/mocked_dbt_adapters.py @@ -1,7 +1,12 @@ +from __future__ import annotations + +from typing import Any + from cosmos.constants import BIGQUERY_PROFILE_TYPE +from cosmos.exceptions import CosmosValueError -def mock_bigquery_adapter() -> None: +def _mock_bigquery_adapter() -> None: from typing import Optional, Tuple import agate @@ -17,5 +22,27 @@ def execute( # type: ignore[no-untyped-def] PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP = { - BIGQUERY_PROFILE_TYPE: mock_bigquery_adapter, + BIGQUERY_PROFILE_TYPE: _mock_bigquery_adapter, } + + +def _associate_bigquery_async_op_args(async_op_obj: Any, **kwargs: Any) -> Any: + sql = kwargs.get("sql") + if not sql: + raise CosmosValueError("Keyword argument 'sql' is required for BigQuery Async operator") + async_op_obj.configuration = { + "query": { + "query": sql, + "useLegacySql": False, + } + } + return async_op_obj + + +PROFILE_TYPE_ASSOCIATE_ARGS_CALLABLE_MAP = { + BIGQUERY_PROFILE_TYPE: _associate_bigquery_async_op_args, +} + + +def _associate_async_operator_args(async_operator_obj: Any, profile_type: str, **kwargs: Any) -> Any: + return PROFILE_TYPE_ASSOCIATE_ARGS_CALLABLE_MAP[profile_type](async_operator_obj, **kwargs) diff --git a/cosmos/operators/airflow_async.py b/cosmos/operators/airflow_async.py index 079ea5625..56056f143 100644 --- a/cosmos/operators/airflow_async.py +++ b/cosmos/operators/airflow_async.py @@ -57,11 +57,7 @@ class DbtSourceAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSourceLocalO class DbtRunAirflowAsyncOperator(BigQueryInsertJobOperator, DbtRunLocalOperator): # type: ignore - template_fields: Sequence[str] = ( - "full_refresh", - "project_dir", - "location", - ) + template_fields: Sequence[str] = DbtRunLocalOperator.template_fields + ("full_refresh", "project_dir", "location") # type: ignore[operator] def __init__( # type: ignore self, @@ -98,18 +94,12 @@ def __init__( # type: ignore deferrable=True, **kwargs, ) - self.extra_context = extra_context or {} - self.extra_context["profile_type"] = self.profile_type + self.async_context = extra_context or {} + self.async_context["profile_type"] = self.profile_type + self.async_context["async_operator"] = BigQueryInsertJobOperator def execute(self, context: Context) -> Any | None: - sql = self.build_and_run_cmd(context, return_sql=True, sql_context=self.extra_context) - self.configuration = { - "query": { - "query": sql, - "useLegacySql": False, - } - } - return super().execute(context) + return self.build_and_run_cmd(context, run_as_async=True, async_context=self.async_context) class DbtTestAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtTestLocalOperator): # type: ignore diff --git a/cosmos/operators/base.py b/cosmos/operators/base.py index 52fb98bac..305a509d7 100644 --- a/cosmos/operators/base.py +++ b/cosmos/operators/base.py @@ -258,7 +258,13 @@ def build_cmd( return dbt_cmd, env @abstractmethod - def build_and_run_cmd(self, context: Context, cmd_flags: list[str]) -> Any: + def build_and_run_cmd( + self, + context: Context, + cmd_flags: list[str], + run_as_async: bool = False, + async_context: dict[str, Any] | None = None, + ) -> Any: """Override this method for the operator to execute the dbt command""" def execute(self, context: Context) -> Any | None: # type: ignore diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index ab2f1cbc3..129946892 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -67,7 +67,7 @@ FullOutputSubprocessResult, ) from cosmos.log import get_logger -from cosmos.mocked_dbt_adapters import PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP +from cosmos.mocked_dbt_adapters import PROFILE_TYPE_ASSOCIATE_ARGS_CALLABLE_MAP, PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP from cosmos.operators.base import ( AbstractDbtBaseOperator, DbtBuildMixin, @@ -411,8 +411,8 @@ def run_command( cmd: list[str], env: dict[str, str | bytes | os.PathLike[Any]], context: Context, - return_sql: bool = False, - sql_context: dict[str, Any] | None = None, + run_as_async: bool = False, + async_context: dict[str, Any] | None = None, ) -> FullOutputSubprocessResult | dbtRunnerResult | str: """ Copies the dbt project to a temporary directory and runs the command. @@ -465,8 +465,10 @@ def run_command( full_cmd = cmd + flags self.log.debug("Using environment variables keys: %s", env.keys()) - if return_sql and sql_context: - profile_type = sql_context["profile_type"] + if run_as_async: + if not async_context: + raise CosmosValueError("async_context is necessary for running the model asynchronously.") + profile_type = async_context["profile_type"] mock_adapter_callable = PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP.get(profile_type) if not mock_adapter_callable: raise CosmosValueError( @@ -505,9 +507,10 @@ def run_command( self.callback(tmp_project_dir, **self.callback_args) self.handle_exception(result) - if return_sql and sql_context: - sql_content = self._read_run_sql_from_target_dir(tmp_project_dir, sql_context) - return sql_content + if run_as_async and async_context: + sql = self._read_run_sql_from_target_dir(tmp_project_dir, async_context) + PROFILE_TYPE_ASSOCIATE_ARGS_CALLABLE_MAP[profile_type](self, sql=sql) + async_context["async_operator"].execute(self, context) return result @@ -651,12 +654,14 @@ def build_and_run_cmd( self, context: Context, cmd_flags: list[str] | None = None, - return_sql: bool = False, - sql_context: dict[str, Any] | None = None, + run_as_async: bool = False, + async_context: dict[str, Any] | None = None, ) -> FullOutputSubprocessResult | dbtRunnerResult: dbt_cmd, env = self.build_cmd(context=context, cmd_flags=cmd_flags) dbt_cmd = dbt_cmd or [] - result = self.run_command(cmd=dbt_cmd, env=env, context=context, return_sql=return_sql, sql_context=sql_context) + result = self.run_command( + cmd=dbt_cmd, env=env, context=context, run_as_async=run_as_async, async_context=async_context + ) return result def on_kill(self) -> None: