diff --git a/docs/quickstart.md b/docs/quickstart.md index 34a5be6c..c5d04e13 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -97,14 +97,14 @@ and backends, as well as bespoke serialization or formatting. >>> >>> result = increment(num=3) >>> workflow = construct(result, simplify_ids=True) ->>> cwl = render(workflow) +>>> cwl = render(workflow)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 inputs: increment-1-num: default: 3 - label: increment-1-num + label: num type: int outputs: out: diff --git a/docs/workflows.md b/docs/workflows.md index 94b05efb..84d78fe7 100644 --- a/docs/workflows.md +++ b/docs/workflows.md @@ -63,19 +63,15 @@ In code, this would be: ... left=double(num=increment(num=23)), ... right=mod10(num=increment(num=23)) ... ) ->>> workflow = construct(result, simplify_ids=True) ->>> cwl = render(workflow) +>>> wkflw = construct(result, simplify_ids=True) +>>> cwl = render(wkflw)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 inputs: increment-1-num: default: 23 - label: increment-1-num - type: int - increment-2-num: - default: 23 - label: increment-2-num + label: num type: int outputs: out: @@ -86,7 +82,7 @@ steps: double-1: in: num: - source: increment-2/out + source: increment-1/out out: - out run: double @@ -97,13 +93,6 @@ steps: out: - out run: increment - increment-2: - in: - num: - source: increment-2-num - out: - - out - run: increment mod10-1: in: num: @@ -157,8 +146,8 @@ This duplication can be avoided by explicitly indicating that the parameters are ... left=double(num=increment(num=num)), ... right=mod10(num=increment(num=num)) ... ) ->>> workflow = construct(result, simplify_ids=True) ->>> cwl = render(workflow) +>>> wkflw = construct(result, simplify_ids=True) +>>> cwl = render(wkflw)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 @@ -232,8 +221,8 @@ For example: ... return (num + INPUT_NUM) % INPUT_NUM >>> >>> result = rotate(num=5) ->>> workflow = construct(result, simplify_ids=True) ->>> cwl = render(workflow) +>>> wkflw = construct(result, simplify_ids=True) +>>> cwl = render(wkflw)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 @@ -244,7 +233,7 @@ inputs: type: int rotate-1-num: default: 5 - label: rotate-1-num + label: num type: int outputs: out: @@ -284,22 +273,24 @@ As code: ```python >>> import sys >>> import yaml ->>> from dewret.tasks import task, construct, nested_task +>>> from dewret.core import set_configuration +>>> from dewret.tasks import task, construct, workflow >>> from dewret.renderers.cwl import render >>> INPUT_NUM = 3 >>> @task() ... def rotate(num: int) -> int: -... """Rotate an integer.""" -... return (num + INPUT_NUM) % INPUT_NUM +... """Rotate an integer.""" +... return (num + INPUT_NUM) % INPUT_NUM >>> ->>> @nested_task() +>>> @workflow() ... def double_rotate(num: int) -> int: -... """Rotate an integer twice.""" -... return rotate(num=rotate(num=num)) +... """Rotate an integer twice.""" +... return rotate(num=rotate(num=num)) >>> ->>> result = double_rotate(num=3) ->>> workflow = construct(result, simplify_ids=True) ->>> cwl = render(workflow) +>>> with set_configuration(flatten_all_nested=True): +... result = double_rotate(num=3) +... wkflw = construct(result, simplify_ids=True) +... cwl = render(wkflw)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 @@ -315,7 +306,7 @@ inputs: outputs: out: label: out - outputSource: rotate-2/out + outputSource: rotate-1/out type: int steps: rotate-1: @@ -323,7 +314,7 @@ steps: INPUT_NUM: source: INPUT_NUM num: - source: num + source: rotate-2/out out: - out run: rotate @@ -332,7 +323,7 @@ steps: INPUT_NUM: source: INPUT_NUM num: - source: rotate-1/out + source: num out: - out run: rotate @@ -349,7 +340,7 @@ For example, the following code renders the same workflow as in the previous exa ```python -@nested_task() +@workflow() def double_rotate(num: int) -> int: """Rotate an integer twice.""" unused_var = increment(num=num) @@ -409,19 +400,15 @@ As code: ... left=shuffle(max_cards_per_suit=13).hearts, ... right=shuffle(max_cards_per_suit=13).diamonds ... ) ->>> workflow = construct(red_total, simplify_ids=True) ->>> cwl = render(workflow) +>>> wkflw = construct(red_total, simplify_ids=True) +>>> cwl = render(wkflw)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 inputs: shuffle-1-max_cards_per_suit: default: 13 - label: shuffle-1-max_cards_per_suit - type: int - shuffle-2-max_cards_per_suit: - default: 13 - label: shuffle-2-max_cards_per_suit + label: max_cards_per_suit type: int outputs: out: @@ -447,28 +434,10 @@ steps: label: spades type: int run: shuffle - shuffle-2: - in: - max_cards_per_suit: - source: shuffle-2-max_cards_per_suit - out: - clubs: - label: clubs - type: int - diamonds: - label: diamonds - type: int - hearts: - label: hearts - type: int - spades: - label: spades - type: int - run: shuffle sum-1: in: left: - source: shuffle-2/hearts + source: shuffle-1/hearts right: source: shuffle-1/diamonds out: @@ -510,19 +479,15 @@ Here, we show the same example with `dataclasses`. ... left=shuffle(max_cards_per_suit=13).hearts, ... right=shuffle(max_cards_per_suit=13).diamonds ... ) ->>> workflow = construct(red_total, simplify_ids=True) ->>> cwl = render(workflow) +>>> wkflw = construct(red_total, simplify_ids=True) +>>> cwl = render(wkflw)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 inputs: shuffle-1-max_cards_per_suit: default: 13 - label: shuffle-1-max_cards_per_suit - type: int - shuffle-2-max_cards_per_suit: - default: 13 - label: shuffle-2-max_cards_per_suit + label: max_cards_per_suit type: int outputs: out: @@ -548,28 +513,10 @@ steps: label: spades type: int run: shuffle - shuffle-2: - in: - max_cards_per_suit: - source: shuffle-2-max_cards_per_suit - out: - clubs: - label: clubs - type: int - diamonds: - label: diamonds - type: int - hearts: - label: hearts - type: int - spades: - label: spades - type: int - run: shuffle sum-1: in: left: - source: shuffle-2/hearts + source: shuffle-1/hearts right: source: shuffle-1/diamonds out: @@ -589,7 +536,7 @@ dewret will produce multiple output workflows that reference each other. >>> import yaml >>> from attrs import define >>> from numpy import random ->>> from dewret.tasks import task, construct, subworkflow +>>> from dewret.tasks import task, construct, workflow >>> from dewret.renderers.cwl import render >>> @define ... class PackResult: @@ -611,21 +558,21 @@ dewret will produce multiple output workflows that reference each other. ... spades=random.randint(max_cards_per_suit), ... diamonds=random.randint(max_cards_per_suit) ... ) ->>> @subworkflow() -... def red_total(): +>>> @workflow() +... def red_total() -> int: ... return sum( ... left=shuffle(max_cards_per_suit=13).hearts, ... right=shuffle(max_cards_per_suit=13).diamonds ... ) ->>> @subworkflow() -... def black_total(): +>>> @workflow() +... def black_total() -> int: ... return sum( ... left=shuffle(max_cards_per_suit=13).spades, ... right=shuffle(max_cards_per_suit=13).clubs ... ) >>> total = sum(left=red_total(), right=black_total()) ->>> workflow = construct(total, simplify_ids=True) ->>> cwl, subworkflows = render(workflow) +>>> wkflw = construct(total, simplify_ids=True) +>>> cwl = render(wkflw)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 @@ -667,7 +614,7 @@ as a second term. >>> import yaml >>> from attrs import define >>> from numpy import random ->>> from dewret.tasks import task, construct, subworkflow +>>> from dewret.tasks import task, construct, workflow >>> from dewret.renderers.cwl import render >>> @define ... class PackResult: @@ -689,33 +636,25 @@ as a second term. ... def sum(left: int, right: int) -> int: ... return left + right >>> ->>> @subworkflow() -... def red_total(): +>>> @workflow() +... def red_total() -> int: ... return sum( ... left=shuffle(max_cards_per_suit=13).hearts, ... right=shuffle(max_cards_per_suit=13).diamonds ... ) ->>> @subworkflow() -... def black_total(): +>>> @workflow() +... def black_total() -> int: ... return sum( ... left=shuffle(max_cards_per_suit=13).spades, ... right=shuffle(max_cards_per_suit=13).clubs ... ) >>> total = sum(left=red_total(), right=black_total()) ->>> workflow = construct(total, simplify_ids=True) ->>> cwl, subworkflows = render(workflow) ->>> yaml.dump(subworkflows["red_total-1"], sys.stdout, indent=2) +>>> wkflw = construct(total, simplify_ids=True) +>>> cwl = render(wkflw) +>>> yaml.dump(cwl["red_total-1"], sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 -inputs: - shuffle-1-1-max_cards_per_suit: - default: 13 - label: shuffle-1-1-max_cards_per_suit - type: int - shuffle-1-2-max_cards_per_suit: - default: 13 - label: shuffle-1-2-max_cards_per_suit - type: int +inputs: {} outputs: out: label: out @@ -725,25 +664,7 @@ steps: shuffle-1-1: in: max_cards_per_suit: - source: shuffle-1-1-max_cards_per_suit - out: - clubs: - label: clubs - type: int - diamonds: - label: diamonds - type: int - hearts: - label: hearts - type: int - spades: - label: spades - type: int - run: shuffle - shuffle-1-2: - in: - max_cards_per_suit: - source: shuffle-1-2-max_cards_per_suit + default: 13 out: clubs: label: clubs @@ -761,7 +682,7 @@ steps: sum-1-1: in: left: - source: shuffle-1-2/hearts + source: shuffle-1-1/hearts right: source: shuffle-1-1/diamonds out: @@ -783,7 +704,7 @@ Below is the default output, treating `Pack` as a task. ```python >>> import sys >>> import yaml ->>> from dewret.tasks import subworkflow, factory, nested_task, construct, task +>>> from dewret.tasks import workflow, factory, workflow, construct, task >>> from attrs import define >>> from dewret.renderers.cwl import render >>> @define @@ -799,39 +720,39 @@ Below is the default output, treating `Pack` as a task. ... def sum(left: int, right: int) -> int: ... return left + right >>> ->>> @nested_task() -... def black_total(pack: PackResult): +>>> @workflow() +... def black_total(pack: PackResult) -> int: ... return sum( ... left=pack.spades, ... right=pack.clubs ... ) >>> pack = Pack(hearts=13, spades=13, diamonds=13, clubs=13) ->>> workflow = construct(black_total(pack=pack), simplify_ids=True) ->>> cwl = render(workflow) +>>> wkflw = construct(black_total(pack=pack), simplify_ids=True) +>>> cwl = render(wkflw)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 inputs: PackResult-1-clubs: default: 13 - label: PackResult-1-clubs + label: clubs type: int PackResult-1-diamonds: default: 13 - label: PackResult-1-diamonds + label: diamonds type: int PackResult-1-hearts: default: 13 - label: PackResult-1-hearts + label: hearts type: int PackResult-1-spades: default: 13 - label: PackResult-1-spades + label: spades type: int outputs: out: label: out - outputSource: sum-1/out + outputSource: black_total-1/out type: int steps: PackResult-1: @@ -858,15 +779,13 @@ steps: label: spades type: int run: PackResult - sum-1: + black_total-1: in: - left: - source: PackResult-1/spades - right: - source: PackResult-1/clubs + pack: + source: PackResult-1/out out: - out - run: sum + run: black_total ``` @@ -876,7 +795,7 @@ types are allowed. ```python >>> import sys >>> import yaml ->>> from dewret.tasks import task, factory, nested_task, construct +>>> from dewret.tasks import task, factory, workflow, construct >>> from attrs import define >>> from dewret.renderers.cwl import render >>> @define @@ -891,36 +810,36 @@ types are allowed. ... def sum(left: int, right: int) -> int: ... return left + right >>> ->>> @nested_task() -... def black_total(pack: PackResult): +>>> @workflow() +... def black_total(pack: PackResult) -> int: ... return sum( ... left=pack.spades, ... right=pack.clubs ... ) >>> pack = Pack(hearts=13, spades=13, diamonds=13, clubs=13) ->>> workflow = construct(black_total(pack=pack), simplify_ids=True) ->>> cwl = render(workflow, allow_complex_types=True, factories_as_params=True) +>>> wkflw = construct(black_total(pack=pack), simplify_ids=True) +>>> cwl = render(wkflw, allow_complex_types=True, factories_as_params=True)["black_total-1"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 inputs: - PackResult-1: - label: PackResult-1 + pack: + label: pack type: record outputs: out: label: out - outputSource: sum-1/out + outputSource: sum-1-1/out type: int steps: - sum-1: + sum-1-1: in: left: - source: PackResult-1/spades + source: pack/spades right: - source: PackResult-1/clubs + source: pack/clubs out: - out run: sum -``` \ No newline at end of file +``` diff --git a/src/dewret/annotations.py b/src/dewret/annotations.py index 74db06f6..44897289 100644 --- a/src/dewret/annotations.py +++ b/src/dewret/annotations.py @@ -56,8 +56,17 @@ def __init__(self, fn: Callable[..., Any]): @property def return_type(self) -> type: - """Return type of the callable.""" - return get_type_hints(inspect.unwrap(self.fn), include_extras=True)["return"] + """Return type of the callable. + + Returns: expected type of the return value. + + Raises: + ValueError: if the return value does not appear to be type-hinted. + """ + hints = get_type_hints(inspect.unwrap(self.fn), include_extras=True) + if "return" not in hints: + raise ValueError(f"Could not find type-hint for return value of {self.fn}") + return hints["return"] @staticmethod def _typ_has(typ: type, annotation: type) -> bool: diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index c59e3bfb..30268d91 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -465,7 +465,7 @@ def find_factories(self) -> dict[str, FactoryCall]: def find_parameters( self, include_factory_calls: bool = True - ) -> set[ParameterReference]: + ) -> set[Parameter]: """Crawl steps for parameter references. As the workflow does not hold its own list of parameters, this @@ -477,7 +477,7 @@ def find_parameters( _, references = expr_to_references( step.arguments for step in self.steps if (include_factory_calls or not isinstance(step, FactoryCall)) ) - return {ref for ref in references if isinstance(ref, ParameterReference)} + return {ref._.parameter for ref in references if isinstance(ref, ParameterReference)} @property def indexed_steps(self) -> dict[str, BaseStep]: @@ -545,7 +545,7 @@ def assimilate(cls, *workflow_args: Workflow) -> "Workflow": for step in base.steps: step.set_workflow(base, with_arguments=True) - results = sorted(set((w.result for w in workflows if w.result))) + results = sorted(set((w.result for w in workflows if w.has_result))) if len(results) == 1: result = results[0] else: @@ -587,9 +587,9 @@ def simplify_ids(self, infix: list[str] | None = None) -> None: param_counter = Counter[str]() name_to_original: dict[str, str] = {} for name, param in { - pr._.parameter.__name__: pr._.parameter - for pr in self.find_parameters() - if isinstance(pr, ParameterReference) + param.__name__: param + for param in self.find_parameters() + if isinstance(param, Parameter) }.items(): if param.__original_name__ != name: param_counter[param.__original_name__] += 1 @@ -1000,7 +1000,7 @@ def __init__( def _to_param_ref(value): if isinstance(value, Parameter): - return value.make_parameter(workflow=workflow) + return value.make_reference(workflow=workflow) expression, refs = expr_to_references(value, remap=_to_param_ref) for ref in refs: diff --git a/tests/test_fieldable.py b/tests/test_fieldable.py index 18495029..5ef7c93d 100644 --- a/tests/test_fieldable.py +++ b/tests/test_fieldable.py @@ -76,9 +76,9 @@ def test_can_get_field_reference_from_parameter(): my_param = param("my_param", typ=MyDataclass) result = sum(left=my_param.left, right=sum(left=my_param.right.left, right=my_param)) wkflw = construct(result, simplify_ids=True) - param_references = {(str(p), p.__type__) for p in wkflw.find_parameters()} + params = {(str(p), p.__type__) for p in wkflw.find_parameters()} - assert param_references == {("my_param/left", int), ("my_param", MyDataclass), ("my_param/right/left", int)} + assert params == {("my_param", MyDataclass)} rendered = render(wkflw, allow_complex_types=True)["__root__"] assert rendered == yaml.safe_load(""" class: Workflow diff --git a/tests/test_subworkflows.py b/tests/test_subworkflows.py index 5cd86ff4..7339ce25 100644 --- a/tests/test_subworkflows.py +++ b/tests/test_subworkflows.py @@ -6,6 +6,7 @@ from dewret.tasks import construct, workflow, task, factory, set_configuration from dewret.renderers.cwl import render from dewret.workflow import param +from attrs import define from ._lib.extra import increment, sum, pi @@ -542,3 +543,72 @@ def test_subworkflows_can_use_globals_in_right_scope() -> None: - out run: to_int """)) + +@define +class PackResult: + hearts: int + clubs: int + spades: int + diamonds: int + +def test_combining_attrs_and_factories(): + Pack = factory(PackResult) + + @task() + def sum(left: int, right: int) -> int: + return left + right + + @workflow() + def black_total(pack: PackResult) -> int: + return sum( + left=pack.spades, + right=pack.clubs + ) + pack = Pack(hearts=13, spades=13, diamonds=13, clubs=13) + wkflw = construct(black_total(pack=pack), simplify_ids=True) + cwl = render(wkflw, allow_complex_types=True, factories_as_params=True) + assert cwl["__root__"] == yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + PackResult-1: + label: PackResult-1 + type: record + outputs: + out: + label: out + outputSource: black_total-1/out + type: int + steps: + black_total-1: + in: + pack: + source: PackResult-1/out + out: + - out + run: black_total + """) + + assert cwl["black_total-1"] == yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + pack: + label: pack + type: record + outputs: + out: + label: out + outputSource: sum-1-1/out + type: int + steps: + sum-1-1: + in: + left: + source: pack/spades + right: + source: pack/clubs + out: + - out + run: sum + """)