Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

env (interactive mode): timeout for blocking operations #440

Merged
155 changes: 145 additions & 10 deletions compiler_opt/rl/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
import dataclasses
from enum import Enum

import logging
import math
import select
import subprocess
import abc
import contextlib
import io
import os
import threading
from typing import Callable, Generator, List, Optional, Tuple, Type

import numpy as np
Expand Down Expand Up @@ -219,6 +222,135 @@ def _reward_fn(a: float, b: float) -> float:
return {key: _reward_fn(score_a[key], score_b[key]) for key in score_a}


@contextlib.contextmanager
def open_write_pipe(filename: str, *, timeout: float):
"""Open the write pipe or timeout.

Assuming a fifo, the `open` will block until the other party (the process we
communicate to) also opens the pipe. If that doesn't happen, we time out.
Afterwards, `write` ops shouldn't block.
"""
opened = threading.Event()
timed_out = threading.Event()

# start a thread that waits for `open` to unblock. If it doesn't, we open the
# fifo ourselves just to unblock.
def _timeout_thread():
if opened.wait(timeout):
logging.debug('[timeout thread] writer opened successfully')
return
timed_out.set()
logging.debug('[timeout thread] writer failed to open')
with open(filename, 'rb'):
pass

waiter = threading.Thread(target=_timeout_thread)
waiter.start()
try:
with io.BufferedWriter(io.FileIO(filename, 'wb')) as writer_pipe:
if not timed_out.is_set():
opened.set()
yield writer_pipe
finally:
waiter.join()
if timed_out.is_set():
# it's possible that the timeout thread timed out but also the other
# process finally opened the pipe and thus the `writer_pipe` is
# functional, but at the end we still raise TimeoutError. We accept that
# right now.
raise TimeoutError('write pipe open')


@contextlib.contextmanager
def open_read_pipe(filename: str, *, timeout: float):
"""Open the read pipe, with a timeout governing the open and each read.

Just like in the writer case, assuming we're opening a fifo pipe, the open
operation will block until the other party opens the pipe. Then, because this
is a reader, each read operation (and variations - readline, etc) can block,
but no more than the provided timeout.
"""

# wrap the underlying io.RawIOBase such that we poll before attempting to read
def _wrap_raw_io(obj: io.RawIOBase):

def _get_polling_wrapper(wrapped_method):

def _replacement(*args, **kwargs):
name = wrapped_method.__name__
logging.debug('ReaderWithTimeout is asked to %s', name)
(r, _, _) = select.select([obj], [], [], timeout)
if r:
logging.debug('ReaderWithTimeout %s should be unblocked', name)
result = wrapped_method(*args, **kwargs)
logging.debug('ReaderWithTimeout %s completed', name)
return result
logging.info('ReaderWithTimeout timed out waiting to %s', name)
raise TimeoutError('timed out reading')

return _replacement

# pylint: disable=protected-access
obj._orig_read = obj.read
obj._orig_readline = obj.readline
obj._orig_readinto = obj.readinto
obj._orig_readall = obj.readall

obj.read = _get_polling_wrapper(obj._orig_read)
obj.readline = _get_polling_wrapper(obj._orig_readline)
obj.readinto = _get_polling_wrapper(obj._orig_readinto)
obj.readall = _get_polling_wrapper(obj._orig_readall)
# pylint: enable=protected-access
return obj

opened = threading.Event()
timed_out = threading.Event()

# same idea as in the writer case - unblock the `open`
def _timeout_thread():
if opened.wait(timeout):
logging.debug('[timeout thread] reader opened successfully')
return
timed_out.set()
logging.debug('[timeout thread] reader failed to open')
with open(filename, 'wb'):
pass
logging.debug('[timeout thread] force-opened the reader')

waiter = threading.Thread(target=_timeout_thread)
waiter.start()
try:
# we must wrap the *raw* stream! wrapping the buffered stream would be
# incorrect because calls to `read` APIs shouldn't poll (they may just
# return from the buffer).
with io.BufferedReader(_wrap_raw_io(io.FileIO(filename,
'rb'))) as reader_pipe:
if not timed_out.is_set():
opened.set()
yield reader_pipe
finally:
waiter.join()
if timed_out.is_set():
# same as in the writer case - we could successfully keep reading but
# still report a timeout at the end of this context.
raise TimeoutError('read pipe open')


@contextlib.contextmanager
def interactive_session(*, reader_name: str, writer_name: str, timeout: float):
"""Start an interactive session with the started process proc.

Blocking pipe operations - open and read - happen under a timeout.
"""

try:
with open_write_pipe(writer_name, timeout=timeout) as writer_pipe:
with open_read_pipe(reader_name, timeout=timeout) as reader_pipe:
yield (reader_pipe, writer_pipe)
finally:
pass


@contextlib.contextmanager
def clang_session(
clang_path: str,
Expand Down Expand Up @@ -269,16 +401,19 @@ def _get_scores() -> dict[str, float]:
cmdline, stderr=subprocess.PIPE, stdout=subprocess.PIPE) as proc:
try:
if interactive:
with io.BufferedWriter(io.FileIO(writer_name, 'wb')) as writer_pipe:
with io.BufferedReader(io.FileIO(reader_name, 'rb')) as reader_pipe:
yield InteractiveClang(
proc,
_get_scores,
module.name,
task_working_dir,
reader_pipe,
writer_pipe,
)
with interactive_session(
writer_name=writer_name,
reader_name=reader_name,
timeout=compilation_runner.COMPILATION_TIMEOUT.value) as (
reader_pipe, writer_pipe):
yield InteractiveClang(
proc,
_get_scores,
module.name,
task_working_dir,
reader_pipe,
writer_pipe,
)
else:
yield ClangProcess(
proc,
Expand Down
Loading
Loading