diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 30268d91..38b0a6b1 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -23,11 +23,11 @@ from attrs import has as attr_has, resolve_types, fields as attrs_fields from dataclasses import is_dataclass, fields as dataclass_fields from collections import Counter, OrderedDict -from typing import Protocol, Any, TypeVar, Generic, cast, Literal, Iterable, get_origin, get_args, Generator, Sized, Sequence, get_type_hints, TYPE_CHECKING +from typing import Protocol, Any, TypeVar, Generic, cast, Literal, Iterable, get_origin, get_args, Generator, Sized, Sequence, get_type_hints, TYPE_CHECKING, Hashable from uuid import uuid4 import logging -from sympy import Symbol, Expr, Basic +from sympy import Symbol, Expr, Basic, Tuple logger = logging.getLogger(__name__) @@ -545,12 +545,24 @@ def assimilate(cls, *workflow_args: Workflow) -> "Workflow": for step in base.steps: step.set_workflow(base, with_arguments=True) - results = sorted(set((w.result for w in workflows if w.has_result))) + hashable_workflows: list[Workflow] = [w for w in workflows if isinstance(w.result, Hashable)] + if len(hashable_workflows) != len(workflows): + raise NotImplementedError("Some results are not hashable.") + + def _get_order(result: None | StepReference | Iterable[StepReference]) -> str: + if result is None: + return "" + if isinstance(result, StepReference): + return result.id + return "|".join(r for r in result) + + + results = sorted(set({w.result for w in hashable_workflows if w.has_result}), key=lambda r: _get_order(r)) if len(results) == 1: result = results[0] else: - results = sorted({r if isinstance(r, tuple | list) else (r,) for r in results}) - result = sum(map(list, results), []) + list_results = [r if isinstance(r, tuple | list | Tuple) else (r,) for r in results] + result = sum(map(list, list_results), []) if result is not None and result != []: unify_workflows(result, base, set_only=True)