diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 45446cc5..6e10aafa 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -3,20 +3,20 @@ import textwrap import time from collections import defaultdict -from typing import Callable, Dict, List, Optional, Type, TypeVar +from typing import Callable, Dict, Iterable, List, Optional, Type, TypeVar from dbally.audit.event_handlers.base import EventHandler from dbally.audit.event_tracker import EventTracker from dbally.audit.events import RequestEnd, RequestStart from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ExecutionResult +from dbally.context.context import CustomContext from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.nl_responder.nl_responder import NLResponder from dbally.similarity.index import AbstractSimilarityIndex from dbally.view_selection.base import ViewSelector from dbally.views.base import BaseView, IndexLocation -from dbally.context.context import BaseCallerContext, CustomContextsList class Collection: @@ -157,7 +157,7 @@ async def ask( dry_run: bool = False, return_natural_response: bool = False, llm_options: Optional[LLMOptions] = None, - contexts: Optional[CustomContextsList] = None + contexts: Optional[Iterable[CustomContext]] = None, ) -> ExecutionResult: """ Ask question in a text form and retrieve the answer based on the available views. @@ -177,6 +177,8 @@ async def ask( the natural response will be included in the answer llm_options: options to use for the LLM client. If provided, these options will be merged with the default options provided to the LLM client, prioritizing option values other than NOT_GIVEN + contexts: An iterable (typically a list) of context objects, each being an instance of + a subclass of BaseCallerContext. May contain contexts irrelevant for the currently processed query. Returns: ExecutionResult object representing the result of the query execution. @@ -217,7 +219,7 @@ async def ask( n_retries=self.n_retries, dry_run=dry_run, llm_options=llm_options, - contexts=contexts + contexts=contexts, ) end_time_view = time.monotonic() diff --git a/src/dbally/context/_utils.py b/src/dbally/context/_utils.py index 7fea4e9b..113b0ca2 100644 --- a/src/dbally/context/_utils.py +++ b/src/dbally/context/_utils.py @@ -1,15 +1,17 @@ -import typing_extensions as type_ext - -from typing import Sequence, Tuple, Optional, Type, Any, Union from inspect import isclass +from typing import Any, Optional, Sequence, Tuple, Type, Union + +import typing_extensions as type_ext from dbally.context.context import BaseCallerContext from dbally.views.exposed_functions import MethodParamWithTyping +ContextClass: type_ext.TypeAlias = Optional[Type[BaseCallerContext]] + def _extract_params_and_context( filter_method_: type_ext.Callable, hidden_args: Sequence[str] -) -> Tuple[Sequence[MethodParamWithTyping], Optional[Type[BaseCallerContext]]]: +) -> Tuple[Sequence[MethodParamWithTyping], ContextClass]: """ Processes the MethodsBaseView filter method signauture to extract the args and type hints in the desired format. Context claases are getting excluded the returned MethodParamWithTyping list. Only the first BaseCallerContext @@ -17,9 +19,10 @@ class is returned. Args: filter_method_: MethodsBaseView filter method (annotated with @decorators.view_filter() decorator) + hidden_args: method arguments that should not be extracted Returns: - A tuple. The first field contains the list of arguments, each encapsulated as MethodParamWithTyping. + The first field contains the list of arguments, each encapsulated as MethodParamWithTyping. The 2nd is the BaseCallerContext subclass provided for this filter, or None if no context specified. """ @@ -52,6 +55,16 @@ class is returned. def _does_arg_allow_context(arg: MethodParamWithTyping) -> bool: + """ + Verifies whether a method argument allows contextualization based on the type hints attached to a method signature. + + Args: + arg: MethodParamWithTyping container preserving information about the method argument + + Returns: + Verification result. + """ + if type_ext.get_origin(arg.type) is not Union and not issubclass(arg.type, BaseCallerContext): return False diff --git a/src/dbally/context/context.py b/src/dbally/context/context.py index d3af935d..baf86e93 100644 --- a/src/dbally/context/context.py +++ b/src/dbally/context/context.py @@ -1,30 +1,59 @@ import ast +from typing import Iterable -from typing import Sequence, TypeVar -from typing_extensions import Self from pydantic import BaseModel +from typing_extensions import Self, TypeAlias from dbally.context.exceptions import ContextNotAvailableError - -CustomContext = TypeVar('CustomContext', bound='BaseCallerContext', covariant=True) -CustomContextsList = Sequence[CustomContext] # TODO confirm the naming +# CustomContext = TypeVar('CustomContext', bound='BaseCallerContext', covariant=True) +CustomContext: TypeAlias = "BaseCallerContext" class BaseCallerContext(BaseModel): """ - Base class for contexts that are used to pass additional knowledge about the caller environment to the filters. It is not made abstract for the convinience of IQL parsing. - LLM will always return `BaseCallerContext()` when the context is required and this call will be later substitue by a proper subclass instance selected based on the filter method signature (type hints). + Pydantic-based record class. Base class for contexts that are used to pass additional knowledge about + the caller environment to the filters. It is not made abstract for the convinience of IQL parsing. + LLM will always return `BaseCallerContext()` when the context is required and this call will be + later substituted by a proper subclass instance selected based on the filter method signature (type hints). """ @classmethod - def select_context(cls, contexts: CustomContextsList) -> Self: + def select_context(cls, contexts: Iterable[CustomContext]) -> Self: + """ + Typically called from a subclass of BaseCallerContext, selects a member object from `contexts` being + an instance of the same class. Effectively provides a type dispatch mechanism, substituting the context + class by its right instance. + + Args: + contexts: A sequence of objects, each being an instance of a different BaseCallerContext subclass. + + Returns: + An instance of the same BaseCallerContext subclass this method is caller from. + + Raises: + ContextNotAvailableError: If the sequence of context objects passed as argument is empty. + """ + if not contexts: - raise ContextNotAvailableError("The LLM detected that the context is required to execute the query and the filter signature allows contextualization while the context was not provided.") + raise ContextNotAvailableError( + "The LLM detected that the context is required to execute the query +\ + and the filter signature allows contextualization while the context was not provided." + ) - # this method is called from the subclass of BaseCallerContext pointing the right type of custom context - return next(filter(lambda obj: isinstance(obj, cls), contexts)) + # TODO confirm whether it is possible to design a correct type hints here and skipping `type: ignore` + return next(filter(lambda obj: isinstance(obj, cls), contexts)) # type: ignore @classmethod def is_context_call(cls, node: ast.expr) -> bool: + """ + Verifies whether an AST node indicates context substitution. + + Args: + node: An AST node (expression) to verify: + + Returns: + Verification result. + """ + return isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == cls.__name__ diff --git a/src/dbally/context/exceptions.py b/src/dbally/context/exceptions.py index 91482d5d..0efa1473 100644 --- a/src/dbally/context/exceptions.py +++ b/src/dbally/context/exceptions.py @@ -3,17 +3,22 @@ class BaseContextException(Exception, ABC): """ - A base exception for all specification context-related exception. + A base (abstract) exception for all specification context-related exception. """ - pass class ContextNotAvailableError(Exception): - pass + """ + An exception inheriting from BaseContextException pointining that no sufficient context information + was provided by the user while calling view.ask(). + """ class ContextualisationNotAllowed(Exception): - pass + """ + An exception inheriting from BaseContextException pointining that the filter method signature + does not allow to provide an additional context. + """ # WORKAROUND - traditional inhertiance syntax is not working in context of abstract Exceptions diff --git a/src/dbally/iql/_processor.py b/src/dbally/iql/_processor.py index b6a1648a..fb9c57be 100644 --- a/src/dbally/iql/_processor.py +++ b/src/dbally/iql/_processor.py @@ -1,8 +1,10 @@ import ast - -from typing import TYPE_CHECKING, Any, List, Optional, Union, Mapping, Type +from typing import Any, Iterable, List, Mapping, Optional, Union from dbally.audit.event_tracker import EventTracker +from dbally.context._utils import _does_arg_allow_context +from dbally.context.context import BaseCallerContext, CustomContext +from dbally.context.exceptions import ContextualisationNotAllowed from dbally.iql import syntax from dbally.iql._exceptions import ( IQLArgumentParsingError, @@ -12,35 +14,48 @@ IQLUnsupportedSyntaxError, ) from dbally.iql._type_validators import validate_arg_type -from dbally.context.context import BaseCallerContext, CustomContextsList -from dbally.context.exceptions import ContextNotAvailableError, ContextualisationNotAllowed -from dbally.context._utils import _extract_params_and_context, _does_arg_allow_context -from dbally.views.exposed_functions import MethodParamWithTyping, ExposedFunction +from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping class IQLProcessor: """ Parses IQL string to tree structure. + + Attributes: + source: Raw LLM response containing IQL filter calls. + allowed_functions: A mapping (typically a dict) of all filters implemented for a certain View. + contexts: A sequence (typically a list) of context objects, each being an instance of + a subclass of BaseCallerContext. May contain contexts irrelevant for the currently processed query. """ + source: str allowed_functions: Mapping[str, "ExposedFunction"] - contexts: CustomContextsList + contexts: Iterable[CustomContext] _event_tracker: EventTracker - def __init__( self, source: str, - allowed_functions: List["ExposedFunction"], - contexts: Optional[CustomContextsList] = None, - event_tracker: Optional[EventTracker] = None + allowed_functions: Iterable[ExposedFunction], + contexts: Optional[Iterable[CustomContext]] = None, + event_tracker: Optional[EventTracker] = None, ) -> None: + """ + IQLProcessor class constructor. + + Args: + source: Raw LLM response containing IQL filter calls. + allowed_functions: An interable (typically a list) of all filters implemented for a certain View. + contexts: An iterable (typically a list) of context objects, each being an instance of + a subclass of BaseCallerContext. + even_tracker: An EvenTracker instance. + """ + self.source = source self.allowed_functions = {func.name: func for func in allowed_functions} self.contexts = contexts or [] self._event_tracker = event_tracker or EventTracker() - async def process(self) -> syntax.Node: """ Process IQL string to root IQL.Node. @@ -89,7 +104,7 @@ async def _parse_call(self, node: ast.Call) -> syntax.FunctionCall: if not isinstance(func, ast.Name): raise IQLUnsupportedSyntaxError(node, self.source, context="FunctionCall") - if func.id not in self.allowed_functions: # TODO add context class constructors to self.allowed_functions + if func.id not in self.allowed_functions: raise IQLFunctionNotExists(func, self.source) func_def = self.allowed_functions[func.id] @@ -117,9 +132,8 @@ def _parse_arg( self, arg: ast.expr, arg_spec: Optional[MethodParamWithTyping] = None, - parent_func_def: Optional[ExposedFunction] = None + parent_func_def: Optional[ExposedFunction] = None, ) -> Any: - if isinstance(arg, ast.List): return [self._parse_arg(x) for x in arg.elts] @@ -129,10 +143,16 @@ def _parse_arg( raise IQLArgumentParsingError(arg, self.source) if parent_func_def.context_class is None: - raise ContextualisationNotAllowed("The LLM detected that the context is required to execute the query while the filter signature does not allow it at all.") + raise ContextualisationNotAllowed( + "The LLM detected that the context is required +\ + to execute the query while the filter signature does not allow it at all." + ) if not _does_arg_allow_context(arg_spec): - raise ContextualisationNotAllowed(f"The LLM detected that the context is required to execute the query while the filter signature does allow it for `{arg_spec.name}` argument.") + raise ContextualisationNotAllowed( + f"The LLM detected that the context is required +\ + to execute the query while the filter signature does allow it for `{arg_spec.name}` argument." + ) return parent_func_def.context_class.select_context(self.contexts) diff --git a/src/dbally/iql/_query.py b/src/dbally/iql/_query.py index 6a610c24..cc090ad6 100644 --- a/src/dbally/iql/_query.py +++ b/src/dbally/iql/_query.py @@ -1,10 +1,12 @@ -from typing import TYPE_CHECKING, List, Optional, Type +from typing import TYPE_CHECKING, Iterable, List, Optional + from typing_extensions import Self +from dbally.context.context import CustomContext + from ..audit.event_tracker import EventTracker from . import syntax from ._processor import IQLProcessor -from dbally.context.context import BaseCallerContext, CustomContextsList if TYPE_CHECKING: from dbally.views.structured import ExposedFunction @@ -30,7 +32,7 @@ async def parse( source: str, allowed_functions: List["ExposedFunction"], event_tracker: Optional[EventTracker] = None, - contexts: Optional[CustomContextsList] = None + contexts: Optional[Iterable[CustomContext]] = None, ) -> Self: """ Parse IQL string to IQLQuery object. @@ -39,6 +41,8 @@ async def parse( source: IQL string that needs to be parsed allowed_functions: list of IQL functions that are allowed for this query event_tracker: EventTracker object to track events + contexts: An iterable (typically a list) of context objects, each being + an instance of a subclass of BaseCallerContext. Returns: IQLQuery object """ diff --git a/src/dbally/iql/_type_validators.py b/src/dbally/iql/_type_validators.py index 848b17ac..7b993ef5 100644 --- a/src/dbally/iql/_type_validators.py +++ b/src/dbally/iql/_type_validators.py @@ -1,9 +1,9 @@ -import typing_extensions as type_ext - from dataclasses import dataclass from typing import _GenericAlias # type: ignore from typing import Any, Callable, Dict, Literal, Optional, Type, Union +import typing_extensions as type_ext + @dataclass class _ValidationResult: @@ -67,8 +67,10 @@ def validate_arg_type(required_type: Union[Type, _GenericAlias], value: Any) -> Returns: _ValidationResult instance """ - actual_type = type_ext.get_origin(required_type) if isinstance(required_type, _GenericAlias) else required_type # typing.Union is an instance of _GenericAlias - if actual_type is None: # workaround to prevent type warning in line `if isisntanc(value, actual_type):`, TODO check whether necessary + actual_type = type_ext.get_origin(required_type) if isinstance(required_type, _GenericAlias) else required_type + # typing.Union is an instance of _GenericAlias + if actual_type is None: + # workaround to prevent type warning in line `if isisntanc(value, actual_type):`, TODO check whether necessary actual_type = required_type.__origin__ if actual_type is Union: @@ -77,7 +79,8 @@ def validate_arg_type(required_type: Union[Type, _GenericAlias], value: Any) -> if res.valid: return _ValidationResult(True) - return _ValidationResult(False, f"{repr(value)} is not of type {repr(required_type)}") # typing.Union does not have __name__ property + # typing.Union does not have __name__ property, thus using repr() is necessary + return _ValidationResult(False, f"{repr(value)} is not of type {repr(required_type)}") custom_type_checker = TYPE_VALIDATOR.get(actual_type) diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index 1946c258..8018f6e1 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -1,6 +1,7 @@ -from typing import List, Optional +from typing import Iterable, List, Optional from dbally.audit.event_tracker import EventTracker +from dbally.context.context import CustomContext from dbally.iql import IQLError, IQLQuery from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat from dbally.llms.base import LLM @@ -8,7 +9,6 @@ from dbally.prompt.elements import FewShotExample from dbally.prompt.template import PromptTemplate from dbally.views.exposed_functions import ExposedFunction -from dbally.context.context import CustomContextsList ERROR_MESSAGE = "Unfortunately, generated IQL is not valid. Please try again, \ generation of correct IQL is very important. Below you have errors generated by the system:\n{error}" @@ -43,8 +43,8 @@ async def generate_iql( examples: Optional[List[FewShotExample]] = None, llm_options: Optional[LLMOptions] = None, n_retries: int = 3, - contexts: Optional[CustomContextsList] = None - ) -> Optional[IQLQuery]: + contexts: Optional[Iterable[CustomContext]] = None, + ) -> IQLQuery: """ Generates IQL in text form using LLM. @@ -55,6 +55,8 @@ async def generate_iql( examples: List of examples to be injected into the conversation. llm_options: Options to use for the LLM client. n_retries: Number of retries to regenerate IQL in case of errors. + contexts: An iterable (typically a list) of context objects, each being + an instance of a subclass of BaseCallerContext. Returns: Generated IQL query. @@ -77,12 +79,12 @@ async def generate_iql( iql = formatted_prompt.response_parser(response) # TODO: Move IQL query parsing to prompt response parser return await IQLQuery.parse( - source=iql, - allowed_functions=filters, - event_tracker=event_tracker, - contexts=contexts + source=iql, allowed_functions=filters, event_tracker=event_tracker, contexts=contexts ) except IQLError as exc: - # TODO handle the possibility of variable `response` being not initialized while runnning the following line + # TODO handle the possibility of variable `response` being not initialized + # while runnning the following line formatted_prompt = formatted_prompt.add_assistant_message(response) formatted_prompt = formatted_prompt.add_user_message(ERROR_MESSAGE.format(error=exc)) + + # TODO handle the situation when all retries fails and the return defaults to None diff --git a/src/dbally/views/base.py b/src/dbally/views/base.py index 43b69dbd..e2292b56 100644 --- a/src/dbally/views/base.py +++ b/src/dbally/views/base.py @@ -1,15 +1,17 @@ import abc -from typing import Dict, List, Optional, Tuple, Type +from typing import Dict, Iterable, List, Optional, Tuple + +from typing_extensions import TypeAlias from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult +from dbally.context.context import CustomContext from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.prompt.elements import FewShotExample from dbally.similarity import AbstractSimilarityIndex -from dbally.context.context import BaseCallerContext, CustomContextsList -IndexLocation = Tuple[str, str, str] +IndexLocation: TypeAlias = Tuple[str, str, str] class BaseView(metaclass=abc.ABCMeta): @@ -27,7 +29,7 @@ async def ask( n_retries: int = 3, dry_run: bool = False, llm_options: Optional[LLMOptions] = None, - contexts: Optional[CustomContextsList] = None + contexts: Optional[Iterable[CustomContext]] = None, ) -> ViewExecutionResult: """ Executes the query and returns the result. @@ -39,6 +41,8 @@ async def ask( n_retries: The number of retries to execute the query in case of errors. dry_run: If True, the query will not be used to fetch data from the datasource. llm_options: Options to use for the LLM. + contexts: An iterable (typically a list) of context objects, each being + an instance of a subclass of BaseCallerContext. Returns: The result of the query. diff --git a/src/dbally/views/exposed_functions.py b/src/dbally/views/exposed_functions.py index c6d400d2..481052f7 100644 --- a/src/dbally/views/exposed_functions.py +++ b/src/dbally/views/exposed_functions.py @@ -1,10 +1,10 @@ import re from dataclasses import dataclass from typing import _GenericAlias # type: ignore -from typing import Sequence, Optional, Union, Type +from typing import Optional, Sequence, Type, Union -from dbally.similarity import AbstractSimilarityIndex from dbally.context.context import BaseCallerContext +from dbally.similarity import AbstractSimilarityIndex def parse_param_type(param_type: Union[type, _GenericAlias]) -> str: diff --git a/src/dbally/views/freeform/text2sql/view.py b/src/dbally/views/freeform/text2sql/view.py index 7f24f00e..4fbb4bef 100644 --- a/src/dbally/views/freeform/text2sql/view.py +++ b/src/dbally/views/freeform/text2sql/view.py @@ -8,6 +8,7 @@ from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult +from dbally.context.context import CustomContext from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.prompt.template import PromptTemplate @@ -103,6 +104,7 @@ async def ask( n_retries: int = 3, dry_run: bool = False, llm_options: Optional[LLMOptions] = None, + contexts: Optional[Iterable[CustomContext]] = None, ) -> ViewExecutionResult: """ Executes the query and returns the result. It generates the SQL query from the natural language query and @@ -115,6 +117,7 @@ async def ask( n_retries: The number of retries to execute the query in case of errors. dry_run: If True, the query will not be used to fetch data from the datasource. llm_options: Options to use for the LLM. + contexts: Currently not used. Returns: The result of the query. diff --git a/src/dbally/views/methods_base.py b/src/dbally/views/methods_base.py index 25baa957..a2a2bd9e 100644 --- a/src/dbally/views/methods_base.py +++ b/src/dbally/views/methods_base.py @@ -3,11 +3,11 @@ import textwrap from typing import Any, Callable, List, Tuple +from dbally.context._utils import _extract_params_and_context from dbally.iql import syntax from dbally.views import decorators -from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping +from dbally.views.exposed_functions import ExposedFunction from dbally.views.structured import BaseStructuredView -from dbally.context._utils import _extract_params_and_context class MethodsBaseView(BaseStructuredView, metaclass=abc.ABCMeta): @@ -43,7 +43,7 @@ def list_methods_by_decorator(cls, decorator: Callable) -> List[ExposedFunction] name=method_name, description=textwrap.dedent(method.__doc__).strip() if method.__doc__ else "", parameters=params, - context_class=context_class + context_class=context_class, ) ) return methods diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index eda1a48c..99f6955a 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -1,16 +1,15 @@ import abc from collections import defaultdict -from typing import Dict, List, Optional, Type +from typing import Dict, Iterable, List, Optional from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult -from dbally.context.context import BaseCallerContext +from dbally.context.context import CustomContext from dbally.iql import IQLQuery from dbally.iql_generator.iql_generator import IQLGenerator from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.views.exposed_functions import ExposedFunction -from dbally.context.context import BaseCallerContext, CustomContextsList from ..similarity import AbstractSimilarityIndex from .base import BaseView, IndexLocation @@ -42,7 +41,7 @@ async def ask( n_retries: int = 3, dry_run: bool = False, llm_options: Optional[LLMOptions] = None, - contexts: Optional[CustomContextsList] = None + contexts: Optional[Iterable[CustomContext]] = None, ) -> ViewExecutionResult: """ Executes the query and returns the result. It generates the IQL query from the natural language query\ @@ -55,6 +54,8 @@ async def ask( n_retries: The number of retries to execute the query in case of errors. dry_run: If True, the query will not be used to fetch data from the datasource. llm_options: Options to use for the LLM. + contexts: An iterable (typically a list) of context objects, each being + an instance of a subclass of BaseCallerContext. Returns: The result of the query. @@ -71,7 +72,7 @@ async def ask( event_tracker=event_tracker, llm_options=llm_options, n_retries=n_retries, - contexts=contexts + contexts=contexts, ) await self.apply_filters(iql)