From e4b9286cdd62ffac70f045379e710ddfc550a09d Mon Sep 17 00:00:00 2001 From: Cedric Hombourger Date: Fri, 17 Jan 2025 15:57:59 +0100 Subject: [PATCH 1/2] refactor(main): rename _session_check to session_ping Signed-off-by: Cedric Hombourger --- mtda/main.py | 98 ++++++++++++++++++++++++++-------------------------- 1 file changed, 49 insertions(+), 49 deletions(-) diff --git a/mtda/main.py b/mtda/main.py index f34a9328..71fa80e9 100644 --- a/mtda/main.py +++ b/mtda/main.py @@ -121,7 +121,7 @@ def agent_version(self, session=None): def command(self, args, session=None): self.mtda.debug(3, "main.command()") - self._session_check(session) + self.session_ping(session) result = False if self.power_locked(session) is False: result = self.power.command(args) @@ -168,7 +168,7 @@ def config_set_power_timeout(self, timeout, session=None): self._power_timeout = timeout if timeout == 0: self._power_expiry = None - self._session_check(session) + self.session_ping(session) self.mtda.debug(3, f"main.config_set_power_timeout(): {result}") return result @@ -230,7 +230,7 @@ def console_init(self): def console_clear(self, session=None): self.mtda.debug(3, "main.console_clear()") - self._session_check(session) + self.session_ping(session) if self.console_locked(session): self.mtda.debug(2, "console_clear(): console is locked") return None @@ -246,7 +246,7 @@ def console_clear(self, session=None): def console_dump(self, session=None): self.mtda.debug(3, "main.console_dump()") - self._session_check(session) + self.session_ping(session) if self.console_locked(session): self.mtda.debug(2, "console_dump(): console is locked") return None @@ -262,7 +262,7 @@ def console_dump(self, session=None): def console_flush(self, session=None): self.mtda.debug(3, "main.console_flush()") - self._session_check(session) + self.session_ping(session) if self.console_locked(session): self.mtda.debug(2, "console_flush(): console is locked") return None @@ -278,7 +278,7 @@ def console_flush(self, session=None): def console_head(self, session=None): self.mtda.debug(3, "main.console_head()") - self._session_check(session) + self.session_ping(session) result = None if self.console_logger is not None: result = self.console_logger.head() @@ -290,7 +290,7 @@ def console_head(self, session=None): def console_lines(self, session=None): self.mtda.debug(3, "main.console_lines()") - self._session_check(session) + self.session_ping(session) result = 0 if self.console_logger is not None: result = self.console_logger.lines() @@ -302,7 +302,7 @@ def console_lines(self, session=None): def console_locked(self, session=None): self.mtda.debug(3, "main.console_locked()") - self._session_check(session) + self.session_ping(session) result = self._check_locked(session) self.mtda.debug(3, f"main.console_locked(): {str(result)}") @@ -312,7 +312,7 @@ def console_locked(self, session=None): def console_print(self, data, session=None): self.mtda.debug(3, "main.console_print()") - self._session_check(session) + self.session_ping(session) result = None if self.console_logger is not None: result = self.console_logger.print(data) @@ -324,7 +324,7 @@ def console_print(self, data, session=None): def console_prompt(self, newPrompt=None, session=None): self.mtda.debug(3, "main.console_prompt()") - self._session_check(session) + self.session_ping(session) result = None if self.console_locked(session) is False and \ self.console_logger is not None: @@ -356,7 +356,7 @@ def console_remote(self, host, screen): def console_run(self, cmd, session=None): self.mtda.debug(3, "main.console_run()") - self._session_check(session) + self.session_ping(session) result = None if self.console_locked(session) is False and \ self.console_logger is not None: @@ -369,7 +369,7 @@ def console_run(self, cmd, session=None): def console_send(self, data, raw=False, session=None): self.mtda.debug(3, "main.console_send()") - self._session_check(session) + self.session_ping(session) result = None if self.console_locked(session) is False and \ self.console_logger is not None: @@ -382,7 +382,7 @@ def console_send(self, data, raw=False, session=None): def console_tail(self, session=None): self.mtda.debug(3, "main.console_tail()") - self._session_check(session) + self.session_ping(session) if self.console_locked(session) is False and \ self.console_logger is not None: result = self.console_logger.tail() @@ -395,7 +395,7 @@ def console_toggle(self, session=None): self.mtda.debug(3, "main.console_toggle()") result = None - self._session_check(session) + self.session_ping(session) if self.console_output is not None: self.console_output.toggle() if self.monitor_output is not None: @@ -408,7 +408,7 @@ def console_toggle(self, session=None): def console_wait(self, what, timeout=None, session=None): self.mtda.debug(3, "main.console_wait()") - self._session_check(session) + self.session_ping(session) result = None if session is not None and timeout is None: timeout = CONSTS.RPC.TIMEOUT @@ -474,7 +474,7 @@ def env_set(self, name, value, session=None): def keyboard_write(self, what, session=None): self.mtda.debug(3, "main.keyboard_write()") - self._session_check(session) + self.session_ping(session) result = None if self.keyboard is not None: special_keys = { @@ -539,7 +539,7 @@ def monitor_remote(self, host, screen): def monitor_send(self, data, raw=False, session=None): self.mtda.debug(3, "main.monitor_send()") - self._session_check(session) + self.session_ping(session) result = None if self.console_locked(session) is False and \ self.monitor_logger is not None: @@ -552,7 +552,7 @@ def monitor_send(self, data, raw=False, session=None): def monitor_wait(self, what, timeout=None, session=None): self.mtda.debug(3, "main.monitor_wait()") - self._session_check(session) + self.session_ping(session) result = None if session is not None and timeout is None: timeout = CONSTS.RPC.TIMEOUT @@ -577,7 +577,7 @@ def pastebin_endpoint(self): def power_locked(self, session=None): self.mtda.debug(3, "main.power_locked()") - self._session_check(session) + self.session_ping(session) if self.power is None: result = True else: @@ -606,7 +606,7 @@ def _storage_event(self, status, reason=""): def storage_bytes_written(self, session=None): self.mtda.debug(3, "main.storage_bytes_written()") - self._session_check(session) + self.session_ping(session) result = self._writer.written self.mtda.debug(3, f"main.storage_bytes_written(): {str(result)}") @@ -616,7 +616,7 @@ def storage_bytes_written(self, session=None): def storage_compression(self, compression, session=None): self.mtda.debug(3, "main.storage_compression()") - self._session_check(session) + self.session_ping(session) if self.storage is None: result = None else: @@ -630,7 +630,7 @@ def storage_compression(self, compression, session=None): def storage_bmap_dict(self, bmapDict, session=None): self.mtda.debug(3, "main.storage_bmap_dict()") - self._session_check(session) + self.session_ping(session) if self.storage is None: result = None else: @@ -642,7 +642,7 @@ def storage_bmap_dict(self, bmapDict, session=None): def storage_close(self, session=None): self.mtda.debug(3, "main.storage_close()") - self._session_check(session) + self.session_ping(session) if self.storage is None: result = False else: @@ -668,7 +668,7 @@ def storage_close(self, session=None): def storage_locked(self, session=None): self.mtda.debug(3, "main.storage_locked()") - self._session_check(session) + self.session_ping(session) result = False reason = "unsure" if self._check_locked(session): @@ -717,7 +717,7 @@ def storage_locked(self, session=None): def storage_mount(self, part=None, session=None): self.mtda.debug(3, "main.storage_mount()") - self._session_check(session) + self.session_ping(session) if self.storage.is_storage_mounted is True: self.mtda.debug(4, "storage_mount(): already mounted") result = True @@ -738,7 +738,7 @@ def storage_mount(self, part=None, session=None): def storage_update(self, dst, offset, session=None): self.mtda.debug(3, "main.storage_update()") - self._session_check(session) + self.session_ping(session) result = False if self.storage is None: self.mtda.debug(4, "storage_update(): no shared storage device") @@ -756,7 +756,7 @@ def storage_network(self, session=None): self.mtda.debug(3, "main.storage_network()") result = False - self._session_check(session) + self.session_ping(session) if self.storage_locked(session) is False: if self.storage.to_host() is True: conf = os.path.join(NBD_CONF_DIR, NBD_CONF_FILE) @@ -788,7 +788,7 @@ def storage_network(self, session=None): def storage_open(self, session=None): self.mtda.debug(3, 'main.storage_open()') - self._session_check(session) + self.session_ping(session) owner = self._storage_owner status, _, _ = self.storage_status() @@ -813,7 +813,7 @@ def storage_open(self, session=None): def storage_status(self, session=None): self.mtda.debug(3, "main.storage_status()") - self._session_check(session) + self.session_ping(session) if self.storage is None: self.mtda.debug(4, "storage_status(): no shared storage device") result = CONSTS.STORAGE.UNKNOWN, False, 0 @@ -829,7 +829,7 @@ def storage_status(self, session=None): def storage_to_host(self, session=None): self.mtda.debug(3, "main.storage_to_host()") - self._session_check(session) + self.session_ping(session) if self.storage_locked(session) is False: result = self.storage.to_host() if result is True: @@ -845,7 +845,7 @@ def storage_to_host(self, session=None): def storage_to_target(self, session=None): self.mtda.debug(3, "main.storage_to_target()") - self._session_check(session) + self.session_ping(session) if self.storage_locked(session) is False: self.storage_close() result = self.storage.to_target() @@ -862,7 +862,7 @@ def storage_to_target(self, session=None): def storage_swap(self, session=None): self.mtda.debug(3, "main.storage_swap()") - self._session_check(session) + self.session_ping(session) if self.storage_locked(session) is False: result, writing, written = self.storage_status(session) if result in [CONSTS.STORAGE.ON_HOST, CONSTS.STORAGE.ON_NETWORK]: @@ -881,7 +881,7 @@ def storage_swap(self, session=None): def storage_write(self, data, session=None): self.mtda.debug(3, "main.storage_write()") - self._session_check(session) + self.session_ping(session) if self.storage is None: raise RuntimeError('no shared storage') elif self._storage_opened is False: @@ -1089,7 +1089,7 @@ def target_on(self, session=None): self.mtda.debug(3, "main.target_on()") result = True - self._session_check(session) + self.session_ping(session) with self._power_lock: status = self._target_status() if status != CONSTS.POWER.ON: @@ -1152,7 +1152,7 @@ def target_off(self, session=None): self.mtda.debug(3, "main.target_off()") result = True - self._session_check(session) + self.session_ping(session) with self._power_lock: status = self._target_status() if status != CONSTS.POWER.OFF: @@ -1189,7 +1189,7 @@ def target_toggle(self, session=None): self.mtda.debug(3, "main.target_toggle()") result = CONSTS.POWER.UNSURE - self._session_check(session) + self.session_ping(session) with self._power_lock: if self.power_locked(session) is False: status = self._target_status(session) @@ -1231,7 +1231,7 @@ def target_uptime(self, session=None): def usb_find_by_class(self, className, session=None): self.mtda.debug(3, "main.usb_find_by_class()") - self._session_check(session) + self.session_ping(session) ports = len(self.usb_switches) ndx = 0 while ndx < ports: @@ -1245,7 +1245,7 @@ def usb_find_by_class(self, className, session=None): def usb_has_class(self, className, session=None): self.mtda.debug(3, "main.usb_has_class()") - self._session_check(session) + self.session_ping(session) usb_switch = self.usb_find_by_class(className, session) return usb_switch is not None @@ -1253,7 +1253,7 @@ def usb_has_class(self, className, session=None): def usb_off(self, ndx, session=None): self.mtda.debug(3, "main.usb_off()") - self._session_check(session) + self.session_ping(session) try: if ndx > 0: usb_switch = self.usb_switches[ndx-1] @@ -1265,7 +1265,7 @@ def usb_off(self, ndx, session=None): def usb_off_by_class(self, className, session=None): self.mtda.debug(3, "main.usb_off_by_class()") - self._session_check(session) + self.session_ping(session) usb_switch = self.usb_find_by_class(className, session) if usb_switch is not None: return usb_switch.off() @@ -1275,7 +1275,7 @@ def usb_off_by_class(self, className, session=None): def usb_on(self, ndx, session=None): self.mtda.debug(3, "main.usb_on()") - self._session_check(session) + self.session_ping(session) try: if ndx > 0: usb_switch = self.usb_switches[ndx-1] @@ -1287,7 +1287,7 @@ def usb_on(self, ndx, session=None): def usb_on_by_class(self, className, session=None): self.mtda.debug(3, "main.usb_on_by_class()") - self._session_check(session) + self.session_ping(session) usb_switch = self.usb_find_by_class(className, session) if usb_switch is not None: return usb_switch.on() @@ -1297,14 +1297,14 @@ def usb_on_by_class(self, className, session=None): def usb_ports(self, session=None): self.mtda.debug(3, "main.usb_ports()") - self._session_check(session) + self.session_ping(session) return len(self.usb_switches) @Pyro4.expose def usb_status(self, ndx, session=None): self.mtda.debug(3, "main.usb_status()") - self._session_check(session) + self.session_ping(session) try: if ndx > 0: usb_switch = self.usb_switches[ndx-1] @@ -1324,7 +1324,7 @@ def usb_status(self, ndx, session=None): def usb_toggle(self, ndx, session=None): self.mtda.debug(3, "main.usb_toggle()") - self._session_check(session) + self.session_ping(session) try: if ndx > 0: usb_switch = self.usb_switches[ndx-1] @@ -1678,7 +1678,7 @@ def start(self): if self.is_server is True: from mtda.utils import RepeatTimer - handler = self._session_check + handler = self.session_ping self._session_timer = RepeatTimer(10, handler) self._session_timer.start() @@ -1763,8 +1763,8 @@ def session_event(self, info): self._storage_event(CONSTS.STORAGE.CORRUPTED) self.storage_close(None) - def _session_check(self, session=None): - self.mtda.debug(4, f"main._session_check({session})") + def session_ping(self, session=None): + self.mtda.debug(4, f"main.session_ping({session})") result = None if self._session_manager is not None: @@ -1777,7 +1777,7 @@ def _session_check(self, session=None): self.mtda.debug(2, "device powered down after " f"{self._power_timeout} seconds of inactivity") - self.mtda.debug(4, f"main._session_check: {result}") + self.mtda.debug(4, f"main.session_ping: {result}") return result def _check_locked(self, session): From eaf594157deeecd1173b143e3095d62c49399548 Mon Sep 17 00:00:00 2001 From: Cedric Hombourger Date: Thu, 16 Jan 2025 14:29:02 +0100 Subject: [PATCH 2/2] perf(storage): use zmq for storage write Replace RPC calls with ZeroMQ messages for a more efficient transfer of image chunks. Signed-off-by: Cedric Hombourger --- docs/config.rst | 4 ++ mtda/client.py | 38 +++++++-------- mtda/constants.py | 3 +- mtda/main.py | 62 +++++++++--------------- mtda/storage/writer.py | 105 +++++++++++++++++++++++++---------------- 5 files changed, 110 insertions(+), 102 deletions(-) diff --git a/docs/config.rst b/docs/config.rst index 39e04d73..408a0b87 100644 --- a/docs/config.rst +++ b/docs/config.rst @@ -98,6 +98,10 @@ General settings Remote port to connect to in order to get console messages (defaults to ``5557``). + * ``data``: integer [optional] + Remote port for data transfers between the client and agent (defaults to + ``0`` for a dynamic port assignment). + * ``host``: string [optional] Remote host name or ip to connect to as a client to interact with the MTDA agent (defaults to ``localhost``). diff --git a/mtda/client.py b/mtda/client.py index 64ff8834..d09c13b6 100644 --- a/mtda/client.py +++ b/mtda/client.py @@ -17,6 +17,7 @@ import subprocess import tempfile import time +import zmq import zstandard as zstd from mtda.main import MultiTenantDeviceAccess @@ -44,6 +45,7 @@ def __init__(self, host=None, session=None, config_files=None, else: self._impl = agent self._agent = agent + self._data = None if session is None: HOST = socket.gethostname() @@ -115,7 +117,12 @@ def storage_open(self): while tries > 0: tries = tries - 1 try: - self._impl.storage_open(self._session) + host = self.remote() + port = self._impl.storage_open(self._session) + context = zmq.Context() + socket = context.socket(zmq.PUSH) + socket.connect(f'tcp://{host}:{port}') + self._data = socket return except Exception: if tries > 0: @@ -201,7 +208,7 @@ def storage_write_image(self, path, callback=None): try: # Prepare for download/copy - file.prepare(image_size) + file.prepare(self._data, image_size) # Copy image to shared storage file.copy() @@ -298,6 +305,7 @@ def __init__(self, path, agent, session, blksz, callback=None): self._path = path self._session = session self._totalread = 0 + self._totalsent = 0 def bmap(self, path): return None @@ -324,21 +332,25 @@ def flush(self): inputsize = self._inputsize totalread = self._totalread outputsize = self._outputsize + + agent.storage_flush(self._totalsent) while True: - status, writing, written = agent.storage_status(self._session) + status, writing, written = agent.storage_status() if callback is not None: callback(imgname, totalread, inputsize, written, outputsize) if writing is False: break time.sleep(0.5) + self._socket.close() def path(self): return self._path - def prepare(self, output_size=None, compression=None): + def prepare(self, socket, output_size=None, compression=None): compr = self.compression() if compression is None else compression self._inputsize = self.size() self._outputsize = output_size + self._socket = socket # if image is uncompressed, we compress on the fly if compr == CONSTS.IMAGE.RAW.value: compr = CONSTS.IMAGE.ZST.value @@ -362,22 +374,8 @@ def size(self): return None def _write_to_storage(self, data): - max_tries = int(CONSTS.STORAGE.TIMEOUT / CONSTS.STORAGE.RETRY_INTERVAL) - - for _ in range(max_tries): - result = self._agent.storage_write(data, self._session) - if result != 0: - break - time.sleep(CONSTS.STORAGE.RETRY_INTERVAL) - - if result > 0: - return result - elif result < 0: - exc = 'write or decompression error from shared storage' - raise IOError(exc) - else: - exc = 'timeout from shared storage' - raise IOError(exc) + self._socket.send(data) + self._totalsent += len(data) class ImageLocal(ImageFile): diff --git a/mtda/constants.py b/mtda/constants.py index 001f8225..d6300835 100644 --- a/mtda/constants.py +++ b/mtda/constants.py @@ -77,7 +77,6 @@ class STORAGE: class WRITER: - QUEUE_SLOTS = 16 - QUEUE_TIMEOUT = 5 + RECV_TIMEOUT = 5 READ_SIZE = 1*1024*1024 WRITE_SIZE = 1*1024*1024 diff --git a/mtda/main.py b/mtda/main.py index 71fa80e9..120ca994 100644 --- a/mtda/main.py +++ b/mtda/main.py @@ -84,6 +84,7 @@ def __init__(self): self.usb_switches = [] self.ctrlport = 5556 self.conport = 5557 + self.dataport = 0 self.prefix_key = self._prefix_key_code(DEFAULT_PREFIX_KEY) self.is_remote = False self.is_server = False @@ -661,7 +662,20 @@ def storage_close(self, session=None): if self.storage is not None: self.storage_locked() - self.mtda.debug(3, f"main.storage_close(): {str(result)}") + self.mtda.debug(3, f"main.storage_close(): {result}") + return result + + @Pyro4.expose + def storage_flush(self, size, session=None): + self.mtda.debug(3, "main.storage_flush()") + + self.session_ping(session) + if self.storage is None: + result = False + else: + result = self._writer.flush(size) + + self.mtda.debug(3, f"main.storage_flush(): {result}") return result @Pyro4.expose @@ -790,6 +804,7 @@ def storage_open(self, session=None): self.session_ping(session) owner = self._storage_owner + result = None status, _, _ = self.storage_status() if self.storage is None: @@ -802,12 +817,13 @@ def storage_open(self, session=None): self.storage.open() self._storage_opened = True self._storage_owner = session - self._writer.start() + result = self._writer.start(session) self._storage_event(CONSTS.STORAGE.OPENED, session) if self.storage is not None: self.storage_locked() - self.mtda.debug(3, 'main.storage_open(): success') + self.mtda.debug(3, f'main.storage_open(): {result}') + return result @Pyro4.expose def storage_status(self, session=None): @@ -877,42 +893,6 @@ def storage_swap(self, session=None): self.mtda.debug(3, f"main.storage_swap(): {str(result)}") return result - @Pyro4.expose - def storage_write(self, data, session=None): - self.mtda.debug(3, "main.storage_write()") - - self.session_ping(session) - if self.storage is None: - raise RuntimeError('no shared storage') - elif self._storage_opened is False: - raise RuntimeError('shared storage was not opened') - elif self._writer.failed is True: - raise RuntimeError('write or decompression error ' - 'from shared storage') - elif session != self._storage_owner: - raise RuntimeError('shared storage in use') - - import queue - try: - if len(data) == 0: - self.mtda.debug(2, "main.storage_write(): " - "using queued data") - data = self._writer_data - self._writer_data = data - self._writer.put(data, timeout=10) - result = self.blksz - except queue.Full: - self.mtda.debug(2, "main.storage_write(): " - "queue is full") - result = 0 - - if self._writer.failed is True: - self.error('storage_write failed: write or decompression error') - result = -1 - - self.mtda.debug(3, f"main.storage_write(): {str(result)}") - return result - def systemd_configure(self): from filecmp import dircmp @@ -1467,7 +1447,7 @@ def post_configure_storage(self, storage, config, parser): self.mtda.debug(3, "main.post_configure_storage()") from mtda.storage.writer import AsyncImageWriter - self._writer = AsyncImageWriter(self, storage) + self._writer = AsyncImageWriter(self, storage, self.dataport) import atexit atexit.register(self.storage_close) @@ -1479,6 +1459,8 @@ def load_remote_config(self, parser): parser.get('remote', 'console', fallback=self.conport)) self.ctrlport = int( parser.get('remote', 'control', fallback=self.ctrlport)) + self.dataport = int( + parser.get('remote', 'data', fallback=self.dataport)) if self.is_server is False: if self.remote is None: # Load remote setting from the configuration diff --git a/mtda/storage/writer.py b/mtda/storage/writer.py index 4d4d35ca..fbff532b 100644 --- a/mtda/storage/writer.py +++ b/mtda/storage/writer.py @@ -10,28 +10,32 @@ # --------------------------------------------------------------------------- import bz2 -import queue import threading import mtda.constants as CONSTS import zlib import zstandard as zstd import lzma +import zmq -class AsyncImageWriter(queue.Queue): +class AsyncImageWriter: - def __init__(self, mtda, storage, compression=CONSTS.IMAGE.RAW): + def __init__(self, mtda, storage, dataport, compression=CONSTS.IMAGE.RAW): self.mtda = mtda self.storage = storage self.compression = compression self._blksz = CONSTS.WRITER.WRITE_SIZE + self._dataport = dataport self._exiting = False self._failed = False + self._session = None + self._size = 0 + self._socket = None self._thread = None + self._receiving = False self._writing = False self._written = 0 self._zdec = None - super().__init__(maxsize=CONSTS.WRITER.QUEUE_SLOTS) @property def compression(self): @@ -68,75 +72,96 @@ def compression(self, compression): def failed(self): return self._failed - def put(self, chunk, block=True, timeout=None): - self.mtda.debug(3, "storage.writer.put()") - - if self.storage is None: - self.mtda.debug(1, "storage.writer.put(): no storage!") - raise IOError("no storage!") - result = super().put(chunk, block, timeout) - # if thread is started and put data is not empty - if len(chunk) > 0 and self._exiting is False: - self._writing = True - self.mtda.debug(3, f"storage.writer.put(): {str(result)}") + def flush(self, size): + self.mtda.debug(3, "mtda.storage.writer.flush()") + + result = None + self._receiving = False + self._size = size + + self.mtda.debug(3, f"storage.writer.flush(): {result}") return result - def start(self): + def start(self, session): self.mtda.debug(3, "mtda.storage.writer.start()") - result = None + self._session = session + context = zmq.Context() + timeout = CONSTS.WRITER.RECV_TIMEOUT * 1000 + + self._socket = context.socket(zmq.PULL) + self._socket.bind(f"tcp://*:{self._dataport}") + self._socket.setsockopt(zmq.RCVTIMEO, timeout) + + endpoint = self._socket.getsockopt_string(zmq.LAST_ENDPOINT) + result = int(endpoint.split(":")[-1]) + self._thread = threading.Thread(target=self.worker, daemon=True, name='writer') self._thread.start() - self.mtda.debug(3, f"storage.writer.start(): {str(result)}") + self.mtda.debug(3, f"storage.writer.start(): {result}") return result def stop(self): self.mtda.debug(3, "storage.writer.stop()") result = None - self.mtda.debug(2, "storage.writer.stop(): waiting on queue...") - self.join() + self._exiting = True if self._thread is not None: self.mtda.debug(2, "storage.writer.stop(): waiting on thread...") - self._exiting = True - self.put(b'') self._thread.join() self.mtda.debug(2, "storage.writer.stop(): all done") + self._thread = None self._zdec = None - self.mtda.debug(3, f"storage.writer.stop(): {str(result)}") + self.mtda.debug(3, f"storage.writer.stop(): {result}") return result def worker(self): self.mtda.debug(3, "storage.writer.worker()") + mtda = self.mtda + received = 0 result = None self._exiting = False self._failed = False + self._receiving = True self._written = 0 + self._writing = True while self._exiting is False: - if self.empty(): - self._writing = False - chunk = self.get() - if self._exiting is False: - try: - self._write(chunk) - except Exception as e: - self.mtda.debug(1, f"storage.writer.worker(): {e}") - self._failed = True - self._writing = False - pass - self.task_done() - if self._failed is True: - self.mtda.debug(1, "storage.writer.worker(): " - "write or decompression error!") - - self.mtda.debug(3, f"storage.writer.worker(): {str(result)}") + try: + chunk = self._socket.recv() + received += len(chunk) + mtda.session_ping(self._session) + self._write(chunk) + except zmq.Again: + if self._receiving is False: + if self._size > 0 and received == self._size: + mtda.debug(1, "storage.writer.worker(): transfer complete") + break + self._failed = True + mtda.debug(1, "storage.writer.worker(): timeout " + f"(recv'd {received} / {self._size})") + except Exception as e: + self._failed = True + mtda.debug(1, f"storage.writer.worker(): {e}") + break + + self._receiving = False + self._writing = False + if self._failed is True: + mtda.debug(1, "storage.writer.worker(): " + "write or decompression error!") + + if self._socket: + self._socket.close() + self._socket = None + + mtda.debug(3, f"storage.writer.worker(): {result}") return result def write_raw(self, data):