Skip to content

Commit

Permalink
fix typed dict types
Browse files Browse the repository at this point in the history
  • Loading branch information
TheWii committed Feb 5, 2024
1 parent d97738a commit 4bbf16b
Show file tree
Hide file tree
Showing 5 changed files with 281 additions and 295 deletions.
7 changes: 3 additions & 4 deletions bolt_expressions/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
NbtType,
NumericNbtValue,
NbtValue,
format_type,
access_type,
convert_type,
get_dict_fields,
Expand All @@ -45,7 +46,7 @@
IrSource,
)
from .exceptions import TypeCheckDiagnostic, TypeCheckError, get_exception_chain
from .utils import format_type, get_globals
from .utils import get_globals


__all__ = [
Expand Down Expand Up @@ -238,9 +239,7 @@ def check_list_type(
raise exc from cause_exc


def check_numeric_type(
write: type[NumericNbtValue], read: Any, **flags: Any
) -> bool:
def check_numeric_type(write: type[NumericNbtValue], read: Any, **flags: Any) -> bool:
if not issubclass(read, Numeric):
raise TypeCheckError(
f'"{format_type(read)}" is not a numeric type and is not compatible with "{format_type(write)}".'
Expand Down
4 changes: 2 additions & 2 deletions bolt_expressions/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from nbtlib import Compound, Path, ListIndex, CompoundMatch, NamedKey # type: ignore

from .node import Expression, UnrollHelper
from .typing import NbtType, is_compound_type
from .utils import format_type, type_name
from .typing import NbtType, is_compound_type, format_type
from .utils import type_name

from .optimizer import (
DataTuple,
Expand Down
76 changes: 63 additions & 13 deletions bolt_expressions/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from nbtlib import End, Compound, Float, Int, List, String, Byte, Short, Long, Double, Array, NamedKey, ListIndex, CompoundMatch # type: ignore

from .utils import get_globals, type_name # type: ignore
from .utils import format_name, get_globals, type_name # type: ignore


__all__ = ["is_type", "convert_type", "literal_types"]
Expand Down Expand Up @@ -64,6 +64,48 @@
}


def format_type(t: Any, *, __refs: list[Any] | None = None) -> str:
if __refs is None:
__refs = []

circular_ref = t in __refs
__refs.append(t)

if t in (None, NoneType):
return "None"

if isinstance(t, (UnionType, _UnionGenericAlias)):
return " | ".join(format_type(x, __refs=__refs) for x in get_args(t))

if isinstance(t, GenericAlias):
origin = format_type(get_origin(t), __refs=__refs)
args = (format_type(x, __refs=__refs) for x in get_args(t))

if circular_ref:
return f"{origin}[...]"

return f"{origin}[{', '.join(args)}]"

if is_typeddict_guard(t) and t.__name__ == "__anonymous_dict__":
t = get_dict_fields(t)
if isinstance(t, dict):
t_dict = cast(dict[str, Any], t)

if circular_ref:
return "{...}"

return (
"{"
+ ", ".join(
f"{key}: {format_type(val, __refs=__refs)}"
for key, val in t_dict.items()
)
+ "}"
)

return format_name(t)


def convert_tag(value: Any) -> NbtValue | None:
match value:
case Byte() | Short() | Int() | Float() | Double() | String() | List() | Array() | Compound():
Expand Down Expand Up @@ -164,20 +206,23 @@ def unwrap_optional_type(value: Any) -> Any:


def get_dict_fields(
t: type[TypedDict] | dict[str, Any], ctx: Context | None
t: type[TypedDict] | dict[str, Any], ctx: Context | None = None
) -> dict[str, NbtType]:
globalns = None

if isinstance(t, dict):
fields = t
else:
fields = get_type_hints(t, get_globals(t, ctx))
globalns = get_globals(t, ctx)
fields = get_type_hints(t, globalns=globalns)

result: dict[str, NbtType] = {}

for key, value in fields.items():
value_type = convert_type(value)
value_type = convert_type(value, globalns=globalns)
result[key] = value_type if value_type is not None else Any

return frozendict(result) # type: ignore
return dict(result)


def is_type(value: Any, allow_dict: bool = True) -> TypeGuard[NbtType]:
Expand All @@ -190,7 +235,9 @@ def is_type(value: Any, allow_dict: bool = True) -> TypeGuard[NbtType]:
)


def convert_type(value: Any, is_origin: bool = False) -> NbtType | None:
def convert_type(
value: Any, is_origin: bool = False, globalns: dict[str, Any] | None = None
) -> NbtType | None:
if value is Any:
return Any

Expand All @@ -202,28 +249,28 @@ def convert_type(value: Any, is_origin: bool = False) -> NbtType | None:

type_dict: dict[str, NbtType] = {}
for key, val in value.items():
val_type = convert_type(val)
val_type = convert_type(val, globalns=globalns)
type_dict[key] = val_type if val_type is not None else Any

return frozendict(type_dict) # type: ignore
return TypedDict("__anonymous_dict__", type_dict) # type: ignore

if is_typeddict(value):
return value # type: ignore

if is_alias(value, Compound):
args = get_args(value)
return convert_type(dict[str, args[0]]) # type: ignore
return convert_type(dict[str, args[0]], globalns=globalns) # type: ignore

if isinstance(value, (UnionType, _UnionGenericAlias)):
args = get_args(value)
converted = tuple(convert_type(arg) for arg in args)
converted = tuple(convert_type(arg, globalns=globalns) for arg in args)

return Union[converted] # type: ignore

if isinstance(value, GenericAlias):
args = get_args(value)

converted = tuple(convert_type(arg) for arg in args)
converted = tuple(convert_type(arg, globalns=globalns) for arg in args)
origin = convert_type(value.__origin__, is_origin=True)

if isinstance(origin, type) and issubclass(origin, Compound):
Expand Down Expand Up @@ -253,6 +300,9 @@ def convert_type(value: Any, is_origin: bool = False) -> NbtType | None:
if issubclass(value, list):
return list if is_origin else list[Any] # type: ignore

if isinstance(value, str) and globalns is not None and value in globalns:
return globalns[value]

raise TypeError(f"type {type_name(value)} cannot be converted to nbt type.")


Expand All @@ -279,7 +329,7 @@ def access_typeddict(
if isinstance(accessor, NamedKey):
key = accessor.key

fields = get_type_hints(t, get_globals(t, ctx))
fields = get_dict_fields(t, ctx)

if attr_type := fields.get(key):
result = convert_type(attr_type)
Expand Down Expand Up @@ -340,7 +390,7 @@ def access_type(
args = get_args(current_type)
subtypes = tuple(access_type(arg, accessor) for arg in args)

return Union[subtypes] # type: ignore
return convert_type(Union[subtypes]) # type: ignore

if is_typeddict_guard(current_type):
return access_typeddict(current_type, accessor, ctx)
Expand Down
42 changes: 2 additions & 40 deletions bolt_expressions/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from contextlib import contextmanager
import sys
from types import GenericAlias, NoneType, UnionType
from typing import Any, Dict, _UnionGenericAlias, cast, get_args, get_origin, is_typeddict # type: ignore
from types import NoneType
from typing import Any, Dict, _UnionGenericAlias, is_typeddict # type: ignore
from bolt import Runtime
from nbtlib import Base # type: ignore
from beet import Context

__all__ = [
"type_name",
"format_type",
"identifier_generator",
"get_globals",
"assert_exception",
Expand All @@ -32,43 +31,6 @@ def format_name(t: Any) -> str:
return f"{t.__module__}.{t.__name__}"


def format_type(t: Any, *, __refs: list[Any] | None = None) -> str:
if __refs is None:
__refs = []

circular_ref = t in __refs
__refs.append(t)

if isinstance(t, (UnionType, _UnionGenericAlias)):
return " | ".join(format_type(x, __refs=__refs) for x in get_args(t))

if isinstance(t, GenericAlias):
origin = format_type(get_origin(t), __refs=__refs)
args = (format_type(x, __refs=__refs) for x in get_args(t))

if circular_ref:
return f"{origin}[...]"

return f"{origin}[{', '.join(args)}]"

if isinstance(t, dict):
t_dict = cast(dict[str, Any], t)

if circular_ref:
return "{...}"

return (
"{"
+ ", ".join(
f"{key}: {format_type(val, __refs=__refs)}"
for key, val in t_dict.items()
)
+ "}"
)

return format_name(t)


def identifier_generator(ctx: Context | None = None):
if ctx:
runtime = ctx.inject(Runtime)
Expand Down
Loading

0 comments on commit 4bbf16b

Please sign in to comment.