Skip to content

Commit

Permalink
Fix asyncio policy
Browse files Browse the repository at this point in the history
  • Loading branch information
Blanca Fuentes Monjas committed Dec 2, 2024
1 parent 47329b6 commit e2893f2
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 80 deletions.
17 changes: 14 additions & 3 deletions reframe/core/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,9 +964,10 @@ def adjust_verbosity(self, num_steps):


class logging_context:

def __init__(self, check=None, level=DEBUG):
try:
task = asyncio.current_task()
task = current_task()
except RuntimeError:
global _global_logger
task = None
Expand All @@ -990,7 +991,7 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, traceback):
global _global_logger
try:
task = asyncio.current_task()
task = current_task()
except RuntimeError:
task = None

Expand Down Expand Up @@ -1034,7 +1035,7 @@ def save_log_files(dest):

def getlogger():
try:
task = asyncio.current_task()
task = current_task()
except RuntimeError:
task = None
if task:
Expand Down Expand Up @@ -1098,3 +1099,13 @@ def __exit__(self, exc_type, exc_value, traceback):
_logger = self._logger
_perf_logger = self._perf_logger
_global_logger = self._context_logger


def current_task():
"""Wrapper for asyncio.current_task() compatible with Python 3.6 and later."""
if sys.version_info >= (3, 7):
# Use asyncio.current_task() directly in Python 3.7+
return asyncio.current_task()
else:
# Fallback to asyncio.tasks.current_task() in Python 3.6
return asyncio.Task.current_task()
143 changes: 85 additions & 58 deletions reframe/frontend/executors/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# SPDX-License-Identifier: BSD-3-Clause

import asyncio
from asyncio import (get_child_watcher,
set_child_watcher,
SafeChildWatcher)
import contextlib
import math
import sys
Expand Down Expand Up @@ -100,7 +103,7 @@ def __init__(self):
self._retired_tasks = []
self.task_listeners.append(self)

def _runcase(self, case):
async def _runcase(self, case):
super()._runcase(case)
check, partition, _ = case
task = RegressionTask(case, self.task_listeners)
Expand Down Expand Up @@ -133,46 +136,25 @@ def _runcase(self, case):
sched_flex_alloc_nodes=self.sched_flex_alloc_nodes,
sched_options=self.sched_options)

async def compile_async_serial():
await task.compile()
await task.compile_wait()

async def run_async_serial():
await task.run()
await task.compile()
await task.compile_wait()
await task.run()

if task.check.local:
sched = self.local_scheduler

self._pollctl.reset_snooze_time()
while True:
if not self.dry_run_mode:
await sched.poll(task.check.job)
if task.run_complete():
break

await self._pollctl.snooze()

await task.run_wait()

if task.check.is_local:
# TODO: ssh scheduler
asyncio.run(compile_async_serial())
asyncio.run(run_async_serial())
else:
asyncio.run(compile_async_serial())
asyncio.run(task.run())

sched = partition.scheduler

self._pollctl.reset_snooze_time()
while True:
if not self.dry_run_mode:
asyncio.run(sched.poll(task.check.job))
if task.run_complete():
break
self._pollctl.reset_snooze_time()
while True:
if not self.dry_run_mode:
await sched.poll(task.check.job)
if task.run_complete():
break

asyncio.run(self._pollctl.snooze())
await self._pollctl.snooze()

asyncio.run(task.run_wait())
await task.run_wait()

if not self.skip_sanity_check:
task.sanity()
Expand Down Expand Up @@ -268,6 +250,44 @@ def on_task_success(self, task):
if self.timeout_expired():
raise RunSessionTimeout('maximum session duration exceeded')

def execute(self, testcases):
'''Execute the policy for a given set of testcases.'''
# Moved here the execution
try:
loop = asyncio.get_event_loop()
for task in all_tasks(loop):
if isinstance(task, asyncio.tasks.Task):
task.cancel()
if loop.is_closed():
loop = asyncio.new_event_loop()
watcher = asyncio.get_child_watcher()
if isinstance(watcher, asyncio.SafeChildWatcher):
# Detach the watcher from the current loop to avoid issues
watcher.close()
watcher.attach_loop(None)
asyncio.set_event_loop(loop)
if isinstance(watcher, asyncio.SafeChildWatcher):
# Reattach the watcher to the new loop
watcher.attach_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
for case in testcases:
try:
loop.run_until_complete(self._runcase(case))
except (Exception, KeyboardInterrupt) as e:
if type(e) in (ABORT_REASONS):
for task in all_tasks(loop):
if isinstance(task, asyncio.tasks.Task):
task.cancel()
loop.close()
raise e
else:
getlogger().info(f"Execution stopped due to an error: {e}")
break
loop.close()
self.exit()

def _exit(self):
# Clean up all remaining tasks
_cleanup_all(self._retired_tasks, not self.keep_stage_files)
Expand Down Expand Up @@ -298,13 +318,17 @@ def __init__(self):
self.task_listeners.append(self)

async def _runcase(self, case, task):
# I added the task here as an argument because, I wanted to initialize it
# outside, when I gather the tasks. If I gather the tasks and then I do asyncio
# manage them, if one of them fails the others are not iformed, I had to code that
# manually. There is a way to make everything stop if an exepction is raised but
# I didn't know how to treat that raise Exception nicelly because I wouldn't be able
# to abort the tasks which the execution has not yet started, I needed to do abortall
# on all the tests, not only the ones which were initiated by the execution. Exit gracefully
# I added the task here as an argument because,
# I wanted to initialize it
# outside, when I gather the tasks.
# If I gather the tasks and then I do asyncio
# manage them, if one of them fails the others are not iformed,
# I had to code that manually. There is a way to make everything
# stop if an exepction is raised but I didn't know how to treat
# that raise Exception nicelly because I wouldn't be able
# to abort the tasks which the execution has not yet started,
# I needed to do abortall on all the tests, not only the ones
# which were initiated by the execution. Exit gracefully
# the execuion loop aborting all the tasks
super()._runcase(case)
check, partition, _ = case
Expand Down Expand Up @@ -366,9 +390,11 @@ async def _runcase(self, case, task):
await task.compile()
self._partition_tasks[partname].add(task)
await task.compile_wait()
self._partition_tasks[partname].remove(task)
while len(self._partition_tasks[partname]) > max_jobs:
await asyncio.sleep(2)
await task.run()
self._partition_tasks[partname].add(task)

# Pick the right scheduler
if task.check.local:
Expand All @@ -387,6 +413,7 @@ async def _runcase(self, case, task):
await self._pollctl.snooze()

await task.run_wait()
self._partition_tasks[partname].remove(task)
if not self.skip_sanity_check:
task.sanity()

Expand Down Expand Up @@ -431,18 +458,6 @@ async def _runcase(self, case, task):
self._partition_tasks[partname].remove(task)
return

self._current_tasks.remove(task)
if task.check.current_partition:
partname = task.check.current_partition.fullname
else:
partname = None

# Remove tasks from the partition tasks if there
with contextlib.suppress(KeyError):
self._partition_tasks['_rfm_local'].remove(task)
if partname:
self._partition_tasks[partname].remove(task)

async def check_deps(self, task):
while not (self.deps_skipped(task) or self.deps_failed(task) or
self.deps_succeeded(task)):
Expand Down Expand Up @@ -578,11 +593,12 @@ def execute(self, testcases):
loop.run_until_complete(self._execute_until_failure(all_cases))
except (Exception, KeyboardInterrupt) as e:
if type(e) in (ABORT_REASONS):
loop.run_until_complete(self._cancel_gracefully(all_cases))
loop.run_until_complete(_cancel_gracefully(all_cases))
try:
raise AbortTaskError
except AbortTaskError as exc:
self._abortall(exc)
loop.close()
raise e
else:
getlogger().info(f"Execution stopped due to an error: {e}")
Expand All @@ -601,7 +617,18 @@ async def _execute_until_failure(self, all_cases):
if task.exception():
raise task.exception() # Exit if aborted

async def _cancel_gracefully(self, all_cases):
for case in all_cases:
case.cancel()
await asyncio.gather(*all_cases, return_exceptions=True)

async def _cancel_gracefully(all_cases):
for case in all_cases:
case.cancel()
await asyncio.gather(*all_cases, return_exceptions=True)


def all_tasks(loop):
"""Wrapper for asyncio.current_task() compatible with Python 3.6 and later."""
if sys.version_info >= (3, 7):
# Use asyncio.current_task() directly in Python 3.7+
return asyncio.all_tasks(loop)
else:
# Fallback to asyncio.tasks.current_task() in Python 3.6
return asyncio.Task.all_tasks(loop)
1 change: 1 addition & 0 deletions unittests/test_perflogging.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def test_perf_logging(make_runner, make_exec_ctx, perf_test,
)
)
)
rt.set_working_dir()
logging.configure_logging(rt.runtime().site_config)
runner = make_runner()
testcases = executors.generate_testcases([perf_test])
Expand Down
11 changes: 5 additions & 6 deletions unittests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#
# SPDX-License-Identifier: BSD-3-Clause

import asyncio
import os
import pytest
import re
Expand All @@ -27,7 +26,7 @@


def _run(test, partition, prgenv):
asyncio.run(_runasync(test, partition, prgenv))
test_util.asyncio_run(_runasync, test, partition, prgenv)


async def _runasync(test, partition, prgenv):
Expand Down Expand Up @@ -322,7 +321,7 @@ class MyTest(rfm.CompileOnlyRegressionTest):
test = MyTest()
test.setup(*local_exec_ctx)
with pytest.raises(BuildError):
asyncio.run(compile_wait(test))
test_util.asyncio_run(compile_wait, test)


def test_compile_only_warning(local_exec_ctx):
Expand Down Expand Up @@ -807,7 +806,7 @@ class MyTest(rfm.CompileOnlyRegressionTest):
test.setup(*local_exec_ctx)
test.sourcepath = '/usr/src'
with pytest.raises(PipelineError):
asyncio.run(test.compile())
test_util.asyncio_run(test.compile)


def test_sourcepath_upref(local_exec_ctx):
Expand All @@ -820,7 +819,7 @@ class MyTest(rfm.CompileOnlyRegressionTest):
test.setup(*local_exec_ctx)
test.sourcepath = '../hellosrc'
with pytest.raises(PipelineError):
asyncio.run(test.compile())
test_util.asyncio_run(test.compile)


def test_sourcepath_non_existent(local_exec_ctx):
Expand All @@ -833,7 +832,7 @@ class MyTest(rfm.CompileOnlyRegressionTest):
test.setup(*local_exec_ctx)
test.sourcepath = 'non_existent.c'
with pytest.raises(BuildError):
asyncio.run(compile_wait(test))
test_util.asyncio_run(compile_wait, test)


def test_extra_resources(HelloTest, testsys_exec_ctx):
Expand Down
3 changes: 3 additions & 0 deletions unittests/test_reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@

_DEFAULT_BASE_COLS = DEFAULT_GROUP_BY + DEFAULT_EXTRA_COLS

rt.set_working_dir()

# NOTE: We could move this to utility


class _timer:
'''Context manager for timing'''

Expand Down
Loading

0 comments on commit e2893f2

Please sign in to comment.