Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pyright mpl backend #17

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ doc = [

dev = [
"genie_python[plot,doc]",
"mock",
"parameterized",
"pyhamcrest",
"pytest",
Expand Down
7 changes: 2 additions & 5 deletions src/genie_python/genie_dae.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,9 +458,7 @@ def end_run(
prepost: run pre and post commands [optional]
"""
if self.get_run_state() == "ENDING" and not immediate:
print(
"Please specify the 'immediate=True' flag to end a run " "while in the ENDING state"
)
print("Please specify the 'immediate=True' flag to end a run while in the ENDING state")
return

run_number = self.get_run_number()
Expand Down Expand Up @@ -601,8 +599,7 @@ def pause_run(self, immediate: bool = False, prepost: bool = True) -> None:
"""
if self.get_run_state() == "PAUSING" and not immediate:
print(
"Please specify the 'immediate=True' flag "
"to pause a run while in the PAUSING state"
"Please specify the 'immediate=True' flag to pause a run while in the PAUSING state"
)
return

Expand Down
141 changes: 86 additions & 55 deletions src/genie_python/matplotlib_backend/ibex_websocket_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@
import threading
from functools import wraps
from time import sleep
from typing import Any, Callable, Mapping, ParamSpec, TypeVar, cast

import tornado
import tornado.websocket
from matplotlib._pylab_helpers import Gcf
from matplotlib.backend_bases import _Backend
from matplotlib.backends import backend_webagg
from matplotlib.backends import backend_webagg_core as core
from py4j.java_collections import ListConverter
from py4j.java_gateway import JavaGateway
from tornado.ioloop import IOLoop
from tornado.websocket import WebSocketClosedError

from genie_python.genie_logging import GenieLogger
Expand All @@ -37,20 +40,26 @@
_is_primary = True


def _ignore_if_websocket_closed(func):
T = TypeVar("T")
P = ParamSpec("P")


def _ignore_if_websocket_closed(func: Callable[P, T]) -> Callable[P, T | None]:
"""
Decorator which ignores exceptions that were caused by websockets being closed.
"""

@wraps(func)
def wrapper(*a, **kw):
def wrapper(*a: P.args, **kw: P.kwargs) -> T | None:
try:
return func(*a, **kw)
except WebSocketClosedError:
pass
except Exception as e:
# Plotting multiple graphs quickly can cause an error where pyplot tries to access a plot which
# has been removed. This error does not break anything, so log it and continue. It is better for the plot
# Plotting multiple graphs quickly can cause an error where
# pyplot tries to access a plot which
# has been removed. This error does not break anything, so log it
# and continue. It is better for the plot
# to fail to update than for the whole user script to crash.
try:
GenieLogger().log_info_msg(
Expand All @@ -64,24 +73,32 @@ def wrapper(*a, **kw):
return wrapper


def _asyncio_send_exceptions_to_logfile_only(loop, context):
def _asyncio_send_exceptions_to_logfile_only(
loop: asyncio.AbstractEventLoop, context: Mapping[str, Any]
) -> None:
exception = context.get("exception")
try:
GenieLogger().log_info_msg(
f"Caught (non-fatal) asyncio exception: " f"{exception.__class__.__name__}: {exception}"
f"Caught (non-fatal) asyncio exception: {exception.__class__.__name__}: {exception}"
)
except Exception:
# Exception while logging, ignore...
pass


def set_up_plot_default(is_primary=True, should_open_ibex_window_on_show=True, max_figures=None):
def set_up_plot_default(
is_primary: bool = True,
should_open_ibex_window_on_show: bool = True,
max_figures: int | None = None,
) -> None:
"""
Set the plot defaults for when show is called

Args:
is_primary: True display plot on primary web port; False display plot on secondary web port
should_open_ibex_window_on_show: Does nothing; provided for backwards-compatibility with older backend
is_primary: True display plot on primary web port; False display
plot on secondary web port
should_open_ibex_window_on_show: Does nothing; provided for
backwards-compatibility with older backend
max_figures: Maximum number of figures to plot simultaneously (int)
"""
global _web_backend_port
Expand All @@ -99,79 +116,84 @@ def set_up_plot_default(is_primary=True, should_open_ibex_window_on_show=True, m


class WebAggApplication(backend_webagg.WebAggApplication):
class WebSocket(tornado.websocket.WebSocketHandler):
class WebSocket(tornado.websocket.WebSocketHandler): # pyright: ignore
supports_binary = True

def write_message(self, *args, **kwargs):
def write_message(self, *args: Any, **kwargs: Any) -> asyncio.Future[None]:
f = super().write_message(*args, **kwargs)

@_ignore_if_websocket_closed
def _cb(*args, **kwargs):
def _cb(*args: Any, **kwargs: Any) -> None:
return f.result()

f.add_done_callback(_cb)
return f

@_ignore_if_websocket_closed
def open(self, fignum):
def open(self, fignum: int, *args: Any, **kwargs: Any) -> None:
self.fignum = int(fignum)
self.manager = Gcf.figs.get(self.fignum, None)
self.manager = cast(_FigureManager | None, Gcf.figs.get(self.fignum, None))
if self.manager is not None:
self.manager.add_web_socket(self)
if hasattr(self, "set_nodelay"):
self.set_nodelay(True)

@_ignore_if_websocket_closed
def on_close(self):
self.manager.remove_web_socket(self)
def on_close(self) -> None:
if self.manager is not None:
self.manager.remove_web_socket(self)

@_ignore_if_websocket_closed
def on_message(self, message):
message = json.loads(message)
def on_message(self, message: str | bytes) -> None:
parsed_message: dict[str, Any] = json.loads(message)
# The 'supports_binary' message is on a client-by-client
# basis. The others affect the (shared) canvas as a
# whole.
if message["type"] == "supports_binary":
self.supports_binary = message["value"]
if parsed_message["type"] == "supports_binary":
self.supports_binary = parsed_message["value"]
else:
manager = Gcf.figs.get(self.fignum, None)
manager = cast(_FigureManager | None, Gcf.figs.get(self.fignum, None))
# It is possible for a figure to be closed,
# but a stale figure UI is still sending messages
# from the browser.
if manager is not None:
manager.handle_json(message)
manager.handle_json(parsed_message)

@_ignore_if_websocket_closed
def send_json(self, content):
def send_json(self, content: dict[str, str]) -> None:
self.write_message(json.dumps(content))

@_ignore_if_websocket_closed
def send_binary(self, blob):
def send_binary(self, blob: str) -> None:
if self.supports_binary:
self.write_message(blob, binary=True)
else:
blob_code = blob.encode("base64").replace("\n", "")
blob_code = blob.encode("base64").replace(b"\n", b"")
data_uri = f"data:image/png;base64,{blob_code}"
self.write_message(data_uri)

ioloop = None
ioloop: IOLoop | None = None
asyncio_loop = None
started = False
app = None

@classmethod
def initialize(cls, url_prefix="", port=None, address=None):
def initialize(
cls, url_prefix: str = "", port: int | None = None, address: str | None = None
) -> None:
"""
Create the class instance

We use a constant, hard-coded port as we will only ever have one plot going at the same time.
We use a constant, hard-coded port as we will only
ever have one plot going at the same time.
"""
cls.app = cls(url_prefix=url_prefix)
cls.url_prefix = url_prefix
cls.port = port
cls.address = address

@classmethod
def start(cls):
def start(cls) -> None:
"""
IOLoop.running() was removed as of Tornado 2.4; see for example
https://groups.google.com/forum/#!topic/python-tornado/QLMzkpQBGOY
Expand All @@ -191,6 +213,8 @@ def start(cls):
asyncio.set_event_loop(loop)
cls.asyncio_loop = loop
cls.ioloop = tornado.ioloop.IOLoop.current()
if cls.port is None or cls.address is None or cls.app is None:
raise RuntimeError("port, address and app must be set")
cls.app.listen(cls.port, cls.address)

# Set the flag to True *before* blocking on ioloop.start()
Expand All @@ -202,22 +226,26 @@ def start(cls):
traceback.print_exc()

@classmethod
def stop(cls):
def stop(cls) -> None:
try:

def _stop():
cls.ioloop.stop()
sys.stdout.flush()
cls.started = False
def _stop() -> None:
if cls.ioloop is not None:
cls.ioloop.stop()
sys.stdout.flush()
cls.started = False

cls.ioloop.add_callback(_stop)
if cls.ioloop is not None:
cls.ioloop.add_callback(_stop)
except Exception:
import traceback

traceback.print_exc()


def ibex_open_plot_window(figures, is_primary=True, host=None):
def ibex_open_plot_window(
figures: list[int], is_primary: bool = True, host: str | None = None
) -> None:
"""
Open the plot window in ibex gui through py4j. With sensible defaults
Args:
Expand All @@ -230,12 +258,14 @@ def ibex_open_plot_window(figures, is_primary=True, host=None):
url = f"{host}:{port}"
try:
gateway = JavaGateway()
figures = ListConverter().convert(figures, gateway._gateway_client)
gateway.entry_point.openMplRenderer(figures, url, is_primary)
converted_figures = ListConverter().convert(figures, gateway._gateway_client)
gateway.entry_point.openMplRenderer(converted_figures, url, is_primary) # pyright: ignore (rpc)
except Exception as e:
# We need this try-except to be very broad as various exceptions can, in principle,
# We need this try-except to be very broad as various
# exceptions can, in principle,
# be thrown while translating between python <-> java.
# If any exceptions occur, it is better to log and continue rather than crashing the entire script.
# If any exceptions occur, it is better to log and
# continue rather than crashing the entire script.
print(f"Failed to open plot in IBEX due to: {e}")


Expand All @@ -248,25 +278,25 @@ class _FigureManager(core.FigureManagerWebAgg):
_toolbar2_class = core.NavigationToolbar2WebAgg

@_ignore_if_websocket_closed
def _send_event(self, *args, **kwargs):
def _send_event(self, *args: Any, **kwargs: Any) -> None:
with _IBEX_FIGURE_MANAGER_LOCK:
super()._send_event(*args, **kwargs)

def remove_web_socket(self, *args, **kwargs):
def remove_web_socket(self, *args: Any, **kwargs: Any) -> None:
with _IBEX_FIGURE_MANAGER_LOCK:
super().remove_web_socket(*args, **kwargs)

def add_web_socket(self, *args, **kwargs):
def add_web_socket(self, *args: Any, **kwargs: Any) -> None:
with _IBEX_FIGURE_MANAGER_LOCK:
super().add_web_socket(*args, **kwargs)

@_ignore_if_websocket_closed
def refresh_all(self, *args, **kwargs):
def refresh_all(self) -> None:
with _IBEX_FIGURE_MANAGER_LOCK:
super().refresh_all(*args, **kwargs)
super().refresh_all()

@classmethod
def pyplot_show(cls, *args, **kwargs):
def pyplot_show(cls, *args: Any, **kwargs: Any) -> None:
"""
Show a plot.

Expand Down Expand Up @@ -305,18 +335,18 @@ def pyplot_show(cls, *args, **kwargs):
class _FigureCanvas(backend_webagg.FigureCanvasWebAgg):
manager_class = _FigureManager

def set_image_mode(self, mode):
def set_image_mode(self, mode: str) -> None:
"""
Always send full images to ibex.
"""
self._current_image_mode = "full"

def get_diff_image(self, *args, **kwargs):
def get_diff_image(self) -> bytes | None:
"""
Always send full images to ibex.
"""
self._force_full = True
return super().get_diff_image(*args, **kwargs)
return super().get_diff_image()

def draw_idle(self) -> None:
"""
Expand All @@ -341,17 +371,17 @@ class _BackendIbexWebAgg(_Backend):
FigureManager = _FigureManager

@classmethod
def trigger_manager_draw(cls, manager):
def trigger_manager_draw(cls, manager: FigureManager) -> None:
with IBEX_BACKEND_LOCK:
manager.canvas.draw_idle()

@classmethod
def draw_if_interactive(cls):
def draw_if_interactive(cls) -> None:
with IBEX_BACKEND_LOCK:
super(_BackendIbexWebAgg, cls).draw_if_interactive()
super(_BackendIbexWebAgg, cls).draw_if_interactive() # pyright: ignore

@classmethod
def new_figure_manager(cls, num, *args, **kwargs):
def new_figure_manager(cls, num: int, *args: Any, **kwargs: Any) -> _FigureManager:
with IBEX_BACKEND_LOCK:
for x in list(figure_numbers):
if x not in Gcf.figs.keys():
Expand All @@ -360,7 +390,8 @@ def new_figure_manager(cls, num, *args, **kwargs):
if len(figure_numbers) > max_number_of_figures:
Gcf.destroy(figure_numbers[0])
print(
f"There are too many figures so deleted the oldest figure, which was {figure_numbers[0]}."
f"There are too many figures so deleted "
f"the oldest figure, which was {figure_numbers[0]}."
)
figure_numbers.pop(0)
return super(_BackendIbexWebAgg, cls).new_figure_manager(num, *args, **kwargs)
return super(_BackendIbexWebAgg, cls).new_figure_manager(num, *args, **kwargs) # pyright: ignore
2 changes: 1 addition & 1 deletion tests/py3_test_genie_experimental_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

import datetime
import unittest
from unittest.mock import call, patch

from hamcrest import assert_that, calling, raises
from mock import call, patch
from parameterized import parameterized

from genie_python.genie_experimental_data import (
Expand Down
Loading
Loading