diff --git a/src/broker/operandi_broker/__init__.py b/src/broker/operandi_broker/__init__.py index b1568ffe..051ade4d 100644 --- a/src/broker/operandi_broker/__init__.py +++ b/src/broker/operandi_broker/__init__.py @@ -1,11 +1,13 @@ __all__ = [ - "cli", - "ServiceBroker", - "JobStatusWorker", - "Worker" + "cli", + "JobWorkerDownload", + "JobWorkerStatus", + "JobWorkerSubmit", + "ServiceBroker", ] from .cli import cli from .broker import ServiceBroker -from .job_status_worker import JobStatusWorker -from .worker import Worker +from .job_worker_download import JobWorkerDownload +from .job_worker_status import JobWorkerStatus +from .job_worker_submit import JobWorkerSubmit diff --git a/src/broker/operandi_broker/broker.py b/src/broker/operandi_broker/broker.py index 9e49eae6..7c1f2825 100644 --- a/src/broker/operandi_broker/broker.py +++ b/src/broker/operandi_broker/broker.py @@ -1,16 +1,14 @@ from logging import getLogger -from os import environ, fork -import psutil -import signal +from os import environ from time import sleep from operandi_utils import ( get_log_file_path_prefix, reconfigure_all_loggers, verify_database_uri, verify_and_parse_mq_uri) from operandi_utils.constants import LOG_LEVEL_BROKER from operandi_utils.rabbitmq.constants import ( - RABBITMQ_QUEUE_HARVESTER, RABBITMQ_QUEUE_USERS, RABBITMQ_QUEUE_JOB_STATUSES) -from .worker import Worker -from .job_status_worker import JobStatusWorker + RABBITMQ_QUEUE_HPC_DOWNLOADS, RABBITMQ_QUEUE_HARVESTER, RABBITMQ_QUEUE_USERS, RABBITMQ_QUEUE_JOB_STATUSES) + +from .broker_utils import create_child_process, kill_workers class ServiceBroker: @@ -48,14 +46,15 @@ def run_broker(self): # A list of queues for which a worker process should be created queues = [RABBITMQ_QUEUE_HARVESTER, RABBITMQ_QUEUE_USERS] status_queue = RABBITMQ_QUEUE_JOB_STATUSES + hpc_download_queue = RABBITMQ_QUEUE_HPC_DOWNLOADS try: for queue_name in queues: self.log.info(f"Creating a worker process to consume from queue: {queue_name}") - self.create_worker_process( - queue_name=queue_name, status_checker=False, tunnel_port_executor=22, tunnel_port_transfer=22) + self.create_worker_process(queue_name, "submit_worker") self.log.info(f"Creating a status worker process to consume from queue: {status_queue}") - self.create_worker_process( - queue_name=status_queue, status_checker=True, tunnel_port_executor=22, tunnel_port_transfer=22) + self.create_worker_process(status_queue, "status_worker") + self.log.info(f"Creating a download worker process to consume from queue: {hpc_download_queue}") + self.create_worker_process(hpc_download_queue, "download_worker") except Exception as error: self.log.error(f"Error while creating worker processes: {error}") @@ -72,7 +71,7 @@ def run_broker(self): except KeyboardInterrupt: self.log.info(f"SIGINT signal received. Sending SIGINT to worker processes.") # Sends SIGINT to workers - self.kill_workers() + kill_workers(self.log, self.queues_and_workers) self.log.info(f"Closing gracefully in 3 seconds!") exit(0) except Exception as error: @@ -80,77 +79,19 @@ def run_broker(self): self.log.error(f"Unexpected error: {error}") # Creates a separate worker process and append its pid if successful - def create_worker_process( - self, queue_name, tunnel_port_executor: int = 22, tunnel_port_transfer: int = 22, status_checker=False - ) -> None: + def create_worker_process(self, queue_name, worker_type: str) -> None: # If the entry for queue_name does not exist, create id if queue_name not in self.queues_and_workers: self.log.info(f"Initializing workers list for queue: {queue_name}") # Initialize the worker pids list for the queue self.queues_and_workers[queue_name] = [] - child_pid = self.__create_child_process( - queue_name=queue_name, status_checker=status_checker, tunnel_port_executor=tunnel_port_executor, - tunnel_port_transfer=tunnel_port_transfer) + child_pid = create_child_process( + self.log, self.db_url, self.rabbitmq_url, queue_name, worker_type, self.test_sbatch) # If creation of the child process was successful if child_pid: self.log.info(f"Assigning a new worker process with pid: {child_pid}, to queue: {queue_name}") # append the pid to the workers list of the queue_name (self.queues_and_workers[queue_name]).append(child_pid) - # Forks a child process - def __create_child_process( - self, queue_name, tunnel_port_executor: int = 22, tunnel_port_transfer: int = 22, status_checker=False - ) -> int: - self.log.info(f"Trying to create a new worker process for queue: {queue_name}") - try: - # TODO: Try to utilize Popen() instead of fork() - created_pid = fork() - except Exception as os_error: - self.log.error(f"Failed to create a child process, reason: {os_error}") - return 0 - if created_pid != 0: - return created_pid - try: - # Clean unnecessary data - # self.queues_and_workers = None - if status_checker: - child_worker = JobStatusWorker( - db_url=self.db_url, rabbitmq_url=self.rabbitmq_url, queue_name=queue_name, - tunnel_port_executor=tunnel_port_executor, tunnel_port_transfer=tunnel_port_transfer, - test_sbatch=self.test_sbatch) - else: - child_worker = Worker( - db_url=self.db_url, rabbitmq_url=self.rabbitmq_url, queue_name=queue_name, - tunnel_port_executor=tunnel_port_executor, tunnel_port_transfer=tunnel_port_transfer, - test_sbatch=self.test_sbatch) - child_worker.run() - exit(0) - except Exception as e: - self.log.error(f"Worker process failed for queue: {queue_name}, reason: {e}") - exit(-1) - - def _send_signal_to_worker(self, worker_pid: int, signal_type: signal): - try: - process = psutil.Process(pid=worker_pid) - process.send_signal(signal_type) - except psutil.ZombieProcess as error: - self.log.info(f"Worker process has become a zombie: {worker_pid}, {error}") - except psutil.NoSuchProcess as error: - self.log.error(f"No such worker process with pid: {worker_pid}, {error}") - except psutil.AccessDenied as error: - self.log.error(f"Access denied to the worker process with pid: {worker_pid}, {error}") - def kill_workers(self): - interrupted_pids = [] - self.log.info(f"Starting to send SIGINT to all workers") - # Send SIGINT to all workers - for queue_name in self.queues_and_workers: - self.log.info(f"Sending SIGINT to workers of queue: {queue_name}") - for worker_pid in self.queues_and_workers[queue_name]: - self._send_signal_to_worker(worker_pid=worker_pid, signal_type=signal.SIGINT) - interrupted_pids.append(worker_pid) - sleep(3) - self.log.info(f"Sending SIGKILL (if needed) to previously interrupted workers") - # Check whether workers exited properly - for pid in interrupted_pids: - self._send_signal_to_worker(worker_pid=pid, signal_type=signal.SIGKILL) + kill_workers(self.log, self.queues_and_workers) diff --git a/src/broker/operandi_broker/broker_utils.py b/src/broker/operandi_broker/broker_utils.py new file mode 100644 index 00000000..35af3f43 --- /dev/null +++ b/src/broker/operandi_broker/broker_utils.py @@ -0,0 +1,67 @@ +from logging import Logger +from os import fork +import psutil +import signal +from time import sleep +from typing import Dict + +from .job_worker_download import JobWorkerDownload +from .job_worker_status import JobWorkerStatus +from .job_worker_submit import JobWorkerSubmit + + +# Forks a child process +def create_child_process( + logger: Logger, db_url: str, rabbitmq_url: str, queue_name: str, worker_type: str, test_batch: bool +) -> int: + logger.info(f"Trying to create a new worker process for queue: {queue_name}") + try: + created_pid = fork() + except Exception as os_error: + logger.error(f"Failed to create a child process, reason: {os_error}") + return 0 + + if created_pid != 0: + return created_pid + try: + if worker_type == "status_worker": + child_worker = JobWorkerStatus(db_url, rabbitmq_url, queue_name) + child_worker.run(hpc_executor=True, hpc_io_transfer=True, publisher=True) + elif worker_type == "download_worker": + child_worker = JobWorkerDownload(db_url, rabbitmq_url, queue_name) + child_worker.run(hpc_executor=True, hpc_io_transfer=True, publisher=False) + else: # worker_type == "submit_worker" + child_worker = JobWorkerSubmit(db_url, rabbitmq_url, queue_name, test_batch) + child_worker.run(hpc_executor=True, hpc_io_transfer=True, publisher=False) + exit(0) + except Exception as e: + logger.error(f"Worker process failed for queue: {queue_name}, reason: {e}") + exit(-1) + + +def send_signal_to_worker(logger: Logger, worker_pid: int, signal_type: signal): + try: + process = psutil.Process(pid=worker_pid) + process.send_signal(signal_type) + except psutil.ZombieProcess as error: + logger.info(f"Worker process has become a zombie: {worker_pid}, {error}") + except psutil.NoSuchProcess as error: + logger.error(f"No such worker process with pid: {worker_pid}, {error}") + except psutil.AccessDenied as error: + logger.error(f"Access denied to the worker process with pid: {worker_pid}, {error}") + + +def kill_workers(logger: Logger, queues_and_workers: Dict): + interrupted_pids = [] + logger.info(f"Starting to send SIGINT to all workers") + # Send SIGINT to all workers + for queue_name in queues_and_workers: + logger.info(f"Sending SIGINT to workers of queue: {queue_name}") + for worker_pid in queues_and_workers[queue_name]: + send_signal_to_worker(logger, worker_pid=worker_pid, signal_type=signal.SIGINT) + interrupted_pids.append(worker_pid) + sleep(3) + logger.info(f"Sending SIGKILL (if needed) to previously interrupted workers") + # Check whether workers exited properly + for pid in interrupted_pids: + send_signal_to_worker(logger, worker_pid=pid, signal_type=signal.SIGKILL) diff --git a/src/broker/operandi_broker/job_status_worker.py b/src/broker/operandi_broker/job_status_worker.py deleted file mode 100644 index d64081b6..00000000 --- a/src/broker/operandi_broker/job_status_worker.py +++ /dev/null @@ -1,224 +0,0 @@ -from json import loads -from logging import getLogger -import signal -from os import getpid, getppid, setsid -from pathlib import Path -from sys import exit - -from ocrd import Resolver -from operandi_utils import reconfigure_all_loggers, get_log_file_path_prefix -from operandi_utils.constants import LOG_LEVEL_WORKER, StateJob, StateWorkspace -from operandi_utils.database import ( - DBHPCSlurmJob, DBWorkflowJob, DBWorkspace, - sync_db_increase_processing_stats, sync_db_initiate_database, sync_db_get_hpc_slurm_job, sync_db_get_workflow_job, - sync_db_get_workspace, sync_db_update_hpc_slurm_job, sync_db_update_workflow_job, sync_db_update_workspace) -from operandi_utils.hpc import NHRExecutor, NHRTransfer -from operandi_utils.rabbitmq import get_connection_consumer - - -class JobStatusWorker: - def __init__(self, db_url, rabbitmq_url, queue_name, tunnel_port_executor, tunnel_port_transfer, test_sbatch=False): - self.log = getLogger(f"operandi_broker.job_status_worker[{getpid()}].{queue_name}") - self.queue_name = queue_name - self.log_file_path = f"{get_log_file_path_prefix(module_type='worker')}_{queue_name}.log" - self.test_sbatch = test_sbatch - - self.db_url = db_url - self.rmq_url = rabbitmq_url - self.rmq_consumer = None - self.hpc_executor = None - self.hpc_io_transfer = None - - # Currently consumed message related parameters - self.current_message_delivery_tag = None - self.current_message_job_id = None - self.has_consumed_message = False - - self.tunnel_port_executor = tunnel_port_executor - self.tunnel_port_transfer = tunnel_port_transfer - - def __del__(self): - if self.rmq_consumer: - self.rmq_consumer.disconnect() - - def run(self): - try: - # Source: https://unix.stackexchange.com/questions/18166/what-are-session-leaders-in-ps - # Make the current process session leader - setsid() - # Reconfigure all loggers to the same format - reconfigure_all_loggers(log_level=LOG_LEVEL_WORKER, log_file_path=self.log_file_path) - self.log.info(f"Activating signal handler for SIGINT, SIGTERM") - signal.signal(signal.SIGINT, self.signal_handler) - signal.signal(signal.SIGTERM, self.signal_handler) - - sync_db_initiate_database(self.db_url) - self.hpc_executor = NHRExecutor() - self.log.info("HPC executor connection successful.") - self.hpc_io_transfer = NHRTransfer() - self.log.info("HPC transfer connection successful.") - - self.rmq_consumer = get_connection_consumer(rabbitmq_url=self.rmq_url) - self.log.info(f"RMQConsumer connected") - self.rmq_consumer.configure_consuming(queue_name=self.queue_name, callback_method=self.__callback) - self.log.info(f"Configured consuming from queue: {self.queue_name}") - self.log.info(f"Starting consuming from queue: {self.queue_name}") - self.rmq_consumer.start_consuming() - except Exception as e: - self.log.error(f"The worker failed, reason: {e}") - raise Exception(f"The worker failed, reason: {e}") - - def __download_results_from_hpc(self, job_dir: str, workspace_dir: str, slurm_job_id: str) -> None: - self.hpc_io_transfer.get_and_unpack_slurm_workspace( - ocrd_workspace_dir=Path(workspace_dir), workflow_job_dir=Path(job_dir), slurm_job_id=slurm_job_id) - self.log.info(f"Transferred slurm workspace from hpc path") - # Delete the result dir from the HPC home folder - # self.hpc_executor.execute_blocking(f"bash -lc 'rm -rf {hpc_slurm_workspace_path}/{workflow_job_id}'") - - def __handle_hpc_and_workflow_states( - self, hpc_slurm_job_db: DBHPCSlurmJob, workflow_job_db: DBWorkflowJob, workspace_db: DBWorkspace - ): - old_slurm_job_state = hpc_slurm_job_db.hpc_slurm_job_state - new_slurm_job_state = self.hpc_executor.check_slurm_job_state(slurm_job_id=hpc_slurm_job_db.hpc_slurm_job_id) - # TODO: Reconsider this - # if not new_slurm_job_state: - # return - - user_id = workspace_db.user_id - job_id = workflow_job_db.job_id - job_dir = workflow_job_db.job_dir - old_job_state = workflow_job_db.job_state - - workspace_id = workspace_db.workspace_id - workspace_dir = workspace_db.workspace_dir - - # If there has been a change of slurm job state, update it - if old_slurm_job_state != new_slurm_job_state: - self.log.info( - f"Slurm job: {hpc_slurm_job_db.hpc_slurm_job_id}, " - f"old state: {old_slurm_job_state}, " - f"new state: {new_slurm_job_state}") - sync_db_update_hpc_slurm_job(find_workflow_job_id=job_id, hpc_slurm_job_state=new_slurm_job_state) - - # Convert the slurm job state to operandi workflow job state - new_job_state = StateJob.convert_from_slurm_job(slurm_job_state=new_slurm_job_state) - - # TODO: Refactor this block of code since nothing is downloaded from the HPC when job fails. - # If there has been a change of operandi workflow state, update it - if old_job_state != new_job_state: - self.log.info(f"Workflow job id: {job_id}, old state: {old_job_state}, new state: {new_job_state}") - sync_db_update_workflow_job(find_job_id=job_id, job_state=new_job_state) - # TODO: Simplify SUCCESS and FAILED duplications - if new_job_state == StateJob.SUCCESS: - sync_db_update_workspace(find_workspace_id=workspace_id, state=StateWorkspace.TRANSFERRING_FROM_HPC) - sync_db_update_workflow_job(find_job_id=job_id, job_state=StateJob.TRANSFERRING_FROM_HPC) - self.__download_results_from_hpc( - job_dir=job_dir, workspace_dir=workspace_dir, slurm_job_id=hpc_slurm_job_db.hpc_slurm_job_id) - - # TODO: Find a better way to do the update - consider callbacks to Operandi Server - try: - workspace = Resolver().workspace_from_url( - mets_url=workspace_db.workspace_mets_path, clobber_mets=False, - mets_basename=workspace_db.mets_basename, download=False) - updated_file_groups = workspace.mets.file_groups - except Exception as error: - self.log.error(f"Failed to extract the processed file groups: {error}") - updated_file_groups = ["CORRUPTED FILE GROUPS"] - self.log.info(f"Setting new workspace state `{StateWorkspace.READY}` of workspace_id: {workspace_id}") - - db_workspace = sync_db_update_workspace( - find_workspace_id=workspace_id, state=StateWorkspace.READY, file_groups=updated_file_groups) - sync_db_update_workflow_job(find_job_id=self.current_message_job_id, job_state=StateJob.SUCCESS) - db_stats = sync_db_increase_processing_stats( - find_user_id=user_id, pages_succeed=db_workspace.pages_amount) - self.hpc_io_transfer.download_slurm_job_log_file(hpc_slurm_job_db.hpc_slurm_job_id, job_dir) - self.log.info(f"Increasing `pages_succeed` stat by {db_workspace.pages_amount}") - self.log.info(f"Total amount of `pages_succeed` stat: {db_stats.pages_succeed}") - if new_job_state == StateJob.FAILED: - self.log.info(f"Setting new workspace state `{StateWorkspace.READY}` of workspace_id: {workspace_id}") - db_workspace = sync_db_update_workspace(find_workspace_id=workspace_id, state=StateWorkspace.READY) - sync_db_update_workflow_job(find_job_id=self.current_message_job_id, job_state=StateJob.FAILED) - db_stats = sync_db_increase_processing_stats( - find_user_id=user_id, pages_failed=db_workspace.pages_amount) - self.hpc_io_transfer.download_slurm_job_log_file(hpc_slurm_job_db.hpc_slurm_job_id, job_dir) - self.log.error(f"Increasing `pages_failed` stat by {db_workspace.pages_amount}") - self.log.error(f"Total amount of `pages_failed` stat: {db_stats.pages_failed}") - - self.log.info(f"Latest slurm job state: {new_slurm_job_state}") - self.log.info(f"Latest workflow job state: {new_job_state}") - - def __callback(self, ch, method, properties, body): - self.log.debug(f"ch: {ch}, method: {method}, properties: {properties}, body: {body}") - self.log.debug(f"Consumed message: {body}") - - self.current_message_delivery_tag = method.delivery_tag - self.has_consumed_message = True - - # Since the workflow_message is constructed by the Operandi Server, - # it should not fail here when parsing under normal circumstances. - try: - consumed_message = loads(body) - self.log.info(f"Consumed message: {consumed_message}") - self.current_message_job_id = consumed_message["job_id"] - except Exception as error: - self.log.warning(f"Parsing the consumed message has failed: {error}") - self.__handle_message_failure(interruption=False) - return - - # Handle database related reads and set the workflow job status to RUNNING - try: - db_workflow_job = sync_db_get_workflow_job(self.current_message_job_id) - db_workspace = sync_db_get_workspace(db_workflow_job.workspace_id) - db_hpc_slurm_job = sync_db_get_hpc_slurm_job(self.current_message_job_id) - except RuntimeError as error: - self.log.warning(f"Database run-time error has occurred: {error}") - self.__handle_message_failure(interruption=False) - return - except Exception as error: - self.log.warning(f"Database related error has occurred: {error}") - self.__handle_message_failure(interruption=False) - return - - try: - self.__handle_hpc_and_workflow_states( - hpc_slurm_job_db=db_hpc_slurm_job, workflow_job_db=db_workflow_job, workspace_db=db_workspace) - except ValueError as error: - self.log.warning(f"{error}") - self.__handle_message_failure(interruption=False) - return - - self.has_consumed_message = False - self.log.debug(f"Ack delivery tag: {self.current_message_delivery_tag}") - ch.basic_ack(delivery_tag=method.delivery_tag) - - def __handle_message_failure(self, interruption: bool = False): - self.has_consumed_message = False - - if interruption: - # self.log.info(f"Nacking delivery tag: {self.current_message_delivery_tag}") - # self.rmq_consumer._channel.basic_nack(delivery_tag=self.current_message_delivery_tag) - # TODO: Sending ACK for now because it is hard to clean up without a mets workspace backup mechanism - self.log.info(f"Interruption ack delivery tag: {self.current_message_delivery_tag}") - self.rmq_consumer.ack_message(delivery_tag=self.current_message_delivery_tag) - return - - self.log.debug(f"Ack delivery tag: {self.current_message_delivery_tag}") - self.rmq_consumer.ack_message(delivery_tag=self.current_message_delivery_tag) - - # Reset the current message related parameters - self.current_message_delivery_tag = None - self.current_message_job_id = None - - # TODO: Ideally this method should be wrapped to be able - # to pass internal data from the Worker class required for the cleaning - # The arguments to this method are passed by the caller from the OS - def signal_handler(self, sig, frame): - signal_name = signal.Signals(sig).name - self.log.info(f"{signal_name} received from parent process `{getppid()}`.") - if self.has_consumed_message: - self.log.info(f"Handling the message failure due to interruption: {signal_name}") - self.__handle_message_failure(interruption=True) - self.rmq_consumer.disconnect() - self.rmq_consumer = None - self.log.info("Exiting gracefully.") - exit(0) diff --git a/src/broker/operandi_broker/job_worker_base.py b/src/broker/operandi_broker/job_worker_base.py new file mode 100644 index 00000000..7e4a7235 --- /dev/null +++ b/src/broker/operandi_broker/job_worker_base.py @@ -0,0 +1,93 @@ +from logging import getLogger +import signal +from os import getpid, getppid, setsid +from sys import exit + +from operandi_utils import reconfigure_all_loggers, get_log_file_path_prefix +from operandi_utils.constants import LOG_LEVEL_WORKER +from operandi_utils.database import sync_db_initiate_database +from operandi_utils.hpc import NHRExecutor, NHRTransfer +from operandi_utils.rabbitmq import get_connection_consumer, get_connection_publisher + +NOT_IMPLEMENTED_ERROR: str = "The method was not implemented in the extending class" + +# Each worker class listens to a specific queue, consumes messages, and processes messages. +class JobWorkerBase: + def __init__(self, db_url, rabbitmq_url, queue_name): + self.db_url = db_url + self.rmq_url = rabbitmq_url + self.queue_name = queue_name + + self.log = getLogger(f"operandi_broker.worker.{self.queue_name}_{getpid()}") + self.log_file_path = f"{get_log_file_path_prefix(module_type='worker')}_{self.queue_name}_{getpid()}.log" + + self.rmq_consumer = None + self.rmq_publisher = None + self.hpc_executor = None + self.hpc_io_transfer = None + + self.has_consumed_message = False + self.current_message_delivery_tag = None + + def disconnect_rmq_connections(self): + self.log.info("Disconnecting existing RabbitMQ connections.") + if self.rmq_consumer: + self.log.info("Disconnecting the RMQ consumer") + self.rmq_consumer.disconnect() + if self.rmq_publisher: + self.log.info("Disconnecting the RMQ publisher") + self.rmq_publisher.disconnect() + + def __del__(self): + self.disconnect_rmq_connections() + + def _consumed_msg_callback(self, ch, method, properties, body): + raise NotImplementedError(NOT_IMPLEMENTED_ERROR) + + def _handle_msg_failure(self, interruption: bool): + raise NotImplementedError(NOT_IMPLEMENTED_ERROR) + + def run(self, hpc_executor: bool, hpc_io_transfer: bool, publisher: bool): + try: + # Source: https://unix.stackexchange.com/questions/18166/what-are-session-leaders-in-ps + # Make the current process session leader + setsid() + # Reconfigure all loggers to the same format + reconfigure_all_loggers(log_level=LOG_LEVEL_WORKER, log_file_path=self.log_file_path) + self.log.info(f"Activating signal handler for SIGINT, SIGTERM") + signal.signal(signal.SIGINT, self.signal_handler) + signal.signal(signal.SIGTERM, self.signal_handler) + + sync_db_initiate_database(self.db_url) + self.log.info("MongoDB connection successful.") + if hpc_executor: + self.hpc_executor = NHRExecutor() + self.log.info("HPC executor connection successful.") + if hpc_io_transfer: + self.hpc_io_transfer = NHRTransfer() + self.log.info("HPC transfer connection successful.") + if publisher: + self.rmq_publisher = get_connection_publisher(rabbitmq_url=self.rmq_url, enable_acks=True) + self.log.info(f"RMQPublisher connected") + + self.rmq_consumer = get_connection_consumer(rabbitmq_url=self.rmq_url) + self.log.info(f"RMQConsumer connected") + self.rmq_consumer.configure_consuming( + queue_name=self.queue_name, callback_method=self._consumed_msg_callback) + self.log.info(f"Starting consuming from queue: {self.queue_name}") + self.rmq_consumer.start_consuming() + except Exception as e: + self.log.error(f"The worker failed, reason: {e}") + raise Exception(f"The worker failed, reason: {e}") + + # The arguments to this method are passed by the caller from the OS + def signal_handler(self, sig, frame): + signal_name = signal.Signals(sig).name + self.log.info(f"{signal_name} received from parent process `{getppid()}`.") + if self.has_consumed_message: + self.log.info(f"Handling the consumed message failure due to an interruption: {signal_name}") + self._handle_msg_failure(interruption=True) + # TODO: Verify if this call here is necessary + self.disconnect_rmq_connections() + self.log.info("Exiting gracefully.") + exit(0) diff --git a/src/broker/operandi_broker/job_worker_download.py b/src/broker/operandi_broker/job_worker_download.py new file mode 100644 index 00000000..f7878c81 --- /dev/null +++ b/src/broker/operandi_broker/job_worker_download.py @@ -0,0 +1,127 @@ +from json import loads +from pathlib import Path +from typing import List +from typing_extensions import override + +from ocrd import Resolver +from operandi_broker.job_worker_base import JobWorkerBase +from operandi_utils.constants import StateJob, StateWorkspace +from operandi_utils.database import ( + DBWorkspace, sync_db_increase_processing_stats, + sync_db_get_hpc_slurm_job, sync_db_get_workflow_job, sync_db_get_workspace, + sync_db_update_workflow_job, sync_db_update_workspace) + + +class JobWorkerDownload(JobWorkerBase): + def __init__(self, db_url, rabbitmq_url, queue_name): + super().__init__(db_url, rabbitmq_url, queue_name) + self.current_message_job_id = None + + @override + def _consumed_msg_callback(self, ch, method, properties, body): + self.log.debug(f"ch: {ch}, method: {method}, properties: {properties}, body: {body}") + self.log.debug(f"Consumed message: {body}") + self.current_message_delivery_tag = method.delivery_tag + self.has_consumed_message = True + + # Since the workflow_message is constructed by the Operandi Server, + # it should not fail here when parsing under normal circumstances. + try: + consumed_message = loads(body) + self.log.info(f"Consumed message: {consumed_message}") + self.current_message_job_id = consumed_message["job_id"] + previous_job_state = consumed_message["previous_job_state"] + except Exception as error: + self.log.warning(f"Parsing the consumed message has failed: {error}") + self._handle_msg_failure(interruption=False) + return + + try: + db_hpc_slurm_job = sync_db_get_hpc_slurm_job(self.current_message_job_id) + slurm_job_id = db_hpc_slurm_job.hpc_slurm_job_id + + db_workflow_job = sync_db_get_workflow_job(self.current_message_job_id) + workspace_id = db_workflow_job.workspace_id + job_dir = db_workflow_job.job_dir + + db_workspace = sync_db_get_workspace(workspace_id) + ws_dir = db_workspace.workspace_dir + user_id = db_workspace.user_id + except RuntimeError as error: + self.log.warning(f"Database run-time error has occurred: {error}") + self._handle_msg_failure(interruption=False) + return + except Exception as error: + self.log.warning(f"Database related error has occurred: {error}") + self._handle_msg_failure(interruption=False) + return + + try: + # TODO: Refactor this block of code since nothing is downloaded from the HPC when job fails. + self.hpc_io_transfer.download_slurm_job_log_file(slurm_job_id, job_dir) + if previous_job_state == StateJob.SUCCESS: + self.__download_results_from_hpc(job_dir=job_dir, workspace_dir=ws_dir) + self.log.info(f"Setting new workspace state `{StateWorkspace.READY}` of workspace_id: {workspace_id}") + updated_file_groups = self.__extract_updated_file_groups(db_workspace=db_workspace) + db_workspace = sync_db_update_workspace( + find_workspace_id=workspace_id, state=StateWorkspace.READY, file_groups=updated_file_groups) + pages_amount = db_workspace.pages_amount + self.log.info(f"Increasing `pages_succeed` stat by {pages_amount}") + db_stats = sync_db_increase_processing_stats(find_user_id=user_id, pages_succeed=pages_amount) + self.log.info(f"Total amount of `pages_succeed` stat: {db_stats.pages_succeed}") + if previous_job_state == StateJob.FAILED: + self.log.info(f"Setting new workspace state `{StateWorkspace.READY}` of workspace_id: {workspace_id}") + db_workspace = sync_db_update_workspace(find_workspace_id=workspace_id, state=StateWorkspace.READY) + pages_amount = db_workspace.pages_amount + self.log.error(f"Increasing `pages_failed` stat by {pages_amount}") + db_stats = sync_db_increase_processing_stats(find_user_id=user_id, pages_failed=pages_amount) + self.log.error(f"Total amount of `pages_failed` stat: {db_stats.pages_failed}") + self.log.info(f"Setting new workflow job state `{previous_job_state}`" + f" of job_id: {self.current_message_job_id}") + sync_db_update_workflow_job(find_job_id=self.current_message_job_id, job_state=previous_job_state) + except ValueError as error: + self.log.warning(f"{error}") + self._handle_msg_failure(interruption=False) + return + + self.has_consumed_message = False + self.log.debug(f"Ack delivery tag: {self.current_message_delivery_tag}") + ch.basic_ack(delivery_tag=method.delivery_tag) + + @override + def _handle_msg_failure(self, interruption: bool): + self.has_consumed_message = False + + if interruption: + # self.log.info(f"Nacking delivery tag: {self.current_message_delivery_tag}") + # self.rmq_consumer._channel.basic_nack(delivery_tag=self.current_message_delivery_tag) + # TODO: Sending ACK for now because it is hard to clean up without a mets workspace backup mechanism + self.log.info(f"Interruption ack delivery tag: {self.current_message_delivery_tag}") + self.rmq_consumer.ack_message(delivery_tag=self.current_message_delivery_tag) + return + + self.log.debug(f"Ack delivery tag: {self.current_message_delivery_tag}") + self.rmq_consumer.ack_message(delivery_tag=self.current_message_delivery_tag) + + # Reset the current message related parameters + self.current_message_delivery_tag = None + self.current_message_job_id = None + + def __download_results_from_hpc(self, job_dir: str, workspace_dir: str) -> None: + self.hpc_io_transfer.get_and_unpack_slurm_workspace( + ocrd_workspace_dir=Path(workspace_dir), workflow_job_dir=Path(job_dir)) + self.log.info(f"Transferred slurm workspace from hpc path") + # Delete the result dir from the HPC home folder + job_id = Path(job_dir).name + self.hpc_executor.remove_workflow_job_dir(workflow_job_id=job_id) + self.log.info(f"Removed slurm workspace from HPC for job: {job_id}") + + def __extract_updated_file_groups(self, db_workspace: DBWorkspace) -> List[str]: + try: + workspace = Resolver().workspace_from_url( + mets_url=db_workspace.workspace_mets_path, clobber_mets=False, + mets_basename=db_workspace.mets_basename, download=False) + return workspace.mets.file_groups + except Exception as error: + self.log.error(f"Failed to extract the processed file groups: {error}") + return ["CORRUPTED FILE GROUPS"] diff --git a/src/broker/operandi_broker/job_worker_status.py b/src/broker/operandi_broker/job_worker_status.py new file mode 100644 index 00000000..1756ce40 --- /dev/null +++ b/src/broker/operandi_broker/job_worker_status.py @@ -0,0 +1,121 @@ +from json import dumps, loads +from typing_extensions import override + +from operandi_broker.job_worker_base import JobWorkerBase +from operandi_utils.constants import StateJob, StateWorkspace +from operandi_utils.database import ( + DBHPCSlurmJob, DBWorkflowJob, + sync_db_get_hpc_slurm_job, sync_db_get_workflow_job, + sync_db_update_hpc_slurm_job, sync_db_update_workflow_job, sync_db_update_workspace) +from operandi_utils.rabbitmq import RABBITMQ_QUEUE_HPC_DOWNLOADS + + +class JobWorkerStatus(JobWorkerBase): + def __init__(self, db_url, rabbitmq_url, queue_name): + super().__init__(db_url, rabbitmq_url, queue_name) + self.current_message_job_id = None + + @override + def _consumed_msg_callback(self, ch, method, properties, body): + self.log.debug(f"ch: {ch}, method: {method}, properties: {properties}, body: {body}") + self.log.debug(f"Consumed message: {body}") + self.current_message_delivery_tag = method.delivery_tag + self.has_consumed_message = True + + # Since the workflow_message is constructed by the Operandi Server, + # it should not fail here when parsing under normal circumstances. + try: + consumed_message = loads(body) + self.log.info(f"Consumed message: {consumed_message}") + self.current_message_job_id = consumed_message["job_id"] + except Exception as error: + self.log.warning(f"Parsing the consumed message has failed: {error}") + self._handle_msg_failure(interruption=False) + return + + try: + db_hpc_slurm_job: DBHPCSlurmJob = sync_db_get_hpc_slurm_job(self.current_message_job_id) + + db_workflow_job: DBWorkflowJob = sync_db_get_workflow_job(self.current_message_job_id) + workspace_id = db_workflow_job.workspace_id + except RuntimeError as error: + self.log.warning(f"Database run-time error has occurred: {error}") + self._handle_msg_failure(interruption=False) + return + except Exception as error: + self.log.warning(f"Database related error has occurred: {error}") + self._handle_msg_failure(interruption=False) + return + + try: + self.__handle_hpc_and_workflow_states( + db_hpc_slurm_job=db_hpc_slurm_job, db_workflow_job=db_workflow_job, workspace_id=workspace_id) + except ValueError as error: + self.log.warning(f"{error}") + self._handle_msg_failure(interruption=False) + return + + self.has_consumed_message = False + self.log.debug(f"Ack delivery tag: {self.current_message_delivery_tag}") + ch.basic_ack(delivery_tag=method.delivery_tag) + + @override + def _handle_msg_failure(self, interruption: bool): + self.has_consumed_message = False + + if interruption: + # self.log.info(f"Nacking delivery tag: {self.current_message_delivery_tag}") + # self.rmq_consumer._channel.basic_nack(delivery_tag=self.current_message_delivery_tag) + # TODO: Sending ACK for now because it is hard to clean up without a mets workspace backup mechanism + self.log.info(f"Interruption ack delivery tag: {self.current_message_delivery_tag}") + self.rmq_consumer.ack_message(delivery_tag=self.current_message_delivery_tag) + return + + self.log.debug(f"Ack delivery tag: {self.current_message_delivery_tag}") + self.rmq_consumer.ack_message(delivery_tag=self.current_message_delivery_tag) + + # Reset the current message related parameters + self.current_message_delivery_tag = None + self.current_message_job_id = None + + def __handle_hpc_and_workflow_states( + self, db_hpc_slurm_job: DBHPCSlurmJob, db_workflow_job: DBWorkflowJob, workspace_id: str + ): + old_slurm_job_state = db_hpc_slurm_job.hpc_slurm_job_state + new_slurm_job_state = self.hpc_executor.check_slurm_job_state(slurm_job_id=db_hpc_slurm_job.hpc_slurm_job_id) + # TODO: Reconsider this + # if not new_slurm_job_state: + # return + + job_id = db_workflow_job.job_id + old_job_state = db_workflow_job.job_state + + # If there has been a change of slurm job state, update it + if old_slurm_job_state != new_slurm_job_state: + self.log.info( + f"Slurm job: {db_hpc_slurm_job.hpc_slurm_job_id}, " + f"old state: {old_slurm_job_state}, " + f"new state: {new_slurm_job_state}") + sync_db_update_hpc_slurm_job(find_workflow_job_id=job_id, hpc_slurm_job_state=new_slurm_job_state) + + # Convert the slurm job state to operandi workflow job state + new_job_state = StateJob.convert_from_slurm_job(slurm_job_state=new_slurm_job_state) + + # If there has been a change of operandi workflow state, update it + if old_job_state != new_job_state: + self.log.info(f"Workflow job id: {job_id}, old state: {old_job_state}, new state: {new_job_state}") + if new_job_state == StateJob.SUCCESS or new_job_state == StateJob.FAILED: + sync_db_update_workspace(find_workspace_id=workspace_id, state=StateWorkspace.TRANSFERRING_FROM_HPC) + sync_db_update_workflow_job(find_job_id=job_id, job_state=StateJob.TRANSFERRING_FROM_HPC) + + result_download_message = { + "job_id": f"{job_id}", + "previous_job_state": f"{new_job_state}" + } + self.log.info(f"Encoding the result download RabbitMQ message: {result_download_message}") + encoded_result_download_message = dumps(result_download_message).encode(encoding="utf-8") + self.rmq_publisher.publish_to_queue( + queue_name=RABBITMQ_QUEUE_HPC_DOWNLOADS, message=encoded_result_download_message) + + self.log.info(f"Latest slurm job state: {new_slurm_job_state}") + self.log.info(f"Latest workflow job state: {new_job_state}") diff --git a/src/broker/operandi_broker/worker.py b/src/broker/operandi_broker/job_worker_submit.py similarity index 62% rename from src/broker/operandi_broker/worker.py rename to src/broker/operandi_broker/job_worker_submit.py index 5c709628..c2b42dc2 100644 --- a/src/broker/operandi_broker/worker.py +++ b/src/broker/operandi_broker/job_worker_submit.py @@ -1,85 +1,34 @@ from json import loads -from logging import getLogger -import signal -from os import getpid, getppid, setsid from os.path import join from pathlib import Path -from sys import exit from typing import List +from typing_extensions import override -from operandi_utils import reconfigure_all_loggers, get_log_file_path_prefix -from operandi_utils.constants import LOG_LEVEL_WORKER, StateJob, StateWorkspace +from operandi_broker.job_worker_base import JobWorkerBase + +from operandi_utils.constants import StateJob, StateWorkspace from operandi_utils.database import ( - sync_db_increase_processing_stats, sync_db_initiate_database, sync_db_get_workflow, sync_db_get_workspace, + DBWorkflow, DBWorkflowJob, DBWorkspace, + sync_db_increase_processing_stats, sync_db_get_workflow, sync_db_get_workflow_job, sync_db_get_workspace, sync_db_create_hpc_slurm_job, sync_db_update_workflow_job, sync_db_update_workspace) -from operandi_utils.hpc import NHRExecutor, NHRTransfer from operandi_utils.hpc.constants import ( HPC_BATCH_SUBMIT_WORKFLOW_JOB, HPC_JOB_DEADLINE_TIME_REGULAR, HPC_JOB_DEADLINE_TIME_TEST, HPC_JOB_QOS_SHORT, HPC_JOB_QOS_DEFAULT) -from operandi_utils.rabbitmq import get_connection_consumer -# Each worker class listens to a specific queue, -# consume messages, and process messages. -class Worker: - def __init__(self, db_url, rabbitmq_url, queue_name, tunnel_port_executor, tunnel_port_transfer, test_sbatch=False): - self.log = getLogger(f"operandi_broker.worker[{getpid()}].{queue_name}") - self.queue_name = queue_name - self.log_file_path = f"{get_log_file_path_prefix(module_type='worker')}_{queue_name}.log" +class JobWorkerSubmit(JobWorkerBase): + def __init__(self, db_url, rabbitmq_url, queue_name, test_sbatch=False): + super().__init__(db_url, rabbitmq_url, queue_name) self.test_sbatch = test_sbatch - - self.db_url = db_url - self.rmq_url = rabbitmq_url - self.rmq_consumer = None - self.hpc_executor = None - self.hpc_io_transfer = None - - # Currently consumed message related parameters - self.current_message_delivery_tag = None + self.current_message_job_id = None self.current_message_user_id = None - self.current_message_ws_id = None self.current_message_wf_id = None - self.current_message_job_id = None - self.has_consumed_message = False - - self.tunnel_port_executor = tunnel_port_executor - self.tunnel_port_transfer = tunnel_port_transfer - - def __del__(self): - if self.rmq_consumer: - self.rmq_consumer.disconnect() - - def run(self): - try: - # Source: https://unix.stackexchange.com/questions/18166/what-are-session-leaders-in-ps - # Make the current process session leader - setsid() - # Reconfigure all loggers to the same format - reconfigure_all_loggers(log_level=LOG_LEVEL_WORKER, log_file_path=self.log_file_path) - self.log.info(f"Activating signal handler for SIGINT, SIGTERM") - signal.signal(signal.SIGINT, self.signal_handler) - signal.signal(signal.SIGTERM, self.signal_handler) - - sync_db_initiate_database(self.db_url) - self.hpc_executor = NHRExecutor() - self.log.info("HPC executor connection successful.") - self.hpc_io_transfer = NHRTransfer() - self.log.info("HPC transfer connection successful.") - - self.rmq_consumer = get_connection_consumer(rabbitmq_url=self.rmq_url) - self.log.info(f"RMQConsumer connected") - self.rmq_consumer.configure_consuming(queue_name=self.queue_name, callback_method=self.__callback) - self.log.info(f"Configured consuming from queue: {self.queue_name}") - self.log.info(f"Starting consuming from queue: {self.queue_name}") - self.rmq_consumer.start_consuming() - except Exception as e: - self.log.error(f"The worker failed, reason: {e}") - raise Exception(f"The worker failed, reason: {e}") + self.current_message_ws_id = None - def __callback(self, ch, method, properties, body): + @override + def _consumed_msg_callback(self, ch, method, properties, body): self.log.debug(f"ch: {ch}, method: {method}, properties: {properties}, body: {body}") self.log.debug(f"Consumed message: {body}") - self.current_message_delivery_tag = method.delivery_tag self.has_consumed_message = True @@ -88,9 +37,6 @@ def __callback(self, ch, method, properties, body): try: consumed_message = loads(body) self.log.info(f"Consumed message: {consumed_message}") - self.current_message_user_id = consumed_message["user_id"] - self.current_message_ws_id = consumed_message["workspace_id"] - self.current_message_wf_id = consumed_message["workflow_id"] self.current_message_job_id = consumed_message["job_id"] input_file_grp = consumed_message["input_file_grp"] remove_file_grps = consumed_message["remove_file_grps"] @@ -102,29 +48,33 @@ def __callback(self, ch, method, properties, body): nf_process_forks = slurm_job_cpus except Exception as error: self.log.error(f"Parsing the consumed message has failed: {error}") - self.__handle_message_failure(interruption=False) + self._handle_msg_failure(interruption=False) return - # Handle database related reads and set the workflow job status to RUNNING try: - workflow_db = sync_db_get_workflow(self.current_message_wf_id) - workspace_db = sync_db_get_workspace(self.current_message_ws_id) - - workflow_script_path = Path(workflow_db.workflow_script_path) - nf_uses_mets_server = workflow_db.uses_mets_server - nf_executable_steps = workflow_db.executable_steps - workspace_dir = Path(workspace_db.workspace_dir) - mets_basename = workspace_db.mets_basename - ws_pages_amount = workspace_db.pages_amount + db_workflow_job: DBWorkflowJob = sync_db_get_workflow_job(self.current_message_job_id) + self.current_message_user_id = db_workflow_job.user_id + self.current_message_wf_id = db_workflow_job.workflow_id + self.current_message_ws_id = db_workflow_job.workspace_id + + db_workflow: DBWorkflow = sync_db_get_workflow(self.current_message_wf_id) + workflow_script_path = Path(db_workflow.workflow_script_path) + nf_uses_mets_server = db_workflow.uses_mets_server + nf_executable_steps = db_workflow.executable_steps + + db_workspace: DBWorkspace = sync_db_get_workspace(self.current_message_ws_id) + workspace_dir = Path(db_workspace.workspace_dir) + mets_basename = db_workspace.mets_basename + ws_pages_amount = db_workspace.pages_amount if not mets_basename: mets_basename = "mets.xml" except RuntimeError as error: self.log.error(f"Database run-time error has occurred: {error}") - self.__handle_message_failure(interruption=False, set_ws_ready=True) + self._handle_msg_failure(interruption=False) return except Exception as error: self.log.error(f"Database related error has occurred: {error}") - self.__handle_message_failure(interruption=False, set_ws_ready=True) + self._handle_msg_failure(interruption=False) return # Trigger a slurm job in the HPC @@ -140,7 +90,7 @@ def __callback(self, ch, method, properties, body): self.log.info(f"The HPC slurm job was successfully submitted") except Exception as error: self.log.error(f"Triggering a slurm job in the HPC has failed: {error}") - self.__handle_message_failure(interruption=False, set_ws_ready=True) + self._handle_msg_failure(interruption=False) return job_state = StateJob.PENDING @@ -155,16 +105,16 @@ def __callback(self, ch, method, properties, body): self.log.debug(f"Ack delivery tag: {self.current_message_delivery_tag}") ch.basic_ack(delivery_tag=method.delivery_tag) - def __handle_message_failure(self, interruption: bool = False, set_ws_ready: bool = False): + @override + def _handle_msg_failure(self, interruption: bool): job_state = StateJob.FAILED self.log.info(f"Setting new state `{job_state}` of job_id: {self.current_message_job_id}") sync_db_update_workflow_job(find_job_id=self.current_message_job_id, job_state=job_state) self.has_consumed_message = False - if set_ws_ready: - ws_state = StateWorkspace.READY - self.log.info(f"Setting new workspace state `{ws_state}` of workspace_id: {self.current_message_ws_id}") - sync_db_update_workspace(find_workspace_id=self.current_message_ws_id, state=ws_state) + ws_state = StateWorkspace.READY + self.log.info(f"Setting new workspace state `{ws_state}` of workspace_id: {self.current_message_ws_id}") + sync_db_update_workspace(find_workspace_id=self.current_message_ws_id, state=ws_state) if interruption: # self.log.info(f"Nacking delivery tag: {self.current_message_delivery_tag}") @@ -179,24 +129,10 @@ def __handle_message_failure(self, interruption: bool = False, set_ws_ready: boo # Reset the current message related parameters self.current_message_delivery_tag = None - self.current_message_ws_id = None - self.current_message_wf_id = None self.current_message_job_id = None - - # TODO: Ideally this method should be wrapped to be able - # to pass internal data from the Worker class required for the cleaning - # The arguments to this method are passed by the caller from the OS - def signal_handler(self, sig, frame): - signal_name = signal.Signals(sig).name - self.log.info(f"{signal_name} received from parent process `{getppid()}`.") - if self.has_consumed_message: - self.log.info(f"Handling the message failure due to interruption: {signal_name}") - self.__handle_message_failure(interruption=True) - - self.rmq_consumer.disconnect() - self.rmq_consumer = None - self.log.info("Exiting gracefully.") - exit(0) + self.current_message_user_id = None + self.current_message_wf_id = None + self.current_message_ws_id = None # TODO: This should be further refined, currently it's just everything in one place def prepare_and_trigger_slurm_job( diff --git a/src/rabbitmq_definitions.json b/src/rabbitmq_definitions.json index 6a89b89d..823d2c6d 100755 --- a/src/rabbitmq_definitions.json +++ b/src/rabbitmq_definitions.json @@ -20,7 +20,9 @@ {"name": "operandi_queue_harvester", "vhost": "/", "durable": false, "auto_delete": false}, {"name": "operandi_queue_harvester", "vhost": "test", "durable": false, "auto_delete": false}, {"name": "operandi_queue_job_statuses", "vhost": "/", "durable": false, "auto_delete": true}, - {"name": "operandi_queue_job_statuses", "vhost": "test", "durable": false, "auto_delete": true} + {"name": "operandi_queue_job_statuses", "vhost": "test", "durable": false, "auto_delete": true}, + {"name": "operandi_queue_hpc_downloads", "vhost": "/", "durable": false, "auto_delete": false}, + {"name": "operandi_queue_hpc_downloads", "vhost": "test", "durable": false, "auto_delete": false} ], "exchanges": [], "bindings": [] diff --git a/src/server/operandi_server/models/base.py b/src/server/operandi_server/models/base.py index 19a37769..b26abd75 100644 --- a/src/server/operandi_server/models/base.py +++ b/src/server/operandi_server/models/base.py @@ -10,6 +10,7 @@ class Resource(BaseModel): resource_url: str = Field(..., description="The unique URL of the resource") description: str = Field(..., description="The description of the resource") datetime: datetime + deleted: bool class Config: allow_population_by_field_name = True diff --git a/src/server/operandi_server/models/workflow.py b/src/server/operandi_server/models/workflow.py index 6c319ecf..ff0c1f81 100644 --- a/src/server/operandi_server/models/workflow.py +++ b/src/server/operandi_server/models/workflow.py @@ -13,6 +13,7 @@ class WorkflowRsrc(Resource): # description: (str) - inherited from Resource # created_by_user: (str) - inherited from Resource # datetime: (datetime) - inherited from Resource + # deleted: bool - inherited from Resource uses_mets_server: bool executable_steps: List[str] producible_file_groups: List[str] @@ -31,6 +32,7 @@ def from_db_workflow(db_workflow: DBWorkflow): executable_steps=db_workflow.executable_steps, producible_file_groups=db_workflow.producible_file_groups, datetime=db_workflow.datetime, + deleted=db_workflow.deleted ) class WorkflowJobRsrc(Resource): @@ -41,6 +43,7 @@ class WorkflowJobRsrc(Resource): # description: (str) - inherited from Resource # created_by_user: (str) - inherited from Resource # datetime: (datetime) - inherited from Resource + # deleted: bool - inherited from Resource job_state: Optional[StateJob] = StateJob.UNSET workflow_rsrc: Optional[WorkflowRsrc] workspace_rsrc: Optional[WorkspaceRsrc] @@ -58,5 +61,6 @@ def from_db_workflow_job(db_workflow_job: DBWorkflowJob, db_workflow: DBWorkflow job_state=db_workflow_job.job_state, workflow_rsrc=WorkflowRsrc.from_db_workflow(db_workflow), workspace_rsrc=WorkspaceRsrc.from_db_workspace(db_workspace), - datetime=db_workflow_job.datetime + datetime=db_workflow_job.datetime, + deleted=db_workflow.deleted ) diff --git a/src/server/operandi_server/models/workspace.py b/src/server/operandi_server/models/workspace.py index 1bff0a6b..f2f07d32 100644 --- a/src/server/operandi_server/models/workspace.py +++ b/src/server/operandi_server/models/workspace.py @@ -12,6 +12,7 @@ class WorkspaceRsrc(Resource): # description: (str) - inherited from Resource # created_by_user: (str) - inherited from Resource # datetime: (datetime) - inherited from Resource + # deleted: bool - inherited from Resource pages_amount: int file_groups: List[str] state: StateWorkspace = StateWorkspace.UNSET @@ -39,5 +40,6 @@ def from_db_workspace(db_workspace: DBWorkspace): bagit_profile_identifier=db_workspace.bagit_profile_identifier, ocrd_base_version_checksum=db_workspace.ocrd_base_version_checksum, mets_basename=db_workspace.mets_basename, - bag_info_add=db_workspace.bag_info_adds + bag_info_add=db_workspace.bag_info_adds, + deleted=db_workspace.deleted ) diff --git a/src/server/operandi_server/routers/admin_panel.py b/src/server/operandi_server/routers/admin_panel.py index ae07de72..16541652 100644 --- a/src/server/operandi_server/routers/admin_panel.py +++ b/src/server/operandi_server/routers/admin_panel.py @@ -99,31 +99,31 @@ async def user_processing_stats(self, user_id: str, auth: HTTPBasicCredentials = async def user_workflow_jobs( self, user_id: str, auth: HTTPBasicCredentials = Depends(HTTPBasic()), - start_date: Optional[datetime] = None, end_date: Optional[datetime] = None + start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, hide_deleted: bool = True ) -> List[WorkflowJobRsrc]: """ The expected datetime format: YYYY-MM-DDTHH:MM:SS, for example, 2024-12-01T18:17:15 """ await self.auth_admin_with_handling(auth) return await get_user_workflow_jobs( - self.logger, self.rmq_publisher, user_id, start_date, end_date) + self.logger, self.rmq_publisher, user_id, start_date, end_date, hide_deleted) async def user_workspaces( self, user_id: str, auth: HTTPBasicCredentials = Depends(HTTPBasic()), - start_date: Optional[datetime] = None, end_date: Optional[datetime] = None + start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, hide_deleted: bool = True ) -> List[WorkspaceRsrc]: """ The expected datetime format: YYYY-MM-DDTHH:MM:SS, for example, 2024-12-01T18:17:15 """ await self.auth_admin_with_handling(auth) - return await get_user_workspaces(user_id=user_id, start_date=start_date, end_date=end_date) + return await get_user_workspaces(user_id, start_date, end_date, hide_deleted) async def user_workflows( self, user_id: str, auth: HTTPBasicCredentials = Depends(HTTPBasic()), - start_date: Optional[datetime] = None, end_date: Optional[datetime] = None + start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, hide_deleted: bool = True ) -> List[WorkflowRsrc]: """ The expected datetime format: YYYY-MM-DDTHH:MM:SS, for example, 2024-12-01T18:17:15 """ await self.auth_admin_with_handling(auth) - return await get_user_workflows(user_id=user_id, start_date=start_date, end_date=end_date) + return await get_user_workflows(user_id, start_date, end_date, hide_deleted) diff --git a/src/server/operandi_server/routers/user.py b/src/server/operandi_server/routers/user.py index 12a518ee..6e56eb5f 100644 --- a/src/server/operandi_server/routers/user.py +++ b/src/server/operandi_server/routers/user.py @@ -108,7 +108,7 @@ async def user_workflow_jobs( """ py_user_action = await user_auth_with_handling(self.logger, auth) return await get_user_workflow_jobs( - self.logger, self.rmq_publisher, py_user_action.user_id, start_date, end_date) + self.logger, self.rmq_publisher, py_user_action.user_id, start_date, end_date, True) async def user_workspaces( self, auth: HTTPBasicCredentials = Depends(HTTPBasic()), @@ -118,7 +118,7 @@ async def user_workspaces( The expected datetime format: YYYY-MM-DDTHH:MM:SS, for example, 2024-12-01T18:17:15 """ py_user_action = await user_auth_with_handling(self.logger, auth) - return await get_user_workspaces(user_id=py_user_action.user_id, start_date=start_date, end_date=end_date) + return await get_user_workspaces(py_user_action.user_id, start_date, end_date, True) async def user_workflows( self, auth: HTTPBasicCredentials = Depends(HTTPBasic()), @@ -128,4 +128,4 @@ async def user_workflows( The expected datetime format: YYYY-MM-DDTHH:MM:SS, for example, 2024-12-01T18:17:15 """ py_user_action = await user_auth_with_handling(self.logger, auth) - return await get_user_workflows(user_id=py_user_action.user_id, start_date=start_date, end_date=end_date) + return await get_user_workflows(py_user_action.user_id, start_date, end_date, True) diff --git a/src/server/operandi_server/routers/workflow.py b/src/server/operandi_server/routers/workflow.py index 19129fc9..4df10b86 100644 --- a/src/server/operandi_server/routers/workflow.py +++ b/src/server/operandi_server/routers/workflow.py @@ -402,9 +402,8 @@ async def submit_to_rabbitmq_queue( workspace_id=workspace_id, workflow_id=workflow_id, details=details) self._push_job_to_rabbitmq( - user_id=py_user_action.user_id, user_type=user_account_type, workflow_id=workflow_id, - workspace_id=workspace_id, job_id=job_id, input_file_grp=input_file_grp, remove_file_grps=remove_file_grps, - partition=partition, cpus=cpus, ram=ram + user_type=user_account_type, job_id=job_id, input_file_grp=input_file_grp, + remove_file_grps=remove_file_grps, partition=partition, cpus=cpus, ram=ram ) await db_increase_processing_stats_with_handling( self.logger, find_user_id=py_user_action.user_id, pages_submitted=db_workspace.pages_amount) @@ -412,15 +411,12 @@ async def submit_to_rabbitmq_queue( db_workflow_job=db_wf_job, db_workflow=db_workflow, db_workspace=db_workspace) def _push_job_to_rabbitmq( - self, user_id: str, user_type: AccountType, workflow_id: str, workspace_id: str, job_id: str, - input_file_grp: str, remove_file_grps: str, partition: str, cpus: int, ram: int + self, user_type: AccountType, job_id: str, input_file_grp: str, remove_file_grps: str, partition: str, + cpus: int, ram: int ): # Create the message to be sent to the RabbitMQ queue self.logger.info("Creating a workflow job RabbitMQ message") workflow_processing_message = { - "user_id": f"{user_id}", - "workflow_id": f"{workflow_id}", - "workspace_id": f"{workspace_id}", "job_id": f"{job_id}", "input_file_grp": f"{input_file_grp}", "remove_file_grps": f"{remove_file_grps}", diff --git a/src/server/operandi_server/routers/workflow_utils.py b/src/server/operandi_server/routers/workflow_utils.py index 6306fd05..e207e068 100644 --- a/src/server/operandi_server/routers/workflow_utils.py +++ b/src/server/operandi_server/routers/workflow_utils.py @@ -106,9 +106,10 @@ async def convert_oton_with_handling( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=message) async def get_user_workflows( - user_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None + user_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, hide_deleted: bool = True ) -> List[WorkflowRsrc]: - db_workflows = await db_get_all_workflows_by_user(user_id=user_id, start_date=start_date, end_date=end_date) + db_workflows = await db_get_all_workflows_by_user( + user_id=user_id, start_date=start_date, end_date=end_date, hide_deleted=hide_deleted) return [WorkflowRsrc.from_db_workflow(db_workflow) for db_workflow in db_workflows] async def push_status_request_to_rabbitmq(logger, rmq_publisher, job_id: str): @@ -125,9 +126,11 @@ async def push_status_request_to_rabbitmq(logger, rmq_publisher, job_id: str): raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=message) async def get_user_workflow_jobs( - logger, rmq_publisher, user_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None + logger, rmq_publisher, user_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, + hide_deleted: bool = True ) -> List[WorkflowJobRsrc]: - db_workflow_jobs = await db_get_all_workflow_jobs_by_user(user_id=user_id, start_date=start_date, end_date=end_date) + db_workflow_jobs = await db_get_all_workflow_jobs_by_user( + user_id=user_id, start_date=start_date, end_date=end_date, hide_deleted=hide_deleted) response = [] for db_workflow_job in db_workflow_jobs: job_state = db_workflow_job.job_state diff --git a/src/server/operandi_server/routers/workspace_utils.py b/src/server/operandi_server/routers/workspace_utils.py index c068970d..f547de57 100644 --- a/src/server/operandi_server/routers/workspace_utils.py +++ b/src/server/operandi_server/routers/workspace_utils.py @@ -230,8 +230,9 @@ def find_file_groups_to_remove_with_handling(logger, db_workspace, preserve_file return remove_groups async def get_user_workspaces( - user_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None + user_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, hide_deleted: bool = True ) -> List[WorkspaceRsrc]: - db_workspaces = await db_get_all_workspaces_by_user(user_id=user_id, start_date=start_date, end_date=end_date) + db_workspaces = await db_get_all_workspaces_by_user( + user_id=user_id, start_date=start_date, end_date=end_date, hide_deleted=hide_deleted) return [WorkspaceRsrc.from_db_workspace(db_workspace) for db_workspace in db_workspaces] diff --git a/src/utils/operandi_utils/constants.py b/src/utils/operandi_utils/constants.py index c38f33cb..f5144ed6 100644 --- a/src/utils/operandi_utils/constants.py +++ b/src/utils/operandi_utils/constants.py @@ -169,6 +169,7 @@ class StateWorkspace(str, Enum): # TODO: Find a more optimal way of achieving this dynamically OCRD_PROCESSOR_EXECUTABLE_TO_IMAGE = { + "ocrd_all": "ocrd_all_maximum_image.sif", "ocrd": "ocrd_core.sif", "ocrd-tesserocr-crop": "ocrd_tesserocr.sif", "ocrd-tesserocr-deskew": "ocrd_tesserocr.sif", diff --git a/src/utils/operandi_utils/database/db_workflow.py b/src/utils/operandi_utils/database/db_workflow.py index c2707424..5f7f2ff7 100644 --- a/src/utils/operandi_utils/database/db_workflow.py +++ b/src/utils/operandi_utils/database/db_workflow.py @@ -55,7 +55,7 @@ async def db_get_workflow(workflow_id: str) -> DBWorkflow: return db_workflow async def db_get_all_workflows_by_user( - user_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None + user_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, hide_deleted: bool = True ) -> List[DBWorkflow]: query = {"user_id": user_id} if start_date or end_date: @@ -64,6 +64,8 @@ async def db_get_all_workflows_by_user( query["datetime"]["$gte"] = start_date if end_date: query["datetime"]["$lte"] = end_date + if hide_deleted: + query["deleted"] = False db_workflows = await DBWorkflow.find_many(query).to_list() return db_workflows @@ -108,5 +110,6 @@ async def sync_db_update_workflow(find_workflow_id: str, **kwargs) -> DBWorkflow @call_sync async def sync_db_get_all_workflows_by_user( - user_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None) -> List[DBWorkflow]: - return await db_get_all_workflows_by_user(user_id, start_date, end_date) + user_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, hide_deleted: bool = True +) -> List[DBWorkflow]: + return await db_get_all_workflows_by_user(user_id, start_date, end_date, hide_deleted) diff --git a/src/utils/operandi_utils/database/db_workflow_job.py b/src/utils/operandi_utils/database/db_workflow_job.py index cce3fd28..2074651b 100644 --- a/src/utils/operandi_utils/database/db_workflow_job.py +++ b/src/utils/operandi_utils/database/db_workflow_job.py @@ -39,7 +39,7 @@ async def db_get_workflow_job(job_id: str) -> DBWorkflowJob: async def db_get_all_workflow_jobs_by_user( - user_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None + user_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, hide_deleted: bool = True ) -> List[DBWorkflowJob]: query = {"user_id": user_id} if start_date or end_date: @@ -48,6 +48,8 @@ async def db_get_all_workflow_jobs_by_user( query["datetime"]["$gte"] = start_date if end_date: query["datetime"]["$lte"] = end_date + if hide_deleted: + query["deleted"] = False db_workflow_jobs = await DBWorkflowJob.find_many(query).to_list() return db_workflow_jobs @@ -95,5 +97,6 @@ async def sync_db_update_workflow_job(find_job_id: str, **kwargs) -> DBWorkflowJ @call_sync async def sync_db_get_all_workflow_jobs_by_user( - user_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None) -> List[DBWorkflowJob]: - return await db_get_all_workflow_jobs_by_user(user_id, start_date, end_date) + user_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, hide_deleted: bool = True +) -> List[DBWorkflowJob]: + return await db_get_all_workflow_jobs_by_user(user_id, start_date, end_date, hide_deleted) diff --git a/src/utils/operandi_utils/database/db_workspace.py b/src/utils/operandi_utils/database/db_workspace.py index c038b50b..72324552 100644 --- a/src/utils/operandi_utils/database/db_workspace.py +++ b/src/utils/operandi_utils/database/db_workspace.py @@ -76,7 +76,7 @@ async def db_get_workspace(workspace_id: str) -> DBWorkspace: async def db_get_all_workspaces_by_user( - user_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None + user_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, hide_deleted: bool = True ) -> List[DBWorkspace]: query = {"user_id": user_id} if start_date or end_date: @@ -85,6 +85,8 @@ async def db_get_all_workspaces_by_user( query["datetime"]["$gte"] = start_date if end_date: query["datetime"]["$lte"] = end_date + if hide_deleted: + query["deleted"] = False db_workspaces = await DBWorkspace.find_many(query).to_list() return db_workspaces @@ -137,5 +139,6 @@ async def sync_db_update_workspace(find_workspace_id: str, **kwargs) -> DBWorksp @call_sync async def sync_db_get_all_workspaces_by_user( - user_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None) -> List[DBWorkspace]: - return await db_get_all_workspaces_by_user(user_id, start_date, end_date) + user_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, hide_deleted: bool = True +) -> List[DBWorkspace]: + return await db_get_all_workspaces_by_user(user_id, start_date, end_date, hide_deleted) diff --git a/src/utils/operandi_utils/database/models.py b/src/utils/operandi_utils/database/models.py index fedf47d0..d5584bc4 100644 --- a/src/utils/operandi_utils/database/models.py +++ b/src/utils/operandi_utils/database/models.py @@ -31,7 +31,7 @@ class DBUserAccount(Document): account_type: AccountType = AccountType.UNSET approved_user: bool = False deleted: bool = False - datetime = datetime.now() + datetime: Optional[datetime] details: Optional[str] class Settings: @@ -85,7 +85,7 @@ class DBHPCSlurmJob(Document): hpc_batch_script_path: Optional[str] hpc_slurm_workspace_path: Optional[str] deleted: bool = False - datetime = datetime.now() + datetime: Optional[datetime] details: Optional[str] class Settings: @@ -117,7 +117,7 @@ class DBWorkflow(Document): executable_steps: List[str] producible_file_groups: List[str] deleted: bool = False - datetime = datetime.now() + datetime: Optional[datetime] details: Optional[str] class Settings: @@ -152,7 +152,7 @@ class DBWorkflowJob(Document): workspace_dir: Optional[str] hpc_slurm_job_id: Optional[str] deleted: bool = False - datetime = datetime.now() + datetime: Optional[datetime] details: Optional[str] class Settings: @@ -195,7 +195,7 @@ class DBWorkspace(Document): mets_basename: Optional[str] bag_info_adds: Optional[dict] deleted: bool = False - datetime = datetime.now() + datetime: Optional[datetime] details: Optional[str] class Settings: diff --git a/src/utils/operandi_utils/hpc/nhr_connector.py b/src/utils/operandi_utils/hpc/nhr_connector.py index 6ea693bf..a6965032 100644 --- a/src/utils/operandi_utils/hpc/nhr_connector.py +++ b/src/utils/operandi_utils/hpc/nhr_connector.py @@ -8,6 +8,8 @@ from .constants import HPC_NHR_CLUSTERS +SSH_RECONNECT_TRIES = 5 + class NHRConnector: def __init__( self, @@ -30,7 +32,7 @@ def __init__( self.check_keyfile_existence(key_path=self.key_path) self.logger.debug(f"Retrieving hpc frontend server private key file from path: {self.key_path}") self._ssh_client = None - self._ssh_reconnect_tries = 5 + self._ssh_reconnect_tries = SSH_RECONNECT_TRIES self._ssh_reconnect_tries_remaining = self._ssh_reconnect_tries # TODO: Make the sub cluster options selectable self.project_root_dir: str = HPC_NHR_CLUSTERS["EmmyPhase2"]["scratch-emmy-hdd"] @@ -44,27 +46,6 @@ def ssh_client(self): self._ssh_client.close() self._ssh_client = None self._ssh_client = self.connect_to_hpc_nhr_frontend_server(host=HPC_NHR_CLUSTERS["EmmyPhase2"]["host"]) - # self._ssh_client.get_transport().set_keepalive(30) - - """ - try: - # Note: This extra check is required against aggressive - # Firewalls that ignore the keepalive option! - self._ssh_client.get_transport().send_ignore() - self._ssh_reconnect_tries_remaining = self._ssh_reconnect_tries - except Exception as error: - self.logger.warning(f"SSH client failed to send ignore, connection is broken: {error}") - if self._ssh_client: - self._ssh_client.close() - self._ssh_client = None - if self._ssh_reconnect_tries_remaining < 0: - raise Exception(f"Failed to reconnect {self._ssh_reconnect_tries} times: {error}") - self.logger.info(f"Reconnecting the SSH client, try times: {self._ssh_reconnect_tries_remaining}") - self._ssh_reconnect_tries_remaining -= 1 - return self.ssh_client # recursive call to itself to try again - return self._ssh_client - """ - return self._ssh_client @staticmethod diff --git a/src/utils/operandi_utils/hpc/nhr_executor.py b/src/utils/operandi_utils/hpc/nhr_executor.py index c1bd4aa0..31854d4e 100644 --- a/src/utils/operandi_utils/hpc/nhr_executor.py +++ b/src/utils/operandi_utils/hpc/nhr_executor.py @@ -18,6 +18,11 @@ PH_NODE_DIR_PROCESSOR_SIFS = "PH_NODE_DIR_PROCESSOR_SIFS" PH_CMD_WRAPPER = "PH_CMD_WRAPPER" +CHECK_SLURM_JOB_TRY_TIMES = 10 +CHECK_SLURM_JOB_WAIT_TIME = 3 +POLL_SLURM_JOB_TIMEOUT = 300 +POLL_SLURM_JOB_CHECK_INTERVAL = 5 + class NHRExecutor(NHRConnector): def __init__(self) -> None: logger = getLogger(name=self.__class__.__name__) @@ -40,6 +45,15 @@ def execute_blocking(self, command, timeout=None, environment=None): return_code = stdout.channel.recv_exit_status() return output, err, return_code + def remove_workflow_job_dir(self, workflow_job_id: str): + hpc_slurm_job_dir = f"{self.slurm_workspaces_dir}/{workflow_job_id}" + command = f"bash -lc 'rm -rf {hpc_slurm_job_dir}'" + self.logger.info(f"About to execute a force command: {command}") + output, err, return_code = self.execute_blocking(command) + self.logger.info(f"Command output: {output}") + self.logger.info(f"Command err: {err}") + self.logger.info(f"Command return code: {return_code}") + def trigger_slurm_job( self, workflow_job_id: str, nextflow_script_path: Path, input_file_grp: str, workspace_id: str, mets_basename: str, nf_process_forks: int, ws_pages_amount: int, use_mets_server: bool, @@ -72,7 +86,7 @@ def trigger_slurm_job( hpc_nf_script_path = join(self.slurm_workspaces_dir, workflow_job_id, nextflow_script_id) hpc_workspace_dir = join(self.slurm_workspaces_dir, workflow_job_id, workspace_id) - sif_ocrd_all = "ocrd_all_maximum_image.sif" + sif_ocrd_all = OCRD_PROCESSOR_EXECUTABLE_TO_IMAGE["ocrd_all"] sif_ocrd_core = OCRD_PROCESSOR_EXECUTABLE_TO_IMAGE["ocrd"] if HPC_USE_SLIM_IMAGES: @@ -122,7 +136,9 @@ def trigger_slurm_job( assert int(slurm_job_id) return slurm_job_id - def check_slurm_job_state(self, slurm_job_id: str, tries: int = 10, wait_time: int = 2) -> str: + def check_slurm_job_state( + self, slurm_job_id: str, tries: int = CHECK_SLURM_JOB_TRY_TIMES, wait_time: int = CHECK_SLURM_JOB_WAIT_TIME + ) -> str: command = f"{HPC_WRAPPER_CHECK_WORKFLOW_JOB_STATUS} {slurm_job_id}" slurm_job_state = None @@ -154,12 +170,14 @@ def check_slurm_job_state(self, slurm_job_id: str, tries: int = 10, wait_time: i self.logger.info(f"Slurm job state of {slurm_job_id}: {slurm_job_state}") return slurm_job_state - def poll_till_end_slurm_job_state(self, slurm_job_id: str, interval: int = 5, timeout: int = 300) -> bool: + def poll_till_end_slurm_job_state( + self, slurm_job_id: str, interval: int = POLL_SLURM_JOB_CHECK_INTERVAL, timeout: int = POLL_SLURM_JOB_TIMEOUT + ) -> bool: self.logger.info(f"Polling slurm job status till end") tries_left = timeout / interval self.logger.info(f"Tries to be performed: {tries_left}") while tries_left: - self.logger.info(f"Sleeping for {interval} secs") + self.logger.info(f"Sleeping for {interval} seconds, before trying again") sleep(interval) tries_left -= 1 self.logger.info(f"Tries left: {tries_left}") @@ -185,7 +203,7 @@ def poll_till_end_slurm_job_state(self, slurm_job_id: str, interval: int = 5, ti self.logger.warning(f"Invalid SLURM job state: {slurm_job_state}") # Timeout reached - self.logger.info("Polling slurm job status timeout reached") + self.logger.warning("Polling slurm job status timeout reached") return False @staticmethod diff --git a/src/utils/operandi_utils/hpc/nhr_transfer.py b/src/utils/operandi_utils/hpc/nhr_transfer.py index 4e49979f..4a0d5b44 100644 --- a/src/utils/operandi_utils/hpc/nhr_transfer.py +++ b/src/utils/operandi_utils/hpc/nhr_transfer.py @@ -11,13 +11,17 @@ from operandi_utils import make_zip_archive, unpack_zip_archive from .nhr_connector import NHRConnector +SFTP_RECONNECT_TRIES = 5 +DOWNLOAD_FILE_TRY_TIMES = 100 +DOWNLOAD_FILE_SLEEP_TIME = 3 + class NHRTransfer(NHRConnector): def __init__(self) -> None: logger = getLogger(name=self.__class__.__name__) super().__init__(logger) self._operandi_data_root = "" self._sftp_client = None - self._sftp_reconnect_tries = 5 + self._sftp_reconnect_tries = SFTP_RECONNECT_TRIES self._sftp_reconnect_tries_remaining = self._sftp_reconnect_tries _ = self.sftp_client # forces a connection @@ -27,26 +31,6 @@ def sftp_client(self): self._sftp_client.close() self._ssh_client = None self._sftp_client = self.ssh_client.open_sftp() - # self._sftp_client.get_channel().get_transport().set_keepalive(30) - - """ - try: - # Note: This extra check is required against aggressive - # Firewalls that ignore the keepalive option! - self._sftp_client.get_channel().get_transport().send_ignore() - self._sftp_reconnect_tries_remaining = self._sftp_reconnect_tries - except Exception as error: - self.logger.warning(f"SFTP client failed to send ignore, connection is broken: {error}") - if self._sftp_client: - self._sftp_client.close() - self._sftp_client = None - if self._sftp_reconnect_tries_remaining < 0: - raise Exception(f"Failed to reconnect {self._sftp_reconnect_tries} times: {error}") - self.logger.info(f"Reconnecting the SFTP client, try times: {self._sftp_reconnect_tries_remaining}") - self._sftp_reconnect_tries_remaining -= 1 - return self.sftp_client # recursive call to itself to try again - return self._sftp_client - """ return self._sftp_client def create_slurm_workspace_zip( @@ -115,9 +99,14 @@ def pack_and_put_slurm_workspace( Path(local_src_slurm_zip).unlink(missing_ok=True) return local_src_slurm_zip, hpc_dst - def _download_file_with_retries(self, remote_src, local_dst, try_times: int = 100, sleep_time: int = 3): + def _download_file_with_retries( + self, remote_src, local_dst, try_times: int = DOWNLOAD_FILE_TRY_TIMES, + sleep_time: int = DOWNLOAD_FILE_SLEEP_TIME + ): if try_times < 0 or sleep_time < 0: - raise ValueError("Negative value passed as a parameter for time") + self.logger.warning("Negative value passed as a parameter to any of the time options, using defaults.") + try_times = DOWNLOAD_FILE_TRY_TIMES + sleep_time = DOWNLOAD_FILE_SLEEP_TIME tries = try_times while tries > 0: try: @@ -158,6 +147,9 @@ def _unzip_workflow_job_dir(self, local_wf_job_zip: Path, local_wf_job_dir: Path try: unpack_zip_archive(source=unpack_src, destination=unpack_dst) except Exception as error: + if remove_zip: + Path(unpack_src).unlink(missing_ok=True) + self.logger.info(f"Removed the temp workflow job zip: {unpack_src}") raise Exception( f"Error when unpacking workflow job zip: {error}, unpack_src: {unpack_src}, unpack_dst: {unpack_dst}") self.logger.info(f"Unpacked workflow job zip from src: {unpack_src}, to dst: {unpack_dst}") @@ -183,6 +175,9 @@ def _unzip_workspace_dir(self, local_ws_dir_zip: Path, local_ocrd_ws_dir: Path, try: unpack_zip_archive(source=unpack_src, destination=unpack_dst) except Exception as error: + if remove_zip: + Path(unpack_src).unlink(missing_ok=True) + self.logger.info(f"Removed the temp workspace zip: {unpack_src}") raise Exception( f"Error when unpacking workspace zip: {error}, unpack_src: {unpack_src}, unpack_dst: {unpack_dst}") self.logger.info(f"Unpacked workspace zip from src: {unpack_src}, to dst: {unpack_dst}") @@ -191,7 +186,7 @@ def _unzip_workspace_dir(self, local_ws_dir_zip: Path, local_ocrd_ws_dir: Path, Path(unpack_src).unlink(missing_ok=True) self.logger.info(f"Removed the temp workspace zip: {unpack_src}") - def get_and_unpack_slurm_workspace(self, ocrd_workspace_dir: Path, workflow_job_dir: Path, slurm_job_id: str): + def get_and_unpack_slurm_workspace(self, ocrd_workspace_dir: Path, workflow_job_dir: Path): _ = self.sftp_client # Force reconnect of the SFTP Client wf_job_zip_path = self._download_workflow_job_zip(local_wf_job_dir=workflow_job_dir) self._unzip_workflow_job_dir(wf_job_zip_path, workflow_job_dir, True) diff --git a/src/utils/operandi_utils/logging.py b/src/utils/operandi_utils/logging.py index 3e406b30..885a4fd8 100644 --- a/src/utils/operandi_utils/logging.py +++ b/src/utils/operandi_utils/logging.py @@ -43,8 +43,12 @@ def reconfigure_all_loggers(log_level: str, log_file_path: str): # Remove other loggers' handlers and propagate logs to root logger for name in logging.root.manager.loggerDict.keys(): print(f"Resetting handlers, propagation True, reconfiguring the logger: {name}") - logging.getLogger(name).handlers = [] - logging.getLogger(name).propagate = True + current_logger = logging.getLogger(name) + if "pika" in current_logger.name or "paramiko" in current_logger.name or "ocrd" in current_logger.name: + print(f"Setting log level to WARNING of: {name}") + current_logger.setLevel(level="WARNING") + current_logger.handlers = [] + current_logger.propagate = True handlers = [ {"sink": sys.stdout}, {"sink": log_file_path, "serialize": False} diff --git a/src/utils/operandi_utils/oton/ocrd_parser.py b/src/utils/operandi_utils/oton/ocrd_parser.py index cc950bd8..971860b8 100644 --- a/src/utils/operandi_utils/oton/ocrd_parser.py +++ b/src/utils/operandi_utils/oton/ocrd_parser.py @@ -52,7 +52,7 @@ def parse_arguments(self, processor_arguments) -> ProcessorCallArguments: self.logger.error(message) raise ValueError(message) processor_call_arguments.self_validate() - self.logger.info(f"Successfully validated parameters of processor: {processor_call_arguments.executable}") + self.logger.debug(f"Successfully validated parameters of processor: {processor_call_arguments.executable}") return processor_call_arguments def purify_line(self, line: str) -> str: @@ -82,10 +82,10 @@ def read_from_file(self, input_file: str) -> Tuple[str, List[str]]: for line in ocrd_file: purified_line = self.purify_line(line) if len(purified_line) > 0: - self.logger.info(f"Appending purified line {line_counter}: {purified_line}") + self.logger.debug(f"Appending purified line {line_counter}: {purified_line}") file_lines.append(purified_line) else: - self.logger.info(f"0 sized line {line_counter} spotted, skipping") + self.logger.debug(f"0 sized line {line_counter} spotted, skipping") line_counter += 1 ocrd_process_command = file_lines[0] processor_tasks = file_lines[1:] diff --git a/src/utils/operandi_utils/oton/ocrd_validator.py b/src/utils/operandi_utils/oton/ocrd_validator.py index f50153e9..d66a1f09 100644 --- a/src/utils/operandi_utils/oton/ocrd_validator.py +++ b/src/utils/operandi_utils/oton/ocrd_validator.py @@ -38,12 +38,12 @@ def validate_file_path(self, filepath: str): def validate_all_processors(self, processors: List[ProcessorCallArguments]): prev_output_file_grps = [] first_processor = processors[0] - self.logger.info(f"Validating parameters against json schema of processor: {first_processor.executable}") + self.logger.debug(f"Validating parameters against json schema of processor: {first_processor.executable}") self.validate_processor_params(first_processor, overwrite_with_defaults=False) prev_output_file_grps += first_processor.output_file_grps.split(',') for processor in processors[1:]: - self.logger.info(f"Validating parameters against json schema of processor: {first_processor.executable}") + self.logger.debug(f"Validating parameters against json schema of processor: {processor.executable}") self.validate_processor_params(processor, overwrite_with_defaults=False) for input_file_grp in processor.input_file_grps.split(','): if input_file_grp not in prev_output_file_grps: diff --git a/src/utils/operandi_utils/rabbitmq/__init__.py b/src/utils/operandi_utils/rabbitmq/__init__.py index 7948b22c..8255870e 100644 --- a/src/utils/operandi_utils/rabbitmq/__init__.py +++ b/src/utils/operandi_utils/rabbitmq/__init__.py @@ -6,6 +6,7 @@ "RABBITMQ_QUEUE_DEFAULT", "RABBITMQ_QUEUE_JOB_STATUSES", "RABBITMQ_QUEUE_HARVESTER", + "RABBITMQ_QUEUE_HPC_DOWNLOADS", "RABBITMQ_QUEUE_USERS", "RMQConnector" ] @@ -17,6 +18,7 @@ RABBITMQ_QUEUE_DEFAULT, RABBITMQ_QUEUE_JOB_STATUSES, RABBITMQ_QUEUE_HARVESTER, + RABBITMQ_QUEUE_HPC_DOWNLOADS, RABBITMQ_QUEUE_USERS ) from .wrappers import get_connection_consumer, get_connection_publisher diff --git a/src/utils/operandi_utils/rabbitmq/constants.py b/src/utils/operandi_utils/rabbitmq/constants.py index 7261fcab..37084d1d 100644 --- a/src/utils/operandi_utils/rabbitmq/constants.py +++ b/src/utils/operandi_utils/rabbitmq/constants.py @@ -1,6 +1,7 @@ DEFAULT_EXCHANGER_NAME: str = "operandi_default_exchange" DEFAULT_EXCHANGER_TYPE: str = "direct" RABBITMQ_QUEUE_DEFAULT: str = "operandi_queue_default" +RABBITMQ_QUEUE_HPC_DOWNLOADS: str = "operandi_queue_hpc_downloads" RABBITMQ_QUEUE_HARVESTER: str = "operandi_queue_harvester" RABBITMQ_QUEUE_JOB_STATUSES: str = "operandi_queue_job_statuses" RABBITMQ_QUEUE_USERS: str = "operandi_queue_users" diff --git a/src/utils/operandi_utils/rabbitmq/consumer.py b/src/utils/operandi_utils/rabbitmq/consumer.py index 7eb70758..cc559a4e 100644 --- a/src/utils/operandi_utils/rabbitmq/consumer.py +++ b/src/utils/operandi_utils/rabbitmq/consumer.py @@ -7,7 +7,7 @@ from .connector import RMQConnector from .constants import ( DEFAULT_EXCHANGER_NAME, DEFAULT_EXCHANGER_TYPE, - RABBITMQ_QUEUE_JOB_STATUSES, RABBITMQ_QUEUE_HARVESTER, RABBITMQ_QUEUE_USERS + RABBITMQ_QUEUE_HPC_DOWNLOADS, RABBITMQ_QUEUE_JOB_STATUSES, RABBITMQ_QUEUE_HARVESTER, RABBITMQ_QUEUE_USERS ) @@ -34,6 +34,7 @@ def setup_defaults(self) -> None: RMQConnector.declare_and_bind_defaults(self._connection, self._channel) self.create_queue(queue_name=RABBITMQ_QUEUE_HARVESTER) self.create_queue(queue_name=RABBITMQ_QUEUE_USERS) + self.create_queue(queue_name=RABBITMQ_QUEUE_HPC_DOWNLOADS) self.create_queue(queue_name=RABBITMQ_QUEUE_JOB_STATUSES, auto_delete=True) def create_queue( diff --git a/src/utils/operandi_utils/rabbitmq/publisher.py b/src/utils/operandi_utils/rabbitmq/publisher.py index 492ea153..4489519d 100644 --- a/src/utils/operandi_utils/rabbitmq/publisher.py +++ b/src/utils/operandi_utils/rabbitmq/publisher.py @@ -7,7 +7,7 @@ from .connector import RMQConnector from .constants import ( DEFAULT_EXCHANGER_NAME, DEFAULT_EXCHANGER_TYPE, - RABBITMQ_QUEUE_JOB_STATUSES, RABBITMQ_QUEUE_HARVESTER, RABBITMQ_QUEUE_USERS + RABBITMQ_QUEUE_HPC_DOWNLOADS, RABBITMQ_QUEUE_JOB_STATUSES, RABBITMQ_QUEUE_HARVESTER, RABBITMQ_QUEUE_USERS ) @@ -33,6 +33,7 @@ def setup_defaults(self) -> None: RMQConnector.declare_and_bind_defaults(self._connection, self._channel) self.create_queue(queue_name=RABBITMQ_QUEUE_HARVESTER) self.create_queue(queue_name=RABBITMQ_QUEUE_USERS) + self.create_queue(queue_name=RABBITMQ_QUEUE_HPC_DOWNLOADS) self.create_queue(queue_name=RABBITMQ_QUEUE_JOB_STATUSES, auto_delete=True) def create_queue( diff --git a/src/utils/setup.py b/src/utils/setup.py index 24f389e4..c913a58d 100644 --- a/src/utils/setup.py +++ b/src/utils/setup.py @@ -5,7 +5,7 @@ setup( name='operandi_utils', - version='2.18.6', + version='2.19.0', description='OPERANDI - Utils', long_description=open('README.md').read(), long_description_content_type='text/markdown', diff --git a/tests/integration_tests/test_full_cycle.py b/tests/integration_tests/test_full_cycle.py index b923234e..2a70ec5c 100644 --- a/tests/integration_tests/test_full_cycle.py +++ b/tests/integration_tests/test_full_cycle.py @@ -5,14 +5,14 @@ from operandi_server.constants import ( DEFAULT_METS_BASENAME, DEFAULT_FILE_GRP, SERVER_WORKFLOW_JOBS_ROUTER, SERVER_WORKSPACES_ROUTER) from operandi_utils.constants import StateJob -from operandi_utils.rabbitmq import RABBITMQ_QUEUE_HARVESTER, RABBITMQ_QUEUE_JOB_STATUSES +from operandi_utils.rabbitmq import RABBITMQ_QUEUE_HPC_DOWNLOADS, RABBITMQ_QUEUE_HARVESTER, RABBITMQ_QUEUE_JOB_STATUSES from operandi_utils.hpc.constants import HPC_NHR_JOB_TEST_PARTITION from tests.tests_server.helpers_asserts import assert_response_status_code OPERANDI_SERVER_BASE_DIR = environ.get("OPERANDI_SERVER_BASE_DIR") def check_job_till_finish(auth_harvester, operandi, workflow_id: str, workflow_job_id: str): - tries = 70 + tries = 60 job_status = None check_job_status_url = f"/workflow/{workflow_id}/{workflow_job_id}" while tries > 0: @@ -32,15 +32,23 @@ def check_job_till_finish(auth_harvester, operandi, workflow_id: str, workflow_j assert job_status == StateJob.SUCCESS -def download_workflow_job_logs(auth_harvester, operandi, workflow_id: str, workflow_job_id: str) -> Path: - get_log_zip_url = f"/workflow/{workflow_id}/{workflow_job_id}/log" - response = operandi.get(url=get_log_zip_url, auth=auth_harvester) - zip_local_path = Path(environ.get("OPERANDI_SERVER_BASE_DIR"), f"{workflow_job_id}.zip") - with open(zip_local_path, "wb") as filePtr: - for chunk in response.iter_bytes(chunk_size=1024): - if chunk: - filePtr.write(chunk) - return zip_local_path +def download_workflow_job_logs(auth_harvester, operandi, workflow_id: str, workflow_job_id: str): + tries = 60 + get_log_zip_url = f"/workflow/{workflow_id}/{workflow_job_id}/logs" + while tries > 0: + tries -= 1 + sleep(30) + response = operandi.get(url=get_log_zip_url, auth=auth_harvester) + if response.status_code != 200: + continue + assert_response_status_code(response.status_code, expected_floor=2) + zip_local_path = Path(environ.get("OPERANDI_SERVER_BASE_DIR"), f"{workflow_job_id}.zip") + with open(zip_local_path, "wb") as filePtr: + for chunk in response.iter_bytes(chunk_size=1024): + if chunk: + filePtr.write(chunk) + assert zip_local_path.exists() + break def test_full_cycle(auth_harvester, operandi, service_broker, bytes_small_workspace): @@ -48,11 +56,11 @@ def test_full_cycle(auth_harvester, operandi, service_broker, bytes_small_worksp assert response.json()["message"] == "The home page of the OPERANDI Server" # Create a background worker for the harvester queue - service_broker.create_worker_process( - queue_name=RABBITMQ_QUEUE_HARVESTER, status_checker=False, tunnel_port_executor=22, tunnel_port_transfer=22) + service_broker.create_worker_process(RABBITMQ_QUEUE_HARVESTER, "submit_worker") # Create a background worker for the job statuses queue - service_broker.create_worker_process( - queue_name=RABBITMQ_QUEUE_JOB_STATUSES, status_checker=True, tunnel_port_executor=22, tunnel_port_transfer=22) + service_broker.create_worker_process(RABBITMQ_QUEUE_JOB_STATUSES, "status_worker") + # Create a background worker for the hpc download queue + service_broker.create_worker_process(RABBITMQ_QUEUE_HPC_DOWNLOADS, "download_worker") # Post a workspace zip response = operandi.post(url="/workspace", files={"workspace": bytes_small_workspace}, auth=auth_harvester) @@ -92,11 +100,7 @@ def test_full_cycle(auth_harvester, operandi, service_broker, bytes_small_worksp workflow_job_id = response.json()["resource_id"] check_job_till_finish(auth_harvester, operandi, workflow_id, workflow_job_id) - - # TODO: Fix this, wait for a few secs till the data is transferred from HPC to Operandi Server - sleep(45) - zip_local_path = download_workflow_job_logs(auth_harvester, operandi, workflow_id, workflow_job_id) - assert zip_local_path.exists() + download_workflow_job_logs(auth_harvester, operandi, workflow_id, workflow_job_id) ws_dir = Path(OPERANDI_SERVER_BASE_DIR, SERVER_WORKSPACES_ROUTER, workspace_id) assert ws_dir.exists() diff --git a/tests/tests_utils/test_3_hpc/test_3_nhr_combined.py b/tests/tests_utils/test_3_hpc/test_3_nhr_combined.py index 13c772e1..c15c2868 100644 --- a/tests/tests_utils/test_3_hpc/test_3_nhr_combined.py +++ b/tests/tests_utils/test_3_hpc/test_3_nhr_combined.py @@ -66,13 +66,12 @@ def test_hpc_connector_run_batch_script( file_groups_to_remove="", cpus=2, ram=16, job_deadline_time=HPC_JOB_DEADLINE_TIME_TEST, partition=HPC_NHR_JOB_TEST_PARTITION, qos=HPC_JOB_QOS_SHORT) finished_successfully = hpc_nhr_command_executor.poll_till_end_slurm_job_state( - slurm_job_id=slurm_job_id, interval=5, timeout=300) + slurm_job_id=slurm_job_id, interval=10, timeout=300) assert finished_successfully ws_dir = Path(OPERANDI_SERVER_BASE_DIR, SERVER_WORKSPACES_ROUTER, ID_WORKSPACE) wf_job_dir = Path(OPERANDI_SERVER_BASE_DIR, SERVER_WORKFLOW_JOBS_ROUTER, ID_WORKFLOW_JOB) - hpc_nhr_data_transfer.get_and_unpack_slurm_workspace( - ocrd_workspace_dir=ws_dir, workflow_job_dir=wf_job_dir, slurm_job_id=slurm_job_id) + hpc_nhr_data_transfer.get_and_unpack_slurm_workspace(ocrd_workspace_dir=ws_dir, workflow_job_dir=wf_job_dir) assert Path(ws_dir, "OCR-D-BIN").exists() assert wf_job_dir.exists() assert Path(wf_job_dir, "work").exists() @@ -86,16 +85,15 @@ def test_hpc_connector_run_batch_script_with_ms( input_file_grp=DEFAULT_FILE_GRP, workspace_id=ID_WORKSPACE_WITH_MS, mets_basename=DEFAULT_METS_BASENAME, nf_process_forks=2, ws_pages_amount=8, use_mets_server=True, nf_executable_steps=["ocrd-cis-ocropy-binarize"], - file_groups_to_remove="", cpus=3, ram=16, job_deadline_time=HPC_JOB_DEADLINE_TIME_TEST, + file_groups_to_remove="", cpus=4, ram=16, job_deadline_time=HPC_JOB_DEADLINE_TIME_TEST, partition=HPC_NHR_JOB_TEST_PARTITION, qos=HPC_JOB_QOS_SHORT) finished_successfully = hpc_nhr_command_executor.poll_till_end_slurm_job_state( - slurm_job_id=slurm_job_id, interval=5, timeout=300) + slurm_job_id=slurm_job_id, interval=10, timeout=300) assert finished_successfully ws_dir = Path(OPERANDI_SERVER_BASE_DIR, SERVER_WORKSPACES_ROUTER, ID_WORKSPACE_WITH_MS) wf_job_dir = Path(OPERANDI_SERVER_BASE_DIR, SERVER_WORKFLOW_JOBS_ROUTER, ID_WORKFLOW_JOB_WITH_MS) - hpc_nhr_data_transfer.get_and_unpack_slurm_workspace( - ocrd_workspace_dir=ws_dir, workflow_job_dir=wf_job_dir, slurm_job_id=slurm_job_id) + hpc_nhr_data_transfer.get_and_unpack_slurm_workspace(ocrd_workspace_dir=ws_dir, workflow_job_dir=wf_job_dir) assert Path(ws_dir, "OCR-D-BIN").exists() assert wf_job_dir.exists() assert Path(wf_job_dir, "work").exists()