Skip to content

Commit

Permalink
feat(dataclass): add serialization function (#68)
Browse files Browse the repository at this point in the history
* feat: add a function to serialize dataclass

Signed-off-by: ktro2828 <[email protected]>

* refactor: remove unused shortcuts method

Signed-off-by: ktro2828 <[email protected]>

---------

Signed-off-by: ktro2828 <[email protected]>
  • Loading branch information
ktro2828 authored Dec 11, 2024
1 parent 8d9061c commit 2c9be73
Show file tree
Hide file tree
Showing 11 changed files with 42 additions and 56 deletions.
4 changes: 4 additions & 0 deletions docs/apis/common.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# `common`

<!-- prettier-ignore-start -->
::: t4_devkit.common.converter

::: t4_devkit.common.geometry

::: t4_devkit.common.io

::: t4_devkit.common.serialize

::: t4_devkit.common.timestamp
<!-- prettier-ignore-end -->
35 changes: 35 additions & 0 deletions t4_devkit/common/serialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

from enum import Enum
from typing import Any

import numpy as np
from attrs import asdict, fields, filters
from pyquaternion import Quaternion


def serialize_dataclass(data: Any) -> dict[str, Any]:
"""Serialize attrs' dataclasses into dict.
Note that all fields specified with `init=False` will be skipped to be serialized.
Args:
data (Any): Dataclass object.
Returns:
dict[str, Any]: Serialized dict.
"""
excludes = filters.exclude(*[a for a in fields(data.__class__) if not a.init])
return asdict(data, filter=excludes, value_serializer=_value_serializer)


def _value_serializer(data: Any, attribute: Any, value: Any) -> Any:
if isinstance(value, np.ndarray):
return value.tolist()
elif isinstance(value, Quaternion):
return value.q.tolist()
elif isinstance(value, tuple):
return list(value)
elif isinstance(value, Enum):
return value.value
return value
22 changes: 3 additions & 19 deletions t4_devkit/schema/serialize.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

import numpy as np
from attrs import asdict, filters
from pyquaternion import Quaternion
from t4_devkit.common.serialize import serialize_dataclass

if TYPE_CHECKING:
from .tables import SchemaTable
Expand Down Expand Up @@ -35,17 +32,4 @@ def serialize_schema(data: SchemaTable) -> dict:
Returns:
Serialized dict data.
"""
excludes = filters.exclude(*data.shortcuts()) if data.shortcuts() is not None else None
return asdict(data, filter=excludes, value_serializer=_value_serializer)


def _value_serializer(data: SchemaTable, attr: Any, value: Any) -> Any:
if isinstance(value, np.ndarray):
return value.tolist()
elif isinstance(value, Quaternion):
return value.q.tolist()
elif isinstance(value, tuple):
return list(value)
elif isinstance(value, Enum):
return value.value
return value
return serialize_dataclass(data)
9 changes: 0 additions & 9 deletions t4_devkit/schema/tables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,6 @@ class SchemaBase(ABC):

token: str

@staticmethod
def shortcuts() -> tuple[str, ...] | None:
"""Return a sequence of shortcut field names.
Returns:
Returns None if there is no shortcut. Otherwise, returns sequence of shortcut field names.
"""
return None

@classmethod
def from_json(cls, filepath: str) -> list[SchemaTable]:
"""Construct dataclass from json file.
Expand Down
4 changes: 0 additions & 4 deletions t4_devkit/schema/tables/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,3 @@ class Log(SchemaBase):

# shortcuts
map_token: str = field(init=False, factory=str)

@staticmethod
def shortcuts() -> tuple[str]:
return ("map_token",)
4 changes: 0 additions & 4 deletions t4_devkit/schema/tables/object_ann.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,6 @@ class ObjectAnn(SchemaBase):
# shortcuts
category_name: str = field(init=False, factory=str)

@staticmethod
def shortcuts() -> tuple[str]:
return ("category_name",)

@property
def width(self) -> int:
"""Return the width of the bounding box.
Expand Down
4 changes: 0 additions & 4 deletions t4_devkit/schema/tables/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,3 @@ class Sample(SchemaBase):
ann_3ds: list[str] = field(factory=list, init=False)
ann_2ds: list[str] = field(factory=list, init=False)
surface_anns: list[str] = field(factory=list, init=False)

@staticmethod
def shortcuts() -> tuple[str, str, str, str]:
return ("data", "ann_3ds", "ann_2ds", "surface_anns")
4 changes: 0 additions & 4 deletions t4_devkit/schema/tables/sample_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,3 @@ class SampleAnnotation(SchemaBase):

# shortcuts
category_name: str = field(init=False, factory=str)

@staticmethod
def shortcuts() -> tuple[str]:
return ("category_name",)
4 changes: 0 additions & 4 deletions t4_devkit/schema/tables/sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,3 @@ class SampleData(SchemaBase):
# shortcuts
modality: SensorModality | None = field(init=False, default=None)
channel: str = field(init=False, factory=str)

@staticmethod
def shortcuts() -> tuple[str, str]:
return ("modality", "channel")
4 changes: 0 additions & 4 deletions t4_devkit/schema/tables/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,3 @@ class Sensor(SchemaBase):

# shortcuts
first_sd_token: str = field(init=False, factory=str)

@staticmethod
def shortcuts() -> tuple[str] | None:
return ("first_sd_token",)
4 changes: 0 additions & 4 deletions t4_devkit/schema/tables/surface_ann.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ class SurfaceAnn(SchemaBase):
# shortcuts
category_name: str = field(init=False, factory=str)

@staticmethod
def shortcuts() -> tuple[str]:
return ("category_name",)

@property
def bbox(self) -> RoiType:
"""Return a bounding box corners calculated from polygon vertices.
Expand Down

0 comments on commit 2c9be73

Please sign in to comment.