Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(typing): Add typing #432

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ indent_size = 4
end_of_line = lf
insert_final_newline = true
trim_trailing_whitespace = true
max_line_length = 79
max_line_length = 140
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ruff setup is on 140 characters, does not make sense to prevent IDE

155 changes: 81 additions & 74 deletions SpiffWorkflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,23 @@
# 02110-1301 USA

import logging
from typing import Optional, Any

from .serializer.base import Serializer
from .specs import WorkflowSpec
from .task import Task
from .util.task import TaskState, TaskIterator, TaskFilter
from .util.compat import mutex
from .util.event import Event
from .exceptions import TaskNotFoundException, WorkflowException
from .exceptions import TaskNotFoundException

logger = logging.getLogger('spiff.workflow')


class Workflow(object):
class Workflow:
"""The instantiation of a `WorkflowSpec`.

Reprsents the state of a running workflow and its data.
Represents the state of a running workflow and its data.

Attributes:
spec (`WorkflowSpec`): the spec that describes this workflow instance
Expand All @@ -44,11 +47,15 @@ class Workflow(object):
completed_event (`Event`): an event holding callbacks to be run when the workflow completes
"""

def __init__(self, workflow_spec, deserializing=False):
def __init__(
self,
workflow_spec: WorkflowSpec,
deserializing: bool = False,
) -> None:
"""
Parameters:
workflow_spec (`WorkflowSpec`): the spec that describes this workflow
deserializing (bool): whether this workflow is being deserialized
workflow_spec: The spec that describes this workflow.
deserializing: Whether this workflow is being deserialized.
"""
self.spec = workflow_spec
self.data = {}
Expand All @@ -67,84 +74,84 @@ def __init__(self, workflow_spec, deserializing=False):
logger.info('Initialized workflow', extra=self.collect_log_extras())
self.task_tree._ready()

def is_completed(self):
def is_completed(self) -> bool:
"""Checks whether the workflow is complete.

Returns:
bool: True if the workflow has no unfinished tasks
True if the workflow has no unfinished tasks.
"""
if not self.completed:
iter = TaskIterator(self.task_tree, state=TaskState.NOT_FINISHED_MASK)
_iter = TaskIterator(self.task_tree, state=TaskState.NOT_FINISHED_MASK)
try:
next(iter)
next(_iter)
except StopIteration:
self.completed = True
return self.completed

def manual_input_required(self):
def manual_input_required(self) -> bool:
"""Checks whether the workflow requires manual input.

Returns:
bool: True if the workflow cannot proceed until manual tasks are complete
True if the workflow cannot proceed until manual tasks are complete.
"""
iter = TaskIterator(self.task_tree, state=TaskState.READY, manual=False)
_iter = TaskIterator(self.task_tree, state=TaskState.READY, manual=False)
try:
next(iter)
next(_iter)
except StopIteration:
return True
return False

def get_tasks(self, first_task=None, **kwargs):
def get_tasks(self, first_task: Task = None, **kwargs) -> list[Task]:
"""Returns a list of `Task`s that meet the conditions specified `kwargs`, starting from the root by default.

Notes:
Keyword args are passed directly to `get_tasks_iterator`
Keyword args are passed directly to `get_tasks_iterator`.

Returns:
list(`Task`): the tasks that match the filtering conditions
The tasks that match the filtering conditions.
"""
return [t for t in self.get_tasks_iterator(first_task, **kwargs)]
return [task for task in self.get_tasks_iterator(first_task, **kwargs)]

def get_next_task(self, first_task=None, **kwargs):
def get_next_task(self, first_task: Task = None, **kwargs) -> Optional[Task]:
"""Returns the next task that meets the iteration conditions, starting from the root by default.

Parameters:
first_task (`Task`): search beginning from this task
first_task: Search beginning from this task.

Notes:
Other keyword args are passed directly into `get_tasks_iterator`
Other keyword args are passed directly into `get_tasks_iterator`.

Returns:
`Task` or None: the first task that meets the conditions or None if no tasks match
The first task that meets the conditions or None if no tasks match
"""
iter = self.get_tasks_iterator(first_task, **kwargs)
_iter = self.get_tasks_iterator(first_task, **kwargs)
try:
return next(iter)
return next(_iter)
except StopIteration:
return None

def get_tasks_iterator(self, first_task=None, **kwargs):
def get_tasks_iterator(self, first_task: Task = None, **kwargs) -> TaskIterator:
"""Returns an iterator of Tasks that meet the conditions specified `kwargs`, starting from the root by default.

Parameters:
first_task (`Task`): search beginning from this task
first_task: Search beginning from this task.

Notes:
Other keyword args are passed directly into `TaskIterator`
Other keyword args are passed directly into `TaskIterator`.

Returns:
`TaskIterator`: an iterator over the matching tasks
An iterator over the matching tasks.
"""
return TaskIterator(first_task or self.task_tree, **kwargs)

def get_task_from_id(self, task_id):
def get_task_from_id(self, task_id: str) -> Task:
"""Returns the task with the given id.

Args:
task_id: the id of the task to run
task_id: The id of the task to run.

Returns:
`Task`: the task
The task.

Raises:
`TaskNotFoundException`: if the task does not exist
Expand All @@ -153,24 +160,24 @@ def get_task_from_id(self, task_id):
raise TaskNotFoundException(f'A task with id {task_id} was not found', task_spec=self.spec)
return self.tasks.get(task_id)

def run_task_from_id(self, task_id):
def run_task_from_id(self, task_id: str) -> Optional[bool]:
"""Runs the task with the given id.

Args:
task_id: the id of the task to run
task_id: The id of the task to run.
"""
task = self.get_task_from_id(task_id)
return task.run()

def run_next(self, use_last_task=True, halt_on_manual=True):
def run_next(self, use_last_task: bool = True, halt_on_manual: bool = True) -> bool:
"""Runs the next task, starting from the branch containing the last completed task by default.

Parameters:
use_last_task (bool): start with the currently running branch
halt_on_manual (bool): do not run tasks with `TaskSpec`s that have the `manual` attribute set
use_last_task: Start with the currently running branch.
halt_on_manual: Do not run tasks with `TaskSpec`s that have the `manual` attribute set.

Returns:
bool: True when a task runs sucessfully
True when a task runs successfully.
"""
first_task = self.last_task if use_last_task and self.last_task is not None else self.task_tree
task_filter = TaskFilter(
Expand All @@ -188,29 +195,29 @@ def run_next(self, use_last_task=True, halt_on_manual=True):
else:
return task.run()

def run_all(self, use_last_task=True, halt_on_manual=True):
def run_all(self, use_last_task: bool = True, halt_on_manual: bool = True) -> None:
"""Runs all possible tasks, starting from the current branch by default.

Parameters:
use_last_task (bool): start with the currently running branch
halt_on_manual (bool): do not run tasks with `TaskSpec`s that have the `manual` attribute set
use_last_task: Start with the currently running branch.
halt_on_manual: Do not run tasks with `TaskSpec`s that have the `manual` attribute set.
"""
while self.run_next(use_last_task, halt_on_manual):
pass

def update_waiting_tasks(self):
def update_waiting_tasks(self) -> None:
"""Update all tasks in the WAITING state"""
for task in TaskIterator(self.task_tree, state=TaskState.WAITING):
task.task_spec._update(task)

def cancel(self, success=False):
def cancel(self, success: bool = False) -> list[Task]:
"""Cancels all open tasks in the workflow.

Args:
success (bool): the state of the workflow
success: The state of the workflow.

Returns:
list(`Task`): the cancelled tasks
The cancelled tasks.
"""
self.success = success
self.completed = True
Expand All @@ -222,39 +229,39 @@ def cancel(self, success=False):
task.cancel()
return cancelled

def set_data(self, **kwargs):
def set_data(self, **kwargs) -> None:
"""Defines the given attribute/value pairs."""
self.data.update(kwargs)

def get_data(self, name, default=None):
def get_data(self, name: str, default: Optional[Any] = None) -> Optional[Any]:
"""Returns the value of the data field with the given name, or the given
default value if the data field does not exist.

Args:
name (str): the dictionary key to return
default (obj): a default value to return if the key does not exist
name: The dictionary key to return.
default: A default value to return if the key does not exist.

Returns:
the value of the key, or the default
The value of the key, or the default.
"""
return self.data.get(name, default)

def reset_from_task_id(self, task_id, data=None):
"""Removed all descendendants of this task and set this task to be runnable.
def reset_from_task_id(self, task_id: str, data: dict = None) -> list[Task]:
"""Removed all descendants of this task and set this task to be runnable.

Args:
task_id: the id of the task to reset to
data (dict): optionally replace the data (if None, data will be copied from the parent task)
task_id: The id of the task to reset to.
data: Optionally replace the data (if None, data will be copied from the parent task).

Returns: extra.update(
list(`Task`): tasks removed from the tree
Returns:
Tasks removed from the tree.
"""
task = self.get_task_from_id(task_id)
self.last_task = task.parent
return task.reset_branch(data)

def collect_log_extras(self, dct=None):
"""Return logging details for this workflow"""
def collect_log_extras(self, dct: Optional[dict] = None) -> dict:
"""Return logging details for this workflow."""
extra = dct or {}
extra.update({
'workflow_spec': self.spec.name,
Expand All @@ -265,13 +272,13 @@ def collect_log_extras(self, dct=None):
extra.update({'tasks': [t.id for t in Workflow.get_tasks(self)]})
return extra

def _predict(self, mask=TaskState.NOT_FINISHED_MASK):
"""Predict tasks with the provided mask"""
def _predict(self, mask: TaskState = TaskState.NOT_FINISHED_MASK) -> None:
"""Predict tasks with the provided mask."""
for task in Workflow.get_tasks(self, state=TaskState.NOT_FINISHED_MASK):
task.task_spec._predict(task, mask=mask)

def _task_completed_notify(self, task):
"""Called whenever a task completes"""
def _task_completed_notify(self, task: Task) -> None:
"""Called whenever a task completes."""
self.last_task = task
if task.task_spec.name == 'End':
self._mark_complete(task)
Expand All @@ -280,29 +287,29 @@ def _task_completed_notify(self, task):
else:
self.update_waiting_tasks()

def _remove_task(self, task_id):
def _remove_task(self, task_id: str) -> None:
task = self.tasks[task_id]
for child in task.children:
for child in task.children:
self._remove_task(child.id)
task.parent._children.remove(task.id)
self.tasks.pop(task_id)

def _mark_complete(self, task):
def _mark_complete(self, task: Task) -> None:
logger.info('Workflow completed', extra=self.collect_log_extras())
self.data.update(task.data)
self.completed = True

def _get_mutex(self, name):
"""Get or create a mutex"""
def _get_mutex(self, name: str) -> mutex:
"""Get or create a mutex."""
if name not in self.locks:
self.locks[name] = mutex()
return self.locks[name]

def get_task_mapping(self):
def get_task_mapping(self) -> dict:
"""I don't know that this does.

Seriously, this returns a mapping of thread ids to tasks in that thread. It can be used to identify
tasks by branch and use this information for decision making (despite the flawed implementation
tasks by branch and use this information for decision-making (despite the flawed implementation
mechanism; IMO, this should be maintained by the workflow rather than a class attribute).
"""
task_mapping = {}
Expand All @@ -314,19 +321,19 @@ def get_task_mapping(self):
task_mapping[task.thread_id] = thread_task_mapping
return task_mapping

def get_dump(self):
def get_dump(self) -> str:
"""Returns a string representation of the task tree.

Returns:
str: a tree view of the current workflow state
A tree view of the current workflow state.
"""
return self.task_tree.get_dump()

def dump(self):
"""Print a dump of the current task tree"""
def dump(self) -> None:
"""Print a dump of the current task tree."""
print(self.task_tree.dump())

def serialize(self, serializer, **kwargs):
def serialize(self, serializer: Serializer, **kwargs) -> Any:
"""
Serializes a Workflow instance using the provided serializer.

Expand All @@ -340,7 +347,7 @@ def serialize(self, serializer, **kwargs):
return serializer.serialize_workflow(self, **kwargs)

@classmethod
def deserialize(cls, serializer, s_state, **kwargs):
def deserialize(cls, serializer: Serializer, s_state: Any, **kwargs) -> "Workflow":
"""
Deserializes a Workflow instance using the provided serializer.

Expand Down