Skip to content

Commit

Permalink
refactor(viewer): use attrs for rendering data (#77)
Browse files Browse the repository at this point in the history
* refactor: update the condition of visualizing velocity

Signed-off-by: ktro2828 <[email protected]>

* refactor: use attrs for rendering data

Signed-off-by: ktro2828 <[email protected]>

---------

Signed-off-by: ktro2828 <[email protected]>
  • Loading branch information
ktro2828 authored Dec 27, 2024
1 parent 5decb29 commit 77b2d5d
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 78 deletions.
1 change: 0 additions & 1 deletion t4_devkit/tier4.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,4 +1241,3 @@ def _append_mask(
camera_masks[camera]["class_ids"] = [class_id]
camera_masks[camera]["uuids"] = [class_id]
return camera_masks
return camera_masks
129 changes: 68 additions & 61 deletions t4_devkit/viewer/rendering_data/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import rerun as rr
from attrs import define, field

if TYPE_CHECKING:
from t4_devkit.dataclass import Box2D, Box3D
Expand All @@ -12,23 +13,28 @@
__all__ = ["BoxData3D", "BoxData2D"]


@define
class BoxData3D:
"""A class to store 3D boxes data for rendering."""

def __init__(self, label2id: dict[str, int] | None = None) -> None:
"""Construct a new object.
Args:
label2id (dict[str, int] | None, optional): Key-value mapping which maps label name to its class ID.
"""
self._centers: list[TranslationType] = []
self._rotations: list[rr.Quaternion] = []
self._sizes: list[SizeType] = []
self._class_ids: list[int] = []
self._uuids: list[int] = []
self._velocities: list[VelocityType] = []

self._label2id: dict[str, int] = {} if label2id is None else label2id
"""A class to store 3D boxes data for rendering.
Attributes:
label2id (dict[str, int]): Key-value of map of label name and its ID.
centers (list[TranslationType]): List of 3D center positions in the order of (x, y, z).
rotations (list[rr.Quaternion]): List of quaternions.
sizes (list[SizeType]): List of 3D box dimensions in the order of (width, length, height).
class_ids (list[int]): List of label class IDs.
uuids (list[str]): List of unique identifier IDs.
velocities (list[Velocities]): List of velocities in the order of (vx, vy, vz).
"""

label2id: dict[str, int] = field(factory=dict)

centers: list[TranslationType] = field(init=False, factory=list)
rotations: list[rr.Quaternion] = field(init=False, factory=list)
sizes: list[SizeType] = field(init=False, factory=list)
class_ids: list[int] = field(init=False, factory=list)
uuids: list[str] = field(init=False, factory=list)
velocities: list[VelocityType] = field(init=False, factory=list)

@overload
def append(self, box: Box3D) -> None:
Expand Down Expand Up @@ -56,8 +62,8 @@ def append(
rotation (RotationType): Quaternion.
size (SizeType): Box size in the order of (width, height, length).
class_id (int): Class ID.
velocity (VelocityType | None, optional): Box velocity. Defaults to None.
uuid (str | None, optional): Unique identifier.
velocity (VelocityType | None, optional): Box velocity. Defaults to None.
"""
pass

Expand All @@ -68,24 +74,24 @@ def append(self, *args, **kwargs) -> None:
self._append_with_elements(*args, **kwargs)

def _append_with_box(self, box: Box3D) -> None:
self._centers.append(box.position)
self.centers.append(box.position)

rotation_xyzw = np.roll(box.rotation.q, shift=-1)
self._rotations.append(rr.Quaternion(xyzw=rotation_xyzw))
self.rotations.append(rr.Quaternion(xyzw=rotation_xyzw))

width, length, height = box.size
self._sizes.append((length, width, height))
self.sizes.append((length, width, height))

if box.semantic_label.name not in self._label2id:
self._label2id[box.semantic_label.name] = len(self._label2id)
if box.semantic_label.name not in self.label2id:
self.label2id[box.semantic_label.name] = len(self.label2id)

self._class_ids.append(self._label2id[box.semantic_label.name])
self.class_ids.append(self.label2id[box.semantic_label.name])

if box.velocity is not None:
self._velocities.append(box.velocity)
self.velocities.append(box.velocity)

if box.uuid is not None:
self._uuids.append(box.uuid[:6])
self.uuids.append(box.uuid[:6])

def _append_with_elements(
self,
Expand All @@ -96,35 +102,35 @@ def _append_with_elements(
velocity: VelocityType | None = None,
uuid: str | None = None,
) -> None:
self._centers.append(center)
self.centers.append(center)

rotation_xyzw = np.roll(rotation.q, shift=-1)
self._rotations.append(rr.Quaternion(xyzw=rotation_xyzw))
self.rotations.append(rr.Quaternion(xyzw=rotation_xyzw))

width, length, height = size
self._sizes.append((length, width, height))
self.sizes.append((length, width, height))

self._class_ids.append(class_id)
self.class_ids.append(class_id)

if velocity is not None:
self._velocities.append(velocity)
self.velocities.append(velocity)

if uuid is not None:
self._uuids.append(uuid)
self.uuids.append(uuid)

def as_boxes3d(self) -> rr.Boxes3D:
"""Return 3D boxes data as a `rr.Boxes3D`.
Returns:
`rr.Boxes3D` object.
"""
labels = None if len(self._uuids) == 0 else self._uuids
labels = None if len(self.uuids) == 0 else self.uuids
return rr.Boxes3D(
sizes=self._sizes,
centers=self._centers,
rotations=self._rotations,
sizes=self.sizes,
centers=self.centers,
rotations=self.rotations,
labels=labels,
class_ids=self._class_ids,
class_ids=self.class_ids,
)

def as_arrows3d(self) -> rr.Arrows3D:
Expand All @@ -134,26 +140,27 @@ def as_arrows3d(self) -> rr.Arrows3D:
`rr.Arrows3D` object.
"""
return rr.Arrows3D(
vectors=self._velocities,
origins=self._centers,
class_ids=self._class_ids,
vectors=self.velocities,
origins=self.centers,
class_ids=self.class_ids,
)


@define
class BoxData2D:
"""A class to store 2D boxes data for rendering."""

def __init__(self, label2id: dict[str, int] | None = None) -> None:
"""Construct a new object.
"""A class to store 2D boxes data for rendering.
Args:
label2id (dict[str, int] | None, optional): Key-value mapping which maps label name to its class ID.
"""
self._rois: list[RoiType] = []
self._uuids: list[str] = []
self._class_ids: list[int] = []
Attributes:
label2id (dict[str, int]): Key-value of map of label name and its ID.
rois (list[RoiType]): List of ROIs in the order of (xmin, ymin, xmax, ymax).
class_ids (list[int]): List of label class IDs.
uuids (list[str]): List of unique identifier IDs.
"""

self._label2id: dict[str, int] = {} if label2id is None else label2id
label2id: dict[str, int] = field(factory=dict)
rois: list[RoiType] = field(init=False, factory=list)
class_ids: list[int] = field(init=False, factory=list)
uuids: list[str] = field(init=False, factory=list)

@overload
def append(self, box: Box2D) -> None:
Expand Down Expand Up @@ -182,34 +189,34 @@ def append(self, *args, **kwargs) -> None:
self._append_with_elements(*args, **kwargs)

def _append_with_box(self, box: Box2D) -> None:
self._rois.append(box.roi.roi)
self.rois.append(box.roi.roi)

if box.semantic_label.name not in self._label2id:
self._label2id[box.semantic_label.name] = len(self._label2id)
if box.semantic_label.name not in self.label2id:
self.label2id[box.semantic_label.name] = len(self.label2id)

self._class_ids.append(self._label2id[box.semantic_label.name])
self.class_ids.append(self.label2id[box.semantic_label.name])

if box.uuid is not None:
self._uuids.append(box.uuid)
self.uuids.append(box.uuid)

def _append_with_elements(self, roi: RoiType, class_id: int, uuid: str | None = None) -> None:
self._rois.append(roi)
self.rois.append(roi)

self._class_ids.append(class_id)
self.class_ids.append(class_id)

if uuid is not None:
self._uuids.append(uuid)
self.uuids.append(uuid)

def as_boxes2d(self) -> rr.Boxes2D:
"""Return 2D boxes data as a `rr.Boxes2D`.
Returns:
`rr.Boxes2D` object.
"""
labels = None if len(self._uuids) == 0 else self._uuids
labels = None if len(self.uuids) == 0 else self.uuids
return rr.Boxes2D(
array=self._rois,
array=self.rois,
array_format=rr.Box2DFormat.XYXY,
labels=labels,
class_ids=self._class_ids,
class_ids=self.class_ids,
)
37 changes: 22 additions & 15 deletions t4_devkit/viewer/rendering_data/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,29 @@

import numpy as np
import rerun as rr
from attrs import define, field

if TYPE_CHECKING:
from t4_devkit.typing import NDArrayU8

__all__ = ["SegmentationData2D"]


@define
class SegmentationData2D:
"""A class to store 2D segmentation image data for rendering."""
"""A class to store 2D segmentation image data for rendering.
def __init__(self) -> None:
self._masks: list[NDArrayU8] = []
self._class_ids: list[int] = []
self._uuids: list[str | None] = []
Attributes:
masks (list[NDArray]): List of segmentation masks in the shape of (H, W).
class_ids (list[int]): List of label class IDs.
uuids (list[str]): List of unique identifier IDs.
size (tuple[int, int] | None): Size of image in the order of (height, width).
"""

self._size: tuple[int, int] = None # (height, width)
masks: list[NDArrayU8] = field(init=False, factory=list)
class_ids: list[int] = field(init=False, factory=list)
uuids: list[str] = field(init=False, factory=list)
size: tuple[int, int] | None = field(init=False, default=None)

def append(self, mask: NDArrayU8, class_id: int, uuid: str | None = None) -> None:
"""Append a segmentation mask and its class ID.
Expand All @@ -34,19 +41,19 @@ def append(self, mask: NDArrayU8, class_id: int, uuid: str | None = None) -> Non
- Expecting all masks are 2D array (H, W).
- Expecting all masks are the same shape (H, W).
"""
if self._size is None:
if self.size is None:
if mask.ndim != 2:
raise ValueError("Expected the mask is 2D array (H, W).")
self._size = mask.shape
self.size = mask.shape
else:
if self._size != mask.shape:
if self.size != mask.shape:
raise ValueError(
f"All masks must be the same size. Expected: {self._size}, "
f"All masks must be the same size. Expected: {self.size}, "
f"but got {mask.shape}."
)
self._masks.append(mask)
self._class_ids.append(class_id)
self._uuids.append(uuid)
self.masks.append(mask)
self.class_ids.append(class_id)
self.uuids.append(uuid)

def as_segmentation_image(self) -> rr.SegmentationImage:
"""Return mask data as a `rr.SegmentationImage`.
Expand All @@ -57,9 +64,9 @@ def as_segmentation_image(self) -> rr.SegmentationImage:
TODO:
Add support of instance segmentation.
"""
image = np.zeros(self._size, dtype=np.uint8)
image = np.zeros(self.size, dtype=np.uint8)

for mask, class_id in zip(self._masks, self._class_ids, strict=True):
for mask, class_id in zip(self.masks, self.class_ids, strict=True):
image[mask == 1] = class_id

return rr.SegmentationImage(image=image)
5 changes: 4 additions & 1 deletion t4_devkit/viewer/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,9 @@ def _render_box3ds_with_elements(

if velocities is None:
velocities = [None] * len(centers)
show_arrows = False
else:
show_arrows = True

box_data = BoxData3D(label2id=self.label2id)
for center, rotation, size, class_id, velocity, uuid in zip(
Expand All @@ -293,7 +296,7 @@ def _render_box3ds_with_elements(

rr.log(format_entity(self.ego_entity, "box"), box_data.as_boxes3d())

if velocities is not None:
if show_arrows:
rr.log(format_entity(self.ego_entity, "velocity"), box_data.as_arrows3d())

@overload
Expand Down

0 comments on commit 77b2d5d

Please sign in to comment.