diff --git a/benchmarks/sql/bench.py b/benchmarks/sql/bench.py index 0257b52d..d8007be9 100644 --- a/benchmarks/sql/bench.py +++ b/benchmarks/sql/bench.py @@ -28,6 +28,7 @@ ) from bench.pipelines import CollectionEvaluationPipeline, IQLViewEvaluationPipeline, SQLViewEvaluationPipeline from bench.utils import save +from hydra.core.hydra_config import HydraConfig from neptune.utils import stringify_unsupported from omegaconf import DictConfig @@ -120,7 +121,7 @@ async def bench(config: DictConfig) -> None: log.info("Evaluation finished. Saving results...") - output_dir = Path(hydra.core.hydra_config.HydraConfig.get().runtime.output_dir) + output_dir = Path(HydraConfig.get().runtime.output_dir) metrics_file = output_dir / "metrics.json" results_file = output_dir / "results.json" diff --git a/benchmarks/sql/bench/contexts/__init__.py b/benchmarks/sql/bench/contexts/__init__.py new file mode 100644 index 00000000..0d1dd11c --- /dev/null +++ b/benchmarks/sql/bench/contexts/__init__.py @@ -0,0 +1,10 @@ +from typing import Dict, Type + +from dbally.context import Context + +from .superhero import SuperheroContext, UserContext + +CONTEXTS_REGISTRY: Dict[str, Type[Context]] = { + UserContext.__name__: UserContext, + SuperheroContext.__name__: SuperheroContext, +} diff --git a/benchmarks/sql/bench/contexts/superhero.py b/benchmarks/sql/bench/contexts/superhero.py new file mode 100644 index 00000000..edd28657 --- /dev/null +++ b/benchmarks/sql/bench/contexts/superhero.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass + +from dbally.context import Context + + +@dataclass +class UserContext(Context): + """ + Current user data. + """ + + name: str = "John Doe" + + +@dataclass +class SuperheroContext(Context): + """ + Current user favourite superhero data. + """ + + name: str = "Batman" diff --git a/benchmarks/sql/bench/metrics/base.py b/benchmarks/sql/bench/metrics/base.py index d0e78072..8df7d1d3 100644 --- a/benchmarks/sql/bench/metrics/base.py +++ b/benchmarks/sql/bench/metrics/base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Type +from omegaconf import DictConfig from typing_extensions import Self from ..pipelines import EvaluationResult @@ -11,7 +12,7 @@ class Metric(ABC): Base class for metrics. """ - def __init__(self, config: Optional[Dict] = None) -> None: + def __init__(self, config: Optional[DictConfig] = None) -> None: """ Initializes the metric. @@ -38,7 +39,7 @@ class MetricSet: Represents a set of metrics. """ - def __init__(self, *metrics: List[Type[Metric]]) -> None: + def __init__(self, *metrics: Type[Metric]) -> None: """ Initializes the metric set. @@ -48,7 +49,7 @@ def __init__(self, *metrics: List[Type[Metric]]) -> None: self._metrics = metrics self.metrics: List[Metric] = [] - def __call__(self, config: Dict) -> Self: + def __call__(self, config: DictConfig) -> Self: """ Initializes the metrics. diff --git a/benchmarks/sql/bench/pipelines/base.py b/benchmarks/sql/bench/pipelines/base.py index dc8d83ea..acde042e 100644 --- a/benchmarks/sql/bench/pipelines/base.py +++ b/benchmarks/sql/bench/pipelines/base.py @@ -1,7 +1,12 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Optional, Union +from functools import cached_property +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union +from omegaconf import DictConfig +from sqlalchemy import Engine, create_engine + +from dbally.context import Context from dbally.iql._exceptions import IQLError from dbally.iql._query import IQLQuery from dbally.iql_generator.prompt import UnsupportedQueryError @@ -9,6 +14,11 @@ from dbally.llms.clients.exceptions import LLMError from dbally.llms.litellm import LiteLLM from dbally.llms.local import LocalLLM +from dbally.views.base import BaseView + +from ..contexts import CONTEXTS_REGISTRY + +ViewT = TypeVar("ViewT", bound=BaseView) @dataclass @@ -23,7 +33,7 @@ class IQL: generated: bool = True @classmethod - def from_query(cls, query: Optional[Union[IQLQuery, Exception]]) -> "IQL": + def from_query(cls, query: Optional[Union[IQLQuery, BaseException]]) -> "IQL": """ Creates an IQL object from the query. @@ -81,7 +91,12 @@ class EvaluationPipeline(ABC): Collection evaluation pipeline. """ - def get_llm(self, config: Dict) -> LLM: + def __init__(self, config: DictConfig) -> None: + super().__init__() + self.config = config + + @staticmethod + def _get_llm(config: DictConfig) -> LLM: """ Returns the LLM based on the configuration. @@ -95,6 +110,13 @@ def get_llm(self, config: Dict) -> LLM: return LocalLLM(config.model_name.split("/", 1)[1]) return LiteLLM(config.model_name) + @cached_property + def dbs(self) -> Dict[str, Engine]: + """ + Returns the database engines based on the configuration. + """ + return {db: create_engine(f"sqlite:///data/{db}.db") for db in self.config.setup.views} + @abstractmethod async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: """ @@ -106,3 +128,25 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: Returns: The evaluation result. """ + + +class ViewEvaluationMixin(Generic[ViewT]): + """ + View evaluation mixin. + """ + + @cached_property + def contexts(self) -> List[Context]: + """ + Returns the contexts based on the configuration. + """ + return [ + CONTEXTS_REGISTRY[context]() for contexts in self.config.setup.contexts.values() for context in contexts + ] + + @cached_property + @abstractmethod + def views(self) -> Dict[str, Type[ViewT]]: + """ + Returns the view classes mapping based on the configuration + """ diff --git a/benchmarks/sql/bench/pipelines/collection.py b/benchmarks/sql/bench/pipelines/collection.py index 19831b0d..784086a7 100644 --- a/benchmarks/sql/bench/pipelines/collection.py +++ b/benchmarks/sql/bench/pipelines/collection.py @@ -1,57 +1,68 @@ -from typing import Any, Dict - -from sqlalchemy import create_engine +from functools import cached_property +from typing import Any, Dict, Type, Union import dbally from dbally.collection.collection import Collection from dbally.collection.exceptions import NoViewFoundError +from dbally.llms.base import LLM from dbally.view_selection.llm_view_selector import LLMViewSelector from dbally.views.exceptions import ViewExecutionError +from dbally.views.freeform.text2sql.view import BaseText2SQLView +from dbally.views.sqlalchemy_base import SqlAlchemyBaseView from ..views import VIEWS_REGISTRY -from .base import IQL, EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult +from .base import IQL, EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult, ViewEvaluationMixin -class CollectionEvaluationPipeline(EvaluationPipeline): +class CollectionEvaluationPipeline( + EvaluationPipeline, ViewEvaluationMixin[Union[SqlAlchemyBaseView, BaseText2SQLView]] +): """ Collection evaluation pipeline. """ - def __init__(self, config: Dict) -> None: + @cached_property + def selector(self) -> LLM: """ - Constructs the pipeline for evaluating collection predictions. - - Args: - config: The configuration for the pipeline. + Returns the selector LLM. """ - self.collection = self.get_collection(config.setup) + return self._get_llm(self.config.setup.selector_llm) - def get_collection(self, config: Dict) -> Collection: + @cached_property + def generator(self) -> LLM: """ - Sets up the collection based on the configuration. - - Args: - config: The collection configuration. + Returns the generator LLM. + """ + return self._get_llm(self.config.setup.generator_llm) - Returns: - The collection. + @cached_property + def views(self) -> Dict[str, Type[Union[SqlAlchemyBaseView, BaseText2SQLView]]]: + """ + Returns the view classes mapping based on the configuration. + """ + return { + db: cls + for db, views in self.config.setup.views.items() + for view in views + if issubclass(cls := VIEWS_REGISTRY[view], (SqlAlchemyBaseView, BaseText2SQLView)) + } + + @cached_property + def collection(self) -> Collection: + """ + Returns the collection used for evaluation. """ - generator_llm = self.get_llm(config.generator_llm) - selector_llm = self.get_llm(config.selector_llm) - view_selector = LLMViewSelector(selector_llm) + view_selector = LLMViewSelector(self.selector) collection = dbally.create_collection( - name=config.name, - llm=generator_llm, + name=self.config.setup.name, + llm=self.generator, view_selector=view_selector, ) collection.n_retries = 0 - for db_name, view_names in config.views.items(): - db = create_engine(f"sqlite:///data/{db_name}.db") - for view_name in view_names: - view_cls = VIEWS_REGISTRY[view_name] - collection.add(view_cls, lambda: view_cls(db)) # pylint: disable=cell-var-from-loop + for db, view in self.views.items(): + collection.add(view, lambda: view(self.dbs[db])) # pylint: disable=cell-var-from-loop return collection @@ -68,6 +79,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: try: result = await self.collection.ask( question=data["question"], + contexts=self.contexts, dry_run=True, return_natural_response=False, ) @@ -85,10 +97,10 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: prediction = ExecutionResult( view_name=result.view_name, iql=IQLResult( - filters=IQL(source=result.context["iql"]["filters"]), - aggregation=IQL(source=result.context["iql"]["aggregation"]), + filters=IQL(source=result.metadata["iql"]["filters"]), + aggregation=IQL(source=result.metadata["iql"]["aggregation"]), ), - sql=result.context["sql"], + sql=result.metadata["sql"], ) reference = ExecutionResult( diff --git a/benchmarks/sql/bench/pipelines/view.py b/benchmarks/sql/bench/pipelines/view.py index be9d8263..237f2858 100644 --- a/benchmarks/sql/bench/pipelines/view.py +++ b/benchmarks/sql/bench/pipelines/view.py @@ -1,76 +1,53 @@ # pylint: disable=duplicate-code from abc import ABC, abstractmethod +from functools import cached_property from typing import Any, Dict, Type -from sqlalchemy import create_engine - +from dbally.llms.base import LLM from dbally.views.exceptions import ViewExecutionError from dbally.views.freeform.text2sql.view import BaseText2SQLView from dbally.views.sqlalchemy_base import SqlAlchemyBaseView from ..views import VIEWS_REGISTRY -from .base import IQL, EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult +from .base import IQL, EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult, ViewEvaluationMixin, ViewT -class ViewEvaluationPipeline(EvaluationPipeline, ABC): +class ViewEvaluationPipeline(EvaluationPipeline, ViewEvaluationMixin[ViewT], ABC): """ View evaluation pipeline. """ - def __init__(self, config: Dict) -> None: - """ - Constructs the pipeline for evaluating IQL predictions. - - Args: - config: The configuration for the pipeline. - """ - self.llm = self.get_llm(config.setup.llm) - self.dbs = self.get_dbs(config.setup) - self.views = self.get_views(config.setup) - - def get_dbs(self, config: Dict) -> Dict: + @cached_property + def llm(self) -> LLM: """ - Returns the database object based on the database name. - - Args: - config: The database configuration. - - Returns: - The database object. + Returns the LLM based on the configuration. """ - return {db: create_engine(f"sqlite:///data/{db}.db") for db in config.views} + return self._get_llm(self.config.setup.llm) + @cached_property @abstractmethod - def get_views(self, config: Dict) -> Dict[str, Type[SqlAlchemyBaseView]]: + def views(self) -> Dict[str, Type[ViewT]]: """ - Creates the view classes mapping based on the configuration. - - Args: - config: The views configuration. - - Returns: - The view classes mapping. + Returns the view classes mapping based on the configuration """ -class IQLViewEvaluationPipeline(ViewEvaluationPipeline): +class IQLViewEvaluationPipeline(ViewEvaluationPipeline[SqlAlchemyBaseView]): """ IQL view evaluation pipeline. """ - def get_views(self, config: Dict) -> Dict[str, Type[SqlAlchemyBaseView]]: + @cached_property + def views(self) -> Dict[str, Type[SqlAlchemyBaseView]]: """ - Creates the view classes mapping based on the configuration. - - Args: - config: The views configuration. - - Returns: - The view classes mapping. + Returns the view classes mapping based on the configuration. """ return { - view_name: VIEWS_REGISTRY[view_name] for view_names in config.views.values() for view_name in view_names + view: cls + for views in self.config.setup.views.values() + for view in views + if issubclass(cls := VIEWS_REGISTRY[view], SqlAlchemyBaseView) } async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: @@ -89,6 +66,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: result = await view.ask( query=data["question"], llm=self.llm, + contexts=self.contexts, dry_run=True, n_retries=0, ) @@ -104,10 +82,10 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: prediction = ExecutionResult( view_name=data["view_name"], iql=IQLResult( - filters=IQL(source=result.context["iql"]["filters"]), - aggregation=IQL(source=result.context["iql"]["aggregation"]), + filters=IQL(source=result.metadata["iql"]["filters"]), + aggregation=IQL(source=result.metadata["iql"]["aggregation"]), ), - sql=result.context["sql"], + sql=result.metadata["sql"], ) reference = ExecutionResult( @@ -135,22 +113,21 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: ) -class SQLViewEvaluationPipeline(ViewEvaluationPipeline): +class SQLViewEvaluationPipeline(ViewEvaluationPipeline[BaseText2SQLView]): """ SQL view evaluation pipeline. """ - def get_views(self, config: Dict) -> Dict[str, Type[BaseText2SQLView]]: + @cached_property + def views(self) -> Dict[str, Type[BaseText2SQLView]]: """ - Creates the view classes mapping based on the configuration. - - Args: - config: The views configuration. - - Returns: - The view classes mapping. + Returns the view classes mapping based on the configuration. """ - return {db_id: VIEWS_REGISTRY[view_name] for db_id, view_name in config.views.items()} + return { + db: cls + for db, view in self.config.setup.views.items() + if issubclass(cls := VIEWS_REGISTRY[view], BaseText2SQLView) + } async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: """ @@ -179,7 +156,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: else: prediction = ExecutionResult( view_name=view.__class__.__name__, - sql=result.context["sql"], + sql=result.metadata["sql"], ) reference = ExecutionResult( diff --git a/benchmarks/sql/bench/views/structured/__init__.py b/benchmarks/sql/bench/views/structured/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/benchmarks/sql/config/setup/collection.yaml b/benchmarks/sql/config/setup/collection.yaml index 2eafb34a..f1d3a13a 100644 --- a/benchmarks/sql/config/setup/collection.yaml +++ b/benchmarks/sql/config/setup/collection.yaml @@ -5,3 +5,5 @@ defaults: - llm@generator_llm: gpt-3.5-turbo - views/structured@views: - superhero + - contexts: + - superhero diff --git a/benchmarks/sql/config/setup/contexts/superhero.yaml b/benchmarks/sql/config/setup/contexts/superhero.yaml new file mode 100644 index 00000000..fcb6f70f --- /dev/null +++ b/benchmarks/sql/config/setup/contexts/superhero.yaml @@ -0,0 +1,4 @@ +superhero: [ + UserContext, + SuperheroContext, +] diff --git a/benchmarks/sql/config/setup/iql-view.yaml b/benchmarks/sql/config/setup/iql-view.yaml index e652bc3b..be482a85 100644 --- a/benchmarks/sql/config/setup/iql-view.yaml +++ b/benchmarks/sql/config/setup/iql-view.yaml @@ -4,3 +4,5 @@ defaults: - llm: gpt-3.5-turbo - views/structured@views: - superhero + - contexts: + - superhero diff --git a/benchmarks/sql/config/setup/sql-view.yaml b/benchmarks/sql/config/setup/sql-view.yaml index e4e1f7d9..5d9de669 100644 --- a/benchmarks/sql/config/setup/sql-view.yaml +++ b/benchmarks/sql/config/setup/sql-view.yaml @@ -4,3 +4,5 @@ defaults: - llm: gpt-3.5-turbo - views/freeform@views: - superhero + - contexts: + - superhero diff --git a/docs/how-to/use_elastic_vector_store_code.py b/docs/how-to/use_elastic_vector_store_code.py index 4817fcf4..48c3309f 100644 --- a/docs/how-to/use_elastic_vector_store_code.py +++ b/docs/how-to/use_elastic_vector_store_code.py @@ -91,7 +91,7 @@ async def main(country="United States", years_of_experience="2"): f"Find someone from the {country} with more than {years_of_experience} years of experience." ) - print(f"The generated SQL query is: {result.context.get('sql')}") + print(f"The generated SQL query is: {result.metadata.get('sql')}") print() print(f"Retrieved {len(result.results)} candidates:") for candidate in result.results: diff --git a/docs/how-to/use_elasticsearch_store_code.py b/docs/how-to/use_elasticsearch_store_code.py index 1f690c35..181a3c89 100644 --- a/docs/how-to/use_elasticsearch_store_code.py +++ b/docs/how-to/use_elasticsearch_store_code.py @@ -95,7 +95,7 @@ async def main(country="United States", years_of_experience="2"): f"Find someone from the {country} with more than {years_of_experience} years of experience." ) - print(f"The generated SQL query is: {result.context.get('sql')}") + print(f"The generated SQL query is: {result.metadata.get('sql')}") print() print(f"Retrieved {len(result.results)} candidates:") for candidate in result.results: diff --git a/docs/how-to/views/custom_views_code.py b/docs/how-to/views/custom_views_code.py index c64a2ffb..fb10e7fd 100644 --- a/docs/how-to/views/custom_views_code.py +++ b/docs/how-to/views/custom_views_code.py @@ -64,7 +64,7 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: print(self._filter) filtered_data = list(filter(self._filter, self.get_data())) - return ViewExecutionResult(results=filtered_data, context={}) + return ViewExecutionResult(results=filtered_data, metadata={}) class CandidateView(FilteredIterableBaseView): diff --git a/examples/intro.py b/examples/intro.py index 7bce4e46..23bd07dd 100644 --- a/examples/intro.py +++ b/examples/intro.py @@ -1,5 +1,6 @@ # pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring, duplicate-code + import asyncio import sqlalchemy @@ -64,7 +65,7 @@ async def main(): result = await collection.ask("Find me French candidates suitable for a senior data scientist position.") - print(f"The generated SQL query is: {result.context.get('sql')}") + print(f"The generated SQL query is: {result.metadata.get('sql')}") print() print(f"Retrieved {len(result.results)} candidates:") for candidate in result.results: diff --git a/examples/semantic_similarity.py b/examples/semantic_similarity.py index 098f167a..178ef6d9 100644 --- a/examples/semantic_similarity.py +++ b/examples/semantic_similarity.py @@ -126,7 +126,7 @@ async def main(): result = await collection.ask("Find someone from the United States with more than 2 years of experience.") - print(f"The generated SQL query is: {result.context.get('sql')}") + print(f"The generated SQL query is: {result.metadata.get('sql')}") print() print(f"Retrieved {len(result.results)} candidates:") for candidate in result.results: diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py index 5c97a016..887b11a8 100644 --- a/src/dbally/audit/event_handlers/cli_event_handler.py +++ b/src/dbally/audit/event_handlers/cli_event_handler.py @@ -139,7 +139,7 @@ async def request_end(self, output: RequestEnd, request_context: Optional[dict] self._print_syntax("[green bold]REQUEST OUTPUT:") self._print_syntax(f"Number of rows: {len(output.result.results)}") - if "sql" in output.result.context: - self._print_syntax(f"{output.result.context['sql']}", "psql") + if "sql" in output.result.metadata: + self._print_syntax(f"{output.result.metadata['sql']}", "psql") else: self._print_syntax("[red bold]No results found") diff --git a/src/dbally/audit/event_handlers/langsmith_event_handler.py b/src/dbally/audit/event_handlers/langsmith_event_handler.py index 89394f8d..b5f41b7e 100644 --- a/src/dbally/audit/event_handlers/langsmith_event_handler.py +++ b/src/dbally/audit/event_handlers/langsmith_event_handler.py @@ -102,5 +102,5 @@ async def request_end(self, output: RequestEnd, request_context: RunTree) -> Non output: The output of the request. In this case - PSQL query. request_context: Optional context passed from request_start method """ - request_context.end(outputs={"sql": output.result.context["sql"]}) + request_context.end(outputs={"sql": output.result.metadata["sql"]}) request_context.post(exclude_child_runs=False) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 542f78e4..a01d7dc7 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -12,6 +12,7 @@ from dbally.audit.events import FallbackEvent, RequestEnd, RequestStart from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ExecutionResult, ViewExecutionResult +from dbally.context import Context from dbally.iql_generator.prompt import UnsupportedQueryError from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions @@ -227,7 +228,8 @@ async def _ask_view( event_tracker: EventTracker, llm_options: Optional[LLMOptions], dry_run: bool, - ): + contexts: List[Context], + ) -> ViewExecutionResult: """ Ask the selected view to provide an answer to the question. @@ -239,12 +241,13 @@ async def _ask_view( dry_run: If True, only generate the query without executing it. Returns: - Any: The result from the selected view. + The result from the selected view. """ selected_view = self.get(selected_view_name) view_result = await selected_view.ask( query=question, llm=self._llm, + contexts=contexts, event_tracker=event_tracker, n_retries=self.n_retries, dry_run=dry_run, @@ -295,9 +298,11 @@ def get_all_event_handlers(self) -> List[EventHandler]: return self._event_handlers return list(set(self._event_handlers).union(self._fallback_collection.get_all_event_handlers())) + # pylint: disable=too-many-arguments async def _handle_fallback( self, question: str, + contexts: Optional[List[Context]], dry_run: bool, return_natural_response: bool, llm_options: Optional[LLMOptions], @@ -319,7 +324,6 @@ async def _handle_fallback( Returns: The result from the fallback collection. - """ if not self._fallback_collection: raise caught_exception @@ -334,6 +338,7 @@ async def _handle_fallback( async with event_tracker.track_event(fallback_event) as span: result = await self._fallback_collection.ask( question=question, + contexts=contexts, dry_run=dry_run, return_natural_response=return_natural_response, llm_options=llm_options, @@ -345,6 +350,7 @@ async def _handle_fallback( async def ask( self, question: str, + contexts: Optional[List[Context]] = None, dry_run: bool = False, return_natural_response: bool = False, llm_options: Optional[LLMOptions] = None, @@ -362,7 +368,9 @@ async def ask( Args: question: question posed using natural language representation e.g\ - "What job offers for Data Scientists do we have?" + "What job offers for Data Scientists do we have?" + contexts: list of context objects, each being an instance of + a subclass of Context. May contain contexts irrelevant for the currently processed query. dry_run: if True, only generate the query without executing it return_natural_response: if True (and dry_run is False as natural response requires query results), the natural response will be included in the answer @@ -404,6 +412,7 @@ async def ask( event_tracker=event_tracker, llm_options=llm_options, dry_run=dry_run, + contexts=contexts or [], ) end_time_view = time.monotonic() @@ -415,7 +424,7 @@ async def ask( result = ExecutionResult( results=view_result.results, - context=view_result.context, + metadata=view_result.metadata, execution_time=time.monotonic() - start_time, execution_time_view=end_time_view - start_time_view, view_name=selected_view_name, @@ -426,6 +435,7 @@ async def ask( if self._fallback_collection: result = await self._handle_fallback( question=question, + contexts=contexts, dry_run=dry_run, return_natural_response=return_natural_response, llm_options=llm_options, diff --git a/src/dbally/collection/results.py b/src/dbally/collection/results.py index b33cf5e3..65421a34 100644 --- a/src/dbally/collection/results.py +++ b/src/dbally/collection/results.py @@ -14,7 +14,7 @@ class ViewExecutionResult: """ results: List[Dict[str, Any]] - context: Dict[str, Any] + metadata: Dict[str, Any] @dataclass @@ -37,7 +37,7 @@ class ExecutionResult: """ results: List[Dict[str, Any]] - context: Dict[str, Any] + metadata: Dict[str, Any] execution_time: float execution_time_view: float view_name: str diff --git a/src/dbally/context.py b/src/dbally/context.py new file mode 100644 index 00000000..46c65030 --- /dev/null +++ b/src/dbally/context.py @@ -0,0 +1,11 @@ +from abc import ABC +from typing import ClassVar + + +class Context(ABC): + """ + Base class for all contexts that are used to pass additional knowledge about the caller environment to the view. + """ + + type_name: ClassVar[str] = "Context" + alias_name: ClassVar[str] = "CONTEXT" diff --git a/src/dbally/iql/_exceptions.py b/src/dbally/iql/_exceptions.py index 7fd3709e..1d9dbf39 100644 --- a/src/dbally/iql/_exceptions.py +++ b/src/dbally/iql/_exceptions.py @@ -5,7 +5,9 @@ class IQLError(DbAllyError): - """Base exception for all IQL parsing related exceptions.""" + """ + Base exception for all IQL parsing related exceptions. + """ def __init__(self, message: str, source: str) -> None: super().__init__(message) @@ -13,7 +15,9 @@ def __init__(self, message: str, source: str) -> None: class IQLSyntaxError(IQLError): - """Raised when IQL syntax is invalid.""" + """ + Raised when IQL syntax is invalid. + """ def __init__(self, source: str) -> None: message = f"Syntax error in: {source}" @@ -21,7 +25,9 @@ def __init__(self, source: str) -> None: class IQLNoStatementError(IQLError): - """Raised when IQL does not have any statement.""" + """ + Raised when IQL does not have any statement. + """ def __init__(self, source: str) -> None: message = "Empty IQL" @@ -29,7 +35,9 @@ def __init__(self, source: str) -> None: class IQLMultipleStatementsError(IQLError): - """Raised when IQL contains multiple statements.""" + """ + Raised when IQL contains multiple statements. + """ def __init__(self, nodes: List[ast.stmt], source: str) -> None: message = "Multiple statements in IQL are not supported" @@ -37,25 +45,32 @@ def __init__(self, nodes: List[ast.stmt], source: str) -> None: self.nodes = nodes -class IQLExpressionError(IQLError): - """Raised when IQL expression is invalid.""" +class IQLNoExpressionError(IQLError): + """ + Raised when IQL expression is not found. + """ - def __init__(self, message: str, node: ast.expr, source: str) -> None: - message = f"{message}: {source[node.col_offset : node.end_col_offset]}" + def __init__(self, node: ast.stmt, source: str) -> None: + message = f"No expression found in IQL: {source[node.col_offset : node.end_col_offset]}" super().__init__(message, source) self.node = node -class IQLNoExpressionError(IQLExpressionError): - """Raised when IQL expression is not found.""" +class IQLExpressionError(IQLError): + """ + Raised when IQL expression is invalid. + """ - def __init__(self, node: ast.stmt, source: str) -> None: - message = "No expression found in IQL" - super().__init__(message, node, source) + def __init__(self, message: str, node: ast.expr, source: str) -> None: + message = f"{message}: {source[node.col_offset : node.end_col_offset]}" + super().__init__(message, source) + self.node = node class IQLArgumentParsingError(IQLExpressionError): - """Raised when an argument cannot be parsed into a valid IQL.""" + """ + Raised when an argument cannot be parsed into a valid IQL. + """ def __init__(self, node: ast.expr, source: str) -> None: message = "Not a valid IQL argument" @@ -63,21 +78,19 @@ def __init__(self, node: ast.expr, source: str) -> None: class IQLUnsupportedSyntaxError(IQLExpressionError): - """Raised when trying to parse an unsupported syntax.""" + """ + Raised when trying to parse an unsupported syntax. + """ def __init__(self, node: ast.expr, source: str, context: Optional[str] = None) -> None: - node_name = node.__class__.__name__ - - message = f"{node_name} syntax is not supported in IQL" - - if context: - message += " " + context - + message = f"{node.__class__.__name__} syntax is not supported in IQL{f' {context}' if context else ''}" super().__init__(message, node, source) class IQLFunctionNotExists(IQLExpressionError): - """Raised when IQL contains function call to a function that not exists.""" + """ + Raised when IQL contains function call to a function that not exists. + """ def __init__(self, node: ast.Name, source: str) -> None: message = f"Function {node.id} not exists" @@ -85,12 +98,43 @@ def __init__(self, node: ast.Name, source: str) -> None: class IQLIncorrectNumberArgumentsError(IQLExpressionError): - """Raised when IQL contains too many arguments for a function.""" + """ + Raised when IQL contains too many arguments for a function. + """ def __init__(self, node: ast.Call, source: str) -> None: - message = f"The method {node.func.id} has incorrect number of arguments" + message = f"The method {node.func.id} has incorrect number of arguments" # type: ignore super().__init__(message, node, source) class IQLArgumentValidationError(IQLExpressionError): - """Raised when argument is not valid for a given method.""" + """ + Raised when argument is not valid for a given method. + """ + + +class IQLContextError(IQLExpressionError): + """ + Base exception for all IQL context related exceptions. + """ + + +class IQLContextNotAllowedError(IQLContextError): + """ + Raised when a context keyword has been passed as an argument to the method that does not support contextualization. + """ + + def __init__(self, node: ast.Name, source: str) -> None: + message = "The context keyword is not allowed here" + super().__init__(message, node, source) + + +class IQLContextNotFoundError(IQLContextError): + """ + Raised when a context keyword has been passed as an argument to the method that does support contextualization + but no matching context found. + """ + + def __init__(self, node: ast.Name, source: str) -> None: + message = "The requested context is not found" + super().__init__(message, node, source) diff --git a/src/dbally/iql/_processor.py b/src/dbally/iql/_processor.py index 1bd72bcc..d66e6977 100644 --- a/src/dbally/iql/_processor.py +++ b/src/dbally/iql/_processor.py @@ -1,12 +1,16 @@ import ast +import asyncio from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Generic, List, Optional, TypeVar, Union +from typing import Any, Generic, List, Optional, TypeVar, Union from dbally.audit.event_tracker import EventTracker +from dbally.context import Context from dbally.iql import syntax from dbally.iql._exceptions import ( IQLArgumentParsingError, IQLArgumentValidationError, + IQLContextNotAllowedError, + IQLContextNotFoundError, IQLFunctionNotExists, IQLIncorrectNumberArgumentsError, IQLMultipleStatementsError, @@ -16,9 +20,7 @@ IQLUnsupportedSyntaxError, ) from dbally.iql._type_validators import validate_arg_type - -if TYPE_CHECKING: - from dbally.views.structured import ExposedFunction +from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping RootT = TypeVar("RootT", bound=syntax.Node) @@ -29,10 +31,15 @@ class IQLProcessor(Generic[RootT], ABC): """ def __init__( - self, source: str, allowed_functions: List["ExposedFunction"], event_tracker: Optional[EventTracker] = None + self, + source: str, + allowed_functions: List[ExposedFunction], + allowed_contexts: Optional[List[Context]] = None, + event_tracker: Optional[EventTracker] = None, ) -> None: self.source = source self.allowed_functions = {func.name: func for func in allowed_functions} + self.contexts = {context.alias_name: context for context in allowed_contexts or []} self._event_tracker = event_tracker or EventTracker() async def process(self) -> RootT: @@ -64,54 +71,61 @@ async def process(self) -> RootT: return await self._parse_node(ast_tree.body[0].value) @abstractmethod - async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> RootT: - """ - Parses AST node to IQL node. - - Args: - node: AST node to parse. - - Returns: - IQL node. - """ + async def _parse_node(self, node: ast.expr) -> RootT: + ... async def _parse_call(self, node: ast.Call) -> syntax.FunctionCall: - func = node.func - - if not isinstance(func, ast.Name): + if not isinstance(node.func, ast.Name): raise IQLUnsupportedSyntaxError(node, self.source, context="FunctionCall") - if func.id not in self.allowed_functions: - raise IQLFunctionNotExists(func, self.source) + if node.func.id not in self.allowed_functions: + raise IQLFunctionNotExists(node.func, self.source) - func_def = self.allowed_functions[func.id] - args = [] + func_def = self.allowed_functions[node.func.id] if len(func_def.parameters) != len(node.args): raise IQLIncorrectNumberArgumentsError(node, self.source) - for arg, arg_def in zip(node.args, func_def.parameters): - arg_value = self._parse_arg(arg) + parsed_args = await asyncio.gather( + *[self._parse_arg(arg, arg_def) for arg, arg_def in zip(node.args, func_def.parameters)] + ) - if arg_def.similarity_index: - arg_value = await arg_def.similarity_index.similar(arg_value, event_tracker=self._event_tracker) + args = [ + self._validate_and_cast_arg(arg, arg_def, node_arg) + for arg, arg_def, node_arg in zip(parsed_args, func_def.parameters, node.args) + ] - check_result = validate_arg_type(arg_def.type, arg_value) + return syntax.FunctionCall(node.func.id, args) - if not check_result.valid: - raise IQLArgumentValidationError(message=check_result.reason or "", node=arg, source=self.source) + async def _parse_arg(self, arg: ast.expr, arg_def: MethodParamWithTyping) -> Any: + if isinstance(arg, ast.List): + return await asyncio.gather(*[self._parse_arg(x, arg_def) for x in arg.elts]) - args.append(check_result.casted_value if check_result.casted_value is not ... else arg_value) + if isinstance(arg, ast.Name): + aliases = [context.alias_name for context in arg_def.contexts] - return syntax.FunctionCall(func.id, args) + if arg.id not in aliases: + raise IQLContextNotAllowedError(arg, self.source) - def _parse_arg(self, arg: ast.expr) -> Any: - if isinstance(arg, ast.List): - return [self._parse_arg(x) for x in arg.elts] + if context := self.contexts.get(arg.id): + return context + + raise IQLContextNotFoundError(arg, self.source) if not isinstance(arg, ast.Constant): raise IQLArgumentParsingError(arg, self.source) - return arg.value + + return ( + await arg_def.similarity_index.similar(arg.value, self._event_tracker) + if arg_def.similarity_index is not None + else arg.value + ) + + def _validate_and_cast_arg(self, arg: Any, arg_def: MethodParamWithTyping, node: ast.expr) -> Any: + check_result = validate_arg_type(arg_def.type, arg) + if not check_result.valid: + raise IQLArgumentValidationError(message=check_result.reason or "", node=node, source=self.source) + return check_result.casted_value if check_result.casted_value is not ... else arg @staticmethod def _to_lower_except_in_quotes(text: str, keywords: List[str]) -> str: @@ -155,7 +169,7 @@ class IQLFiltersProcessor(IQLProcessor[syntax.Node]): IQL processor for filters. """ - async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.Node: + async def _parse_node(self, node: ast.expr) -> syntax.Node: if isinstance(node, ast.BoolOp): return await self._parse_bool_op(node) if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not): @@ -181,7 +195,7 @@ class IQLAggregationProcessor(IQLProcessor[syntax.FunctionCall]): IQL processor for aggregation. """ - async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.FunctionCall: + async def _parse_node(self, node: ast.expr) -> syntax.FunctionCall: if isinstance(node, ast.Call): return await self._parse_call(node) diff --git a/src/dbally/iql/_query.py b/src/dbally/iql/_query.py index 57b3b4ed..1d8e1ef2 100644 --- a/src/dbally/iql/_query.py +++ b/src/dbally/iql/_query.py @@ -1,12 +1,15 @@ from abc import ABC from typing import TYPE_CHECKING, Generic, List, Optional, Type +from typing_extensions import Self + from ..audit.event_tracker import EventTracker from . import syntax from ._processor import IQLAggregationProcessor, IQLFiltersProcessor, IQLProcessor, RootT if TYPE_CHECKING: - from dbally.views.structured import ExposedFunction + from dbally.context import Context + from dbally.views.exposed_functions import ExposedFunction class IQLQuery(Generic[RootT], ABC): @@ -30,14 +33,16 @@ async def parse( cls, source: str, allowed_functions: List["ExposedFunction"], + allowed_contexts: Optional[List["Context"]] = None, event_tracker: Optional[EventTracker] = None, - ) -> "IQLQuery[RootT]": + ) -> Self: """ Parse IQL string to IQLQuery object. Args: source: IQL string that needs to be parsed. allowed_functions: List of IQL functions that are allowed for this query. + allowed_contexts: List of contexts that are allowed for this query. event_tracker: EventTracker object to track events. Returns: @@ -46,7 +51,12 @@ async def parse( Raises: IQLError: If parsing fails. """ - root = await cls._processor(source, allowed_functions, event_tracker=event_tracker).process() + root = await cls._processor( + source=source, + allowed_functions=allowed_functions, + allowed_contexts=allowed_contexts, + event_tracker=event_tracker, + ).process() return cls(root=root, source=source) @@ -55,7 +65,7 @@ class IQLFiltersQuery(IQLQuery[syntax.Node]): IQL filters query container. """ - _processor: Type[IQLFiltersProcessor] = IQLFiltersProcessor + _processor: Type[IQLProcessor[syntax.Node]] = IQLFiltersProcessor class IQLAggregationQuery(IQLQuery[syntax.FunctionCall]): @@ -63,4 +73,4 @@ class IQLAggregationQuery(IQLQuery[syntax.FunctionCall]): IQL aggregation query container. """ - _processor: Type[IQLAggregationProcessor] = IQLAggregationProcessor + _processor: Type[IQLProcessor[syntax.FunctionCall]] = IQLAggregationProcessor diff --git a/src/dbally/iql/_type_validators.py b/src/dbally/iql/_type_validators.py index 7932cff7..7f0bce48 100644 --- a/src/dbally/iql/_type_validators.py +++ b/src/dbally/iql/_type_validators.py @@ -1,6 +1,7 @@ from dataclasses import dataclass -from typing import _GenericAlias # type: ignore -from typing import Any, Callable, Dict, Literal, Optional, Type, Union +from typing import Any, Callable, Dict, Literal, Optional, Type, Union, _GenericAlias # type: ignore + +from typing_extensions import Annotated, get_args, get_origin @dataclass @@ -10,15 +11,30 @@ class _ValidationResult: reason: Optional[str] = None +def _check_annotated(required_type: Annotated, value: Any) -> _ValidationResult: + type_args = get_args(required_type) + return validate_arg_type(type_args[0], value) + + def _check_literal(required_type: _GenericAlias, value: Any) -> _ValidationResult: - if value not in required_type.__args__: - return _ValidationResult( - False, reason=f"{value} must be one of [{', '.join(repr(x) for x in required_type.__args__)}]" - ) + type_args = get_args(required_type) + if value not in type_args: + return _ValidationResult(False, reason=f"{value} must be one of [{', '.join(repr(x) for x in type_args)}]") return _ValidationResult(True) +def _check_union(required_type: _GenericAlias, value: Any) -> _ValidationResult: + type_args = get_args(required_type) + + for subtype in get_args(required_type): + res = validate_arg_type(subtype, value) + if res.valid: + return _ValidationResult(True) + + return _ValidationResult(False, reason=f"{repr(value)} is not of type {', '.join(repr(x) for x in type_args)}") + + def _check_float(required_type: Type[float], value: Any) -> _ValidationResult: if isinstance(value, float): return _ValidationResult(True) @@ -47,7 +63,9 @@ def _check_bool(required_type: Type[bool], value: Any) -> _ValidationResult: TYPE_VALIDATOR: Dict[Any, Callable[[Any, Any], _ValidationResult]] = { + Annotated: _check_annotated, Literal: _check_literal, + Union: _check_union, float: _check_float, int: _check_int, bool: _check_bool, @@ -65,11 +83,9 @@ def validate_arg_type(required_type: Union[Type, _GenericAlias], value: Any) -> Returns: _ValidationResult instance """ - actual_type = required_type.__origin__ if isinstance(required_type, _GenericAlias) else required_type - - custom_type_checker = TYPE_VALIDATOR.get(actual_type) + actual_type = get_origin(required_type) or required_type - if custom_type_checker: + if custom_type_checker := TYPE_VALIDATOR.get(actual_type): return custom_type_checker(required_type, value) if isinstance(value, actual_type): diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index 4ea65340..c6700ef3 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -3,6 +3,7 @@ from typing import Generic, List, Optional, TypeVar, Union from dbally.audit.event_tracker import EventTracker +from dbally.context import Context from dbally.iql import IQLError, IQLQuery from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.iql_generator.prompt import ( @@ -29,18 +30,8 @@ class IQLGeneratorState: State of the IQL generator. """ - filters: Optional[Union[IQLFiltersQuery, Exception]] = None - aggregation: Optional[Union[IQLAggregationQuery, Exception]] = None - - @property - def failed(self) -> bool: - """ - Checks if the generation failed. - - Returns: - True if the generation failed, False otherwise. - """ - return isinstance(self.filters, Exception) or isinstance(self.aggregation, Exception) + filters: Optional[Union[IQLFiltersQuery, BaseException]] = None + aggregation: Optional[Union[IQLAggregationQuery, BaseException]] = None class IQLGenerator: @@ -76,6 +67,7 @@ async def __call__( question: str, filters: List[ExposedFunction], aggregations: List[ExposedFunction], + contexts: List[Context], examples: List[FewShotExample], llm: LLM, event_tracker: Optional[EventTracker] = None, @@ -89,6 +81,7 @@ async def __call__( question: User question. filters: List of filters exposed by the view. aggregations: List of aggregations exposed by the view. + contexts: List of contexts to be injected after filters and aggregation generation. examples: List of examples to be injected during filters and aggregation generation. llm: LLM used to generate IQL. event_tracker: Event store used to audit the generation process. @@ -98,10 +91,11 @@ async def __call__( Returns: Generated IQL operations. """ - filters, aggregation = await asyncio.gather( + iql_filters, iql_aggregation = await asyncio.gather( self._filters_generation( question=question, methods=filters, + contexts=contexts, examples=examples, llm=llm, llm_options=llm_options, @@ -111,6 +105,7 @@ async def __call__( self._aggregation_generation( question=question, methods=aggregations, + contexts=contexts, examples=examples, llm=llm, llm_options=llm_options, @@ -120,8 +115,8 @@ async def __call__( return_exceptions=True, ) return IQLGeneratorState( - filters=filters, - aggregation=aggregation, + filters=iql_filters, + aggregation=iql_aggregation, ) @@ -145,11 +140,13 @@ def __init__( self.assessor = IQLQuestionAssessor(assessor_prompt) self.generator = IQLQueryGenerator[IQLQueryT](generator_prompt) + # pylint: disable=too-many-arguments async def __call__( self, *, question: str, methods: List[ExposedFunction], + contexts: List[Context], examples: List[FewShotExample], llm: LLM, event_tracker: Optional[EventTracker] = None, @@ -163,6 +160,7 @@ async def __call__( llm: LLM used to generate IQL. question: User question. methods: List of methods exposed by the view. + contexts: List of contexts to be injected as method arguments. examples: List of examples to be injected into the conversation. event_tracker: Event store used to audit the generation process. llm_options: Options to use for the LLM client. @@ -189,6 +187,7 @@ async def __call__( return await self.generator( question=question, methods=methods, + contexts=contexts, examples=examples, llm=llm, llm_options=llm_options, @@ -213,7 +212,7 @@ async def __call__( llm_options: Optional[LLMOptions] = None, event_tracker: Optional[EventTracker] = None, n_retries: int = 3, - ) -> bool: + ) -> Optional[bool]: """ Decides whether the question requires generating IQL or not. @@ -260,23 +259,26 @@ class IQLQueryGenerator(Generic[IQLQueryT]): def __init__(self, prompt: PromptTemplate[IQLGenerationPromptFormat]) -> None: self.prompt = prompt + # pylint: disable=too-many-arguments async def __call__( self, *, question: str, methods: List[ExposedFunction], + contexts: List[Context], examples: List[FewShotExample], llm: LLM, llm_options: Optional[LLMOptions] = None, event_tracker: Optional[EventTracker] = None, n_retries: int = 3, - ) -> IQLQueryT: + ) -> Optional[IQLQueryT]: """ Generates IQL query for the given question. Args: question: User question. - filters: List of filters exposed by the view. + methods: List of methods exposed by the view. + contexts: List of contexts to be injected as method arguments. examples: List of examples to be injected into the conversation. llm: LLM used to generate IQL. llm_options: Options to use for the LLM client. @@ -294,6 +296,7 @@ async def __call__( prompt_format = IQLGenerationPromptFormat( question=question, methods=methods, + contexts=contexts, examples=examples, ) formatted_prompt = self.prompt.format_prompt(prompt_format) @@ -309,13 +312,12 @@ async def __call__( return await formatted_prompt.response_parser( response=response, allowed_functions=methods, + allowed_contexts=contexts, event_tracker=event_tracker, ) - except LLMError as exc: - if retry == n_retries: - raise exc - except IQLError as exc: + except (IQLError, LLMError) as exc: if retry == n_retries: raise exc - formatted_prompt = formatted_prompt.add_assistant_message(response) - formatted_prompt = formatted_prompt.add_user_message(self.ERROR_MESSAGE.format(error=exc)) + if isinstance(exc, IQLError): + formatted_prompt = formatted_prompt.add_assistant_message(response) + formatted_prompt = formatted_prompt.add_user_message(self.ERROR_MESSAGE.format(error=exc)) diff --git a/src/dbally/iql_generator/prompt.py b/src/dbally/iql_generator/prompt.py index f2c29d62..64c99d8a 100644 --- a/src/dbally/iql_generator/prompt.py +++ b/src/dbally/iql_generator/prompt.py @@ -3,6 +3,7 @@ from typing import List, Optional from dbally.audit.event_tracker import EventTracker +from dbally.context import Context from dbally.exceptions import DbAllyError from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.prompt.elements import FewShotExample @@ -20,6 +21,7 @@ class UnsupportedQueryError(DbAllyError): async def _iql_filters_parser( response: str, allowed_functions: List[ExposedFunction], + allowed_contexts: List[Context], event_tracker: Optional[EventTracker] = None, ) -> IQLFiltersQuery: """ @@ -28,6 +30,7 @@ async def _iql_filters_parser( Args: response: LLM response. allowed_functions: List of functions that can be used in the IQL. + allowed_contexts: List of contexts that can be used in the IQL. event_tracker: Event tracker to be used for auditing. Returns: @@ -42,6 +45,7 @@ async def _iql_filters_parser( return await IQLFiltersQuery.parse( source=response, allowed_functions=allowed_functions, + allowed_contexts=allowed_contexts, event_tracker=event_tracker, ) @@ -49,6 +53,7 @@ async def _iql_filters_parser( async def _iql_aggregation_parser( response: str, allowed_functions: List[ExposedFunction], + allowed_contexts: List[Context], event_tracker: Optional[EventTracker] = None, ) -> IQLAggregationQuery: """ @@ -57,6 +62,7 @@ async def _iql_aggregation_parser( Args: response: LLM response. allowed_functions: List of functions that can be used in the IQL. + allowed_contexts: List of contexts that can be used in the IQL. event_tracker: Event tracker to be used for auditing. Returns: @@ -71,6 +77,7 @@ async def _iql_aggregation_parser( return await IQLAggregationQuery.parse( source=response, allowed_functions=allowed_functions, + allowed_contexts=allowed_contexts, event_tracker=event_tracker, ) @@ -98,7 +105,7 @@ class DecisionPromptFormat(PromptFormat): IQL prompt format, providing a question and filters to be used in the conversation. """ - def __init__(self, *, question: str, examples: List[FewShotExample] = None) -> None: + def __init__(self, *, question: str, examples: Optional[List[FewShotExample]] = None) -> None: """ Constructs a new IQLGenerationPromptFormat instance. @@ -120,6 +127,7 @@ def __init__( *, question: str, methods: List[ExposedFunction], + contexts: List[Context], examples: Optional[List[FewShotExample]] = None, ) -> None: """ @@ -128,12 +136,13 @@ def __init__( Args: question: Question to be asked. methods: List of methods exposed by the view. + contexts: List of contexts to be used in the conversation. examples: List of examples to be injected into the conversation. - aggregations: List of aggregations exposed by the view. """ super().__init__(examples) self.question = question self.methods = "\n".join(str(method) for method in methods) + self.contexts = "\n".join(str(context) for context in contexts) FILTERING_DECISION_TEMPLATE = PromptTemplate[DecisionPromptFormat]( diff --git a/src/dbally/nl_responder/nl_responder.py b/src/dbally/nl_responder/nl_responder.py index 38473e98..eb9d9c82 100644 --- a/src/dbally/nl_responder/nl_responder.py +++ b/src/dbally/nl_responder/nl_responder.py @@ -73,7 +73,7 @@ async def generate_response( if tokens_count > self._max_tokens_count: prompt_format = QueryExplanationPromptFormat( question=question, - context=result.context, + metadata=result.metadata, results=result.results, ) formatted_prompt = self._explainer_prompt_template.format_prompt(prompt_format) diff --git a/src/dbally/nl_responder/prompts.py b/src/dbally/nl_responder/prompts.py index 17f63898..0d465c5f 100644 --- a/src/dbally/nl_responder/prompts.py +++ b/src/dbally/nl_responder/prompts.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import pandas as pd @@ -48,9 +48,9 @@ def __init__( self, *, question: str, - context: Dict[str, Any], + metadata: Dict[str, Any], results: List[Dict[str, Any]], - examples: List[FewShotExample] = None, + examples: Optional[List[FewShotExample]] = None, ) -> None: """ Constructs a new QueryExplanationPromptFormat instance. @@ -63,7 +63,7 @@ def __init__( """ super().__init__(examples) self.question = question - self.query = next((context.get(key) for key in ("iql", "sql", "query") if context.get(key)), question) + self.query = next((metadata.get(key) for key in ("iql", "sql", "query") if metadata.get(key)), question) self.number_of_results = len(results) diff --git a/src/dbally/views/base.py b/src/dbally/views/base.py index 66cbe5b4..d85edfa3 100644 --- a/src/dbally/views/base.py +++ b/src/dbally/views/base.py @@ -1,14 +1,17 @@ import abc from typing import Dict, List, Optional, Tuple +from typing_extensions import TypeAlias + from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult +from dbally.context import Context from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.prompt.elements import FewShotExample from dbally.similarity import AbstractSimilarityIndex -IndexLocation = Tuple[str, str, str] +IndexLocation: TypeAlias = Tuple[str, str, str] class BaseView(metaclass=abc.ABCMeta): @@ -22,6 +25,7 @@ async def ask( self, query: str, llm: LLM, + contexts: Optional[List[Context]] = None, event_tracker: Optional[EventTracker] = None, n_retries: int = 3, dry_run: bool = False, @@ -33,6 +37,8 @@ async def ask( Args: query: The natural language query to execute. llm: The LLM used to execute the query. + contexts: An iterable (typically a list) of context objects, each being + an instance of a subclass of Context. event_tracker: The event tracker used to audit the query execution. n_retries: The number of retries to execute the query in case of errors. dry_run: If True, the query will not be used to fetch data from the datasource. @@ -53,9 +59,9 @@ def list_similarity_indexes(self) -> Dict[AbstractSimilarityIndex, List[IndexLoc def list_few_shots(self) -> List[FewShotExample]: """ - List all examples to be injected into few-shot prompt. + Lists all examples to be injected into few-shot prompt. Returns: - List of few-shot examples + List of few-shot examples. """ return [] diff --git a/src/dbally/views/exposed_functions.py b/src/dbally/views/exposed_functions.py index efccb7ef..bb03fe68 100644 --- a/src/dbally/views/exposed_functions.py +++ b/src/dbally/views/exposed_functions.py @@ -1,34 +1,13 @@ import re from dataclasses import dataclass -from typing import List, Optional, Union, _GenericAlias # type: ignore +from typing import List, Optional, Type, Union, _GenericAlias # type: ignore -from typing_extensions import _AnnotatedAlias +from typing_extensions import _AnnotatedAlias, get_origin +from dbally.context import Context from dbally.similarity import AbstractSimilarityIndex -def parse_param_type(param_type: Union[type, _GenericAlias]) -> str: - """ - Parses the type of a method parameter and returns a string representation of it. - - Args: - param_type: type of the parameter - - Returns: - str: string representation of the type - """ - if hasattr(param_type, "__name__"): - return param_type.__name__ - - if param_type.__module__ == "typing": - return re.sub(r"\btyping\.", "", str(param_type)) - - if isinstance(param_type, _AnnotatedAlias): - return parse_param_type(param_type.__origin__) - - return str(param_type) - - @dataclass class MethodParamWithTyping: """ @@ -39,18 +18,51 @@ class MethodParamWithTyping: type: Union[type, _GenericAlias] def __str__(self) -> str: - return f"{self.name}: {parse_param_type(self.type)}" + return f"{self.name}: {self._parse_type()}" + + @property + def contexts(self) -> List[Type[Context]]: + """ + Returns the contexts if the type is annotated with them. + """ + return [arg for arg in getattr(self.type, "__args__", []) if issubclass(arg, Context)] @property def similarity_index(self) -> Optional[AbstractSimilarityIndex]: """ Returns the SimilarityIndex object if the type is annotated with it. """ - if hasattr(self.type, "__metadata__"): - similarity_indexes = [meta for meta in self.type.__metadata__ if isinstance(meta, AbstractSimilarityIndex)] - return similarity_indexes[0] if similarity_indexes else None + return next( + (arg for arg in getattr(self.type, "__metadata__", []) if isinstance(arg, AbstractSimilarityIndex)), None + ) + + def _parse_type(self) -> str: + """ + Parses the type of a method parameter and returns a string representation of it. + + Returns: + String representation of the type. + """ + + def _parse_type_inner(param_type: Union[type, _GenericAlias]) -> str: + if get_origin(param_type) is Union: + return " | ".join(_parse_type_inner(arg) for arg in self.type.__args__) + + if param_type.__module__ == "typing": + return re.sub(r"\btyping\.", "", str(param_type)) + + if issubclass(param_type, Context): + return param_type.type_name + + if hasattr(param_type, "__name__"): + return param_type.__name__ + + if isinstance(self.type, _AnnotatedAlias): + return _parse_type_inner(self.type.__origin__) + + return str(param_type) - return None + return _parse_type_inner(self.type) @dataclass diff --git a/src/dbally/views/freeform/text2sql/view.py b/src/dbally/views/freeform/text2sql/view.py index 1dfa8f62..c53007c0 100644 --- a/src/dbally/views/freeform/text2sql/view.py +++ b/src/dbally/views/freeform/text2sql/view.py @@ -8,6 +8,7 @@ from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult +from dbally.context import Context from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.prompt.template import PromptTemplate @@ -99,6 +100,7 @@ async def ask( self, query: str, llm: LLM, + contexts: Optional[List[Context]] = None, event_tracker: Optional[EventTracker] = None, n_retries: int = 3, dry_run: bool = False, @@ -111,6 +113,7 @@ async def ask( Args: query: The natural language query to execute. llm: The LLM used to execute the query. + contexts: Currently not used. event_tracker: The event tracker used to audit the query execution. n_retries: The number of retries to execute the query in case of errors. dry_run: If True, the query will not be used to fetch data from the datasource. @@ -148,7 +151,7 @@ async def ask( ) if dry_run: - return ViewExecutionResult(results=[], context={"sql": sql}) + return ViewExecutionResult(results=[], metadata={"sql": sql}) rows = await self._execute_sql(sql, parameters, event_tracker=event_tracker) break @@ -164,7 +167,7 @@ async def ask( # pylint: disable=protected-access return ViewExecutionResult( results=[dict(row._mapping) for row in rows], - context={ + metadata={ "sql": sql, }, ) diff --git a/src/dbally/views/methods_base.py b/src/dbally/views/methods_base.py index 8bf93363..38e64eda 100644 --- a/src/dbally/views/methods_base.py +++ b/src/dbally/views/methods_base.py @@ -23,33 +23,34 @@ def list_methods_by_decorator(cls, decorator: Callable) -> List[ExposedFunction] Lists all methods decorated with the given decorator. Args: - decorator: The decorator to filter the methods + decorator: The decorator to filter the methods. Returns: - List of exposed methods + List of exposed methods. """ - methods = [] - for method_name in dir(cls): - method = getattr(cls, method_name) - if ( - hasattr(method, "_methodDecorator") - and method._methodDecorator == decorator # pylint: disable=protected-access - ): - annotations = method.__annotations__.items() - methods.append( - ExposedFunction( - name=method_name, - description=textwrap.dedent(method.__doc__).strip() if method.__doc__ else "", - parameters=[ - MethodParamWithTyping(n, t) for n, t in annotations if n not in cls.HIDDEN_ARGUMENTS - ], + # pylint: disable=protected-access + return [ + ExposedFunction( + name=method_name, + description=textwrap.dedent(method.__doc__).strip() if method.__doc__ else "", + parameters=[ + MethodParamWithTyping( + name=name, + type=type, ) - ) - return methods + for name, type in method.__annotations__.items() + if name not in cls.HIDDEN_ARGUMENTS + ], + ) + for method_name in dir(cls) + if (method := getattr(cls, method_name)) + and hasattr(method, "_methodDecorator") + and method._methodDecorator == decorator + ] def list_filters(self) -> List[ExposedFunction]: """ - List filters in the given view + List filters in the given view. Returns: Filters defined inside the View and decorated with `decorators.view_filter`. @@ -58,7 +59,7 @@ def list_filters(self) -> List[ExposedFunction]: def list_aggregations(self) -> List[ExposedFunction]: """ - List aggregations in the given view + List aggregations in the given view. Returns: Aggregations defined inside the View and decorated with `decorators.view_aggregation`. @@ -72,12 +73,12 @@ def _method_with_args_from_call( Converts a IQL FunctionCall node to a method object and its arguments. Args: - func: IQL FunctionCall node + func: IQL FunctionCall node. method_decorator: The decorator that the method should have - (currently allows discrimination between filters and aggregations) + (currently allows discrimination between filters and aggregations). Returns: - Tuple with the method object and its arguments + Tuple with the method object and its arguments. """ decorator_name = method_decorator.__name__ @@ -114,10 +115,10 @@ async def call_filter_method(self, func: syntax.FunctionCall) -> Any: Converts a IQL FunctonCall filter to a method call. If the method is a coroutine, it will be awaited. Args: - func: IQL FunctionCall node + func: IQL FunctionCall node. Returns: - The result of the method call + The result of the method call. """ method, args = self._method_with_args_from_call(func, decorators.view_filter) return await self._call_method(method, args) @@ -127,10 +128,10 @@ async def call_aggregation_method(self, func: syntax.FunctionCall) -> Any: Converts a IQL FunctonCall aggregation to a method call. If the method is a coroutine, it will be awaited. Args: - func: IQL FunctionCall node + func: IQL FunctionCall node. Returns: - The result of the method call + The result of the method call. """ method, args = self._method_with_args_from_call(func, decorators.view_aggregation) return await self._call_method(method, args) diff --git a/src/dbally/views/pandas_base.py b/src/dbally/views/pandas_base.py index e4da84c4..7be901db 100644 --- a/src/dbally/views/pandas_base.py +++ b/src/dbally/views/pandas_base.py @@ -128,7 +128,7 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: return ViewExecutionResult( results=results.to_dict(orient="records"), - context={ + metadata={ "filter_mask": self._filter_mask, "groupbys": self._aggregation_group.groupbys, "aggregations": self._aggregation_group.aggregations, diff --git a/src/dbally/views/sqlalchemy_base.py b/src/dbally/views/sqlalchemy_base.py index 3a7c7981..8a887194 100644 --- a/src/dbally/views/sqlalchemy_base.py +++ b/src/dbally/views/sqlalchemy_base.py @@ -104,5 +104,5 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: return ViewExecutionResult( results=results, - context={"sql": sql}, + metadata={"sql": sql}, ) diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index 2e5cff85..2e477216 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -4,6 +4,7 @@ from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult +from dbally.context import Context from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.iql_generator.iql_generator import IQLGenerator from dbally.llms.base import LLM @@ -34,6 +35,7 @@ async def ask( self, query: str, llm: LLM, + contexts: Optional[List[Context]] = None, event_tracker: Optional[EventTracker] = None, n_retries: int = 3, dry_run: bool = False, @@ -46,6 +48,7 @@ async def ask( Args: query: The natural language query to execute. llm: The LLM used to execute the query. + contexts: The context data to be used in the query. event_tracker: The event tracker used to audit the query execution. n_retries: The number of retries to execute the query in case of errors. dry_run: If True, the query will not be used to fetch data from the datasource. @@ -57,15 +60,17 @@ async def ask( Raises: ViewExecutionError: When an error occurs while executing the view. """ + contexts = contexts or [] filters = self.list_filters() - examples = self.list_few_shots() aggregations = self.list_aggregations() + examples = self.list_few_shots() iql_generator = self.get_iql_generator() iql = await iql_generator( question=query, filters=filters, aggregations=aggregations, + contexts=contexts, examples=examples, llm=llm, event_tracker=event_tracker, @@ -73,7 +78,7 @@ async def ask( n_retries=n_retries, ) - if iql.failed: + if isinstance(iql.filters, BaseException) or isinstance(iql.aggregation, BaseException): raise ViewExecutionError( view_name=self.__class__.__name__, iql=iql, @@ -86,7 +91,7 @@ async def ask( await self.apply_aggregation(iql.aggregation) result = self.execute(dry_run=dry_run) - result.context["iql"] = { + result.metadata["iql"] = { "filters": str(iql.filters) if iql.filters else None, "aggregation": str(iql.aggregation) if iql.aggregation else None, } @@ -149,8 +154,11 @@ def list_similarity_indexes(self) -> Dict[AbstractSimilarityIndex, List[IndexLoc """ indexes = defaultdict(list) filters = self.list_filters() - for filter_ in filters: - for param in filter_.parameters: + aggregations = self.list_aggregations() + + for method in filters + aggregations: + for param in method.parameters: if param.similarity_index: - indexes[param.similarity_index].append((self.__class__.__name__, filter_.name, param.name)) + indexes[param.similarity_index].append((self.__class__.__name__, method.name, param.name)) + return indexes diff --git a/tests/unit/iql/test_iql_parser.py b/tests/unit/iql/test_iql_parser.py index bed83d0a..c74e5eda 100644 --- a/tests/unit/iql/test_iql_parser.py +++ b/tests/unit/iql/test_iql_parser.py @@ -1,11 +1,16 @@ import re -from typing import List +from dataclasses import dataclass +from typing import List, Union import pytest +from typing_extensions import Annotated +from dbally.context import Context from dbally.iql import IQLArgumentParsingError, IQLUnsupportedSyntaxError, syntax from dbally.iql._exceptions import ( IQLArgumentValidationError, + IQLContextNotAllowedError, + IQLContextNotFoundError, IQLFunctionNotExists, IQLIncorrectNumberArgumentsError, IQLMultipleStatementsError, @@ -43,13 +48,124 @@ async def test_iql_filter_parser(): name_filter, city_filter, company_filter = and_op.children assert isinstance(name_filter, syntax.FunctionCall) - assert name_filter.arguments[0] == ["John", "Anne"] + assert name_filter.arguments == [["John", "Anne"]] assert isinstance(city_filter, syntax.FunctionCall) - assert city_filter.arguments[0] == "cracow" + assert city_filter.arguments == ["cracow"] assert isinstance(company_filter, syntax.FunctionCall) - assert company_filter.arguments[0] == "deepsense.ai" + assert company_filter.arguments == ["deepsense.ai"] + + +async def test_iql_filter_context_parser(): + @dataclass + class TestCustomContext(Context): + city: str + + test_context = TestCustomContext(city="cracow") + parsed = await IQLFiltersQuery.parse( + "not (filter_by_city('Bydgoszcz') and filter_by_city(CONTEXT) and filter_by_company('deepsense.ai'))", + allowed_functions=[ + ExposedFunction( + name="filter_by_name", description="", parameters=[MethodParamWithTyping(name="name", type=List[str])] + ), + ExposedFunction( + name="filter_by_city", + description="", + parameters=[MethodParamWithTyping(name="city", type=Union[str, TestCustomContext])], + ), + ExposedFunction( + name="filter_by_company", + description="", + parameters=[ + MethodParamWithTyping( + name="company", type=Annotated[Union[str, TestCustomContext], lambda x: x, "context"] + ) + ], + ), + ], + allowed_contexts=[ + test_context, + ], + ) + + not_op = parsed.root + assert isinstance(not_op, syntax.Not) + + and_op = not_op.child + assert isinstance(and_op, syntax.And) + + name_filter, city_filter, company_filter = and_op.children + + assert isinstance(name_filter, syntax.FunctionCall) + assert name_filter.arguments == ["Bydgoszcz"] + + assert isinstance(city_filter, syntax.FunctionCall) + assert city_filter.arguments == [test_context] + + assert isinstance(company_filter, syntax.FunctionCall) + assert company_filter.arguments == ["deepsense.ai"] + + +async def test_iql_filter_context_not_allowed_error(): + @dataclass + class TestCustomContext(Context): + city: str + + with pytest.raises(IQLContextNotAllowedError) as exc_info: + await IQLFiltersQuery.parse( + "not (filter_by_city('Bydgoszcz') and filter_by_city(CONTEXT) and filter_by_company('deepsense.ai'))", + allowed_functions=[ + ExposedFunction( + name="filter_by_name", + description="", + parameters=[MethodParamWithTyping(name="name", type=List[str])], + ), + ExposedFunction( + name="filter_by_city", description="", parameters=[MethodParamWithTyping(name="city", type=str)] + ), + ExposedFunction( + name="filter_by_company", + description="", + parameters=[MethodParamWithTyping(name="company", type=str)], + ), + ], + allowed_contexts=[ + TestCustomContext(city="cracow"), + ], + ) + + assert exc_info.match(re.escape("The context keyword is not allowed here")) + + +async def test_iql_filter_context_not_found_error(): + @dataclass + class TestCustomContext(Context): + city: str + + with pytest.raises(IQLContextNotFoundError) as exc_info: + await IQLFiltersQuery.parse( + "not (filter_by_city('Bydgoszcz') and filter_by_city(CONTEXT) and filter_by_company('deepsense.ai'))", + allowed_functions=[ + ExposedFunction( + name="filter_by_name", + description="", + parameters=[MethodParamWithTyping(name="name", type=List[str])], + ), + ExposedFunction( + name="filter_by_city", + description="", + parameters=[MethodParamWithTyping(name="city", type=Union[str, TestCustomContext])], + ), + ExposedFunction( + name="filter_by_company", + description="", + parameters=[MethodParamWithTyping(name="company", type=str)], + ), + ], + ) + + assert exc_info.match(re.escape("The requested context is not found: CONTEXT")) async def test_iql_filter_parser_arg_error(): @@ -242,6 +358,80 @@ async def test_iql_aggregation_parser(): assert parsed.root.arguments == ["Paris"] +async def test_iql_aggregation_context_parser(): + @dataclass + class TestCustomContext(Context): + city: str + + test_context = TestCustomContext(city="cracow") + parsed = await IQLAggregationQuery.parse( + "mean_age_by_city(CONTEXT)", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[ + MethodParamWithTyping(name="city", type=Union[str, TestCustomContext]), + ], + ), + ], + allowed_contexts=[ + test_context, + ], + ) + + assert isinstance(parsed.root, syntax.FunctionCall) + assert parsed.root.name == "mean_age_by_city" + assert parsed.root.arguments == [test_context] + + +async def test_iql_aggregation_context_not_allowed_error(): + @dataclass + class TestCustomContext(Context): + city: str + + with pytest.raises(IQLContextNotAllowedError) as exc_info: + await IQLAggregationQuery.parse( + "mean_age_by_city(CONTEXT)", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[ + MethodParamWithTyping(name="city", type=str), + ], + ), + ], + allowed_contexts=[ + TestCustomContext(city="cracow"), + ], + ) + + assert exc_info.match(re.escape("The context keyword is not allowed here")) + + +async def test_iql_aggregation_context_not_found_error(): + @dataclass + class TestCustomContext(Context): + city: str + + with pytest.raises(IQLContextNotFoundError) as exc_info: + await IQLAggregationQuery.parse( + "mean_age_by_city(CONTEXT)", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[ + MethodParamWithTyping(name="city", type=Union[str, TestCustomContext]), + ], + ), + ], + ) + + assert exc_info.match(re.escape("The requested context is not found: CONTEXT")) + + async def test_iql_aggregation_parser_arg_error(): with pytest.raises(IQLArgumentParsingError) as exc_info: await IQLAggregationQuery.parse( diff --git a/tests/unit/iql/test_type_validators.py b/tests/unit/iql/test_type_validators.py index 88af46e7..4356a518 100644 --- a/tests/unit/iql/test_type_validators.py +++ b/tests/unit/iql/test_type_validators.py @@ -1,4 +1,6 @@ -from typing import List, Literal +from typing import List, Literal, Union + +from typing_extensions import Annotated from dbally.iql._type_validators import validate_arg_type @@ -20,6 +22,38 @@ def test_list_validator(): assert result.valid is False +def test_annotated_validator(): + result = validate_arg_type(Annotated[str, "This is some value"], "smth") + assert result.valid is True + assert result.casted_value == ... + assert result.reason is None + + result = validate_arg_type(Annotated[str, "This is some value"], 5) + assert result.valid is False + + +def test_union_validator(): + result = validate_arg_type(Union[str, int], "smth") + assert result.valid is True + assert result.casted_value == ... + assert result.reason is None + + result = validate_arg_type(Union[str, int], 5) + assert result.valid is True + assert result.casted_value == ... + assert result.reason is None + + result = validate_arg_type(Union[str, int], 5.0) + assert result.valid is True + assert result.casted_value == ... + assert result.reason is None + + result = validate_arg_type(Union[str, int], [1, 2, 3]) + assert result.valid is False + assert result.casted_value == ... + assert result.reason == "[1, 2, 3] is not of type , " + + def test_simple_types(): assert validate_arg_type(int, 5).valid is True assert validate_arg_type(int, "smth").valid is False diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 69174389..1c125f32 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -36,7 +36,7 @@ async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: ... def execute(self, dry_run: bool = False) -> ViewExecutionResult: - return ViewExecutionResult(results=[], context={}) + return ViewExecutionResult(results=[], metadata={}) class MockIQLGenerator(IQLGenerator): diff --git a/tests/unit/similarity/sample_module/submodule.py b/tests/unit/similarity/sample_module/submodule.py index ab4b6c7e..00aa016b 100644 --- a/tests/unit/similarity/sample_module/submodule.py +++ b/tests/unit/similarity/sample_module/submodule.py @@ -27,7 +27,7 @@ async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: ... def execute(self, dry_run: bool = False) -> ViewExecutionResult: - return ViewExecutionResult(results=[], context={}) + return ViewExecutionResult(results=[], metadata={}) class BarView(MethodsBaseView): @@ -49,4 +49,4 @@ async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: ... def execute(self, dry_run: bool = False) -> ViewExecutionResult: - return ViewExecutionResult(results=[], context={}) + return ViewExecutionResult(results=[], metadata={}) diff --git a/tests/unit/test_collection.py b/tests/unit/test_collection.py index 1d675d84..089b7edd 100644 --- a/tests/unit/test_collection.py +++ b/tests/unit/test_collection.py @@ -55,7 +55,7 @@ class MockViewWithResults(MockViewBase): """ def execute(self, dry_run=False) -> ViewExecutionResult: - return ViewExecutionResult(results=[{"foo": "bar"}], context={"baz": "qux"}) + return ViewExecutionResult(results=[{"foo": "bar"}], metadata={"baz": "qux"}) def list_filters(self) -> List[ExposedFunction]: return [ExposedFunction("test_filter", "", [])] @@ -88,7 +88,7 @@ class MockViewWithSimilarity(MockViewBase): """ def execute(self, dry_run=False) -> ViewExecutionResult: - return ViewExecutionResult(results=[{"foo": "bar"}], context={"baz": "qux"}) + return ViewExecutionResult(results=[{"foo": "bar"}], metadata={"baz": "qux"}) def list_filters(self) -> List[ExposedFunction]: return [ @@ -115,7 +115,7 @@ class MockViewWithSimilarity2(MockViewBase): """ def execute(self, dry_run=False) -> ViewExecutionResult: - return ViewExecutionResult(results=[{"foo": "bar"}], context={"baz": "qux"}) + return ViewExecutionResult(results=[{"foo": "bar"}], metadata={"baz": "qux"}) def list_filters(self) -> List[ExposedFunction]: return [ @@ -300,7 +300,7 @@ async def test_ask_view_selection_single_view() -> None: result = await collection.ask("Mock question") assert result.view_name == "MockViewWithResults" assert result.results == [{"foo": "bar"}] - assert result.context == {"baz": "qux", "iql": {"aggregation": "test_aggregation()", "filters": "test_filter()"}} + assert result.metadata == {"baz": "qux", "iql": {"aggregation": "test_aggregation()", "filters": "test_filter()"}} async def test_ask_view_selection_multiple_views() -> None: @@ -321,7 +321,7 @@ async def test_ask_view_selection_multiple_views() -> None: result = await collection.ask("Mock question") assert result.view_name == "MockViewWithResults" assert result.results == [{"foo": "bar"}] - assert result.context == {"baz": "qux", "iql": {"aggregation": "test_aggregation()", "filters": "test_filter()"}} + assert result.metadata == {"baz": "qux", "iql": {"aggregation": "test_aggregation()", "filters": "test_filter()"}} async def test_ask_view_selection_no_views() -> None: diff --git a/tests/unit/test_fallback_collection.py b/tests/unit/test_fallback_collection.py index 137581b6..4a4dbb86 100644 --- a/tests/unit/test_fallback_collection.py +++ b/tests/unit/test_fallback_collection.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Iterable, List, Optional from unittest.mock import AsyncMock, Mock import pytest @@ -8,6 +8,7 @@ from dbally.audit import CLIEventHandler, EventTracker, OtelEventHandler from dbally.audit.event_handlers.buffer_event_handler import BufferEventHandler from dbally.collection import Collection, ViewExecutionResult +from dbally.context import Context from dbally.iql_generator.prompt import UnsupportedQueryError from dbally.llms import LLM from dbally.llms.clients import LLMOptions @@ -37,13 +38,14 @@ async def ask( self, query: str, llm: LLM, - event_tracker: EventTracker, + event_tracker: Optional[EventTracker], n_retries: int = 3, dry_run: bool = False, llm_options: Optional[LLMOptions] = None, + contexts: Optional[Iterable[Context]] = None, ) -> ViewExecutionResult: return ViewExecutionResult( - results=[{"mock_result": "fallback_result"}], context={"mock_context": "fallback_context"} + results=[{"mock_result": "fallback_result"}], metadata={"mock_context": "fallback_context"} ) @@ -53,7 +55,7 @@ class MockView1(MockViewBase): """ def execute(self, dry_run=False) -> ViewExecutionResult: - return ViewExecutionResult(results=[{"foo": "bar"}], context={"baz": "qux"}) + return ViewExecutionResult(results=[{"foo": "bar"}], metadata={"baz": "qux"}) def get_iql_generator(self, *_, **__) -> MockIQLGenerator: raise UnsupportedQueryError @@ -107,7 +109,7 @@ async def test_fallback_collection(base_collection: Collection, fallback_collect base_collection.set_fallback(fallback_collection) result = await base_collection.ask("Mock fallback question") assert result.results == [{"mock_result": "fallback_result"}] - assert result.context == {"mock_context": "fallback_context"} + assert result.metadata == {"mock_context": "fallback_context"} def test_get_all_event_handlers_no_fallback(): diff --git a/tests/unit/test_iql_format.py b/tests/unit/test_iql_format.py index 3a21a1fe..1f7728c8 100644 --- a/tests/unit/test_iql_format.py +++ b/tests/unit/test_iql_format.py @@ -6,6 +6,7 @@ async def test_iql_prompt_format_default() -> None: prompt_format = IQLGenerationPromptFormat( question="", methods=[], + contexts=[], examples=[], ) formatted_prompt = FILTERS_GENERATION_TEMPLATE.format_prompt(prompt_format) @@ -36,6 +37,7 @@ async def test_iql_prompt_format_few_shots_injected() -> None: prompt_format = IQLGenerationPromptFormat( question="", methods=[], + contexts=[], examples=examples, ) formatted_prompt = FILTERS_GENERATION_TEMPLATE.format_prompt(prompt_format) @@ -68,6 +70,7 @@ async def test_iql_input_format_few_shot_examples_repeat_no_example_duplicates() prompt_format = IQLGenerationPromptFormat( question="", methods=[], + contexts=[], examples=examples, ) formatted_prompt = FILTERS_GENERATION_TEMPLATE.format_prompt(prompt_format) diff --git a/tests/unit/test_iql_generator.py b/tests/unit/test_iql_generator.py index 0defc8e1..a4ff0f52 100644 --- a/tests/unit/test_iql_generator.py +++ b/tests/unit/test_iql_generator.py @@ -7,6 +7,7 @@ from dbally import decorators from dbally.audit.event_tracker import EventTracker +from dbally.collection.results import ViewExecutionResult from dbally.iql import IQLAggregationQuery, IQLError, IQLFiltersQuery from dbally.iql_generator.iql_generator import IQLGenerator, IQLGeneratorState from dbally.views.methods_base import MethodsBaseView @@ -20,7 +21,7 @@ async def apply_filters(self, filters: IQLFiltersQuery) -> None: async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: ... - def execute(self, dry_run: bool = False): + def execute(self, dry_run: bool = False) -> ViewExecutionResult: ... @decorators.view_filter() @@ -85,6 +86,7 @@ async def test_iql_generation( filters=filters, aggregations=aggregations, examples=examples, + contexts=[], llm=llm, event_tracker=event_tracker, ) @@ -96,11 +98,13 @@ async def test_iql_generation( mock_filters_parse.assert_called_once_with( source=llm_responses[1], allowed_functions=filters, + allowed_contexts=[], event_tracker=event_tracker, ) mock_aggregation_parse.assert_called_once_with( source=llm_responses[3], allowed_functions=aggregations, + allowed_contexts=[], event_tracker=event_tracker, ) @@ -150,6 +154,7 @@ async def test_iql_generation_error_escalation_after_max_retires( filters=filters, aggregations=aggregations, examples=examples, + contexts=[], llm=llm, event_tracker=event_tracker, n_retries=3, @@ -211,6 +216,7 @@ async def test_iql_generation_response_after_max_retries( filters=filters, aggregations=aggregations, examples=examples, + contexts=[], llm=llm, event_tracker=event_tracker, n_retries=3, diff --git a/tests/unit/test_nl_responder.py b/tests/unit/test_nl_responder.py index e23fe3d1..3f5815a0 100644 --- a/tests/unit/test_nl_responder.py +++ b/tests/unit/test_nl_responder.py @@ -22,7 +22,7 @@ def event_tracker() -> EventTracker: @pytest.fixture def answer() -> ViewExecutionResult: - return ViewExecutionResult(results=[{"id": 1, "name": "Mock name"}], context={"sql": "Mock SQL"}) + return ViewExecutionResult(results=[{"id": 1, "name": "Mock name"}], metadata={"sql": "Mock SQL"}) @pytest.mark.asyncio diff --git a/tests/unit/views/test_methods_base.py b/tests/unit/views/test_methods_base.py index 57c0b68a..e870e2dc 100644 --- a/tests/unit/views/test_methods_base.py +++ b/tests/unit/views/test_methods_base.py @@ -1,15 +1,25 @@ # pylint: disable=missing-docstring, missing-return-doc, missing-param-doc, disallowed-name - -from typing import List, Literal, Tuple +from dataclasses import dataclass +from typing import List, Literal, Tuple, Union from dbally.collection.results import ViewExecutionResult +from dbally.context import Context from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.views.decorators import view_aggregation, view_filter from dbally.views.exposed_functions import MethodParamWithTyping from dbally.views.methods_base import MethodsBaseView +@dataclass +class CallerContext(Context): + """ + Mock class for testing context. + """ + + current_year: Literal["2023", "2024"] + + class MockMethodsBase(MethodsBaseView): """ Mock class for testing the MethodsBaseView @@ -22,7 +32,9 @@ def method_foo(self, idx: int) -> None: """ @view_filter() - def method_bar(self, cities: List[str], year: Literal["2023", "2024"], pairs: List[Tuple[str, int]]) -> str: + def method_bar( + self, cities: List[str], year: Union[Literal["2023", "2024"], CallerContext], pairs: List[Tuple[str, int]] + ) -> str: return f"hello {cities} in {year} of {pairs}" @view_aggregation() @@ -32,7 +44,9 @@ def method_baz(self) -> None: """ @view_aggregation() - def method_qux(self, ages: List[int], names: List[str]) -> str: + def method_qux( + self, ages: List[int], years: Union[Literal["2023", "2024"], CallerContext], names: List[str] + ) -> str: return f"hello {ages} and {names}" async def apply_filters(self, filters: IQLFiltersQuery) -> None: @@ -42,7 +56,7 @@ async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: ... def execute(self, dry_run: bool = False) -> ViewExecutionResult: - return ViewExecutionResult(results=[], context={}) + return ViewExecutionResult(results=[], metadata={}) def test_list_filters() -> None: @@ -60,11 +74,12 @@ def test_list_filters() -> None: assert method_bar.description == "" assert method_bar.parameters == [ MethodParamWithTyping("cities", List[str]), - MethodParamWithTyping("year", Literal["2023", "2024"]), + MethodParamWithTyping("year", Union[Literal["2023", "2024"], CallerContext]), MethodParamWithTyping("pairs", List[Tuple[str, int]]), ] assert ( - str(method_bar) == "method_bar(cities: List[str], year: Literal['2023', '2024'], pairs: List[Tuple[str, int]])" + str(method_bar) + == "method_bar(cities: List[str], year: Literal['2023', '2024'] | Context, pairs: List[Tuple[str, int]])" ) @@ -83,6 +98,7 @@ def test_list_aggregations() -> None: assert method_qux.description == "" assert method_qux.parameters == [ MethodParamWithTyping("ages", List[int]), + MethodParamWithTyping("years", Union[Literal["2023", "2024"], CallerContext]), MethodParamWithTyping("names", List[str]), ] - assert str(method_qux) == "method_qux(ages: List[int], names: List[str])" + assert str(method_qux) == "method_qux(ages: List[int], years: Literal['2023', '2024'] | Context, names: List[str])" diff --git a/tests/unit/views/test_pandas_base.py b/tests/unit/views/test_pandas_base.py index 029fe30f..b9af69f2 100644 --- a/tests/unit/views/test_pandas_base.py +++ b/tests/unit/views/test_pandas_base.py @@ -85,9 +85,9 @@ async def test_filter_or() -> None: await mock_view.apply_filters(query) result = mock_view.execute() assert result.results == MOCK_DATA_BERLIN_OR_LONDON - assert result.context["filter_mask"].tolist() == [True, False, True, False, True] - assert result.context["groupbys"] is None - assert result.context["aggregations"] is None + assert result.metadata["filter_mask"].tolist() == [True, False, True, False, True] + assert result.metadata["groupbys"] is None + assert result.metadata["aggregations"] is None async def test_filter_and() -> None: @@ -102,9 +102,9 @@ async def test_filter_and() -> None: await mock_view.apply_filters(query) result = mock_view.execute() assert result.results == MOCK_DATA_PARIS_2020 - assert result.context["filter_mask"].tolist() == [False, True, False, False, False] - assert result.context["groupbys"] is None - assert result.context["aggregations"] is None + assert result.metadata["filter_mask"].tolist() == [False, True, False, False, False] + assert result.metadata["groupbys"] is None + assert result.metadata["aggregations"] is None async def test_filter_not() -> None: @@ -119,9 +119,9 @@ async def test_filter_not() -> None: await mock_view.apply_filters(query) result = mock_view.execute() assert result.results == MOCK_DATA_NOT_PARIS_2020 - assert result.context["filter_mask"].tolist() == [True, False, True, True, True] - assert result.context["groupbys"] is None - assert result.context["aggregations"] is None + assert result.metadata["filter_mask"].tolist() == [True, False, True, True, True] + assert result.metadata["groupbys"] is None + assert result.metadata["aggregations"] is None async def test_aggregation() -> None: @@ -138,9 +138,9 @@ async def test_aggregation() -> None: assert result.results == [ {"index": "name_count", "name": 5}, ] - assert result.context["filter_mask"] is None - assert result.context["groupbys"] is None - assert result.context["aggregations"] == [Aggregation(column="name", function="count")] + assert result.metadata["filter_mask"] is None + assert result.metadata["groupbys"] is None + assert result.metadata["aggregations"] == [Aggregation(column="name", function="count")] async def test_aggregtion_with_groupby() -> None: @@ -159,9 +159,9 @@ async def test_aggregtion_with_groupby() -> None: {"city": "London", "age_mean": 32.5}, {"city": "Paris", "age_mean": 32.5}, ] - assert result.context["filter_mask"] is None - assert result.context["groupbys"] == "city" - assert result.context["aggregations"] == [Aggregation(column="age", function="mean")] + assert result.metadata["filter_mask"] is None + assert result.metadata["groupbys"] == "city" + assert result.metadata["aggregations"] == [Aggregation(column="age", function="mean")] async def test_filters_and_aggregtion() -> None: @@ -181,6 +181,6 @@ async def test_filters_and_aggregtion() -> None: await mock_view.apply_aggregation(query) result = mock_view.execute() assert result.results == [{"city": "Paris", "age_mean": 32.5}] - assert result.context["filter_mask"].tolist() == [False, True, False, True, False] - assert result.context["groupbys"] == "city" - assert result.context["aggregations"] == [Aggregation(column="age", function="mean")] + assert result.metadata["filter_mask"].tolist() == [False, True, False, True, False] + assert result.metadata["groupbys"] == "city" + assert result.metadata["aggregations"] == [Aggregation(column="age", function="mean")] diff --git a/tests/unit/views/test_sqlalchemy_base.py b/tests/unit/views/test_sqlalchemy_base.py index 571e6a70..3b621a54 100644 --- a/tests/unit/views/test_sqlalchemy_base.py +++ b/tests/unit/views/test_sqlalchemy_base.py @@ -1,15 +1,22 @@ # pylint: disable=missing-docstring, missing-return-doc, missing-param-doc, disallowed-name import re +from dataclasses import dataclass +from typing import Union import sqlalchemy -from dbally.iql import IQLFiltersQuery -from dbally.iql._query import IQLAggregationQuery +from dbally.context import Context +from dbally.iql import IQLAggregationQuery, IQLFiltersQuery from dbally.views.decorators import view_aggregation, view_filter from dbally.views.sqlalchemy_base import SqlAlchemyBaseView +@dataclass +class SomeTestContext(Context): + age: int + + class MockSqlAlchemyView(SqlAlchemyBaseView): """ Mock class for testing the SqlAlchemyBaseView @@ -29,8 +36,14 @@ def method_foo(self, idx: int) -> sqlalchemy.ColumnElement: async def method_bar(self, city: str, year: int) -> sqlalchemy.ColumnElement: return sqlalchemy.literal(f"hello {city} in {year}") + @view_filter() + async def method_baz(self, age: Union[int, SomeTestContext]) -> sqlalchemy.ColumnElement: + if isinstance(age, SomeTestContext): + return sqlalchemy.literal(age.age) + return sqlalchemy.literal(age) + @view_aggregation() - def method_baz(self) -> sqlalchemy.Select: + def method_agg(self) -> sqlalchemy.Select: """ Some documentation string """ @@ -56,7 +69,7 @@ async def test_filter_sql_generation() -> None: allowed_functions=mock_view.list_filters(), ) await mock_view.apply_filters(query) - sql = normalize_whitespace(mock_view.execute(dry_run=True).context["sql"]) + sql = normalize_whitespace(mock_view.execute(dry_run=True).metadata["sql"]) assert sql == "SELECT 'test' AS foo WHERE 1 AND 'hello London in 2020'" @@ -68,11 +81,11 @@ async def test_aggregation_sql_generation() -> None: mock_connection = sqlalchemy.create_mock_engine("postgresql://", executor=None) mock_view = MockSqlAlchemyView(mock_connection.engine) query = await IQLAggregationQuery.parse( - "method_baz()", + "method_agg()", allowed_functions=mock_view.list_aggregations(), ) await mock_view.apply_aggregation(query) - sql = normalize_whitespace(mock_view.execute(dry_run=True).context["sql"]) + sql = normalize_whitespace(mock_view.execute(dry_run=True).metadata["sql"]) assert sql == "SELECT 'test' AS foo, 'baz' AS anon_1 GROUP BY 'baz'" @@ -89,9 +102,9 @@ async def test_filter_and_aggregation_sql_generation() -> None: ) await mock_view.apply_filters(query) query = await IQLAggregationQuery.parse( - "method_baz()", + "method_agg()", allowed_functions=mock_view.list_aggregations(), ) await mock_view.apply_aggregation(query) - sql = normalize_whitespace(mock_view.execute(dry_run=True).context["sql"]) + sql = normalize_whitespace(mock_view.execute(dry_run=True).metadata["sql"]) assert sql == "SELECT 'test' AS foo, 'baz' AS anon_1 WHERE 1 AND 'hello London in 2020' GROUP BY 'baz'" diff --git a/tests/unit/views/text2sql/test_view.py b/tests/unit/views/text2sql/test_view.py index 91b7b50f..f5c51247 100644 --- a/tests/unit/views/text2sql/test_view.py +++ b/tests/unit/views/text2sql/test_view.py @@ -61,7 +61,7 @@ async def test_text2sql_view(sample_db: Engine): response = await collection.ask("Show me customers from New York") - assert response.context["sql"] == llm_response["sql"] + assert response.metadata["sql"] == llm_response["sql"] assert response.results == [ {"id": 1, "name": "Alice", "city": "New York"}, {"id": 3, "name": "Charlie", "city": "New York"},