Skip to content

Commit

Permalink
chore: Replace all Unions with | for consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
KamenDimitrov97 committed Sep 2, 2024
1 parent a7c0d8d commit c32469e
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 88 deletions.
2 changes: 1 addition & 1 deletion docs/renderer_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ from dewret.utils import Raw, BasicType
from dewret.workflow import Lazy
from dewret.workflow import Reference, Workflow, Step, Task

RawType = typing.Union[BasicType, list[str], list["RawType"], dict[str, "RawType"]]
RawType = BasicType | list[str] | list["RawType"] | dict[str, "RawType"]
```

## To run this example:
Expand Down
95 changes: 79 additions & 16 deletions src/dewret/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,37 @@
import base64
from attrs import define
from functools import lru_cache
from typing import Generic, TypeVar, Protocol, Iterator, Unpack, TypedDict, NotRequired, Generator, Union, Any, get_args, get_origin, Annotated, Callable, cast, runtime_checkable
from typing import (
Generic,
TypeVar,
Protocol,
Iterator,
Unpack,
TypedDict,
NotRequired,
Generator,
Any,
get_args,
get_origin,
Annotated,
Callable,
cast,
runtime_checkable,
)
from contextlib import contextmanager
from contextvars import ContextVar
from sympy import Expr, Symbol, Basic
import copy

BasicType = str | float | bool | bytes | int | None
RawType = Union[BasicType, list["RawType"], dict[str, "RawType"]]
RawType = BasicType | list["RawType"] | dict[str, "RawType"]
FirmType = RawType | list["FirmType"] | dict[str, "FirmType"] | tuple["FirmType", ...]
ExprType = FirmType | Basic | list["ExprType"] | dict[str, "ExprType"] | tuple["ExprType", ...] # type: ignore
ExprType = (FirmType | Basic | list["ExprType"] | dict[str, "ExprType"] | tuple["ExprType", ...]) # type: ignore

U = TypeVar("U")
T = TypeVar("T")


def strip_annotations(parent_type: type) -> tuple[type, tuple[str]]:
"""Discovers and removes annotations from a parent type.
Expand All @@ -53,14 +70,17 @@ def strip_annotations(parent_type: type) -> tuple[type, tuple[str]]:
metadata += list(parent_metadata)
return parent_type, tuple(metadata)


RenderConfiguration = dict[str, RawType]


class WorkflowProtocol(Protocol):
"""Expected structure for a workflow.
We do not expect various workflow implementations, but this allows us to define the
interface expected by the core classes.
"""

def remap(self, name: str) -> str:
"""Perform any name-changing for steps, etc. in the workflow.
Expand All @@ -82,8 +102,10 @@ def simplify_ids(self, infix: list[str] | None = None) -> None:
"""
...


class BaseRenderModule(Protocol):
"""Common routines for all renderer modules."""

@staticmethod
def default_config() -> dict[str, RawType]:
"""Retrieve default settings.
Expand All @@ -94,29 +116,41 @@ def default_config() -> dict[str, RawType]:
"""
...


@runtime_checkable
class RawRenderModule(BaseRenderModule, Protocol):
"""Render module that returns raw text."""
def render_raw(self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration) -> dict[str, str]:

def render_raw(
self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration
) -> dict[str, str]:
"""Turn a workflow into flat strings.
Returns: one or more subworkflows with a `__root__` key representing the outermost workflow, at least.
"""
...


@runtime_checkable
class StructuredRenderModule(BaseRenderModule, Protocol):
"""Render module that returns JSON/YAML-serializable structures."""
def render(self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration) -> dict[str, dict[str, RawType]]:

def render(
self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration
) -> dict[str, dict[str, RawType]]:
"""Turn a workflow into a serializable structure.
Returns: one or more subworkflows with a `__root__` key representing the outermost workflow, at least.
"""
...


class RenderCall(Protocol):
"""Callable that will render out workflow(s)."""
def __call__(self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration) -> dict[str, str] | dict[str, RawType]:

def __call__(
self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration
) -> dict[str, str] | dict[str, RawType]:
"""Take a workflow and turn it into a set of serializable (sub)workflows.
Args:
Expand All @@ -127,11 +161,13 @@ def __call__(self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration) ->
"""
...


class UnevaluatableError(Exception):
"""Signposts that a user has tried to treat a reference as the real (runtime) value.
For example, by comparing to a concrete integer or value, etc.
"""

...


Expand All @@ -142,13 +178,15 @@ class ConstructConfiguration:
Holds configuration that may be relevant to `construst(...)` calls or, realistically,
anything prior to rendering. It should hold generic configuration that is renderer-independent.
"""

flatten_all_nested: bool = False
allow_positional_args: bool = False
allow_plain_dict_fields: bool = False
field_separator: str = "/"
field_index_types: str = "int"
simplify_ids: bool = False


class ConstructConfigurationTypedDict(TypedDict):
"""Basic configuration of the construction process.
Expand All @@ -157,26 +195,33 @@ class ConstructConfigurationTypedDict(TypedDict):
**THIS MUST BE KEPT IDENTICAL TO ConstructConfiguration.**
"""

flatten_all_nested: NotRequired[bool]
allow_positional_args: NotRequired[bool]
allow_plain_dict_fields: NotRequired[bool]
field_separator: NotRequired[str]
field_index_types: NotRequired[str]
simplify_ids: NotRequired[bool]


@define
class GlobalConfiguration:
"""Overall configuration structure.
Having a single configuration dict allows us to manage only one ContextVar.
"""

construct: ConstructConfiguration
render: dict[str, RawType]


CONFIGURATION: ContextVar[GlobalConfiguration] = ContextVar("configuration")


@contextmanager
def set_configuration(**kwargs: Unpack[ConstructConfigurationTypedDict]) -> Iterator[ContextVar[GlobalConfiguration]]:
def set_configuration(
**kwargs: Unpack[ConstructConfigurationTypedDict],
) -> Iterator[ContextVar[GlobalConfiguration]]:
"""Sets the construct-time configuration.
This is a context manager, so that a setting can be temporarily overridden and automatically restored.
Expand All @@ -186,8 +231,11 @@ def set_configuration(**kwargs: Unpack[ConstructConfigurationTypedDict]) -> Iter
setattr(CONFIGURATION.get().construct, key, value)
yield CONFIGURATION


@contextmanager
def set_render_configuration(kwargs: dict[str, RawType]) -> Iterator[ContextVar[GlobalConfiguration]]:
def set_render_configuration(
kwargs: dict[str, RawType],
) -> Iterator[ContextVar[GlobalConfiguration]]:
"""Sets the render-time configuration.
This is a context manager, so that a setting can be temporarily overridden and automatically restored.
Expand All @@ -198,6 +246,7 @@ def set_render_configuration(kwargs: dict[str, RawType]) -> Iterator[ContextVar[
CONFIGURATION.get().render.update(**kwargs)
yield CONFIGURATION


@contextmanager
def _set_configuration() -> Iterator[ContextVar[GlobalConfiguration]]:
"""Prepares and tidied up the configuration for applying settings.
Expand All @@ -207,7 +256,9 @@ def _set_configuration() -> Iterator[ContextVar[GlobalConfiguration]]:
try:
previous = CONFIGURATION.get()
except LookupError:
previous = GlobalConfiguration(construct=ConstructConfiguration(), render=default_renderer_config())
previous = GlobalConfiguration(
construct=ConstructConfiguration(), render=default_renderer_config()
)
CONFIGURATION.set(previous)
previous = copy.deepcopy(previous)

Expand All @@ -230,12 +281,15 @@ def default_renderer_config() -> RenderConfiguration:
"""
try:
# We have to use a cast as we do not know if the renderer module is valid.
render_module = cast(BaseRenderModule, importlib.import_module("__renderer_mod__"))
render_module = cast(
BaseRenderModule, importlib.import_module("__renderer_mod__")
)
default_config: Callable[[], RenderConfiguration] = render_module.default_config
except ImportError:
return {}
return default_config()


@lru_cache
def default_construct_config() -> ConstructConfiguration:
"""Gets the default construct-time configuration.
Expand All @@ -254,6 +308,7 @@ def default_construct_config() -> ConstructConfiguration:
field_index_types="int",
)


def get_configuration(key: str) -> RawType:
"""Retrieve the configuration or (silently) return the default.
Expand All @@ -267,10 +322,11 @@ def get_configuration(key: str) -> RawType:
Returns: (preferably) a JSON/YAML-serializable construct.
"""
try:
return getattr(CONFIGURATION.get().construct, key) # type: ignore
return getattr(CONFIGURATION.get().construct, key) # type: ignore
except LookupError:
# TODO: Not sure what the best way to typehint this is.
return getattr(ConstructConfiguration(), key) # type: ignore
return getattr(ConstructConfiguration(), key) # type: ignore


def get_render_configuration(key: str) -> RawType:
"""Retrieve configuration for the active renderer.
Expand All @@ -288,6 +344,7 @@ def get_render_configuration(key: str) -> RawType:
except LookupError:
return default_renderer_config().get(key)


class WorkflowComponent:
"""Base class for anything directly tied to an individual `Workflow`.
Expand All @@ -312,12 +369,12 @@ def __init__(self, *args: Any, workflow: WorkflowProtocol, **kwargs: Any):

@property
def __workflow__(self) -> WorkflowProtocol:
"""Workflow to which this reference applies."""
"""Workflow to which this reference applies."""
return self.__workflow_real__

@__workflow__.setter
def __workflow__(self, workflow: WorkflowProtocol) -> None:
"""Workflow to which this reference applies."""
"""Workflow to which this reference applies."""
self.__workflow_real__ = workflow


Expand Down Expand Up @@ -370,7 +427,9 @@ def __type__(self) -> type:

def _raise_unevaluatable_error(self) -> None:
"""Convenience method to consistently formulate an UnevaluatableError for this reference."""
raise UnevaluatableError(f"This reference, {self.__name__}, cannot be evaluated during construction.")
raise UnevaluatableError(
f"This reference, {self.__name__}, cannot be evaluated during construction."
)

def __eq__(self, other: object) -> Any:
"""Test equality at construct-time, if sensible.
Expand Down Expand Up @@ -431,11 +490,13 @@ def __str__(self) -> str:
"""
return self.__name__


class IterableMixin(Reference[U]):
"""Functionality for iterating over references to give new references."""

__fixed_len__: int | None = None

def __init__(self, typ: type[U] | None=None, **kwargs: Any):
def __init__(self, typ: type[U] | None = None, **kwargs: Any):
"""Extract length, if available from type.
Captures types of the form (e.g.) `tuple[int, float]` and records the length
Expand Down Expand Up @@ -503,13 +564,15 @@ def __getitem__(self, attr: str | int) -> "Reference[U] | Any":
"""
return super().__getitem__(attr)


class IteratedGenerator(Generic[U]):
"""Sentinel value for capturing that an iteration has occured without performing it.
Allows us to lazily evaluate a loop, for instance, in the renderer. This may be relevant
if the renderer wishes to postpone iteration to runtime, and simply record it is required,
rather than evaluating the iterator.
"""

__wrapped__: Reference[U]

def __init__(self, to_wrap: Reference[U]):
Expand Down
Loading

0 comments on commit c32469e

Please sign in to comment.