Skip to content

Commit

Permalink
added missing docstrings; fixed type hints; fixed issues detected by …
Browse files Browse the repository at this point in the history
…pylint; run pre-commit auto refactor
  • Loading branch information
ds-jakub-cierocki committed Jul 4, 2024
1 parent 9ba89e5 commit 5fd802f
Show file tree
Hide file tree
Showing 13 changed files with 158 additions and 72 deletions.
10 changes: 6 additions & 4 deletions src/dbally/collection/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@
import textwrap
import time
from collections import defaultdict
from typing import Callable, Dict, List, Optional, Type, TypeVar
from typing import Callable, Dict, Iterable, List, Optional, Type, TypeVar

from dbally.audit.event_handlers.base import EventHandler
from dbally.audit.event_tracker import EventTracker
from dbally.audit.events import RequestEnd, RequestStart
from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError
from dbally.collection.results import ExecutionResult
from dbally.context.context import CustomContext
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
from dbally.nl_responder.nl_responder import NLResponder
from dbally.similarity.index import AbstractSimilarityIndex
from dbally.view_selection.base import ViewSelector
from dbally.views.base import BaseView, IndexLocation
from dbally.context.context import BaseCallerContext, CustomContextsList


class Collection:
Expand Down Expand Up @@ -157,7 +157,7 @@ async def ask(
dry_run: bool = False,
return_natural_response: bool = False,
llm_options: Optional[LLMOptions] = None,
contexts: Optional[CustomContextsList] = None
contexts: Optional[Iterable[CustomContext]] = None,
) -> ExecutionResult:
"""
Ask question in a text form and retrieve the answer based on the available views.
Expand All @@ -177,6 +177,8 @@ async def ask(
the natural response will be included in the answer
llm_options: options to use for the LLM client. If provided, these options will be merged with the default
options provided to the LLM client, prioritizing option values other than NOT_GIVEN
contexts: An iterable (typically a list) of context objects, each being an instance of
a subclass of BaseCallerContext. May contain contexts irrelevant for the currently processed query.
Returns:
ExecutionResult object representing the result of the query execution.
Expand Down Expand Up @@ -217,7 +219,7 @@ async def ask(
n_retries=self.n_retries,
dry_run=dry_run,
llm_options=llm_options,
contexts=contexts
contexts=contexts,
)
end_time_view = time.monotonic()

Expand Down
23 changes: 18 additions & 5 deletions src/dbally/context/_utils.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
import typing_extensions as type_ext

from typing import Sequence, Tuple, Optional, Type, Any, Union
from inspect import isclass
from typing import Any, Optional, Sequence, Tuple, Type, Union

import typing_extensions as type_ext

from dbally.context.context import BaseCallerContext
from dbally.views.exposed_functions import MethodParamWithTyping

ContextClass: type_ext.TypeAlias = Optional[Type[BaseCallerContext]]


def _extract_params_and_context(
filter_method_: type_ext.Callable, hidden_args: Sequence[str]
) -> Tuple[Sequence[MethodParamWithTyping], Optional[Type[BaseCallerContext]]]:
) -> Tuple[Sequence[MethodParamWithTyping], ContextClass]:
"""
Processes the MethodsBaseView filter method signauture to extract the args and type hints in the desired format.
Context claases are getting excluded the returned MethodParamWithTyping list. Only the first BaseCallerContext
class is returned.
Args:
filter_method_: MethodsBaseView filter method (annotated with @decorators.view_filter() decorator)
hidden_args: method arguments that should not be extracted
Returns:
A tuple. The first field contains the list of arguments, each encapsulated as MethodParamWithTyping.
The first field contains the list of arguments, each encapsulated as MethodParamWithTyping.
The 2nd is the BaseCallerContext subclass provided for this filter, or None if no context specified.
"""

Expand Down Expand Up @@ -52,6 +55,16 @@ class is returned.


def _does_arg_allow_context(arg: MethodParamWithTyping) -> bool:
"""
Verifies whether a method argument allows contextualization based on the type hints attached to a method signature.
Args:
arg: MethodParamWithTyping container preserving information about the method argument
Returns:
Verification result.
"""

if type_ext.get_origin(arg.type) is not Union and not issubclass(arg.type, BaseCallerContext):
return False

Expand Down
51 changes: 40 additions & 11 deletions src/dbally/context/context.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,59 @@
import ast
from typing import Iterable

from typing import Sequence, TypeVar
from typing_extensions import Self
from pydantic import BaseModel
from typing_extensions import Self, TypeAlias

from dbally.context.exceptions import ContextNotAvailableError


CustomContext = TypeVar('CustomContext', bound='BaseCallerContext', covariant=True)
CustomContextsList = Sequence[CustomContext] # TODO confirm the naming
# CustomContext = TypeVar('CustomContext', bound='BaseCallerContext', covariant=True)
CustomContext: TypeAlias = "BaseCallerContext"


class BaseCallerContext(BaseModel):
"""
Base class for contexts that are used to pass additional knowledge about the caller environment to the filters. It is not made abstract for the convinience of IQL parsing.
LLM will always return `BaseCallerContext()` when the context is required and this call will be later substitue by a proper subclass instance selected based on the filter method signature (type hints).
Pydantic-based record class. Base class for contexts that are used to pass additional knowledge about
the caller environment to the filters. It is not made abstract for the convinience of IQL parsing.
LLM will always return `BaseCallerContext()` when the context is required and this call will be
later substituted by a proper subclass instance selected based on the filter method signature (type hints).
"""

@classmethod
def select_context(cls, contexts: CustomContextsList) -> Self:
def select_context(cls, contexts: Iterable[CustomContext]) -> Self:
"""
Typically called from a subclass of BaseCallerContext, selects a member object from `contexts` being
an instance of the same class. Effectively provides a type dispatch mechanism, substituting the context
class by its right instance.
Args:
contexts: A sequence of objects, each being an instance of a different BaseCallerContext subclass.
Returns:
An instance of the same BaseCallerContext subclass this method is caller from.
Raises:
ContextNotAvailableError: If the sequence of context objects passed as argument is empty.
"""

if not contexts:
raise ContextNotAvailableError("The LLM detected that the context is required to execute the query and the filter signature allows contextualization while the context was not provided.")
raise ContextNotAvailableError(
"The LLM detected that the context is required to execute the query +\
and the filter signature allows contextualization while the context was not provided."
)

# this method is called from the subclass of BaseCallerContext pointing the right type of custom context
return next(filter(lambda obj: isinstance(obj, cls), contexts))
# TODO confirm whether it is possible to design a correct type hints here and skipping `type: ignore`
return next(filter(lambda obj: isinstance(obj, cls), contexts)) # type: ignore

@classmethod
def is_context_call(cls, node: ast.expr) -> bool:
"""
Verifies whether an AST node indicates context substitution.
Args:
node: An AST node (expression) to verify:
Returns:
Verification result.
"""

return isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == cls.__name__
13 changes: 9 additions & 4 deletions src/dbally/context/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,22 @@

class BaseContextException(Exception, ABC):
"""
A base exception for all specification context-related exception.
A base (abstract) exception for all specification context-related exception.
"""
pass


class ContextNotAvailableError(Exception):
pass
"""
An exception inheriting from BaseContextException pointining that no sufficient context information
was provided by the user while calling view.ask().
"""


class ContextualisationNotAllowed(Exception):
pass
"""
An exception inheriting from BaseContextException pointining that the filter method signature
does not allow to provide an additional context.
"""


# WORKAROUND - traditional inhertiance syntax is not working in context of abstract Exceptions
Expand Down
54 changes: 37 additions & 17 deletions src/dbally/iql/_processor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import ast

from typing import TYPE_CHECKING, Any, List, Optional, Union, Mapping, Type
from typing import Any, Iterable, List, Mapping, Optional, Union

from dbally.audit.event_tracker import EventTracker
from dbally.context._utils import _does_arg_allow_context
from dbally.context.context import BaseCallerContext, CustomContext
from dbally.context.exceptions import ContextualisationNotAllowed
from dbally.iql import syntax
from dbally.iql._exceptions import (
IQLArgumentParsingError,
Expand All @@ -12,35 +14,48 @@
IQLUnsupportedSyntaxError,
)
from dbally.iql._type_validators import validate_arg_type
from dbally.context.context import BaseCallerContext, CustomContextsList
from dbally.context.exceptions import ContextNotAvailableError, ContextualisationNotAllowed
from dbally.context._utils import _extract_params_and_context, _does_arg_allow_context
from dbally.views.exposed_functions import MethodParamWithTyping, ExposedFunction
from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping


class IQLProcessor:
"""
Parses IQL string to tree structure.
Attributes:
source: Raw LLM response containing IQL filter calls.
allowed_functions: A mapping (typically a dict) of all filters implemented for a certain View.
contexts: A sequence (typically a list) of context objects, each being an instance of
a subclass of BaseCallerContext. May contain contexts irrelevant for the currently processed query.
"""

source: str
allowed_functions: Mapping[str, "ExposedFunction"]
contexts: CustomContextsList
contexts: Iterable[CustomContext]
_event_tracker: EventTracker


def __init__(
self,
source: str,
allowed_functions: List["ExposedFunction"],
contexts: Optional[CustomContextsList] = None,
event_tracker: Optional[EventTracker] = None
allowed_functions: Iterable[ExposedFunction],
contexts: Optional[Iterable[CustomContext]] = None,
event_tracker: Optional[EventTracker] = None,
) -> None:
"""
IQLProcessor class constructor.
Args:
source: Raw LLM response containing IQL filter calls.
allowed_functions: An interable (typically a list) of all filters implemented for a certain View.
contexts: An iterable (typically a list) of context objects, each being an instance of
a subclass of BaseCallerContext.
even_tracker: An EvenTracker instance.
"""

self.source = source
self.allowed_functions = {func.name: func for func in allowed_functions}
self.contexts = contexts or []
self._event_tracker = event_tracker or EventTracker()


async def process(self) -> syntax.Node:
"""
Process IQL string to root IQL.Node.
Expand Down Expand Up @@ -89,7 +104,7 @@ async def _parse_call(self, node: ast.Call) -> syntax.FunctionCall:
if not isinstance(func, ast.Name):
raise IQLUnsupportedSyntaxError(node, self.source, context="FunctionCall")

if func.id not in self.allowed_functions: # TODO add context class constructors to self.allowed_functions
if func.id not in self.allowed_functions:
raise IQLFunctionNotExists(func, self.source)

func_def = self.allowed_functions[func.id]
Expand Down Expand Up @@ -117,9 +132,8 @@ def _parse_arg(
self,
arg: ast.expr,
arg_spec: Optional[MethodParamWithTyping] = None,
parent_func_def: Optional[ExposedFunction] = None
parent_func_def: Optional[ExposedFunction] = None,
) -> Any:

if isinstance(arg, ast.List):
return [self._parse_arg(x) for x in arg.elts]

Expand All @@ -129,10 +143,16 @@ def _parse_arg(
raise IQLArgumentParsingError(arg, self.source)

if parent_func_def.context_class is None:
raise ContextualisationNotAllowed("The LLM detected that the context is required to execute the query while the filter signature does not allow it at all.")
raise ContextualisationNotAllowed(
"The LLM detected that the context is required +\
to execute the query while the filter signature does not allow it at all."
)

if not _does_arg_allow_context(arg_spec):
raise ContextualisationNotAllowed(f"The LLM detected that the context is required to execute the query while the filter signature does allow it for `{arg_spec.name}` argument.")
raise ContextualisationNotAllowed(
f"The LLM detected that the context is required +\
to execute the query while the filter signature does allow it for `{arg_spec.name}` argument."
)

return parent_func_def.context_class.select_context(self.contexts)

Expand Down
10 changes: 7 additions & 3 deletions src/dbally/iql/_query.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import TYPE_CHECKING, List, Optional, Type
from typing import TYPE_CHECKING, Iterable, List, Optional

from typing_extensions import Self

from dbally.context.context import CustomContext

from ..audit.event_tracker import EventTracker
from . import syntax
from ._processor import IQLProcessor
from dbally.context.context import BaseCallerContext, CustomContextsList

if TYPE_CHECKING:
from dbally.views.structured import ExposedFunction
Expand All @@ -30,7 +32,7 @@ async def parse(
source: str,
allowed_functions: List["ExposedFunction"],
event_tracker: Optional[EventTracker] = None,
contexts: Optional[CustomContextsList] = None
contexts: Optional[Iterable[CustomContext]] = None,
) -> Self:
"""
Parse IQL string to IQLQuery object.
Expand All @@ -39,6 +41,8 @@ async def parse(
source: IQL string that needs to be parsed
allowed_functions: list of IQL functions that are allowed for this query
event_tracker: EventTracker object to track events
contexts: An iterable (typically a list) of context objects, each being
an instance of a subclass of BaseCallerContext.
Returns:
IQLQuery object
"""
Expand Down
13 changes: 8 additions & 5 deletions src/dbally/iql/_type_validators.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import typing_extensions as type_ext

from dataclasses import dataclass
from typing import _GenericAlias # type: ignore
from typing import Any, Callable, Dict, Literal, Optional, Type, Union

import typing_extensions as type_ext


@dataclass
class _ValidationResult:
Expand Down Expand Up @@ -67,8 +67,10 @@ def validate_arg_type(required_type: Union[Type, _GenericAlias], value: Any) ->
Returns:
_ValidationResult instance
"""
actual_type = type_ext.get_origin(required_type) if isinstance(required_type, _GenericAlias) else required_type # typing.Union is an instance of _GenericAlias
if actual_type is None: # workaround to prevent type warning in line `if isisntanc(value, actual_type):`, TODO check whether necessary
actual_type = type_ext.get_origin(required_type) if isinstance(required_type, _GenericAlias) else required_type
# typing.Union is an instance of _GenericAlias
if actual_type is None:
# workaround to prevent type warning in line `if isisntanc(value, actual_type):`, TODO check whether necessary
actual_type = required_type.__origin__

if actual_type is Union:
Expand All @@ -77,7 +79,8 @@ def validate_arg_type(required_type: Union[Type, _GenericAlias], value: Any) ->
if res.valid:
return _ValidationResult(True)

return _ValidationResult(False, f"{repr(value)} is not of type {repr(required_type)}") # typing.Union does not have __name__ property
# typing.Union does not have __name__ property, thus using repr() is necessary
return _ValidationResult(False, f"{repr(value)} is not of type {repr(required_type)}")

custom_type_checker = TYPE_VALIDATOR.get(actual_type)

Expand Down
Loading

0 comments on commit 5fd802f

Please sign in to comment.