From 642d1197c544b33b7045fa1e89eb67bb815f999d Mon Sep 17 00:00:00 2001 From: ktro2828 Date: Wed, 25 Dec 2024 20:18:22 +0900 Subject: [PATCH] fix: replace np.array instead of np.asarray in dataclass constructors Signed-off-by: ktro2828 --- t4_devkit/common/converter.py | 4 ++-- t4_devkit/dataclass/box.py | 8 ++++---- t4_devkit/dataclass/pointcloud.py | 6 +++--- t4_devkit/dataclass/roi.py | 10 ++++++---- t4_devkit/dataclass/shape.py | 2 +- t4_devkit/dataclass/trajectory.py | 9 +++++---- t4_devkit/dataclass/transform.py | 8 ++++---- t4_devkit/schema/tables/calibrated_sensor.py | 10 +++++----- t4_devkit/schema/tables/ego_pose.py | 19 +++++++------------ t4_devkit/schema/tables/keypoint.py | 2 +- t4_devkit/schema/tables/sample_annotation.py | 12 ++++++------ t4_devkit/schema/tables/vehicle_state.py | 5 ++--- 12 files changed, 46 insertions(+), 49 deletions(-) diff --git a/t4_devkit/common/converter.py b/t4_devkit/common/converter.py index 071bb07..8f74a38 100644 --- a/t4_devkit/common/converter.py +++ b/t4_devkit/common/converter.py @@ -8,10 +8,10 @@ if TYPE_CHECKING: from t4_devkit.typing import ArrayLike, NDArray -__all__ = ["as_quaternion"] +__all__ = ["to_quaternion"] -def as_quaternion(value: ArrayLike | NDArray) -> Quaternion: +def to_quaternion(value: ArrayLike | NDArray) -> Quaternion: """Convert input rotation like array to `Quaternion`. Args: diff --git a/t4_devkit/dataclass/box.py b/t4_devkit/dataclass/box.py index 4a141e9..04456a3 100644 --- a/t4_devkit/dataclass/box.py +++ b/t4_devkit/dataclass/box.py @@ -8,7 +8,7 @@ from shapely.geometry import Polygon from typing_extensions import Self -from t4_devkit.common.converter import as_quaternion +from t4_devkit.common.converter import to_quaternion from .roi import Roi from .trajectory import to_trajectories @@ -111,10 +111,10 @@ class Box3D(BaseBox): ... ) """ - position: TranslationType = field(converter=np.asarray) - rotation: RotationType = field(converter=as_quaternion) + position: TranslationType = field(converter=np.array) + rotation: RotationType = field(converter=to_quaternion) shape: Shape - velocity: VelocityType | None = field(default=None, converter=optional(np.asarray)) + velocity: VelocityType | None = field(default=None, converter=optional(np.array)) num_points: int | None = field(default=None) # additional attributes: set by `with_**` diff --git a/t4_devkit/dataclass/pointcloud.py b/t4_devkit/dataclass/pointcloud.py index d0a64dd..9f2d93d 100644 --- a/t4_devkit/dataclass/pointcloud.py +++ b/t4_devkit/dataclass/pointcloud.py @@ -25,10 +25,10 @@ class PointCloud: """Abstract base dataclass for pointcloud data.""" - points: NDArrayFloat = field(converter=np.asarray) + points: NDArrayFloat = field(converter=np.array) @points.validator - def check_dims(self, attribute, value) -> None: + def _check_dims(self, attribute, value) -> None: if value.shape[0] != self.num_dims(): raise ValueError( f"Expected point dimension is {self.num_dims()}, but got {value.shape[0]}" @@ -211,7 +211,7 @@ class SegmentationPointCloud(PointCloud): labels (NDArrayU8): Label matrix. """ - labels: NDArrayU8 = field(converter=lambda x: np.asarray(x, dtype=np.uint8)) + labels: NDArrayU8 = field(converter=lambda x: np.array(x, dtype=np.uint8)) @staticmethod def num_dims() -> int: diff --git a/t4_devkit/dataclass/roi.py b/t4_devkit/dataclass/roi.py index 4d12c3e..32ae58b 100644 --- a/t4_devkit/dataclass/roi.py +++ b/t4_devkit/dataclass/roi.py @@ -20,10 +20,12 @@ class Roi: roi: RoiType = field(converter=tuple) - def __attrs_post_init__(self) -> None: - assert len(self.roi) == 4, ( - "Expected roi is (x, y, width, height), " f"but got length with {len(self.roi)}." - ) + @roi.validator + def _check_dims(self, attribute, value) -> None: + if len(value) != 4: + raise ValueError( + f"Expected {attribute.name} is (x, y, width, height), but got length with {value}." + ) @property def offset(self) -> tuple[int, int]: diff --git a/t4_devkit/dataclass/shape.py b/t4_devkit/dataclass/shape.py index 83d21b1..d01b19f 100644 --- a/t4_devkit/dataclass/shape.py +++ b/t4_devkit/dataclass/shape.py @@ -52,7 +52,7 @@ class Shape: """ shape_type: ShapeType - size: SizeType = field(converter=np.asarray) + size: SizeType = field(converter=np.array) footprint: Polygon = field(default=None) def __attrs_post_init__(self) -> None: diff --git a/t4_devkit/dataclass/trajectory.py b/t4_devkit/dataclass/trajectory.py index 3987a5f..e23a6b6 100644 --- a/t4_devkit/dataclass/trajectory.py +++ b/t4_devkit/dataclass/trajectory.py @@ -41,12 +41,13 @@ class Trajectory: [2. 2. 2.] """ - waypoints: TrajectoryType = field(converter=np.asarray) + waypoints: TrajectoryType = field(converter=np.array) confidence: float = field(default=1.0) - def __attrs_post_init__(self) -> None: - if self.waypoints.shape[1] != 3: - raise ValueError("Trajectory dimension must be 3.") + @waypoints.validator + def _check_dims(self, attribute, value) -> None: + if value.shape[1] != 3: + raise ValueError(f"{attribute.name} dimension must be 3.") def __len__(self) -> int: return len(self.waypoints) diff --git a/t4_devkit/dataclass/transform.py b/t4_devkit/dataclass/transform.py index ae9a400..fbf84ed 100644 --- a/t4_devkit/dataclass/transform.py +++ b/t4_devkit/dataclass/transform.py @@ -8,7 +8,7 @@ from pyquaternion import Quaternion from typing_extensions import Self -from t4_devkit.common.converter import as_quaternion +from t4_devkit.common.converter import to_quaternion from t4_devkit.typing import NDArray, RotationType if TYPE_CHECKING: @@ -108,8 +108,8 @@ def do_transform(self, src: str, dst: str, *args, **kwargs) -> TransformItemLike @define class HomogeneousMatrix: - position: TranslationType = field(converter=np.asarray) - rotation: Quaternion = field(converter=as_quaternion) + position: TranslationType = field(converter=np.array) + rotation: Quaternion = field(converter=to_quaternion) src: str dst: str matrix: NDArray = field(init=False) @@ -498,7 +498,7 @@ def _generate_homogeneous_matrix( A 4x4 homogeneous matrix. """ position = np.asarray(position) - rotation = as_quaternion(rotation) + rotation = to_quaternion(rotation) matrix = np.eye(4) matrix[:3, 3] = position diff --git a/t4_devkit/schema/tables/calibrated_sensor.py b/t4_devkit/schema/tables/calibrated_sensor.py index 92e1084..34b8fe9 100644 --- a/t4_devkit/schema/tables/calibrated_sensor.py +++ b/t4_devkit/schema/tables/calibrated_sensor.py @@ -5,7 +5,7 @@ import numpy as np from attrs import define, field -from t4_devkit.common.converter import as_quaternion +from t4_devkit.common.converter import to_quaternion from ..name import SchemaName from .base import SchemaBase @@ -32,7 +32,7 @@ class CalibratedSensor(SchemaBase): """ sensor_token: str - translation: TranslationType = field(converter=np.asarray) - rotation: RotationType = field(converter=as_quaternion) - camera_intrinsic: CamIntrinsicType = field(converter=np.asarray) - camera_distortion: CamDistortionType = field(converter=np.asarray) + translation: TranslationType = field(converter=np.array) + rotation: RotationType = field(converter=to_quaternion) + camera_intrinsic: CamIntrinsicType = field(converter=np.array) + camera_distortion: CamDistortionType = field(converter=np.array) diff --git a/t4_devkit/schema/tables/ego_pose.py b/t4_devkit/schema/tables/ego_pose.py index 9963d20..1be6444 100644 --- a/t4_devkit/schema/tables/ego_pose.py +++ b/t4_devkit/schema/tables/ego_pose.py @@ -4,8 +4,9 @@ import numpy as np from attrs import define, field +from attrs.converters import optional -from t4_devkit.common.converter import as_quaternion +from t4_devkit.common.converter import to_quaternion from ..name import SchemaName from .base import SchemaBase @@ -42,15 +43,9 @@ class EgoPose(SchemaBase): (latitude, longitude, altitude) in degrees and meters. """ - translation: TranslationType = field(converter=np.asarray) - rotation: RotationType = field(converter=as_quaternion) + translation: TranslationType = field(converter=np.array) + rotation: RotationType = field(converter=to_quaternion) timestamp: int - twist: TwistType | None = field( - default=None, converter=lambda x: np.asarray(x) if x is not None else x - ) - acceleration: AccelerationType | None = field( - default=None, converter=lambda x: np.asarray(x) if x is not None else x - ) - geocoordinate: GeoCoordinateType | None = field( - default=None, converter=lambda x: np.asarray(x) if x is not None else x - ) + twist: TwistType | None = field(default=None, converter=optional(np.array)) + acceleration: AccelerationType | None = field(default=None, converter=optional(np.array)) + geocoordinate: GeoCoordinateType | None = field(default=None, converter=optional(np.array)) diff --git a/t4_devkit/schema/tables/keypoint.py b/t4_devkit/schema/tables/keypoint.py index a8f794b..ad405a2 100644 --- a/t4_devkit/schema/tables/keypoint.py +++ b/t4_devkit/schema/tables/keypoint.py @@ -32,5 +32,5 @@ class Keypoint(SchemaBase): sample_data_token: str instance_token: str category_tokens: list[str] - keypoints: KeypointType = field(converter=np.asarray) + keypoints: KeypointType = field(converter=np.array) num_keypoints: int diff --git a/t4_devkit/schema/tables/sample_annotation.py b/t4_devkit/schema/tables/sample_annotation.py index 6bf32aa..c9d3ae8 100644 --- a/t4_devkit/schema/tables/sample_annotation.py +++ b/t4_devkit/schema/tables/sample_annotation.py @@ -6,7 +6,7 @@ from attrs import define, field from attrs.converters import optional -from t4_devkit.common.converter import as_quaternion +from t4_devkit.common.converter import to_quaternion from ..name import SchemaName from .base import SchemaBase @@ -58,15 +58,15 @@ class SampleAnnotation(SchemaBase): instance_token: str attribute_tokens: list[str] visibility_token: str - translation: TranslationType = field(converter=np.asarray) - size: SizeType = field(converter=np.asarray) - rotation: RotationType = field(converter=as_quaternion) + translation: TranslationType = field(converter=np.array) + size: SizeType = field(converter=np.array) + rotation: RotationType = field(converter=to_quaternion) num_lidar_pts: int num_radar_pts: int next: str # noqa: A003 prev: str - velocity: VelocityType | None = field(default=None, converter=optional(np.asarray)) - acceleration: AccelerationType | None = field(default=None, converter=optional(np.asarray)) + velocity: VelocityType | None = field(default=None, converter=optional(np.array)) + acceleration: AccelerationType | None = field(default=None, converter=optional(np.array)) # shortcuts category_name: str = field(init=False, factory=str) diff --git a/t4_devkit/schema/tables/vehicle_state.py b/t4_devkit/schema/tables/vehicle_state.py index a8cf3c9..56c1d28 100644 --- a/t4_devkit/schema/tables/vehicle_state.py +++ b/t4_devkit/schema/tables/vehicle_state.py @@ -3,6 +3,7 @@ from enum import Enum, unique from attrs import define, field +from attrs.converters import optional from ..name import SchemaName from .base import SchemaBase @@ -83,9 +84,7 @@ class VehicleState(SchemaBase): steer_pedal: float | None = field(default=None) steering_tire_angle: float | None = field(default=None) steering_wheel_angle: float | None = field(default=None) - shift_state: ShiftState | None = field( - default=None, converter=lambda x: None if x is None else ShiftState(x) - ) + shift_state: ShiftState | None = field(default=None, converter=optional(ShiftState)) indicators: Indicators | None = field( default=None, converter=lambda x: Indicators(**x) if isinstance(x, dict) else x )