From eee4ec991fd930ee6b89f90248b5989b97c17b09 Mon Sep 17 00:00:00 2001 From: Jason Eu Date: Tue, 12 Jan 2021 21:52:43 +0800 Subject: [PATCH] handle interupt signal when runnning process --- aiida/engine/daemon/runner.py | 11 ++++----- aiida/engine/runners.py | 42 +++++++++++++++++++++++++---------- 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/aiida/engine/daemon/runner.py b/aiida/engine/daemon/runner.py index 7085c15167..611a9f931a 100644 --- a/aiida/engine/daemon/runner.py +++ b/aiida/engine/daemon/runner.py @@ -19,16 +19,17 @@ LOGGER = logging.getLogger(__name__) -async def shutdown_runner(runner): +async def shutdown_runner(_signal, runner): """Cleanup tasks tied to the service's shutdown.""" - LOGGER.info('Received signal to shut down the daemon runner') + LOGGER.info(f'Received signal {_signal.name} to shut down the daemon runner') try: from asyncio import all_tasks from asyncio import current_task except ImportError: - # Necessary for Python 3.6 as `asyncio.all_tasks` and `asyncio.current_task` were introduced in Python 3.7. The - # Standalone functions should be used as the classmethods are removed as of Python 3.9. + # Necessary for Python 3.6 as `asyncio.all_tasks` and `asyncio.current_task` + # were introduced in Python 3.7. The Standalone functions + # should be used as the classmethods are removed as of Python 3.9. all_tasks = asyncio.Task.all_tasks current_task = asyncio.Task.current_task @@ -56,7 +57,7 @@ def start_daemon(): signals = (signal.SIGTERM, signal.SIGINT) for s in signals: # pylint: disable=invalid-name - runner.loop.add_signal_handler(s, lambda s=s: asyncio.create_task(shutdown_runner(runner))) + runner.loop.add_signal_handler(s, lambda s=s: asyncio.create_task(shutdown_runner(s, runner))) try: LOGGER.info('Starting a daemon runner') diff --git a/aiida/engine/runners.py b/aiida/engine/runners.py index be2a3d377b..ba1d90eeb0 100644 --- a/aiida/engine/runners.py +++ b/aiida/engine/runners.py @@ -216,19 +216,37 @@ def _run(self, process, *args, **inputs): result, node = process.run_get_node(*args, **inputs) return result, node - with utils.loop_scope(self.loop): - process = self.instantiate_process(process, *args, **inputs) - - def kill_process(_num, _frame): - """Send the kill signal to the process in the current scope.""" - LOGGER.critical('runner received interrupt, killing process %s', process.pid) - process.kill(msg='Process was killed because the runner received an interrupt') - - signal.signal(signal.SIGINT, kill_process) - signal.signal(signal.SIGTERM, kill_process) + process = self.instantiate_process(process, *args, **inputs) - process.execute() - return process.outputs, process.node + async def kill_process(_signal, process): + """Send the kill signal to the process in the current scope.""" + LOGGER.critical(f'runner received interrupt signal {_signal.name}, killing process {process.pid}') + try: + from asyncio import all_tasks + from asyncio import current_task + except ImportError: + # Necessary for Python 3.6 as `asyncio.all_tasks` and `asyncio.current_task` + # were introduced in Python 3.7. The Standalone functions + # should be used as the classmethods are removed as of Python 3.9. + all_tasks = asyncio.Task.all_tasks + current_task = asyncio.Task.current_task + + tasks = [task for task in all_tasks() if task is not current_task()] + for task in tasks: + task.cancel() + + await asyncio.gather(*tasks, return_exceptions=True) + + res = process.kill(msg='Process was killed because the runner received an interrupt') + if asyncio.isfuture(res): + await res + + signals = (signal.SIGTERM, signal.SIGINT) + for s in signals: # pylint: disable=invalid-name + self.loop.add_signal_handler(s, lambda s=s: asyncio.create_task(kill_process(s, process))) + + process.execute() + return process.outputs, process.node def run(self, process, *args, **inputs): """