From 2c9be73b58d4b0a91609eef15776bbddba5016f3 Mon Sep 17 00:00:00 2001 From: Kotaro Uetake <60615504+ktro2828@users.noreply.github.com> Date: Wed, 11 Dec 2024 14:53:24 +0900 Subject: [PATCH] feat(dataclass): add serialization function (#68) * feat: add a function to serialize dataclass Signed-off-by: ktro2828 * refactor: remove unused shortcuts method Signed-off-by: ktro2828 --------- Signed-off-by: ktro2828 --- docs/apis/common.md | 4 +++ t4_devkit/common/serialize.py | 35 ++++++++++++++++++++ t4_devkit/schema/serialize.py | 22 ++---------- t4_devkit/schema/tables/base.py | 9 ----- t4_devkit/schema/tables/log.py | 4 --- t4_devkit/schema/tables/object_ann.py | 4 --- t4_devkit/schema/tables/sample.py | 4 --- t4_devkit/schema/tables/sample_annotation.py | 4 --- t4_devkit/schema/tables/sample_data.py | 4 --- t4_devkit/schema/tables/sensor.py | 4 --- t4_devkit/schema/tables/surface_ann.py | 4 --- 11 files changed, 42 insertions(+), 56 deletions(-) create mode 100644 t4_devkit/common/serialize.py diff --git a/docs/apis/common.md b/docs/apis/common.md index aae2712..eff2b8d 100644 --- a/docs/apis/common.md +++ b/docs/apis/common.md @@ -1,9 +1,13 @@ # `common` +::: t4_devkit.common.converter + ::: t4_devkit.common.geometry ::: t4_devkit.common.io +::: t4_devkit.common.serialize + ::: t4_devkit.common.timestamp diff --git a/t4_devkit/common/serialize.py b/t4_devkit/common/serialize.py new file mode 100644 index 0000000..f52c71b --- /dev/null +++ b/t4_devkit/common/serialize.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from enum import Enum +from typing import Any + +import numpy as np +from attrs import asdict, fields, filters +from pyquaternion import Quaternion + + +def serialize_dataclass(data: Any) -> dict[str, Any]: + """Serialize attrs' dataclasses into dict. + + Note that all fields specified with `init=False` will be skipped to be serialized. + + Args: + data (Any): Dataclass object. + + Returns: + dict[str, Any]: Serialized dict. + """ + excludes = filters.exclude(*[a for a in fields(data.__class__) if not a.init]) + return asdict(data, filter=excludes, value_serializer=_value_serializer) + + +def _value_serializer(data: Any, attribute: Any, value: Any) -> Any: + if isinstance(value, np.ndarray): + return value.tolist() + elif isinstance(value, Quaternion): + return value.q.tolist() + elif isinstance(value, tuple): + return list(value) + elif isinstance(value, Enum): + return value.value + return value diff --git a/t4_devkit/schema/serialize.py b/t4_devkit/schema/serialize.py index d8cee1e..dfea46e 100644 --- a/t4_devkit/schema/serialize.py +++ b/t4_devkit/schema/serialize.py @@ -1,11 +1,8 @@ from __future__ import annotations -from enum import Enum -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING -import numpy as np -from attrs import asdict, filters -from pyquaternion import Quaternion +from t4_devkit.common.serialize import serialize_dataclass if TYPE_CHECKING: from .tables import SchemaTable @@ -35,17 +32,4 @@ def serialize_schema(data: SchemaTable) -> dict: Returns: Serialized dict data. """ - excludes = filters.exclude(*data.shortcuts()) if data.shortcuts() is not None else None - return asdict(data, filter=excludes, value_serializer=_value_serializer) - - -def _value_serializer(data: SchemaTable, attr: Any, value: Any) -> Any: - if isinstance(value, np.ndarray): - return value.tolist() - elif isinstance(value, Quaternion): - return value.q.tolist() - elif isinstance(value, tuple): - return list(value) - elif isinstance(value, Enum): - return value.value - return value + return serialize_dataclass(data) diff --git a/t4_devkit/schema/tables/base.py b/t4_devkit/schema/tables/base.py index 3bf8c17..4317bf7 100644 --- a/t4_devkit/schema/tables/base.py +++ b/t4_devkit/schema/tables/base.py @@ -17,15 +17,6 @@ class SchemaBase(ABC): token: str - @staticmethod - def shortcuts() -> tuple[str, ...] | None: - """Return a sequence of shortcut field names. - - Returns: - Returns None if there is no shortcut. Otherwise, returns sequence of shortcut field names. - """ - return None - @classmethod def from_json(cls, filepath: str) -> list[SchemaTable]: """Construct dataclass from json file. diff --git a/t4_devkit/schema/tables/log.py b/t4_devkit/schema/tables/log.py index deed631..e920be1 100644 --- a/t4_devkit/schema/tables/log.py +++ b/t4_devkit/schema/tables/log.py @@ -34,7 +34,3 @@ class Log(SchemaBase): # shortcuts map_token: str = field(init=False, factory=str) - - @staticmethod - def shortcuts() -> tuple[str]: - return ("map_token",) diff --git a/t4_devkit/schema/tables/object_ann.py b/t4_devkit/schema/tables/object_ann.py index ff5dbf0..da4354d 100644 --- a/t4_devkit/schema/tables/object_ann.py +++ b/t4_devkit/schema/tables/object_ann.py @@ -76,10 +76,6 @@ class ObjectAnn(SchemaBase): # shortcuts category_name: str = field(init=False, factory=str) - @staticmethod - def shortcuts() -> tuple[str]: - return ("category_name",) - @property def width(self) -> int: """Return the width of the bounding box. diff --git a/t4_devkit/schema/tables/sample.py b/t4_devkit/schema/tables/sample.py index ca83564..f44577a 100644 --- a/t4_devkit/schema/tables/sample.py +++ b/t4_devkit/schema/tables/sample.py @@ -43,7 +43,3 @@ class Sample(SchemaBase): ann_3ds: list[str] = field(factory=list, init=False) ann_2ds: list[str] = field(factory=list, init=False) surface_anns: list[str] = field(factory=list, init=False) - - @staticmethod - def shortcuts() -> tuple[str, str, str, str]: - return ("data", "ann_3ds", "ann_2ds", "surface_anns") diff --git a/t4_devkit/schema/tables/sample_annotation.py b/t4_devkit/schema/tables/sample_annotation.py index 7fc7926..6bf32aa 100644 --- a/t4_devkit/schema/tables/sample_annotation.py +++ b/t4_devkit/schema/tables/sample_annotation.py @@ -70,7 +70,3 @@ class SampleAnnotation(SchemaBase): # shortcuts category_name: str = field(init=False, factory=str) - - @staticmethod - def shortcuts() -> tuple[str]: - return ("category_name",) diff --git a/t4_devkit/schema/tables/sample_data.py b/t4_devkit/schema/tables/sample_data.py index 41994a6..adc2199 100644 --- a/t4_devkit/schema/tables/sample_data.py +++ b/t4_devkit/schema/tables/sample_data.py @@ -106,7 +106,3 @@ class SampleData(SchemaBase): # shortcuts modality: SensorModality | None = field(init=False, default=None) channel: str = field(init=False, factory=str) - - @staticmethod - def shortcuts() -> tuple[str, str]: - return ("modality", "channel") diff --git a/t4_devkit/schema/tables/sensor.py b/t4_devkit/schema/tables/sensor.py index 19fe895..8e684b9 100644 --- a/t4_devkit/schema/tables/sensor.py +++ b/t4_devkit/schema/tables/sensor.py @@ -45,7 +45,3 @@ class Sensor(SchemaBase): # shortcuts first_sd_token: str = field(init=False, factory=str) - - @staticmethod - def shortcuts() -> tuple[str] | None: - return ("first_sd_token",) diff --git a/t4_devkit/schema/tables/surface_ann.py b/t4_devkit/schema/tables/surface_ann.py index f032ce3..085d23e 100644 --- a/t4_devkit/schema/tables/surface_ann.py +++ b/t4_devkit/schema/tables/surface_ann.py @@ -39,10 +39,6 @@ class SurfaceAnn(SchemaBase): # shortcuts category_name: str = field(init=False, factory=str) - @staticmethod - def shortcuts() -> tuple[str]: - return ("category_name",) - @property def bbox(self) -> RoiType: """Return a bounding box corners calculated from polygon vertices.