From ccb1da8a56cb12293470433869e29d74da74ee8f Mon Sep 17 00:00:00 2001 From: Cedric Hombourger Date: Thu, 16 Jan 2025 14:29:02 +0100 Subject: [PATCH] perf(storage): use zmq for storage write Replace RPC calls with ZeroMQ messages for a more efficient transfer of image chunks. Signed-off-by: Cedric Hombourger --- docs/config.rst | 4 ++ mtda/client.py | 37 ++++++++---------- mtda/constants.py | 2 - mtda/main.py | 62 +++++++++++------------------ mtda/storage/writer.py | 89 +++++++++++++++++++++++++----------------- 5 files changed, 97 insertions(+), 97 deletions(-) diff --git a/docs/config.rst b/docs/config.rst index 39e04d73..408a0b87 100644 --- a/docs/config.rst +++ b/docs/config.rst @@ -98,6 +98,10 @@ General settings Remote port to connect to in order to get console messages (defaults to ``5557``). + * ``data``: integer [optional] + Remote port for data transfers between the client and agent (defaults to + ``0`` for a dynamic port assignment). + * ``host``: string [optional] Remote host name or ip to connect to as a client to interact with the MTDA agent (defaults to ``localhost``). diff --git a/mtda/client.py b/mtda/client.py index 64ff8834..3633772e 100644 --- a/mtda/client.py +++ b/mtda/client.py @@ -17,6 +17,7 @@ import subprocess import tempfile import time +import zmq import zstandard as zstd from mtda.main import MultiTenantDeviceAccess @@ -44,6 +45,7 @@ def __init__(self, host=None, session=None, config_files=None, else: self._impl = agent self._agent = agent + self._data = None if session is None: HOST = socket.gethostname() @@ -115,7 +117,12 @@ def storage_open(self): while tries > 0: tries = tries - 1 try: - self._impl.storage_open(self._session) + host = self.remote() + port = self._impl.storage_open(self._session) + context = zmq.Context() + socket = context.socket(zmq.PUSH) + socket.connect(f'tcp://{host}:{port}') + self._data = socket return except Exception: if tries > 0: @@ -201,7 +208,7 @@ def storage_write_image(self, path, callback=None): try: # Prepare for download/copy - file.prepare(image_size) + file.prepare(self._data, image_size) # Copy image to shared storage file.copy() @@ -324,8 +331,12 @@ def flush(self): inputsize = self._inputsize totalread = self._totalread outputsize = self._outputsize + + agent.storage_flush(totalread) + self._socket.close() + while True: - status, writing, written = agent.storage_status(self._session) + status, writing, written = agent.storage_status() if callback is not None: callback(imgname, totalread, inputsize, written, outputsize) if writing is False: @@ -335,10 +346,11 @@ def flush(self): def path(self): return self._path - def prepare(self, output_size=None, compression=None): + def prepare(self, socket, output_size=None, compression=None): compr = self.compression() if compression is None else compression self._inputsize = self.size() self._outputsize = output_size + self._socket = socket # if image is uncompressed, we compress on the fly if compr == CONSTS.IMAGE.RAW.value: compr = CONSTS.IMAGE.ZST.value @@ -362,22 +374,7 @@ def size(self): return None def _write_to_storage(self, data): - max_tries = int(CONSTS.STORAGE.TIMEOUT / CONSTS.STORAGE.RETRY_INTERVAL) - - for _ in range(max_tries): - result = self._agent.storage_write(data, self._session) - if result != 0: - break - time.sleep(CONSTS.STORAGE.RETRY_INTERVAL) - - if result > 0: - return result - elif result < 0: - exc = 'write or decompression error from shared storage' - raise IOError(exc) - else: - exc = 'timeout from shared storage' - raise IOError(exc) + self._socket.send(data) class ImageLocal(ImageFile): diff --git a/mtda/constants.py b/mtda/constants.py index 001f8225..a337bc0a 100644 --- a/mtda/constants.py +++ b/mtda/constants.py @@ -77,7 +77,5 @@ class STORAGE: class WRITER: - QUEUE_SLOTS = 16 - QUEUE_TIMEOUT = 5 READ_SIZE = 1*1024*1024 WRITE_SIZE = 1*1024*1024 diff --git a/mtda/main.py b/mtda/main.py index f34a9328..de7f056b 100644 --- a/mtda/main.py +++ b/mtda/main.py @@ -84,6 +84,7 @@ def __init__(self): self.usb_switches = [] self.ctrlport = 5556 self.conport = 5557 + self.dataport = 0 self.prefix_key = self._prefix_key_code(DEFAULT_PREFIX_KEY) self.is_remote = False self.is_server = False @@ -661,7 +662,20 @@ def storage_close(self, session=None): if self.storage is not None: self.storage_locked() - self.mtda.debug(3, f"main.storage_close(): {str(result)}") + self.mtda.debug(3, f"main.storage_close(): {result}") + return result + + @Pyro4.expose + def storage_flush(self, size, session=None): + self.mtda.debug(3, "main.storage_flush()") + + self._session_check(session) + if self.storage is None: + result = False + else: + result = self._writer.flush(size) + + self.mtda.debug(3, f"main.storage_flush(): {result}") return result @Pyro4.expose @@ -790,6 +804,7 @@ def storage_open(self, session=None): self._session_check(session) owner = self._storage_owner + result = None status, _, _ = self.storage_status() if self.storage is None: @@ -802,12 +817,13 @@ def storage_open(self, session=None): self.storage.open() self._storage_opened = True self._storage_owner = session - self._writer.start() + result = self._writer.start() self._storage_event(CONSTS.STORAGE.OPENED, session) if self.storage is not None: self.storage_locked() - self.mtda.debug(3, 'main.storage_open(): success') + self.mtda.debug(3, f'main.storage_open(): {result}') + return result @Pyro4.expose def storage_status(self, session=None): @@ -877,42 +893,6 @@ def storage_swap(self, session=None): self.mtda.debug(3, f"main.storage_swap(): {str(result)}") return result - @Pyro4.expose - def storage_write(self, data, session=None): - self.mtda.debug(3, "main.storage_write()") - - self._session_check(session) - if self.storage is None: - raise RuntimeError('no shared storage') - elif self._storage_opened is False: - raise RuntimeError('shared storage was not opened') - elif self._writer.failed is True: - raise RuntimeError('write or decompression error ' - 'from shared storage') - elif session != self._storage_owner: - raise RuntimeError('shared storage in use') - - import queue - try: - if len(data) == 0: - self.mtda.debug(2, "main.storage_write(): " - "using queued data") - data = self._writer_data - self._writer_data = data - self._writer.put(data, timeout=10) - result = self.blksz - except queue.Full: - self.mtda.debug(2, "main.storage_write(): " - "queue is full") - result = 0 - - if self._writer.failed is True: - self.error('storage_write failed: write or decompression error') - result = -1 - - self.mtda.debug(3, f"main.storage_write(): {str(result)}") - return result - def systemd_configure(self): from filecmp import dircmp @@ -1467,7 +1447,7 @@ def post_configure_storage(self, storage, config, parser): self.mtda.debug(3, "main.post_configure_storage()") from mtda.storage.writer import AsyncImageWriter - self._writer = AsyncImageWriter(self, storage) + self._writer = AsyncImageWriter(self, storage, self.dataport) import atexit atexit.register(self.storage_close) @@ -1479,6 +1459,8 @@ def load_remote_config(self, parser): parser.get('remote', 'console', fallback=self.conport)) self.ctrlport = int( parser.get('remote', 'control', fallback=self.ctrlport)) + self.dataport = int( + parser.get('remote', 'data', fallback=self.dataport)) if self.is_server is False: if self.remote is None: # Load remote setting from the configuration diff --git a/mtda/storage/writer.py b/mtda/storage/writer.py index 4d4d35ca..0306f285 100644 --- a/mtda/storage/writer.py +++ b/mtda/storage/writer.py @@ -10,28 +10,31 @@ # --------------------------------------------------------------------------- import bz2 -import queue import threading import mtda.constants as CONSTS import zlib import zstandard as zstd import lzma +import zmq -class AsyncImageWriter(queue.Queue): +class AsyncImageWriter: - def __init__(self, mtda, storage, compression=CONSTS.IMAGE.RAW): + def __init__(self, mtda, storage, dataport, compression=CONSTS.IMAGE.RAW): self.mtda = mtda self.storage = storage self.compression = compression self._blksz = CONSTS.WRITER.WRITE_SIZE + self._dataport = dataport self._exiting = False self._failed = False + self._size = 0 + self._socket = None self._thread = None + self._receiving = False self._writing = False self._written = 0 self._zdec = None - super().__init__(maxsize=CONSTS.WRITER.QUEUE_SLOTS) @property def compression(self): @@ -68,75 +71,91 @@ def compression(self, compression): def failed(self): return self._failed - def put(self, chunk, block=True, timeout=None): - self.mtda.debug(3, "storage.writer.put()") - - if self.storage is None: - self.mtda.debug(1, "storage.writer.put(): no storage!") - raise IOError("no storage!") - result = super().put(chunk, block, timeout) - # if thread is started and put data is not empty - if len(chunk) > 0 and self._exiting is False: - self._writing = True - self.mtda.debug(3, f"storage.writer.put(): {str(result)}") + def flush(self, size): + self.mtda.debug(3, "mtda.storage.writer.flush()") + + result = None + self._receiving = False + self._size = size + + self.mtda.debug(3, f"storage.writer.flush(): {result}") return result def start(self): self.mtda.debug(3, "mtda.storage.writer.start()") - result = None + context = zmq.Context() + timeout = CONSTS.WRITER.RECV_TIMEOUT * 1000 + + self._socket = context.socket(zmq.PULL) + self._socket.bind(f"tcp://*:{self._dataport}") + self._socket.setsockopt(zmq.RCVTIMEO, timeout) + + endpoint = self._socket.getsockopt_string(zmq.LAST_ENDPOINT) + result = int(endpoint.split(":")[-1]) + self._thread = threading.Thread(target=self.worker, daemon=True, name='writer') self._thread.start() - self.mtda.debug(3, f"storage.writer.start(): {str(result)}") + self.mtda.debug(3, f"storage.writer.start(): {result}") return result def stop(self): self.mtda.debug(3, "storage.writer.stop()") result = None - self.mtda.debug(2, "storage.writer.stop(): waiting on queue...") - self.join() + self._exiting = True if self._thread is not None: self.mtda.debug(2, "storage.writer.stop(): waiting on thread...") - self._exiting = True - self.put(b'') self._thread.join() self.mtda.debug(2, "storage.writer.stop(): all done") + self._thread = None self._zdec = None - self.mtda.debug(3, f"storage.writer.stop(): {str(result)}") + self.mtda.debug(3, f"storage.writer.stop(): {result}") return result def worker(self): self.mtda.debug(3, "storage.writer.worker()") + received = 0 result = None self._exiting = False self._failed = False + self._receiving = True self._written = 0 + self._writing = True while self._exiting is False: - if self.empty(): - self._writing = False - chunk = self.get() - if self._exiting is False: - try: - self._write(chunk) - except Exception as e: - self.mtda.debug(1, f"storage.writer.worker(): {e}") - self._failed = True - self._writing = False - pass - self.task_done() + try: + chunk = self._socket.recv() + received += len(chunk) + self._write(chunk) + except zmq.Again: + if self._receiving is False: + if self._size > 0 and received == self._size: + self.mtda.debug(1, "storage.writer.worker(): transfer complete") + break + self._failed = True + self.mtda.debug(1, "storage.writer.worker(): timeout") + except Exception as e: + self._failed = True + self.mtda.debug(1, f"storage.writer.worker(): {e}") + break if self._failed is True: self.mtda.debug(1, "storage.writer.worker(): " "write or decompression error!") - self.mtda.debug(3, f"storage.writer.worker(): {str(result)}") + if self._socket: + self._socket.close() + self._socket = None + self._receiving = False + self._writing = False + + self.mtda.debug(3, f"storage.writer.worker(): {result}") return result def write_raw(self, data):