diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index 96e830b6..a36f5ccc 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -29,8 +29,10 @@ jobs: run: python3 -m pip install . env: CONDUCTOR_PYTHON_VERSION: v0.0.0+test.unit + - name: Instal pytest + run: python3 -m pip install pytest==7.1.2 - name: Run Unit Tests - run: python3 -m unittest discover --verbose --start-directory=./tests/unit + run: python3 -m pytest tests/unit/ integration-tests: runs-on: ubuntu-latest steps: diff --git a/src/conductor/client/automator/task_handler.py b/src/conductor/client/automator/task_handler.py index a12b2e5e..295ae0bd 100644 --- a/src/conductor/client/automator/task_handler.py +++ b/src/conductor/client/automator/task_handler.py @@ -5,7 +5,9 @@ from conductor.client.worker.worker_interface import WorkerInterface from multiprocessing import Process from typing import List +from typing_extensions import Self import logging +import threading logger = logging.getLogger( Configuration.get_logging_formatted_name( @@ -16,19 +18,22 @@ class TaskHandler: def __init__( - self, - workers: List[WorkerInterface], - configuration: Configuration = None, - metrics_settings: MetricsSettings = None - ): + self, + workers: List[WorkerInterface], + configuration: Configuration = None, + metrics_settings: MetricsSettings = None + ) -> Self: if not isinstance(workers, list): workers = [workers] - self.__create_task_runner_processes( - workers, configuration, metrics_settings - ) - self.__create_metrics_provider_process( - metrics_settings - ) + self.configuration = configuration + self.metrics_settings = metrics_settings + + self._task_runner = {} + self._task_runner_thread = {} + self._task_runner_mutex = threading.Lock() + + self.start_worker(*workers) + self.__create_metrics_provider_process() logger.info('Created all processes') def __enter__(self): @@ -38,53 +43,51 @@ def __exit__(self, exc_type, exc_value, traceback): self.stop_processes() def stop_processes(self) -> None: - self.__stop_task_runner_processes() self.__stop_metrics_provider_process() def start_processes(self) -> None: - self.__start_task_runner_processes() self.__start_metrics_provider_process() logger.info('Started all processes') def join_processes(self) -> None: - self.__join_task_runner_processes() + self.__join_workers() self.__join_metrics_provider_process() logger.info('Joined all processes') - def __create_metrics_provider_process(self, metrics_settings: MetricsSettings) -> None: - if metrics_settings == None: + def __create_metrics_provider_process(self) -> None: + if self.metrics_settings == None: self.metrics_provider_process = None return self.metrics_provider_process = Process( target=MetricsCollector.provide_metrics, - args=(metrics_settings,) + args=(self.metrics_settings,) ) logger.info('Created MetricsProvider process') - def __create_task_runner_processes( - self, - workers: List[WorkerInterface], - configuration: Configuration, - metrics_settings: MetricsSettings - ) -> None: - self.task_runner_processes = [] + def start_worker(self, *workers: WorkerInterface) -> None: for worker in workers: - self.__create_task_runner_process( - worker, configuration, metrics_settings - ) + self.__start_worker(worker) logger.info('Created TaskRunner processes') - def __create_task_runner_process( - self, - worker: WorkerInterface, - configuration: Configuration, - metrics_settings: MetricsSettings - ) -> None: - task_runner = TaskRunner(worker, configuration, metrics_settings) - process = Process( - target=task_runner.run - ) - self.task_runner_processes.append(process) + def __start_worker(self, worker: WorkerInterface): + task_name = worker.get_task_definition_name() + with self._task_runner_mutex: + if task_name in self._task_runner: + raise Exception(f'worker already started for {task_name}') + task_runner = TaskRunner( + configuration=self.configuration, + task_definition_name=worker.task_definition_name, + batch_size=worker.batch_size, + polling_interval=worker.polling_interval, + worker_execution_function=worker.execute, + worker_id=worker.get_identity(), + domain=worker.get_domain(), + metrics_settings=self.metrics_settings + ) + self._task_runner[task_name] = task_runner + task_runner_thread = threading.Thread(target=task_runner.run) + self._task_runner_thread[task_name] = task_runner_thread + task_runner_thread.start() def __start_metrics_provider_process(self): if self.metrics_provider_process == None: @@ -92,29 +95,21 @@ def __start_metrics_provider_process(self): self.metrics_provider_process.start() logger.info('Started MetricsProvider process') - def __start_task_runner_processes(self): - for task_runner_process in self.task_runner_processes: - task_runner_process.start() - logger.info('Started TaskRunner processes') - def __join_metrics_provider_process(self): if self.metrics_provider_process == None: return self.metrics_provider_process.join() logger.info('Joined MetricsProvider processes') - def __join_task_runner_processes(self): - for task_runner_process in self.task_runner_processes: - task_runner_process.join() - logger.info('Joined TaskRunner processes') + def __join_workers(self): + with self._task_runner_mutex: + for thread in self._task_runner_thread: + thread.join() + logger.info('Joined all workers') def __stop_metrics_provider_process(self): self.__stop_process(self.metrics_provider_process) - def __stop_task_runner_processes(self): - for task_runner_process in self.task_runner_processes: - self.__stop_process(task_runner_process) - def __stop_process(self, process: Process): if process == None: return diff --git a/src/conductor/client/automator/task_runner.py b/src/conductor/client/automator/task_runner.py index 1cbaa50f..49cbf709 100644 --- a/src/conductor/client/automator/task_runner.py +++ b/src/conductor/client/automator/task_runner.py @@ -4,195 +4,396 @@ from conductor.client.http.api.task_resource_api import TaskResourceApi from conductor.client.http.models.task import Task from conductor.client.http.models.task_result import TaskResult +from conductor.client.http.models.task_result_status import TaskResultStatus from conductor.client.telemetry.metrics_collector import MetricsCollector -from conductor.client.worker.worker_interface import WorkerInterface +from copy import deepcopy +from typing import Callable import logging +import multiprocessing import sys +import threading import time import traceback -logger = logging.getLogger( +_logger = logging.getLogger( Configuration.get_logging_formatted_name( __name__ ) ) +_TASK_UPDATE_RETRY_ATTEMPTS_LIMIT = 3 +_BATCH_POLL_ERROR_RETRY_INTERVAL = 0.1 # 100ms +_BATCH_POLL_NO_AVAILABLE_WORKER_RETRY_INTERVAL = 0.001 # 1ms -class TaskRunner: - def __init__( - self, - worker: WorkerInterface, - configuration: Configuration = None, - metrics_settings: MetricsSettings = None - ): - if not isinstance(worker, WorkerInterface): - raise Exception('Invalid worker') - self.worker = worker - if not isinstance(configuration, Configuration): - configuration = Configuration() - self.configuration = configuration - self.metrics_collector = None - if metrics_settings is not None: - self.metrics_collector = MetricsCollector( - metrics_settings - ) - - def run(self) -> None: - if self.configuration != None: - self.configuration.apply_logging_config() - while True: - self.run_once() - def run_once(self) -> None: - task = self.__poll_task() - if task != None: - task_result = self.__execute_task(task) - self.__update_task(task_result) - self.__wait_for_polling_interval() - - def __poll_task(self) -> Task: - task_definition_name = self.worker.get_task_definition_name() - if self.metrics_collector is not None: - self.metrics_collector.increment_task_poll( - task_definition_name +def _batch_poll( + task_resource_api: TaskResourceApi, + task_name: str, + batch_size: int, + timeout: str, + worker_id: str = None, + domain: str = None, + metrics_collector: MetricsCollector = None, +) -> Task: + if batch_size < 1: + return None + _logger.debug( + f'Polling for the next {batch_size} task(s) with name {task_name}' + ) + kwargs = { + 'count': batch_size, + 'timeout': timeout, + } + if domain is not None: + kwargs['domain'] = domain + if worker_id is not None: + kwargs['workerid'] = worker_id + try: + start_time = time.time() + tasks = task_resource_api.batch_poll( + tasktype=task_name, + **kwargs, + ) + time_spent = time.time() - start_time + except Exception as e: + if metrics_collector is not None: + metrics_collector.increment_task_poll_error( + task_name, type(e) ) - logger.debug(f'Polling task for: {task_definition_name}') - try: - start_time = time.time() - domain = self.worker.get_domain() - if domain != None: - task = self.__get_task_resource_api().poll( - tasktype=task_definition_name, - workerid=self.worker.get_identity(), - domain=self.worker.get_domain(), - ) - else: - task = self.__get_task_resource_api().poll( - tasktype=task_definition_name, - workerid=self.worker.get_identity(), - ) - finish_time = time.time() - time_spent = finish_time - start_time - if self.metrics_collector is not None: - self.metrics_collector.record_task_poll_time( - task_definition_name, time_spent - ) - except Exception as e: - if self.metrics_collector is not None: - self.metrics_collector.increment_task_poll_error( - task_definition_name, type(e) - ) - logger.info( - f'Failed to poll task for: {task_definition_name}, reason: {traceback.format_exc()}' + _logger.info( + f'Failed to poll task for: {task_name}, reason: {traceback.format_exc()}' + ) + return None + if metrics_collector is not None: + metrics_collector.increment_task_poll( + task_name + ) + metrics_collector.record_task_poll_time( + task_name, time_spent + ) + if tasks != None: + _logger.debug( + 'Polled {} task(s) of type {} with worker_id {} and domain {}'.format( + len(tasks), task_name, worker_id, domain ) - return None - if task != None: - logger.debug( - f'Polled task: {task_definition_name}, worker_id: {self.worker.get_identity()}' + ) + return tasks + + +def _worker_process_daemon( + task_resource_api: TaskResourceApi, + task: Task, + worker_execution_function: Callable[[Task], TaskResult], + worker_id: str = None, + metrics_collector: MetricsCollector = None, +): + # apply_logging_config() + task_result = _execute_task( + task, + worker_execution_function, + worker_id, + metrics_collector + ) + _update_task( + task.task_def_name, + task_result, + task_resource_api, + metrics_collector + ) + + +def _execute_task( + task: Task, + worker_execution_function: Callable[[Task], TaskResult], + worker_id: str = None, + metrics_collector: MetricsCollector = None, +) -> TaskResult: + task_name = task.task_def_name + _logger.debug( + 'Executing task, id: {task_id}, workflow_instance_id: {workflow_instance_id}, task_definition_name: {task_definition_name}'.format( + task_id=task.task_id, + workflow_instance_id=task.workflow_instance_id, + task_definition_name=task_name + ) + ) + try: + start_time = time.time() + task_result = worker_execution_function(task) + time_spent = time.time() - start_time + except Exception as e: + if metrics_collector is not None: + metrics_collector.increment_task_execution_error( + task_name, type(e) ) - return task - - def __execute_task(self, task: Task) -> TaskResult: - if not isinstance(task, Task): - return None - task_definition_name = self.worker.get_task_definition_name() - logger.debug( - 'Executing task, id: {task_id}, workflow_instance_id: {workflow_instance_id}, task_definition_name: {task_definition_name}'.format( + failed_task_result = TaskResult( + task_id=task.task_id, + workflow_instance_id=task.workflow_instance_id, + worker_id=worker_id + ) + failed_task_result.status = TaskResultStatus.FAILED + failed_task_result.reason_for_incompletion = str(e) + _logger.info( + 'Failed to execute task, id: {task_id}, workflow_instance_id: {workflow_instance_id}, task_definition_name: {task_definition_name}, reason: {reason}'.format( task_id=task.task_id, workflow_instance_id=task.workflow_instance_id, - task_definition_name=task_definition_name + task_definition_name=task_result, + reason=traceback.format_exc() ) ) + return failed_task_result + if metrics_collector is not None: + metrics_collector.record_task_execute_time( + task_name, + time_spent + ) + _logger.debug( + 'Executed task, id: {}, workflow_instance_id: {}, task_definition_name: {}'.format( + task.task_id, task.workflow_instance_id, task_name + ) + ) + return task_result + + +def _update_task( + task_name: str, + task_result: TaskResult, + task_resource_api: TaskResourceApi, + metrics_collector: MetricsCollector = None, +) -> None: + _logger.debug( + 'Updating task, id: {}, workflow_instance_id: {}, task_definition_name: {}'.format( + task_id=task_result.task_id, + workflow_instance_id=task_result.workflow_instance_id, + task_definition_name=task_name + ) + ) + for attempt in range(_TASK_UPDATE_RETRY_ATTEMPTS_LIMIT + 1): + if attempt > 0: + # sleeps for [10s, 20s, 30s] on failure + time.sleep(attempt * 10) try: start_time = time.time() - task_result = self.worker.execute(task) - finish_time = time.time() - time_spent = finish_time - start_time - if self.metrics_collector is not None: - self.metrics_collector.record_task_execute_time( - task_definition_name, - time_spent - ) - self.metrics_collector.record_task_result_payload_size( - task_definition_name, - sys.getsizeof(task_result) - ) - logger.debug( - 'Executed task, id: {task_id}, workflow_instance_id: {workflow_instance_id}, task_definition_name: {task_definition_name}'.format( - task_id=task.task_id, - workflow_instance_id=task.workflow_instance_id, - task_definition_name=task_definition_name - ) + response = task_resource_api.update_task( + body=task_result ) + time_spent = time.time() - start_time + break except Exception as e: - if self.metrics_collector is not None: - self.metrics_collector.increment_task_execution_error( - task_definition_name, type(e) + if metrics_collector is not None: + metrics_collector.increment_task_update_error( + task_name, type(e) ) - task_result = TaskResult( - task_id=task.task_id, - workflow_instance_id=task.workflow_instance_id, - worker_id=self.worker.get_identity() - ) - task_result.status = 'FAILED' - task_result.reason_for_incompletion = str(e) - logger.info( - 'Failed to execute task, id: {task_id}, workflow_instance_id: {workflow_instance_id}, task_definition_name: {task_definition_name}, reason: {reason}'.format( - task_id=task.task_id, - workflow_instance_id=task.workflow_instance_id, - task_definition_name=task_definition_name, - reason=traceback.format_exc() + _logger.debug( + 'Failed to update task, id: {}, workflow_instance_id: {}, task_definition_name: {}, reason: {}'.format( + task_result.task_id, + task_result.workflow_instance_id, + task_name, + traceback.format_exc() ) ) - return task_result - - def __update_task(self, task_result: TaskResult): - if not isinstance(task_result, TaskResult): - return None - task_definition_name = self.worker.get_task_definition_name() - logger.debug( - 'Updating task, id: {task_id}, workflow_instance_id: {workflow_instance_id}, task_definition_name: {task_definition_name}'.format( - task_id=task_result.task_id, - workflow_instance_id=task_result.workflow_instance_id, - task_definition_name=task_definition_name + if metrics_collector is not None: + metrics_collector.record_task_result_payload_size( + task_name, + sys.getsizeof(task_result) + ) + metrics_collector.record_task_update_time( + task_name, + time_spent + ) + _logger.debug( + 'Updated task, id: {}, workflow_instance_id: {}, task_definition_name: {}, response: {}'.format( + task_result.task_id, + task_result.workflow_instance_id, + task_name, + response + ) + ) + return response + + +class TaskRunner: + def __init__( + self, + configuration: Configuration, + task_definition_name: str, + batch_size: int, + polling_interval: float, + worker_execution_function: Callable[[Task], TaskResult], + worker_id: str = None, + domain: str = None, + metrics_settings: MetricsSettings = None + ): + self.configuration = configuration + self._task_resource_api = TaskResourceApi( + ApiClient(configuration) + ) + + self._task_name = task_definition_name + + self._batch_size_mutex = threading.Lock() + self.batch_size = batch_size + + self._poll_interval_mutex = threading.Lock() + self.poll_interval = polling_interval + + self._worker_execution_function_mutex = threading.Lock() + self.worker_execution_function = worker_execution_function + + self._worker_id_mutex = threading.Lock() + self.worker_id = worker_id + + self._domain_mutex = threading.Lock() + self.domain = domain + + self._running_workers_mutex = threading.Lock() + self._running_workers = {} # {key=pid, value=process} + + self._paused_worker_mutex = threading.Lock() + self._paused_worker = False + + self.metrics_collector = None + if metrics_settings is not None: + self.metrics_collector = MetricsCollector( + metrics_settings + ) + + def __start_worker(self, task: Task) -> None: + worker_process = multiprocessing.Process( + target=_worker_process_daemon, + args=( + self._task_resource_api, + task, + self.worker_execution_function, + self.worker_id, + self.metrics_collector ) ) - try: - response = self.__get_task_resource_api().update_task( - body=task_result + with self._running_workers_mutex: + self._running_workers[worker_process.pid] = worker_process + worker_process.start() + _logger.debug( + 'Started worker for task {} with task_id {} - pid: {}'.format( + self._task_name, + task.task_id, + worker_process.pid ) - except Exception as e: - if self.metrics_collector is not None: - self.metrics_collector.increment_task_update_error( - task_definition_name, type(e) - ) - logger.info( - 'Failed to update task, id: {task_id}, workflow_instance_id: {workflow_instance_id}, task_definition_name: {task_definition_name}, reason: {reason}'.format( - task_id=task_result.task_id, - workflow_instance_id=task_result.workflow_instance_id, - task_definition_name=task_definition_name, - reason=traceback.format_exc() - ) + ) + worker_monitor_thread = threading.Thread( + target=self.__monitor_running_worker, + args=( + worker_process, + task.task_id, ) - return None - logger.debug( - 'Updated task, id: {task_id}, workflow_instance_id: {workflow_instance_id}, task_definition_name: {task_definition_name}, response: {response}'.format( - task_id=task_result.task_id, - workflow_instance_id=task_result.workflow_instance_id, - task_definition_name=task_definition_name, - response=response + ) + worker_monitor_thread.start() + + def __monitor_running_worker(self, worker_process: multiprocessing.Process, task_id: str) -> None: + worker_process.join() + with self._running_workers_mutex: + del self._running_workers[worker_process.pid] + _logger.debug( + 'Finished worker for task {} with task_id {} - pid: {}'.format( + self._task_name, + task_id, + worker_process.pid ) ) - return response - def __wait_for_polling_interval(self) -> None: - polling_interval = self.worker.get_polling_interval_in_seconds() - logger.debug(f'Sleep for {polling_interval} seconds') - time.sleep(polling_interval) + def run(self) -> None: + while True: + try: + self.run_once() + except Exception as e: + if self.metrics_collector is not None: + self.metrics_collector.increment_uncaught_exception() + _logger.debug( + f'Exception raised while running worker for task: {self._task_name}. Reason: {str(e)}' + ) - def __get_task_resource_api(self) -> TaskResourceApi: - return TaskResourceApi( - ApiClient( - configuration=self.configuration - ) + def run_once(self) -> None: + if self.is_worker_paused(): + time.sleep(_BATCH_POLL_ERROR_RETRY_INTERVAL) + return + available_workers = self.batch_size - self.running_workers + if available_workers < 1: + time.sleep(_BATCH_POLL_NO_AVAILABLE_WORKER_RETRY_INTERVAL) + return + tasks = _batch_poll( + task_resource_api=self._task_resource_api, + task_name=self._task_name, + batch_size=available_workers, + poll_interval=self.poll_interval, + worker_id=self.worker_id, + domain=self.domain, ) + for task in tasks: + self.__start_worker(task) + time.sleep(self.poll_interval) + + @property + def batch_size(self) -> int: + with self._batch_size_mutex: + return self._batch_size + + @batch_size.setter + def batch_size(self, batch_size: int) -> None: + with self._batch_size_mutex: + self._batch_size = deepcopy(batch_size) + + @property + def poll_interval(self) -> float: + with self._poll_interval_mutex: + return self._poll_interval + + @poll_interval.setter + def poll_interval(self, poll_interval: float) -> None: + with self._poll_interval_mutex: + self._poll_interval = deepcopy(poll_interval) + + @property + def worker_execution_function(self) -> Callable[[Task], TaskResult]: + with self._worker_id_mutex: + return self._worker_id + + @worker_execution_function.setter + def worker_execution_function(self, worker_execution_function: Callable[[Task], TaskResult]) -> None: + with self._worker_execution_function_mutex: + self._worker_execution_function = deepcopy( + worker_execution_function) + + @property + def worker_id(self) -> str: + with self._worker_id_mutex: + return self._worker_id + + @worker_id.setter + def worker_id(self, worker_id: str) -> None: + with self._worker_id_mutex: + self._worker_id = deepcopy(worker_id) + + @property + def domain(self) -> str: + with self._domain_mutex: + return self._domain + + @domain.setter + def domain(self, domain: str) -> None: + with self._domain_mutex: + self._domain = deepcopy(domain) + + @property + def running_workers(self) -> int: + with self._running_workers_mutex: + return len(self._running_workers) + + def resume_worker(self) -> None: + with self._paused_worker_mutex: + self._paused_worker = False + + def pause_worker(self) -> None: + with self._paused_worker_mutex: + self._paused_worker = True + + def is_worker_paused(self) -> bool: + with self._paused_worker_mutex: + return self._paused_worker diff --git a/src/conductor/client/configuration/settings/metrics_settings.py b/src/conductor/client/configuration/settings/metrics_settings.py index 64377fe4..57badca0 100644 --- a/src/conductor/client/configuration/settings/metrics_settings.py +++ b/src/conductor/client/configuration/settings/metrics_settings.py @@ -32,5 +32,6 @@ def __set_dir(self, dir: str) -> None: os.mkdir(dir) except Exception as e: logger.warning( - 'Failed to create metrics temporary folder, reason: ', e) + f'Failed to create metrics temporary folder, reason: {str(e)}' + ) self.directory = dir diff --git a/src/conductor/client/telemetry/metrics_collector.py b/src/conductor/client/telemetry/metrics_collector.py index 74348841..42130faa 100644 --- a/src/conductor/client/telemetry/metrics_collector.py +++ b/src/conductor/client/telemetry/metrics_collector.py @@ -194,6 +194,16 @@ def record_task_execute_time(self, task_type: str, time_spent: float) -> None: value=time_spent ) + def record_task_update_time(self, task_type: str, time_spent: float) -> None: + self.__record_gauge( + name=MetricName.TASK_UPDATE_TIME, + documentation=MetricDocumentation.TASK_UPDATE_TIME, + labels={ + MetricLabel.TASK_TYPE: task_type + }, + value=time_spent + ) + def __increment_counter( self, name: MetricName, diff --git a/src/conductor/client/telemetry/model/metric_documentation.py b/src/conductor/client/telemetry/model/metric_documentation.py index 9f63f5d5..937dcee0 100644 --- a/src/conductor/client/telemetry/model/metric_documentation.py +++ b/src/conductor/client/telemetry/model/metric_documentation.py @@ -14,6 +14,7 @@ class MetricDocumentation(str, Enum): TASK_POLL_TIME = "Time to poll for a batch of tasks" TASK_RESULT_SIZE = "Records output payload size of a task" TASK_UPDATE_ERROR = "Task status cannot be updated back to server" + TASK_UPDATE_TIME = "Time to update a task" THREAD_UNCAUGHT_EXCEPTION = "thread_uncaught_exceptions" - WORKFLOW_START_ERROR = "Counter for workflow start errors" WORKFLOW_INPUT_SIZE = "Records input payload size of a workflow" + WORKFLOW_START_ERROR = "Counter for workflow start errors" diff --git a/src/conductor/client/telemetry/model/metric_name.py b/src/conductor/client/telemetry/model/metric_name.py index 1301434b..a440fd74 100644 --- a/src/conductor/client/telemetry/model/metric_name.py +++ b/src/conductor/client/telemetry/model/metric_name.py @@ -14,6 +14,7 @@ class MetricName(str, Enum): TASK_POLL_TIME = "task_poll_time" TASK_RESULT_SIZE = "task_result_size" TASK_UPDATE_ERROR = "task_update_error" + TASK_UPDATE_TIME = 'task_update_time' THREAD_UNCAUGHT_EXCEPTION = "thread_uncaught_exceptions" WORKFLOW_INPUT_SIZE = "workflow_input_size" WORKFLOW_START_ERROR = "workflow_start_error" diff --git a/src/conductor/client/worker/worker.py b/src/conductor/client/worker/worker.py index 99dfa27b..21f1eeaf 100644 --- a/src/conductor/client/worker/worker.py +++ b/src/conductor/client/worker/worker.py @@ -1,21 +1,18 @@ -from copy import deepcopy from conductor.client.http.models.task import Task from conductor.client.http.models.task_result import TaskResult from conductor.client.http.models.task_result_status import TaskResultStatus from conductor.client.worker.worker_interface import WorkerInterface +from copy import deepcopy from typing import Any, Callable, Union from typing_extensions import Self import inspect -ExecuteTaskFunction = Callable[ - [ - Union[Task, object] - ], - Union[TaskResult, object] -] +WorkerInput = Union[Task, Any] +WorkerOutput = Union[TaskResult, Any] +WorkerExecutionFunction = Callable[[WorkerInput], WorkerOutput] -def is_callable_input_parameter_a_task(callable: ExecuteTaskFunction, object_type: Any) -> bool: +def is_callable_input_parameter_of_type(callable: WorkerExecutionFunction, object_type: Any) -> bool: parameters = inspect.signature(callable).parameters if len(parameters) != 1: return False @@ -23,43 +20,46 @@ def is_callable_input_parameter_a_task(callable: ExecuteTaskFunction, object_typ return parameter.annotation == object_type -def is_callable_return_value_of_type(callable: ExecuteTaskFunction, object_type: Any) -> bool: +def is_callable_return_value_of_type(callable: WorkerExecutionFunction, object_type: Any) -> bool: return_annotation = inspect.signature(callable).return_annotation return return_annotation == object_type class Worker(WorkerInterface): - def __init__(self, - task_definition_name: str, - execute_function: ExecuteTaskFunction, - poll_interval: float = None, - domain: str = None, - ) -> Self: + def __init__( + self, + task_definition_name: str, + worker_execution_function: WorkerExecutionFunction, + poll_interval: float = None, + domain: str = None, + ) -> Self: super().__init__(task_definition_name) if poll_interval == None: - poll_interval = super().get_polling_interval_in_seconds() - self.poll_interval = deepcopy(poll_interval) + self.poll_interval = super().get_polling_interval_in_seconds() + else: + self.poll_interval = deepcopy(poll_interval) if domain == None: - domain = super().get_domain() - self.domain = deepcopy(domain) - self.execute_function = deepcopy(execute_function) + self.domain = super().get_domain() + else: + self.domain = deepcopy(domain) + self.worker_execution_function = deepcopy(worker_execution_function) def execute(self, task: Task) -> TaskResult: - execute_function_input = None - if self._is_execute_function_input_parameter_a_task: - execute_function_input = task + worker_execution_function_input = None + if self._is_worker_execution_function_input_parameter_a_task: + worker_execution_function_input = task else: - execute_function_input = task.input_data - if self._is_execute_function_return_value_a_task_result: - execute_function_output = self.execute_function( - execute_function_input) - if type(execute_function_output) == TaskResult: - execute_function_output.task_id = task.task_id - execute_function_output.workflow_instance_id = task.workflow_instance_id - return execute_function_output + worker_execution_function_input = task.input_data + if self._is_worker_execution_function_return_value_a_task_result: + worker_execution_function_output = self.worker_execution_function( + worker_execution_function_input) + if type(worker_execution_function_output) == TaskResult: + worker_execution_function_output.task_id = task.task_id + worker_execution_function_output.workflow_instance_id = task.workflow_instance_id + return worker_execution_function_output task_result = self.get_task_result_from_task(task) task_result.status = TaskResultStatus.COMPLETED - task_result.output_data = self.execute_function(task) + task_result.output_data = self.worker_execution_function(task) return task_result def get_polling_interval_in_seconds(self) -> float: @@ -69,17 +69,17 @@ def get_domain(self) -> str: return self.domain @property - def execute_function(self) -> ExecuteTaskFunction: - return self._execute_function + def worker_execution_function(self) -> WorkerExecutionFunction: + return self._worker_execution_function - @execute_function.setter - def execute_function(self, execute_function: ExecuteTaskFunction) -> None: - self._execute_function = execute_function - self._is_execute_function_input_parameter_a_task = is_callable_input_parameter_a_task( - callable=execute_function, + @worker_execution_function.setter + def worker_execution_function(self, worker_execution_function: WorkerExecutionFunction) -> None: + self._worker_execution_function = worker_execution_function + self._is_worker_execution_function_input_parameter_a_task = is_callable_input_parameter_of_type( + callable=worker_execution_function, object_type=Task, ) - self._is_execute_function_return_value_a_task_result = is_callable_return_value_of_type( - callable=execute_function, + self._is_worker_execution_function_return_value_a_task_result = is_callable_return_value_of_type( + callable=worker_execution_function, object_type=TaskResult, ) diff --git a/src/conductor/client/worker/worker_interface.py b/src/conductor/client/worker/worker_interface.py index 3a58400f..4db75829 100644 --- a/src/conductor/client/worker/worker_interface.py +++ b/src/conductor/client/worker/worker_interface.py @@ -1,12 +1,21 @@ from conductor.client.http.models.task import Task from conductor.client.http.models.task_result import TaskResult +from copy import deepcopy +from typing_extensions import Self import abc import socket class WorkerInterface(abc.ABC): - def __init__(self, task_definition_name: str): + def __init__( + self, + task_definition_name: str, + batch_size: int = None, + polling_interval: float = None, + ) -> Self: self.task_definition_name = task_definition_name + self.batch_size = batch_size + self.polling_interval = polling_interval @abc.abstractmethod def execute(self, task: Task) -> TaskResult: @@ -34,7 +43,7 @@ def get_polling_interval_in_seconds(self) -> float: :return: float Default: 100ms """ - return 0.1 + return self.polling_interval def get_task_definition_name(self) -> str: """ @@ -64,3 +73,41 @@ def get_domain(self) -> str: :return: str """ return None + + @property + def batch_size(self) -> int: + return self._batch_size + + @batch_size.setter + def batch_size(self, batch_size: int = None) -> None: + if batch_size == None: + batch_size = 1 + if not isinstance(batch_size, int): + raise Exception('Batch size must be of integer type') + if batch_size < 1: + raise Exception('Batch size must be have a positive value') + self._batch_size = batch_size + + @property + def task_definition_name(self) -> str: + return self._task_definition_name + + @task_definition_name.setter + def task_definition_name(self, task_definition_name: str) -> None: + if not isinstance(task_definition_name, str): + raise Exception('Task definition name must be of string type') + if task_definition_name is None or task_definition_name == '': + raise Exception('Task definition name must not be empty') + self._task_definition_name = deepcopy(task_definition_name) + + @property + def polling_interval(self) -> float: + return self._polling_interval + + @polling_interval.setter + def polling_interval(self, polling_interval: float = None) -> None: + if polling_interval == None: + polling_interval = 0.1 + if not isinstance(polling_interval, (int, float)): + raise Exception('Polling interval must be a number') + self._polling_interval = deepcopy(polling_interval) diff --git a/src/conductor/client/workflow/conductor_workflow.py b/src/conductor/client/workflow/conductor_workflow.py index 11e4c8df..6ca0ba11 100644 --- a/src/conductor/client/workflow/conductor_workflow.py +++ b/src/conductor/client/workflow/conductor_workflow.py @@ -150,7 +150,7 @@ def input_parameters(self, input_parameters: List[str]) -> Self: # Register the workflow definition with the server. If overwrite is set, the definition on the server will be # overwritten. When not set, the call fails if there is any change in the workflow definition between the server # and what is being registered. - def register(self, overwrite: bool): + def register(self, overwrite: bool = None): return self._executor.register_workflow( overwrite=overwrite, workflow=self.to_workflow_def(), diff --git a/tests/integration/workflow/test_workflow_execution.py b/tests/integration/workflow/test_workflow_execution.py index d2fd0c2f..99bda2cb 100644 --- a/tests/integration/workflow/test_workflow_execution.py +++ b/tests/integration/workflow/test_workflow_execution.py @@ -3,8 +3,8 @@ from conductor.client.http.models import RerunWorkflowRequest from conductor.client.http.models import StartWorkflowRequest from conductor.client.http.models import TaskDef -from conductor.client.worker.worker import ExecuteTaskFunction from conductor.client.worker.worker import Worker +from conductor.client.worker.worker import WorkerExecutionFunction from conductor.client.workflow.conductor_workflow import ConductorWorkflow from conductor.client.workflow.executor.workflow_executor import WorkflowExecutor from conductor.client.workflow.task.simple_task import SimpleTask @@ -23,21 +23,8 @@ def run_workflow_execution_tests(configuration: Configuration, workflow_executor: WorkflowExecutor): - task_handler = TaskHandler( - workers=[ - ClassWorker(TASK_NAME), - ClassWorkerWithDomain(TASK_NAME), - generate_worker(worker_with_generic_input_and_generic_output), - generate_worker(worker_with_generic_input_and_task_result_output), - generate_worker(worker_with_task_input_and_generic_output), - generate_worker(worker_with_task_input_and_task_result_output), - ], - configuration=configuration - ) - task_handler.start_processes() - test_get_workflow_by_correlation_ids(workflow_executor) - test_workflow_registration(workflow_executor) test_workflow_execution( + configuration=configuration, workflow_quantity=10, workflow_name=WORKFLOW_NAME, workflow_executor=workflow_executor, @@ -47,7 +34,6 @@ def run_workflow_execution_tests(configuration: Configuration, workflow_executor workflow_executor, workflow_quantity=10, ) - task_handler.stop_processes() def generate_tasks_defs(): @@ -65,35 +51,29 @@ def generate_tasks_defs(): return [python_simple_task_from_code] -def test_get_workflow_by_correlation_ids(workflow_executor: WorkflowExecutor): - ids = workflow_executor.get_by_correlation_ids( - workflow_name=WORKFLOW_NAME, - correlation_ids=[ - '2', '5', '33', '4', '32', '7', '34', '1', '3', '6', '1440' - ] - ) - assert ids != None - - def test_workflow_methods( workflow_executor: WorkflowExecutor, workflow_quantity: int, ) -> None: - task = SimpleTask( - 'python_integration_test_abc1asjdkajskdjsad', - 'python_integration_test_abc1asjdkajskdjsad' - ) - workflow_executor.metadata_client.register_task_def( - [task.to_workflow_task()] + _ = workflow_executor.get_by_correlation_ids( + workflow_name=WORKFLOW_NAME, + correlation_ids=[ + '2', '5', '33', '4', '32', '7', '34', '1', '3', '6', '1440' + ] ) - workflow_name = 'python_integration_test_abc1asjdk' workflow = ConductorWorkflow( executor=workflow_executor, - name=workflow_name, + name='python_integration_test_abc1asjdk', description='Python workflow example from code', version=1234, ).add( - task + SimpleTask( + 'python_integration_test_abc1asjdkajskdjsad', + 'python_integration_test_abc1asjdkajskdjsad' + ) + ) + workflow_executor.metadata_client.register_task_def( + workflow.to_workflow_def().tasks ) workflow_executor.register_workflow( workflow.to_workflow_def(), @@ -101,7 +81,9 @@ def test_workflow_methods( ) start_workflow_requests = [''] * workflow_quantity for i in range(workflow_quantity): - start_workflow_requests[i] = StartWorkflowRequest(name=workflow_name) + start_workflow_requests[i] = StartWorkflowRequest( + name=workflow.name + ) workflow_ids = workflow_executor.start_workflows( *start_workflow_requests ) @@ -125,36 +107,38 @@ def test_workflow_methods( ) -def test_workflow_registration(workflow_executor: WorkflowExecutor): - workflow = generate_workflow(workflow_executor) - try: - workflow_executor.metadata_client.unregister_workflow_def_with_http_info( - workflow.name, workflow.version - ) - except: - pass - assert workflow.register(overwrite=True) == None - assert workflow_executor.register_workflow( - workflow.to_workflow_def(), overwrite=True - ) == None - - def test_workflow_execution( + configuration: Configuration, workflow_quantity: int, workflow_name: str, workflow_executor: WorkflowExecutor, workflow_completion_timeout: float, ) -> None: + task_handler = TaskHandler( + workers=[ + ClassWorker(TASK_NAME), + # ClassWorkerWithDomain(TASK_NAME), + # _generate_worker(worker_with_generic_input_and_generic_output), + # _generate_worker(worker_with_generic_input_and_task_result_output), + # _generate_worker(worker_with_task_input_and_generic_output), + # _generate_worker(worker_with_task_input_and_task_result_output), + ], + configuration=configuration + ) + task_handler.start_processes() + workflow = _generate_workflow(workflow_executor) + _register_workflow(workflow, workflow_executor) start_workflow_requests = [''] * workflow_quantity for i in range(workflow_quantity): start_workflow_requests[i] = StartWorkflowRequest(name=workflow_name) workflow_ids = workflow_executor.start_workflows(*start_workflow_requests) sleep(workflow_completion_timeout) for workflow_id in workflow_ids: - validate_workflow_status(workflow_id, workflow_executor) + _validate_workflow_status(workflow_id, workflow_executor) + task_handler.stop_processes() -def generate_workflow(workflow_executor: WorkflowExecutor) -> ConductorWorkflow: +def _generate_workflow(workflow_executor: WorkflowExecutor) -> ConductorWorkflow: return ConductorWorkflow( executor=workflow_executor, name=WORKFLOW_NAME, @@ -169,7 +153,29 @@ def generate_workflow(workflow_executor: WorkflowExecutor) -> ConductorWorkflow: ) -def validate_workflow_status(workflow_id: str, workflow_executor: WorkflowExecutor) -> None: +def _register_workflow(workflow: ConductorWorkflow, workflow_executor: WorkflowExecutor) -> None: + for task in workflow.to_workflow_def().tasks: + try: + workflow_executor.metadata_client.unregister_task_def( + task.task_reference_name + ) + except Exception as e: + if '404' not in str(e): + raise Exception() + workflow_executor.metadata_client.register_task_def( + workflow.to_workflow_def().tasks + ) + workflow_executor.metadata_client.unregister_workflow_def_with_http_info( + workflow.name, workflow.version + ) + workflow.register() + workflow_executor.register_workflow( + workflow=workflow.to_workflow_def(), + overwrite=True + ) + + +def _validate_workflow_status(workflow_id: str, workflow_executor: WorkflowExecutor) -> None: workflow = workflow_executor.get_workflow( workflow_id=workflow_id, include_tasks=False, @@ -183,7 +189,7 @@ def validate_workflow_status(workflow_id: str, workflow_executor: WorkflowExecut assert workflow_status.status == 'COMPLETED' -def generate_worker(execute_function: ExecuteTaskFunction) -> Worker: +def _generate_worker(execute_function: WorkerExecutionFunction) -> Worker: return Worker( task_definition_name=TASK_NAME, execute_function=execute_function, diff --git a/tests/unit/test_configuration.py b/tests/unit/test_configuration.py index 41694aa2..78431b41 100644 --- a/tests/unit/test_configuration.py +++ b/tests/unit/test_configuration.py @@ -1,48 +1,38 @@ from conductor.client.configuration.configuration import Configuration from conductor.client.http.api_client import ApiClient import base64 -import unittest - - -class TestConfiguration(unittest.TestCase): - def test_initialization_default(self): - configuration = Configuration() - self.assertEqual( - configuration.host, - 'http://localhost:8080/api/' - ) - - def test_initialization_with_base_url(self): - configuration = Configuration( - base_url='https://play.orkes.io' - ) - self.assertEqual( - configuration.host, - 'https://play.orkes.io/api/' - ) - - def test_initialization_with_server_api_url(self): - configuration = Configuration( - server_api_url='https://play.orkes.io/api/' - ) - self.assertEqual( - configuration.host, - 'https://play.orkes.io/api/' - ) - - def test_initialization_with_basic_auth_server_api_url(self): - configuration = Configuration( - server_api_url="https://user:password@play.orkes.io/api/" - ) - basic_auth = "user:password" - expected_host = f"https://{basic_auth}@play.orkes.io/api/" - self.assertEqual( - configuration.host, expected_host, - ) - token = "Basic " + \ - base64.b64encode(bytes(basic_auth, "utf-8")).decode("utf-8") - api_client = ApiClient(configuration) - self.assertEqual( - api_client.default_headers, - {"Accept-Encoding": "gzip", "authorization": token}, - ) + + +def test_initialization_default(): + configuration = Configuration() + assert configuration.host == 'http://localhost:8080/api/' + + +def test_initialization_with_base_url(): + configuration = Configuration( + base_url='https://play.orkes.io' + ) + assert configuration.host == 'https://play.orkes.io/api/' + + +def test_initialization_with_server_api_url(): + configuration = Configuration( + server_api_url='https://play.orkes.io/api/' + ) + assert configuration.host == 'https://play.orkes.io/api/' + + +def test_initialization_with_basic_auth_server_api_url(): + configuration = Configuration( + server_api_url="https://user:password@play.orkes.io/api/" + ) + basic_auth = "user:password" + expected_host = f"https://{basic_auth}@play.orkes.io/api/" + assert configuration.host == expected_host + token = "Basic " + \ + base64.b64encode(bytes(basic_auth, "utf-8")).decode("utf-8") + api_client = ApiClient(configuration) + assert api_client.default_headers == { + "Accept-Encoding": "gzip", + "authorization": token + } diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py index aa9a52dd..8cd144c5 100644 --- a/tests/unit/test_metrics.py +++ b/tests/unit/test_metrics.py @@ -1,34 +1,20 @@ from conductor.client.configuration.settings.metrics_settings import MetricsSettings -import logging -import unittest -class TestMetricsCollection(unittest.TestCase): - def setUp(self): - logging.disable(logging.CRITICAL) +def test_default_initialization(): + metrics_settings = MetricsSettings() + assert metrics_settings.file_name == 'metrics.log' + assert metrics_settings.update_interval == 0.1 - def tearDown(self): - logging.disable(logging.NOTSET) - def test_default_initialization(self): - metrics_settings = MetricsSettings() - self.assertEqual(metrics_settings.file_name, 'metrics.log') - self.assertEqual(metrics_settings.update_interval, 0.1) - - def test_default_initialization_with_parameters(self): - expected_directory = '/a/b' - expected_file_name = 'some_name.txt' - expected_update_interval = 0.5 - metrics_settings = MetricsSettings( - directory=expected_directory, - file_name=expected_file_name, - update_interval=expected_update_interval, - ) - self.assertEqual( - metrics_settings.file_name, - expected_file_name - ) - self.assertEqual( - metrics_settings.update_interval, - expected_update_interval - ) +def test_default_initialization_with_parameters(): + expected_directory = '/a/b' + expected_file_name = 'some_name.txt' + expected_update_interval = 0.5 + metrics_settings = MetricsSettings( + directory=expected_directory, + file_name=expected_file_name, + update_interval=expected_update_interval, + ) + assert metrics_settings.file_name == expected_file_name + assert metrics_settings.update_interval == expected_update_interval diff --git a/tests/unit/test_task_handler.py b/tests/unit/test_task_handler.py index 1e9ce2d8..9d63d4a6 100644 --- a/tests/unit/test_task_handler.py +++ b/tests/unit/test_task_handler.py @@ -2,42 +2,12 @@ from conductor.client.automator.task_runner import TaskRunner from conductor.client.configuration.configuration import Configuration from tests.unit.resources.workers import ClassWorker -from unittest.mock import Mock -from unittest.mock import patch -import multiprocessing -import unittest +import pytest -class PickableMock(Mock): - def __reduce__(self): - return (Mock, ()) - - -class TestTaskHandler(unittest.TestCase): - def test_initialization_with_invalid_workers(self): - expected_exception = Exception('Invalid worker list') - with self.assertRaises(Exception) as context: - TaskHandler( - configuration=Configuration(), - workers=ClassWorker() - ) - self.assertEqual(expected_exception, context.exception) - - def test_start_processes(self): - with patch.object(TaskRunner, 'run', PickableMock(return_value=None)): - with _get_valid_task_handler() as task_handler: - task_handler.start_processes() - self.assertEqual(len(task_handler.task_runner_processes), 1) - for process in task_handler.task_runner_processes: - self.assertTrue( - isinstance(process, multiprocessing.Process) - ) - - -def _get_valid_task_handler(): - return TaskHandler( - configuration=Configuration(), - workers=[ - ClassWorker('task') - ] - ) +def test_initialization_with_invalid_workers(): + with pytest.raises(TypeError): + TaskHandler( + configuration=Configuration(), + workers=ClassWorker() + ) diff --git a/tests/unit/test_task_runner.py b/tests/unit/test_task_runner.py index 2abfb327..1ee573c4 100644 --- a/tests/unit/test_task_runner.py +++ b/tests/unit/test_task_runner.py @@ -1,197 +1,28 @@ from conductor.client.automator.task_runner import TaskRunner from conductor.client.configuration.configuration import Configuration -from conductor.client.http.api.task_resource_api import TaskResourceApi -from conductor.client.http.models.task import Task -from conductor.client.http.models.task_result import TaskResult -from conductor.client.http.models.task_result_status import TaskResultStatus from tests.unit.resources.workers import ClassWorker -from tests.unit.resources.workers import FaultyExecutionWorker -from unittest.mock import patch -import logging -import time -import unittest +import pytest -class TestTaskRunner(unittest.TestCase): - TASK_ID = 'VALID_TASK_ID' - WORKFLOW_INSTANCE_ID = 'VALID_WORKFLOW_INSTANCE_ID' - UPDATE_TASK_RESPONSE = 'VALID_UPDATE_TASK_RESPONSE' +TASK_ID = 'VALID_TASK_ID' +WORKFLOW_INSTANCE_ID = 'VALID_WORKFLOW_INSTANCE_ID' +UPDATE_TASK_RESPONSE = 'VALID_UPDATE_TASK_RESPONSE' - def setUp(self): - logging.disable(logging.CRITICAL) - def tearDown(self): - logging.disable(logging.NOTSET) +def test_initialization_without_configuration(): + TaskRunner( + configuration=None, + worker=_get_valid_worker() + ) - def test_initialization_with_invalid_configuration(self): - expected_exception = Exception('Invalid configuration') - with self.assertRaises(Exception) as context: - TaskRunner( - configuration=None, - worker=self.__get_valid_worker() - ) - self.assertEqual(expected_exception, context.exception) - def test_initialization_with_invalid_worker(self): - expected_exception = Exception('Invalid worker') - with self.assertRaises(Exception) as context: - TaskRunner( - configuration=Configuration("http://localhost:8080/api"), - worker=None - ) - self.assertEqual(expected_exception, context.exception) - - def test_run_once(self): - expected_time = self.__get_valid_worker().get_polling_interval_in_seconds() - with patch.object( - TaskResourceApi, - 'poll', - return_value=self.__get_valid_task() - ): - with patch.object( - TaskResourceApi, - 'update_task', - return_value=self.UPDATE_TASK_RESPONSE - ): - task_runner = self.__get_valid_task_runner() - start_time = time.time() - task_runner.run_once() - finish_time = time.time() - spent_time = finish_time - start_time - self.assertGreater(spent_time, expected_time) - - def test_poll_task(self): - expected_task = self.__get_valid_task() - with patch.object( - TaskResourceApi, - 'poll', - return_value=self.__get_valid_task() - ): - task_runner = self.__get_valid_task_runner() - task = task_runner._TaskRunner__poll_task() - self.assertEqual(task, expected_task) - - def test_poll_task_with_faulty_task_api(self): - expected_task = None - with patch.object( - TaskResourceApi, - 'poll', - side_effect=Exception() - ): - task_runner = self.__get_valid_task_runner() - task = task_runner._TaskRunner__poll_task() - self.assertEqual(task, expected_task) - - def test_execute_task_with_invalid_task(self): - task_runner = self.__get_valid_task_runner() - task_result = task_runner._TaskRunner__execute_task(None) - self.assertEqual(task_result, None) - - def test_execute_task_with_faulty_execution_worker(self): - worker = FaultyExecutionWorker('task') - expected_task_result = TaskResult( - task_id=self.TASK_ID, - workflow_instance_id=self.WORKFLOW_INSTANCE_ID, - worker_id=worker.get_identity(), - status=TaskResultStatus.FAILED, - reason_for_incompletion='faulty execution' +def test_initialization_with_invalid_worker(): + with pytest.raises(Exception, match='Invalid worker'): + TaskRunner( + configuration=Configuration("http://localhost:8080/api"), + worker=None ) - task_runner = TaskRunner( - configuration=Configuration(), - worker=worker - ) - task = self.__get_valid_task() - task_result = task_runner._TaskRunner__execute_task(task) - self.assertEqual(task_result, expected_task_result) - - def test_execute_task(self): - expected_task_result = self.__get_valid_task_result() - worker = self.__get_valid_worker() - task_runner = TaskRunner( - configuration=Configuration(), - worker=worker - ) - task = self.__get_valid_task() - task_result = task_runner._TaskRunner__execute_task(task) - self.assertEqual(task_result, expected_task_result) - - def test_update_task_with_invalid_task_result(self): - expected_response = None - task_runner = self.__get_valid_task_runner() - response = task_runner._TaskRunner__update_task(None) - self.assertEqual(response, expected_response) - - def test_update_task_with_faulty_task_api(self): - expected_response = None - with patch.object( - TaskResourceApi, - 'update_task', - side_effect=Exception() - ): - task_runner = self.__get_valid_task_runner() - task_result = self.__get_valid_task_result() - response = task_runner._TaskRunner__update_task(task_result) - self.assertEqual(response, expected_response) - def test_update_task(self): - expected_response = self.UPDATE_TASK_RESPONSE - with patch.object( - TaskResourceApi, - 'update_task', - return_value=self.UPDATE_TASK_RESPONSE - ): - task_runner = self.__get_valid_task_runner() - task_result = self.__get_valid_task_result() - response = task_runner._TaskRunner__update_task(task_result) - self.assertEqual(response, expected_response) - - def test_wait_for_polling_interval_with_faulty_worker(self): - expected_exception = Exception( - "Failed to get polling interval" - ) - with patch.object( - ClassWorker, - 'get_polling_interval_in_seconds', - side_effect=expected_exception - ): - task_runner = self.__get_valid_task_runner() - with self.assertRaises(Exception) as context: - task_runner._TaskRunner__wait_for_polling_interval() - self.assertEqual(expected_exception, context.exception) - - def test_wait_for_polling_interval(self): - expected_time = self.__get_valid_worker().get_polling_interval_in_seconds() - task_runner = self.__get_valid_task_runner() - start_time = time.time() - task_runner._TaskRunner__wait_for_polling_interval() - finish_time = time.time() - spent_time = finish_time - start_time - self.assertGreater(spent_time, expected_time) - - def __get_valid_task_runner(self): - return TaskRunner( - configuration=Configuration(), - worker=self.__get_valid_worker() - ) - - def __get_valid_task(self): - return Task( - task_id=self.TASK_ID, - workflow_instance_id=self.WORKFLOW_INSTANCE_ID - ) - - def __get_valid_task_result(self): - return TaskResult( - task_id=self.TASK_ID, - workflow_instance_id=self.WORKFLOW_INSTANCE_ID, - worker_id=self.__get_valid_worker().get_identity(), - status=TaskResultStatus.COMPLETED, - output_data={ - 'worker_style': 'class', - 'secret_number': 1234, - 'is_it_true': False, - } - ) - def __get_valid_worker(self): - return ClassWorker('task') +def _get_valid_worker(): + return ClassWorker('task')