Skip to content

Commit

Permalink
fix: make result ordering consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
philtweir committed Aug 25, 2024
1 parent 1dc5692 commit 3871ef4
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions src/dewret/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
from attrs import has as attr_has, resolve_types, fields as attrs_fields
from dataclasses import is_dataclass, fields as dataclass_fields
from collections import Counter, OrderedDict
from typing import Protocol, Any, TypeVar, Generic, cast, Literal, Iterable, get_origin, get_args, Generator, Sized, Sequence, get_type_hints, TYPE_CHECKING
from typing import Protocol, Any, TypeVar, Generic, cast, Literal, Iterable, get_origin, get_args, Generator, Sized, Sequence, get_type_hints, TYPE_CHECKING, Hashable
from uuid import uuid4
import logging

from sympy import Symbol, Expr, Basic
from sympy import Symbol, Expr, Basic, Tuple

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -545,12 +545,24 @@ def assimilate(cls, *workflow_args: Workflow) -> "Workflow":
for step in base.steps:
step.set_workflow(base, with_arguments=True)

results = sorted(set((w.result for w in workflows if w.has_result)))
hashable_workflows: list[Workflow] = [w for w in workflows if isinstance(w.result, Hashable)]
if len(hashable_workflows) != len(workflows):
raise NotImplementedError("Some results are not hashable.")

def _get_order(result: None | StepReference | Iterable[StepReference]) -> str:
if result is None:
return ""
if isinstance(result, StepReference):
return result.id
return "|".join(r for r in result)


results = sorted(set({w.result for w in hashable_workflows if w.has_result}), key=lambda r: _get_order(r))
if len(results) == 1:
result = results[0]
else:
results = sorted({r if isinstance(r, tuple | list) else (r,) for r in results})
result = sum(map(list, results), [])
list_results = [r if isinstance(r, tuple | list | Tuple) else (r,) for r in results]
result = sum(map(list, list_results), [])

if result is not None and result != []:
unify_workflows(result, base, set_only=True)
Expand Down

0 comments on commit 3871ef4

Please sign in to comment.