Skip to content

Commit

Permalink
feat(eager): add eager evaluation and immediate task evaluation/execu…
Browse files Browse the repository at this point in the history
…tion
  • Loading branch information
philtweir committed Oct 18, 2024
1 parent 4c1a3fe commit 67938b6
Show file tree
Hide file tree
Showing 9 changed files with 188 additions and 19 deletions.
15 changes: 15 additions & 0 deletions src/dewret/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@
default=False,
help="Pretty-print output where possible.",
)
@click.option(
"--eager",
is_flag=True,
show_default=True,
default=False,
help="Eagerly evaluate tasks at render-time for debugging purposes.",
)
@click.option(
"--backend",
type=click.Choice(list(Backend.__members__)),
Expand Down Expand Up @@ -81,6 +88,7 @@ def render(
task: str,
arguments: list[str],
pretty: bool,
eager: bool,
backend: Backend,
construct_args: str,
renderer: str,
Expand Down Expand Up @@ -154,6 +162,13 @@ def _opener(key: str, mode: str) -> Generator[IO[Any], None, None]:
workflow = load_module_or_package(pkg, workflow_py)
task_fn = getattr(workflow, task)

if eager:
construct_kwargs["eager"] = True
with set_configuration(**construct_kwargs):
output = task_fn(**kwargs)
print(output)
return

try:
with (
set_configuration(**construct_kwargs),
Expand Down
2 changes: 1 addition & 1 deletion src/dewret/backends/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class BackendModule(Protocol):
"""
lazy: LazyFactory

def run(self, workflow: Workflow | None, task: Lazy | list[Lazy] | tuple[Lazy], thread_pool: ThreadPoolExecutor | None=None) -> StepReference[Any] | list[StepReference[Any]] | tuple[StepReference[Any]]:
def run(self, workflow: Workflow | None, task: Lazy | list[Lazy] | tuple[Lazy, ...], thread_pool: ThreadPoolExecutor | None=None) -> StepReference[Any] | list[StepReference[Any]] | tuple[StepReference[Any]]:
"""Execute a lazy task for this `Workflow`.
Args:
Expand Down
2 changes: 1 addition & 1 deletion src/dewret/backends/backend_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def is_lazy(task: Any) -> bool:

def run(
workflow: Workflow | None,
task: Lazy | list[Lazy] | tuple[Lazy],
task: Lazy | list[Lazy] | tuple[Lazy, ...],
thread_pool: ThreadPoolExecutor | None = None,
**kwargs: Any,
) -> Any:
Expand Down
2 changes: 2 additions & 0 deletions src/dewret/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ class ConstructConfiguration:
field_separator: str = "/"
field_index_types: str = "int"
simplify_ids: bool = False
eager: bool = False


class ConstructConfigurationTypedDict(TypedDict):
Expand All @@ -203,6 +204,7 @@ class ConstructConfigurationTypedDict(TypedDict):
field_separator: NotRequired[str]
field_index_types: NotRequired[str]
simplify_ids: NotRequired[bool]
eager: NotRequired[bool]


@define
Expand Down
30 changes: 26 additions & 4 deletions src/dewret/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

from .utils import is_firm, make_traceback, is_expr
from .workflow import (
execute_step,
expr_to_references,
unify_workflows,
UNSET,
Expand Down Expand Up @@ -142,7 +143,7 @@ def make_lazy(self) -> LazyFactory:
"""
return self.backend.lazy

def evaluate(self, task: Lazy | list[Lazy] | tuple[Lazy], __workflow__: Workflow, thread_pool: ThreadPoolExecutor | None=None, **kwargs: Any) -> Any:
def evaluate(self, task: Lazy | list[Lazy] | tuple[Lazy, ...], __workflow__: Workflow, thread_pool: ThreadPoolExecutor | None=None, **kwargs: Any) -> Any:
"""Evaluate a single task for a known workflow.
Args:
Expand Down Expand Up @@ -230,11 +231,28 @@ def _initializer() -> None:
lazy = _manager.make_lazy
ensure_lazy = _manager.ensure_lazy
unwrap = _manager.unwrap
evaluate = _manager.evaluate
construct = _manager

def evaluate(task: Any, *args: Any, execute: bool = False, **kwargs: Any) -> Any:
"""Get a result of a task, either as a value or lazily.
Args:
task: task to evaluate
*args: other arguments to the evaluator
execute: whether or not to evaluate to obtain the final result
**kwargs: other arguments to the evaluator
Returns:
Structure of lazy evaluations if lazy, else actual result.
"""
if execute:
return execute_step(task, *args, **kwargs)
else:
return _manager.evaluate(task, *args, **kwargs)

"""An alias pointing to an instance of the TaskManager class.
Used for constructing a set of tasks into a dewret workflow instance.
"""
construct = _manager


class TaskException(Exception):
Expand Down Expand Up @@ -313,7 +331,8 @@ def factory(fn: Callable[..., RetType]) -> Callable[..., RetType]:
Args:
fn: a callable to create the entity.
"""
return task(is_factory=True)(fn)
ret = task(is_factory=True)(fn)
return ret

# Workaround for PyCharm
factory: Callable[[Callable[..., RetType]], Callable[..., RetType]] = factory
Expand Down Expand Up @@ -398,6 +417,9 @@ def _fn(
__traceback__: TracebackType | None = None,
**kwargs: Param.kwargs,
) -> RetType:
if get_configuration("eager"):
return fn(*args, **kwargs)

configuration = None
allow_positional_args = bool(get_configuration("allow_positional_args"))

Expand Down
53 changes: 52 additions & 1 deletion src/dewret/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,13 +1151,23 @@ def _generate_id(self) -> str:

return f"{self.task}-{hasher(comp_tup)}"

def __call__(self, **additional_kwargs: Any) -> Any:
"""Evaluate this step eagerly.
Args:
**additional_kwargs: any extra/overriding arguments to the step.
"""
raise NotImplementedError("No eager evaluation for this BaseStep type")


class NestedStep(BaseStep):
"""Calling out to a subworkflow.
Type of BaseStep to call a subworkflow, which holds a reference to it.
"""

task: Workflow

def __init__(
self,
workflow: Workflow,
Expand Down Expand Up @@ -1205,11 +1215,31 @@ def return_type(self) -> Any:
raise RuntimeError("Can only use a subworkflow if the reference exists.")
return self.__subworkflow__.result_type

def __call__(self, **additional_kwargs: Any) -> Any:
"""Evaluate this nested step, by eagerly evaluating the subworkflow result.
Args:
**additional_kwargs: any extra/overriding arguments to the subworkflow result step.
"""
kwargs: dict[str, Any] = dict(self.arguments)
kwargs.update(additional_kwargs)
return execute_step(self.__subworkflow__.result, **kwargs)


class Step(BaseStep):
"""Regular step."""

...
task: Task

def __call__(self, **additional_kwargs: Any) -> Any:
"""Evaluate this step, by eagerly evaluating the result.
Args:
**additional_kwargs: any extra/overriding arguments to the step.
"""
kwargs: dict[str, Any] = dict(self.arguments)
kwargs.update(additional_kwargs)
return self.task.target(**kwargs)


class FactoryCall(Step):
Expand Down Expand Up @@ -1606,6 +1636,27 @@ def __make_reference__(self, **kwargs: Any) -> "StepReference[U]":
kwargs["workflow"] = self.__workflow__
return self._.step.make_reference(**kwargs)

def execute_step(task: Any, **kwargs: Any) -> Any:
"""Evaluate a single task for a known workflow.
Args:
task: the task to evaluate.
**kwargs: any arguments to pass to the task.
"""
if isinstance(task, list):
return [execute_step(t, **kwargs) for t in task]
elif isinstance(task, tuple):
return tuple(execute_step(t, **kwargs) for t in task)

if not isinstance(task, StepReference):
raise TypeError(
f"Attempted to execute a task that is not step-like (Step/Workflow): {type(task)}"
)

result = task._.step()

return result

class IterableStepReference(IterableMixin[U], StepReference[U]):
"""Iterable form of a step reference."""
def __iter__(self) -> Generator[Reference[U], None, None]:
Expand Down
12 changes: 12 additions & 0 deletions tests/_lib/extra.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from attrs import define
from dewret.tasks import task, workflow

from .other import nothing
Expand All @@ -6,6 +7,15 @@
test: float = nothing


@define
class PackResult:
"""A class representing the counts of card suits in a deck, including hearts, clubs, spades, and diamonds."""

hearts: int
clubs: int
spades: int
diamonds: int

@workflow()
def try_nothing() -> int:
"""Check that we can see AtRender in another module."""
Expand Down Expand Up @@ -69,3 +79,5 @@ def reverse_list(to_sort: list[int | float]) -> list[int | float]:
@task()
def max_list(lst: list[int | float]) -> int | float:
return max(lst)


78 changes: 78 additions & 0 deletions tests/test_execution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import math
from sympy import Expr, Symbol as S

from dewret.tasks import (
workflow,
factory,
task,
evaluate,
)
from dewret.core import set_configuration
from ._lib.extra import (
pi,
PackResult
)

def test_basic_eager_execution() -> None:
"""Check whether we can run a simple flow without lazy-evaluation.
Will skip dask delayeds and execute during construction.
"""
result = pi()

# Execute this step immediately.
output = evaluate(result, execute=True)
assert output == math.pi

with set_configuration(eager=True):
output = pi()

assert output == math.pi

def test_eager_execution_of_a_workflow() -> None:
"""Check whether we can run a workflow without lazy-evaluation.
Will skip dask delayeds and execute during construction.
"""
@workflow()
def pair_pi() -> tuple[float, float]:
return pi(), pi()

# Execute this step immediately.
with set_configuration(flatten_all_nested=True):
result = pair_pi()
output = evaluate(result, execute=True)

assert output == (math.pi, math.pi)

with set_configuration(eager=True):
output = pair_pi()

assert output == (math.pi, math.pi)


def test_eager_execution_of_a_rich_workflow() -> None:
"""Ensures that a workflow with both tasks and workflows can be eager-evaluated."""
Pack = factory(PackResult)

@task()
def sum(left: int, right: int) -> int:
return left + right

@workflow()
def black_total(pack: PackResult) -> int:
return sum(left=pack.spades, right=pack.clubs)

pack = Pack(hearts=13, spades=13, diamonds=13, clubs=13)

output_sympy: Expr = evaluate(black_total(pack=pack), execute=True)
clubs, spades = output_sympy.free_symbols
output: int = output_sympy.subs({spades: 13, clubs: 13})

assert output == 26

with set_configuration(eager=True):
pack = Pack(hearts=13, spades=13, diamonds=13, clubs=13)
output = black_total(pack=pack)

assert output == 26
13 changes: 1 addition & 12 deletions tests/test_subworkflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from dewret.core import set_configuration
from dewret.renderers.cwl import render
from dewret.workflow import param
from attrs import define

from ._lib.extra import increment, sum, pi
from ._lib.extra import increment, sum, pi, PackResult

CONSTANT: int = 3

Expand Down Expand Up @@ -559,16 +558,6 @@ def test_subworkflows_can_use_globals_in_right_scope() -> None:
)


@define
class PackResult:
"""A class representing the counts of card suits in a deck, including hearts, clubs, spades, and diamonds."""

hearts: int
clubs: int
spades: int
diamonds: int


def test_combining_attrs_and_factories() -> None:
"""Check combining attributes from a dataclass with factory-produced instances."""
Pack = factory(PackResult)
Expand Down

0 comments on commit 67938b6

Please sign in to comment.