From 4aef9c494a1c68a6d7b60c1c660748653ee9ba97 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 24 Aug 2024 22:56:53 +0100 Subject: [PATCH] docs: tidyup docstrings and structures --- src/dewret/__main__.py | 6 +- src/dewret/annotations.py | 70 ++++++++++++++++++---- src/dewret/backends/backend_dask.py | 5 +- src/dewret/core.py | 18 +++--- src/dewret/renderers/cwl.py | 2 +- src/dewret/tasks.py | 90 ++++++++++++++++------------- src/dewret/workflow.py | 21 ++----- tests/test_nested.py | 2 +- 8 files changed, 130 insertions(+), 84 deletions(-) diff --git a/src/dewret/__main__.py b/src/dewret/__main__.py index 26b49461..1816c30b 100644 --- a/src/dewret/__main__.py +++ b/src/dewret/__main__.py @@ -27,7 +27,7 @@ import re import yaml from typing import Any -import sys +from types import ModuleType import click import json @@ -90,9 +90,11 @@ def render( key, val = arg.split(":", 1) kwargs[key] = json.loads(val) - render_module: Path | RawRenderModule | StructuredRenderModule + render_module: Path | ModuleType if (mtch := re.match(r"^([a-z_0-9-.]+)$", renderer)): render_module = importlib.import_module(f"dewret.renderers.{mtch.group(1)}") + if not isinstance(render_module, RawRenderModule) and not isinstance(render_module, StructuredRenderModule): + raise NotImplementedError("The imported render module does not seem to match the `RawRenderModule` or `StructuredRenderModule` protocols.") elif renderer.startswith("@"): render_module = Path(renderer[1:]) else: diff --git a/src/dewret/annotations.py b/src/dewret/annotations.py index c561c28b..de46b2c6 100644 --- a/src/dewret/annotations.py +++ b/src/dewret/annotations.py @@ -3,19 +3,29 @@ import sys import importlib from functools import lru_cache -from types import FunctionType -from dataclasses import dataclass -from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated, Callable, get_origin, get_args, Mapping, TypeAliasType, get_type_hints +from types import FunctionType, ModuleType +from typing import Any, TypeVar, Annotated, Callable, get_origin, get_args, Mapping, get_type_hints T = TypeVar("T") AtRender = Annotated[T, "AtRender"] Fixed = Annotated[T, "Fixed"] class FunctionAnalyser: + """Convenience class for analysing a function with reduced duplication of effort. + + Attributes: + _fn: the wrapped callable + _annotations: stored annotations for the function. + """ _fn: Callable[..., Any] _annotations: dict[str, Any] def __init__(self, fn: Callable[..., Any]): + """Set the function. + + If `fn` is a class, it takes the constructor, and if it is a method, it takes + the `__func__` attribute. + """ self.fn = ( fn.__init__ if inspect.isclass(fn) else @@ -25,11 +35,20 @@ def __init__(self, fn: Callable[..., Any]): ) @property - def return_type(self): + def return_type(self) -> type: + """Return type of the callable.""" return get_type_hints(inspect.unwrap(self.fn), include_extras=True)["return"] @staticmethod def _typ_has(typ: type, annotation: type) -> bool: + """Check if the type has an annotation. + + Args: + typ: type to check. + annotation: the Annotated to compare against. + + Returns: True if the type has the given annotation, otherwise False. + """ if not hasattr(annotation, "__metadata__"): return False if (origin := get_origin(typ)): @@ -40,33 +59,52 @@ def _typ_has(typ: type, annotation: type) -> bool: return False def get_all_module_names(self): - return sys.modules[self.fn.__module__].__annotations__ + """Find all of the annotations within this module.""" + return get_type_hints(sys.modules[self.fn.__module__], include_extras=True) def get_all_imported_names(self): + """Find all of the annotations that were imported into this module.""" return self._get_all_imported_names(sys.modules[self.fn.__module__]) @staticmethod @lru_cache - def _get_all_imported_names(mod): + def _get_all_imported_names(mod: ModuleType) -> dict[str, tuple[ModuleType, str]]: + """Get all of the names with this module, and their original locations. + + Args: + mod: a module in the `sys.modules`. + + Returns: + A dict whose keys are the known names in the current module, where the Callable lives, + and whose values are pairs of the module and the remote name. + """ ast_tree = ast.parse(inspect.getsource(mod)) imported_names = {} for node in ast.walk(ast_tree): if isinstance(node, ast.ImportFrom): for name in node.names: imported_names[name.asname or name.name] = ( - importlib.import_module("".join(["."] * node.level) + node.module, package=mod.__package__), + importlib.import_module("".join(["."] * node.level) + (node.module or ""), package=mod.__package__), name.name ) return imported_names @property - def free_vars(self): + def free_vars(self) -> dict[str, Any]: + """Get the free variables for this Callable.""" if self.fn.__code__ and self.fn.__closure__: - return dict(zip(self.fn.__code__.co_freevars, (c.cell_contents for c in self.fn.__closure__))) + return dict(zip(self.fn.__code__.co_freevars, (c.cell_contents for c in self.fn.__closure__), strict=False)) return {} def get_argument_annotation(self, arg: str, exhaustive: bool=False) -> type | None: - all_annotations: dict[str, type] = {} + """Retrieve the annotations for this argument. + + Args: + arg: name of the argument. + exhaustive: True if we should search outside the function itself, into the module globals, and into imported modules. + + Returns: annotation if available, else None. + """ typ: type | None = None if (typ := self.fn.__annotations__.get(arg)): if isinstance(typ, str): @@ -84,14 +122,25 @@ def get_argument_annotation(self, arg: str, exhaustive: bool=False) -> type | No return typ def argument_has(self, arg: str, annotation: type, exhaustive: bool=False) -> bool: + """Check if the named argument has the given annotation. + + Args: + arg: argument to retrieve. + annotation: Annotated to search for. + exhaustive: whether to check the globals and other modules. + + Returns: True if the Annotated is present in this argument's annotation. + """ typ = self.get_argument_annotation(arg, exhaustive) return bool(typ and self._typ_has(typ, annotation)) def is_at_construct_arg(self, arg: str, exhaustive: bool=False) -> bool: + """Convience function to check for `AtConstruct`, wrapping `FunctionAnalyser.argument_has`.""" return self.argument_has(arg, AtRender, exhaustive) @property def globals(self) -> Mapping[str, Any]: + """Get the globals for this Callable.""" try: fn_tuple = inspect.getclosurevars(self.fn) fn_globals = dict(fn_tuple.globals) @@ -102,6 +151,7 @@ def globals(self) -> Mapping[str, Any]: return fn_globals def with_new_globals(self, new_globals: dict[str, Any]) -> Callable[..., Any]: + """Create a Callable that will run the current Callable with new globals.""" code = self.fn.__code__ fn_name = self.fn.__name__ all_globals = dict(self.globals) diff --git a/src/dewret/backends/backend_dask.py b/src/dewret/backends/backend_dask.py index 36e7fb21..918ba242 100644 --- a/src/dewret/backends/backend_dask.py +++ b/src/dewret/backends/backend_dask.py @@ -19,10 +19,8 @@ from dask.delayed import delayed, DelayedLeaf from dask.config import config -import contextvars -from functools import partial from typing import Protocol, runtime_checkable, Any, cast -from concurrent.futures import Executor, ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor from dewret.workflow import Workflow, Lazy, StepReference, Target @@ -102,7 +100,6 @@ def run(workflow: Workflow | None, task: Lazy | list[Lazy] | tuple[Lazy], thread workflow: `Workflow` in which to record the execution. task: `dask.delayed` function, wrapped by dewret, that we wish to compute. """ - # def _check_delayed(task: Lazy | list[Lazy] | tuple[Lazy]) -> Delayed: # # We need isinstance to reassure type-checker. # if isinstance(task, list) or isinstance(task, tuple): diff --git a/src/dewret/core.py b/src/dewret/core.py index 85f7ef59..266630f1 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -1,10 +1,10 @@ from dataclasses import dataclass -from abc import abstractmethod, abstractstaticmethod +from abc import abstractmethod import importlib import base64 from attrs import define from functools import lru_cache -from typing import Generic, TypeVar, Protocol, Iterator, Unpack, TypedDict, NotRequired, Generator, Union, Any, get_args, get_origin, Annotated, Literal, Callable, cast, runtime_checkable +from typing import Generic, TypeVar, Protocol, Iterator, Unpack, TypedDict, NotRequired, Generator, Union, Any, get_args, get_origin, Annotated, Callable, cast, runtime_checkable from contextlib import contextmanager from contextvars import ContextVar from sympy import Expr, Symbol, Basic @@ -129,7 +129,7 @@ class ConstructConfiguration: allow_plain_dict_fields: bool = False field_separator: str = "/" field_index_types: str = "int" - simplify_ids: bool = True + simplify_ids: bool = False class ConstructConfigurationTypedDict(TypedDict): """Basic configuration of the construction process. @@ -139,12 +139,12 @@ class ConstructConfigurationTypedDict(TypedDict): **THIS MUST BE KEPT IDENTICAL TO ConstructConfiguration.** """ - flatten_all_nested: bool - allow_positional_args: bool - allow_plain_dict_fields: bool - field_separator: str - field_index_types: str - simplify_ids: bool + flatten_all_nested: NotRequired[bool] + allow_positional_args: NotRequired[bool] + allow_plain_dict_fields: NotRequired[bool] + field_separator: NotRequired[str] + field_index_types: NotRequired[str] + simplify_ids: NotRequired[bool] @define class GlobalConfiguration: diff --git a/src/dewret/renderers/cwl.py b/src/dewret/renderers/cwl.py index 233f3a67..23b4214e 100644 --- a/src/dewret/renderers/cwl.py +++ b/src/dewret/renderers/cwl.py @@ -556,7 +556,7 @@ class OutputsDefinition: @classmethod def from_results( - cls, results: dict[str, StepReference[Any]] | list[StepReference[Any]] | tuple[StepReference[Any]] + cls, results: dict[str, StepReference[Any]] | list[StepReference[Any]] | tuple[StepReference[Any], ...] ) -> "OutputsDefinition": """Takes a mapping of results into a CWL structure. diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index c6bb30ff..31d55120 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -32,32 +32,29 @@ import inspect import importlib import sys -from typing import TypedDict, NotRequired, Unpack from enum import Enum from functools import cached_property from collections.abc import Callable -from typing import Any, ParamSpec, TypeVar, cast, Generator +from typing import Any, ParamSpec, TypeVar, cast, Generator, Unpack, Literal from types import TracebackType from attrs import has as attrs_has -from dataclasses import dataclass, is_dataclass +from dataclasses import is_dataclass import traceback from concurrent.futures import ThreadPoolExecutor from contextvars import ContextVar, copy_context from contextlib import contextmanager -from .utils import is_firm, make_traceback, is_expr, is_raw_type +from .utils import is_firm, make_traceback, is_expr from .workflow import ( expr_to_references, unify_workflows, UNSET, Reference, - StepReference, Workflow, Lazy, LazyEvaluation, Target, LazyFactory, - merge_workflows, Parameter, param, Task, @@ -65,8 +62,7 @@ ) from .backends._base import BackendModule from .annotations import FunctionAnalyser -from .core import get_configuration, set_configuration, IteratedGenerator, ConstructConfiguration -import ast +from .core import get_configuration, set_configuration, IteratedGenerator, ConstructConfigurationTypedDict Param = ParamSpec("Param") RetType = TypeVar("RetType") @@ -146,13 +142,17 @@ def evaluate(self, task: Lazy | list[Lazy] | tuple[Lazy], __workflow__: Workflow Args: task: the task to evaluate. __workflow__: workflow within which this exists. + thread_pool: existing pool of threads to run this in, or None. **kwargs: any arguments to pass to the task. """ result = self.backend.run(__workflow__, task, thread_pool=thread_pool, **kwargs) - result, collected_workflow = unify_workflows(result, __workflow__) + new_result, collected_workflow = unify_workflows(result, __workflow__) + + if collected_workflow is None: + raise RuntimeError("A new workflow could not be found") # Then we set the result to be the whole thing - collected_workflow.set_result(result) + collected_workflow.set_result(new_result) return collected_workflow.result def unwrap(self, task: Lazy) -> Target: @@ -193,9 +193,8 @@ def ensure_lazy(self, task: Any) -> Lazy | None: def __call__( self, task: Any, - simplify_ids: bool = False, __workflow__: Workflow | None = None, - **kwargs: ConstructConfiguration, + **kwargs: Unpack[ConstructConfigurationTypedDict], ) -> Workflow: """Execute the lazy evalution. @@ -217,6 +216,7 @@ def _initializer(): thread_pool = ThreadPoolExecutor(initializer=_initializer) result = self.evaluate(task, workflow, thread_pool=thread_pool, **kwargs) + simplify_ids = bool(get_configuration("simplify_ids")) return Workflow.from_result(result, simplify_ids=simplify_ids) @@ -387,12 +387,12 @@ def _fn( **kwargs: Param.kwargs, ) -> RetType: configuration = None - allow_positional_args = get_configuration("allow_positional_args") + allow_positional_args = bool(get_configuration("allow_positional_args")) try: # Ensure that all arguments are passed as keyword args and prevent positional args. # passed at all. - if args and not get_configuration("allow_positional_args"): + if args and not allow_positional_args: raise TypeError( f""" Calling {fn.__name__}: Arguments must _always_ be named, @@ -409,9 +409,9 @@ def add_numbers(left: int, right: int): # Ensure that the passed arguments are, at least, a Python-match for the signature. sig = inspect.signature(fn) positional_args = {key: False for key in kwargs} - for arg, (key, _) in zip(args, sig.parameters.items()): + for arg, (key, _) in zip(args, sig.parameters.items(), strict=False): if isinstance(arg, IteratedGenerator): - for inner_arg, (key, _) in zip(arg, sig.parameters.items()): + for inner_arg, (key, _) in zip(arg, sig.parameters.items(), strict=False): if key in positional_args: continue kwargs[key] = inner_arg @@ -430,8 +430,9 @@ def _to_param_ref(value): val, kw_refs = expr_to_references(val, remap=_to_param_ref) refs += kw_refs kwargs[key] = val - workflows = [ - reference.__workflow__ + # Not realistically going to be other than Workflow. + workflows: list[Workflow] = [ + cast(Workflow, reference.__workflow__) for reference in refs if hasattr(reference, "__workflow__") and reference.__workflow__ is not None @@ -439,7 +440,7 @@ def _to_param_ref(value): if __workflow__ is not None: workflows.insert(0, __workflow__) if workflows: - workflow = merge_workflows(*workflows) + workflow = Workflow.assimilate(*workflows) else: workflow = Workflow() @@ -452,17 +453,20 @@ def _to_param_ref(value): elif is_firm(value): # We leave this reference dangling for a consumer to pick up ("tethered"), unless # we are in a nested task, that does not have any existence of its own. - tethered = ( + tethered: Literal[False] | None = ( False if nested and ( flatten_nested or get_configuration("flatten_all_nested") ) else None ) - kwargs[var] = param( - var, - value, - tethered=tethered, - autoname=tethered is not False, - typ=analyser.get_argument_annotation(var) or UNSET + kwargs[var] = cast( + Parameter, + param( + var, + value, + tethered=tethered, + autoname=tethered is not False, + typ=analyser.get_argument_annotation(var) or UNSET + ) ).make_reference(workflow=workflow) original_kwargs = dict(kwargs) fn_globals = analyser.globals @@ -515,13 +519,16 @@ def {fn.__name__}(...) -> ...: (attrs_has(value) or is_dataclass(value)) and not inspect.isclass(value) ): - kwargs[var] = param( - var, - default=value, - tethered=False, - typ=analyser.get_argument_annotation(var, exhaustive=True) or UNSET + kwargs[var] = cast( + Parameter, + param( + var, + default=value, + tethered=False, + typ=analyser.get_argument_annotation(var, exhaustive=True) or UNSET + ) ).make_reference(workflow=workflow) - elif is_expr(value) and expr_to_references(value)[1] is not []: + elif is_expr(value) and (expr_refs := expr_to_references(value)) and len(expr_refs[1]) != 0: kwargs[var] = value elif nested: raise NotImplementedError( @@ -542,13 +549,16 @@ def {fn.__name__}(...) -> ...: else: nested_workflow = Workflow(name=fn.__name__) nested_globals: Param.kwargs = { - var: param( - var, - default=value.__default__ if hasattr(value, "__default__") else UNSET, - typ=( - value.__type__ - ), - tethered=nested_workflow + var: cast( + Parameter, + param( + var, + default=value.__default__ if hasattr(value, "__default__") else UNSET, + typ=( + value.__type__ + ), + tethered=nested_workflow + ) ).make_reference(workflow=nested_workflow) if isinstance(value, Reference) else value for var, value in kwargs.items() } @@ -589,7 +599,7 @@ def {fn.__name__}(...) -> ...: configuration.__exit__(None, None, None) _fn.__step_expression__ = True # type: ignore - _fn.__original__ = fn + _fn.__original__ = fn # type: ignore return LazyEvaluation(_fn) return _task diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 83d65624..eb97c9cd 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -502,9 +502,10 @@ def assimilate(cls, *workflow_args: Workflow) -> "Workflow": This could happen if the hashing function is flawed or some Python magic to do with Targets being passed. - Argument: - left: workflow to use as base - right: workflow to combine on top + Args: + workflow_args: workflows to use as base + + j """ workflows = sorted((w for w in set(workflow_args)), key=lambda w: w.id) base = workflows[0] @@ -1587,20 +1588,6 @@ def __iter__(self) -> Generator[Reference, None, None]: # We cast this so that we can treat a step iterator as if it really loops over results. yield cast(Reference, IteratedGenerator(self)) -def merge_workflows(*workflows: Workflow) -> Workflow: - """Combine several workflows into one. - - Merges a series of workflows by combining steps and tasks. - - Argument: - *workflows: series of workflows to combine. - - Returns: - One workflow with all steps. - """ - return Workflow.assimilate(*workflows) - - def is_task(task: Lazy) -> bool: """Decide whether this is a task. diff --git a/tests/test_nested.py b/tests/test_nested.py index 99ccf1ca..c05f85c2 100644 --- a/tests/test_nested.py +++ b/tests/test_nested.py @@ -2,7 +2,7 @@ import pytest import math from dewret.workflow import param -from dewret.tasks import construct, task, factory +from dewret.tasks import construct from dewret.renderers.cwl import render from ._lib.extra import reverse_list, max_list