diff --git a/src/dbally/views/exposed_functions.py b/src/dbally/views/exposed_functions.py index 481052f7..1f35de15 100644 --- a/src/dbally/views/exposed_functions.py +++ b/src/dbally/views/exposed_functions.py @@ -1,7 +1,9 @@ import re +import typing_extensions as type_ext from dataclasses import dataclass from typing import _GenericAlias # type: ignore from typing import Optional, Sequence, Type, Union +from inspect import isclass from dbally.context.context import BaseCallerContext from dbally.similarity import AbstractSimilarityIndex @@ -12,15 +14,22 @@ def parse_param_type(param_type: Union[type, _GenericAlias]) -> str: Parses the type of a method parameter and returns a string representation of it. Args: - param_type: type of the parameter + param_type: Type of the parameter. Returns: - str: string representation of the type + A string representation of the type. """ - if param_type in {int, float, str, bool, list, dict, set, tuple}: + + # TODO consider using hasattr() to ensure correctness of the IF's below + if isclass(param_type): return param_type.__name__ + + if type_ext.get_origin(param_type) is Union: + args_str_repr = ', '.join(parse_param_type(arg) for arg in type_ext.get_args(param_type)) + return f"Union[{args_str_repr}]" + if param_type.__module__ == "typing": - return re.sub(r"\btyping\.", "", str(param_type)) + return param_type._name return str(param_type)