Skip to content

Commit

Permalink
Refactor message serialisation and deserialisation (#197)
Browse files Browse the repository at this point in the history
* Refactor message serialization and deserialization

Addiing `Message` and `SerialisedMessage` classes in attempt to improve information hiding and decoupling.

* Rename `utils.py` -> `message.py`

* Add `decode()` method for `SerialisedMessage`

* Update docstring

* Use new classes in message testing

* Refactor message processing in the CLI

* Refactor `process_message` to use `SerialisedMessage` class in EHR API

* Refactor `process_message` to use `SerialisedMessage` class in imaging API

* Fix `ImagingStudy` initalisation in  `ImagingStudy.from_message()`

* Fix imports

* Fix test: access serialised message bodies

* Turn `Message` into a `dataclass`

* Fix failing tests

* Use `jsonpickle` for (de)serializing messages

This also removes the need for the `SerialisedMessage` class

* Fix `test_deserialise_datetime()` so it uses the `Message` class to assert the `study_datetime`

* Add `study_datetime` property for `Message`

* No need to test deserialising individual fields, already covered by `test_deserialise()` which deserialises the entire object

* Remove `study_date_from_serialised()`, use the class attribute `study_datetime` instead

* Revert "Add `study_datetime` property for `Message`"

This reverts commit 8719153.

* Remove `Messages` class, use `list[Message]` instead

* Add type checking for messages parsed from parquet input

* Update `test_messages_from_parquet()` to use JSON strings instead of bytes

* Update `PixlProducer.publish()` to use a list of Message objects and handle serialisation

* Convert JSON string to bytes when serialising

* Revert "Update `test_messages_from_parquet()` to use JSON strings instead of bytes"

This reverts commit 0e4fce4.

* `PixlProducer.publish()` should take a `list[Message]` as input in tests

* Update EHR API to use new `Message` design

* Update imaging API to use new `Message` design

* Update deserialise function to accept bytes-encoded JSON string

* Assert messages against list of `Message`s

* Print dataclass in logs

* `jsonpickle.decode()` can handle bytes so no need to decode first

Also add a note about why we ignore ruff rule S301

* Make `deserialisable` a keyword only argument

* Copilot forgot to convert dates to datetimes 🥲

* Refactor PixlConsumer run method to accept Message object as callback parameter and deserialise

* Update consumer in `test_subscriber` to accept Message object instead of bytes
  • Loading branch information
milanmlft authored Dec 20, 2023
1 parent bf39b46 commit 90b6672
Show file tree
Hide file tree
Showing 15 changed files with 253 additions and 241 deletions.
82 changes: 29 additions & 53 deletions cli/src/pixl_cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@
import datetime
import json
import os
from operator import attrgetter
from pathlib import Path
from typing import Any, Optional

import click
import pandas as pd
import requests
import yaml
from core.patient_queue.message import Message, deserialise
from core.patient_queue.producer import PixlProducer
from core.patient_queue.subscriber import PixlBlockingConsumer
from core.patient_queue.utils import deserialise, serialise

from ._logging import logger, set_log_level
from ._utils import clear_file, remove_file_if_it_exists, string_is_non_empty
Expand Down Expand Up @@ -84,12 +85,12 @@ def populate(queues: str, *, restart: bool, parquet_dir: Path) -> None:
if state_filepath.exists() and restart:
logger.info(f"Extracting messages from state: {state_filepath}")
inform_user_that_queue_will_be_populated_from(state_filepath)
messages = Messages.from_state_file(state_filepath)
messages = messages_from_state_file(state_filepath)
elif parquet_dir is not None:
messages = messages_from_parquet(parquet_dir)

remove_file_if_it_exists(state_filepath) # will be stale
producer.publish(sorted(messages, key=study_date_from_serialised))
producer.publish(sorted(messages, key=attrgetter("study_datetime")))


@cli.command()
Expand Down Expand Up @@ -273,41 +274,26 @@ def state_filepath_for_queue(queue_name: str) -> Path:
return Path(f"{queue_name.replace('/', '_')}.state")


class Messages(list):
def messages_from_state_file(filepath: Path) -> list[Message]:
"""
Class to represent messages
Return messages from a state file path
Methods
-------
from_state_file(cls, filepath)
Return messages from a state file path
:param filepath: Path for state file to be read
:return: A list of Message objects containing all the messages from the state file
"""
logger.info(f"Creating messages from {filepath}")
if not filepath.exists():
raise FileNotFoundError
if filepath.suffix != ".state":
msg = f"Invalid file suffix for {filepath}. Expected .state"
raise ValueError(msg)

@classmethod
def from_state_file(cls, filepath: Path) -> "Messages":
"""
Return messages from a state file path
:param filepath: Path for state file to be read
:return: A Messages object containing all the messages from the state file
"""
logger.info(f"Creating messages from {filepath}")
if not filepath.exists():
raise FileNotFoundError
if filepath.suffix != ".state":
msg = f"Invalid file suffix for {filepath}. Expected .state"
raise ValueError(msg)

return cls(
[
line.encode("utf-8")
for line in Path.open(filepath).readlines()
if string_is_non_empty(line)
]
)
return [
deserialise(line) for line in Path.open(filepath).readlines() if string_is_non_empty(line)
]


def messages_from_parquet(dir_path: Path) -> Messages:
def messages_from_parquet(dir_path: Path) -> list[Message]:
"""
Reads patient information from parquet files within directory structure
and transforms that into messages.
Expand Down Expand Up @@ -345,9 +331,6 @@ def messages_from_parquet(dir_path: Path) -> Messages:
f"{expected_col_names}"
)

# First line is column names
messages = Messages()

for col in expected_col_names:
if col not in list(cohort_data.columns):
msg = f"csv file expected to have at least {expected_col_names} as " f"column names"
Expand All @@ -367,17 +350,19 @@ def messages_from_parquet(dir_path: Path) -> Messages:
project_name = logs["settings"]["cdm_source_name"]
omop_es_timestamp = datetime.datetime.fromisoformat(logs["datetime"])

messages = []

for _, row in cohort_data.iterrows():
messages.append(
serialise(
mrn=row[mrn_col_name],
accession_number=row[acc_num_col_name],
study_datetime=row[dt_col_name],
procedure_occurrence_id=row[procedure_occurrence_id],
project_name=project_name,
omop_es_timestamp=omop_es_timestamp,
)
# Create new dict to initialise message
message = Message(
mrn=row[mrn_col_name],
accession_number=row[acc_num_col_name],
study_datetime=row[dt_col_name],
procedure_occurrence_id=row[procedure_occurrence_id],
project_name=project_name,
omop_es_timestamp=omop_es_timestamp,
)
messages.append(message)

if len(messages) == 0:
msg = f"Failed to find any messages in {dir_path}"
Expand Down Expand Up @@ -446,12 +431,3 @@ def api_config_for_queue(queue_name: str) -> APIConfig:
raise ValueError(msg)

return APIConfig(config[config_key])


def study_date_from_serialised(message: bytes) -> datetime.datetime:
"""Get the study date from a serialised message as a datetime"""
result = deserialise(message)["study_datetime"]
if not isinstance(result, datetime.datetime):
msg = "Expected study date to be a datetime. Got %s"
raise TypeError(msg, type(result))
return result
48 changes: 36 additions & 12 deletions cli/tests/test_messages_from_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.
"""Unit tests for reading cohorts from parquet files."""

import datetime
from pathlib import Path

from core.patient_queue.message import Message
from pixl_cli.main import messages_from_parquet


Expand All @@ -25,19 +27,41 @@ def test_messages_from_parquet(resources: Path) -> None:
"""
omop_parquet_dir = resources / "omop"
messages = messages_from_parquet(omop_parquet_dir)
assert all(isinstance(msg, Message) for msg in messages)

expected_messages = [
b'{"mrn": "12345678", "accession_number": "12345678", "study_datetime": "2021-07-01", '
b'"procedure_occurrence_id": 1, "project_name": "Test Extract - UCLH OMOP CDM", '
b'"omop_es_timestamp": "2023-12-07T14:08:58"}',
b'{"mrn": "12345678", "accession_number": "ABC1234567", "study_datetime": "2021-07-01", '
b'"procedure_occurrence_id": 2, "project_name": "Test Extract - UCLH OMOP CDM", '
b'"omop_es_timestamp": "2023-12-07T14:08:58"}',
b'{"mrn": "987654321", "accession_number": "ABC1234560", "study_datetime": "2020-05-01", '
b'"procedure_occurrence_id": 3, "project_name": "Test Extract - UCLH OMOP CDM", '
b'"omop_es_timestamp": "2023-12-07T14:08:58"}',
b'{"mrn": "5020765", "accession_number": "MIG0234560", "study_datetime": "2015-05-01", '
b'"procedure_occurrence_id": 4, "project_name": "Test Extract - UCLH OMOP CDM", '
b'"omop_es_timestamp": "2023-12-07T14:08:58"}',
Message(
mrn="12345678",
accession_number="12345678",
study_datetime=datetime.date.fromisoformat("2021-07-01"),
procedure_occurrence_id=1,
project_name="Test Extract - UCLH OMOP CDM",
omop_es_timestamp=datetime.datetime.fromisoformat("2023-12-07T14:08:58"),
),
Message(
mrn="12345678",
accession_number="ABC1234567",
study_datetime=datetime.date.fromisoformat("2021-07-01"),
procedure_occurrence_id=2,
project_name="Test Extract - UCLH OMOP CDM",
omop_es_timestamp=datetime.datetime.fromisoformat("2023-12-07T14:08:58"),
),
Message(
mrn="987654321",
accession_number="ABC1234560",
study_datetime=datetime.date.fromisoformat("2020-05-01"),
procedure_occurrence_id=3,
project_name="Test Extract - UCLH OMOP CDM",
omop_es_timestamp=datetime.datetime.fromisoformat("2023-12-07T14:08:58"),
),
Message(
mrn="5020765",
accession_number="MIG0234560",
study_datetime=datetime.date.fromisoformat("2015-05-01"),
procedure_occurrence_id=4,
project_name="Test Extract - UCLH OMOP CDM",
omop_es_timestamp=datetime.datetime.fromisoformat("2023-12-07T14:08:58"),
),
]

assert messages == expected_messages
3 changes: 2 additions & 1 deletion pixl_core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ dependencies = [
"pika==1.3.1",
"aio_pika==8.2.4",
"environs==9.5.0",
"requests==2.31.0"
"requests==2.31.0",
"jsonpickle==3.0.2"
]

[project.optional-dependencies]
Expand Down
74 changes: 74 additions & 0 deletions pixl_core/src/core/patient_queue/message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) 2022 University College London Hospitals NHS Foundation Trust
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Classes to represent messages in the patient queue."""

import logging
from dataclasses import dataclass
from datetime import datetime
from typing import Any

from jsonpickle import decode, encode

logger = logging.getLogger(__name__)


@dataclass
class Message:
"""Class to represent a message containing the relevant information for a study."""

mrn: str
accession_number: str
study_datetime: datetime
procedure_occurrence_id: str
project_name: str
omop_es_timestamp: datetime

def serialise(self, *, deserialisable: bool = True) -> bytes:
"""
Serialise the message into a JSON string and convert to bytes.
:param deserialisable: If True, the serialised message will be deserialisable, by setting
the unpicklable flag to False in jsonpickle.encode(), meaning that the original Message
object can be recovered by `deserialise()`. If False, calling `deserialise()` on the
serialised message will return a dictionary.
"""
msg = (
"Serialising message with\n"
" * patient id: %s\n"
" * accession number: %s\n"
" * timestamp: %s\n"
" * procedure_occurrence_id: %s\n",
" * project_name: %s\n * omop_es_timestamp: %s",
self.mrn,
self.accession_number,
self.study_datetime,
self.procedure_occurrence_id,
self.project_name,
self.omop_es_timestamp,
)
logger.debug(msg)

return str.encode(encode(self, unpicklable=deserialisable))


def deserialise(serialised_msg: bytes) -> Any:
"""
Deserialise a message from a bytes-encoded JSON string.
If the message was serialised with `deserialisable=True`, the original Message object will be
returned. Otherwise, a dictionary will be returned.
:param serialised_msg: The serialised message.
"""
return decode(serialised_msg) # noqa: S301, since we control the input, so no security risks
12 changes: 9 additions & 3 deletions pixl_core/src/core/patient_queue/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import logging
from time import sleep

from core.patient_queue.message import Message

from ._base import PixlBlockingInterface

LOGGER = logging.getLogger(__name__)
Expand All @@ -24,19 +26,23 @@
class PixlProducer(PixlBlockingInterface):
"""Generic publisher for RabbitMQ"""

def publish(self, messages: list[bytes]) -> None:
def publish(self, messages: list[Message]) -> None:
"""
Sends a list of serialised messages to a queue.
:param messages: list of messages to be sent to queue
"""
LOGGER.debug("Publishing %i messages to queue: %s", len(messages), self.queue_name)
if len(messages) > 0:
for msg in messages:
LOGGER.debug("Serialising message")
serialised_msg = msg.serialise()
LOGGER.debug("Preparing to publish")
self._channel.basic_publish(exchange="", routing_key=self.queue_name, body=msg)
self._channel.basic_publish(
exchange="", routing_key=self.queue_name, body=serialised_msg
)
# RabbitMQ can miss-order messages if there is not a sufficient delay
sleep(0.1)
LOGGER.debug("Message %s published to queue %s", msg.decode(), self.queue_name)
LOGGER.debug("Message %s published to queue %s", msg, self.queue_name)
else:
LOGGER.debug("List of messages is empty so nothing will be published to queue.")

Expand Down
5 changes: 3 additions & 2 deletions pixl_core/src/core/patient_queue/subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import aio_pika

from core.patient_queue.message import Message, deserialise
from core.token_buffer.tokens import TokenBucket

from ._base import PixlBlockingInterface, PixlQueueInterface
Expand Down Expand Up @@ -52,7 +53,7 @@ async def __aenter__(self) -> "PixlConsumer":
self._queue = await self._channel.declare_queue(self.queue_name)
return self

async def run(self, callback: Callable[[bytes], Awaitable[None]]) -> None:
async def run(self, callback: Callable[[Message], Awaitable[None]]) -> None:
"""
Creates loop that waits for messages from producer and processes them as
they appear.
Expand All @@ -73,7 +74,7 @@ async def run(self, callback: Callable[[bytes], Awaitable[None]]) -> None:

try:
await asyncio.sleep(0.01) # Avoid very fast callbacks
await callback(message.body)
await callback(deserialise(message.body))
except Exception:
LOGGER.exception(
"Failed to process %s" "Not re-queuing message",
Expand Down
Loading

0 comments on commit 90b6672

Please sign in to comment.