Skip to content

Commit

Permalink
docs: tidyup docstrings and structures
Browse files Browse the repository at this point in the history
  • Loading branch information
philtweir committed Aug 24, 2024
1 parent 5c2d7a2 commit 4aef9c4
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 84 deletions.
6 changes: 4 additions & 2 deletions src/dewret/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import re
import yaml
from typing import Any
import sys
from types import ModuleType
import click
import json

Expand Down Expand Up @@ -90,9 +90,11 @@ def render(
key, val = arg.split(":", 1)
kwargs[key] = json.loads(val)

render_module: Path | RawRenderModule | StructuredRenderModule
render_module: Path | ModuleType
if (mtch := re.match(r"^([a-z_0-9-.]+)$", renderer)):
render_module = importlib.import_module(f"dewret.renderers.{mtch.group(1)}")
if not isinstance(render_module, RawRenderModule) and not isinstance(render_module, StructuredRenderModule):
raise NotImplementedError("The imported render module does not seem to match the `RawRenderModule` or `StructuredRenderModule` protocols.")
elif renderer.startswith("@"):
render_module = Path(renderer[1:])
else:
Expand Down
70 changes: 60 additions & 10 deletions src/dewret/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,29 @@
import sys
import importlib
from functools import lru_cache
from types import FunctionType
from dataclasses import dataclass
from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated, Callable, get_origin, get_args, Mapping, TypeAliasType, get_type_hints
from types import FunctionType, ModuleType
from typing import Any, TypeVar, Annotated, Callable, get_origin, get_args, Mapping, get_type_hints

T = TypeVar("T")
AtRender = Annotated[T, "AtRender"]
Fixed = Annotated[T, "Fixed"]

class FunctionAnalyser:
"""Convenience class for analysing a function with reduced duplication of effort.
Attributes:
_fn: the wrapped callable
_annotations: stored annotations for the function.
"""
_fn: Callable[..., Any]
_annotations: dict[str, Any]

def __init__(self, fn: Callable[..., Any]):
"""Set the function.
If `fn` is a class, it takes the constructor, and if it is a method, it takes
the `__func__` attribute.
"""
self.fn = (
fn.__init__
if inspect.isclass(fn) else
Expand All @@ -25,11 +35,20 @@ def __init__(self, fn: Callable[..., Any]):
)

@property
def return_type(self):
def return_type(self) -> type:
"""Return type of the callable."""
return get_type_hints(inspect.unwrap(self.fn), include_extras=True)["return"]

@staticmethod
def _typ_has(typ: type, annotation: type) -> bool:
"""Check if the type has an annotation.
Args:
typ: type to check.
annotation: the Annotated to compare against.
Returns: True if the type has the given annotation, otherwise False.
"""
if not hasattr(annotation, "__metadata__"):
return False
if (origin := get_origin(typ)):
Expand All @@ -40,33 +59,52 @@ def _typ_has(typ: type, annotation: type) -> bool:
return False

def get_all_module_names(self):
return sys.modules[self.fn.__module__].__annotations__
"""Find all of the annotations within this module."""
return get_type_hints(sys.modules[self.fn.__module__], include_extras=True)

def get_all_imported_names(self):
"""Find all of the annotations that were imported into this module."""
return self._get_all_imported_names(sys.modules[self.fn.__module__])

@staticmethod
@lru_cache
def _get_all_imported_names(mod):
def _get_all_imported_names(mod: ModuleType) -> dict[str, tuple[ModuleType, str]]:
"""Get all of the names with this module, and their original locations.
Args:
mod: a module in the `sys.modules`.
Returns:
A dict whose keys are the known names in the current module, where the Callable lives,
and whose values are pairs of the module and the remote name.
"""
ast_tree = ast.parse(inspect.getsource(mod))
imported_names = {}
for node in ast.walk(ast_tree):
if isinstance(node, ast.ImportFrom):
for name in node.names:
imported_names[name.asname or name.name] = (
importlib.import_module("".join(["."] * node.level) + node.module, package=mod.__package__),
importlib.import_module("".join(["."] * node.level) + (node.module or ""), package=mod.__package__),
name.name
)
return imported_names

@property
def free_vars(self):
def free_vars(self) -> dict[str, Any]:
"""Get the free variables for this Callable."""
if self.fn.__code__ and self.fn.__closure__:
return dict(zip(self.fn.__code__.co_freevars, (c.cell_contents for c in self.fn.__closure__)))
return dict(zip(self.fn.__code__.co_freevars, (c.cell_contents for c in self.fn.__closure__), strict=False))
return {}

def get_argument_annotation(self, arg: str, exhaustive: bool=False) -> type | None:
all_annotations: dict[str, type] = {}
"""Retrieve the annotations for this argument.
Args:
arg: name of the argument.
exhaustive: True if we should search outside the function itself, into the module globals, and into imported modules.
Returns: annotation if available, else None.
"""
typ: type | None = None
if (typ := self.fn.__annotations__.get(arg)):
if isinstance(typ, str):
Expand All @@ -84,14 +122,25 @@ def get_argument_annotation(self, arg: str, exhaustive: bool=False) -> type | No
return typ

def argument_has(self, arg: str, annotation: type, exhaustive: bool=False) -> bool:
"""Check if the named argument has the given annotation.
Args:
arg: argument to retrieve.
annotation: Annotated to search for.
exhaustive: whether to check the globals and other modules.
Returns: True if the Annotated is present in this argument's annotation.
"""
typ = self.get_argument_annotation(arg, exhaustive)
return bool(typ and self._typ_has(typ, annotation))

def is_at_construct_arg(self, arg: str, exhaustive: bool=False) -> bool:
"""Convience function to check for `AtConstruct`, wrapping `FunctionAnalyser.argument_has`."""
return self.argument_has(arg, AtRender, exhaustive)

@property
def globals(self) -> Mapping[str, Any]:
"""Get the globals for this Callable."""
try:
fn_tuple = inspect.getclosurevars(self.fn)
fn_globals = dict(fn_tuple.globals)
Expand All @@ -102,6 +151,7 @@ def globals(self) -> Mapping[str, Any]:
return fn_globals

def with_new_globals(self, new_globals: dict[str, Any]) -> Callable[..., Any]:
"""Create a Callable that will run the current Callable with new globals."""
code = self.fn.__code__
fn_name = self.fn.__name__
all_globals = dict(self.globals)
Expand Down
5 changes: 1 addition & 4 deletions src/dewret/backends/backend_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@

from dask.delayed import delayed, DelayedLeaf
from dask.config import config
import contextvars
from functools import partial
from typing import Protocol, runtime_checkable, Any, cast
from concurrent.futures import Executor, ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor
from dewret.workflow import Workflow, Lazy, StepReference, Target


Expand Down Expand Up @@ -102,7 +100,6 @@ def run(workflow: Workflow | None, task: Lazy | list[Lazy] | tuple[Lazy], thread
workflow: `Workflow` in which to record the execution.
task: `dask.delayed` function, wrapped by dewret, that we wish to compute.
"""

# def _check_delayed(task: Lazy | list[Lazy] | tuple[Lazy]) -> Delayed:
# # We need isinstance to reassure type-checker.
# if isinstance(task, list) or isinstance(task, tuple):
Expand Down
18 changes: 9 additions & 9 deletions src/dewret/core.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from dataclasses import dataclass
from abc import abstractmethod, abstractstaticmethod
from abc import abstractmethod
import importlib
import base64
from attrs import define
from functools import lru_cache
from typing import Generic, TypeVar, Protocol, Iterator, Unpack, TypedDict, NotRequired, Generator, Union, Any, get_args, get_origin, Annotated, Literal, Callable, cast, runtime_checkable
from typing import Generic, TypeVar, Protocol, Iterator, Unpack, TypedDict, NotRequired, Generator, Union, Any, get_args, get_origin, Annotated, Callable, cast, runtime_checkable
from contextlib import contextmanager
from contextvars import ContextVar
from sympy import Expr, Symbol, Basic
Expand Down Expand Up @@ -129,7 +129,7 @@ class ConstructConfiguration:
allow_plain_dict_fields: bool = False
field_separator: str = "/"
field_index_types: str = "int"
simplify_ids: bool = True
simplify_ids: bool = False

class ConstructConfigurationTypedDict(TypedDict):
"""Basic configuration of the construction process.
Expand All @@ -139,12 +139,12 @@ class ConstructConfigurationTypedDict(TypedDict):
**THIS MUST BE KEPT IDENTICAL TO ConstructConfiguration.**
"""
flatten_all_nested: bool
allow_positional_args: bool
allow_plain_dict_fields: bool
field_separator: str
field_index_types: str
simplify_ids: bool
flatten_all_nested: NotRequired[bool]
allow_positional_args: NotRequired[bool]
allow_plain_dict_fields: NotRequired[bool]
field_separator: NotRequired[str]
field_index_types: NotRequired[str]
simplify_ids: NotRequired[bool]

@define
class GlobalConfiguration:
Expand Down
2 changes: 1 addition & 1 deletion src/dewret/renderers/cwl.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ class OutputsDefinition:

@classmethod
def from_results(
cls, results: dict[str, StepReference[Any]] | list[StepReference[Any]] | tuple[StepReference[Any]]
cls, results: dict[str, StepReference[Any]] | list[StepReference[Any]] | tuple[StepReference[Any], ...]
) -> "OutputsDefinition":
"""Takes a mapping of results into a CWL structure.
Expand Down
Loading

0 comments on commit 4aef9c4

Please sign in to comment.