Skip to content

Commit

Permalink
fix: replace np.array instead of np.asarray in dataclass constructors
Browse files Browse the repository at this point in the history
Signed-off-by: ktro2828 <[email protected]>
  • Loading branch information
ktro2828 committed Dec 25, 2024
1 parent aadb39c commit 642d119
Show file tree
Hide file tree
Showing 12 changed files with 46 additions and 49 deletions.
4 changes: 2 additions & 2 deletions t4_devkit/common/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
if TYPE_CHECKING:
from t4_devkit.typing import ArrayLike, NDArray

__all__ = ["as_quaternion"]
__all__ = ["to_quaternion"]


def as_quaternion(value: ArrayLike | NDArray) -> Quaternion:
def to_quaternion(value: ArrayLike | NDArray) -> Quaternion:
"""Convert input rotation like array to `Quaternion`.
Args:
Expand Down
8 changes: 4 additions & 4 deletions t4_devkit/dataclass/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from shapely.geometry import Polygon
from typing_extensions import Self

from t4_devkit.common.converter import as_quaternion
from t4_devkit.common.converter import to_quaternion

from .roi import Roi
from .trajectory import to_trajectories
Expand Down Expand Up @@ -111,10 +111,10 @@ class Box3D(BaseBox):
... )
"""

position: TranslationType = field(converter=np.asarray)
rotation: RotationType = field(converter=as_quaternion)
position: TranslationType = field(converter=np.array)
rotation: RotationType = field(converter=to_quaternion)
shape: Shape
velocity: VelocityType | None = field(default=None, converter=optional(np.asarray))
velocity: VelocityType | None = field(default=None, converter=optional(np.array))
num_points: int | None = field(default=None)

# additional attributes: set by `with_**`
Expand Down
6 changes: 3 additions & 3 deletions t4_devkit/dataclass/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
class PointCloud:
"""Abstract base dataclass for pointcloud data."""

points: NDArrayFloat = field(converter=np.asarray)
points: NDArrayFloat = field(converter=np.array)

@points.validator
def check_dims(self, attribute, value) -> None:
def _check_dims(self, attribute, value) -> None:
if value.shape[0] != self.num_dims():
raise ValueError(
f"Expected point dimension is {self.num_dims()}, but got {value.shape[0]}"
Expand Down Expand Up @@ -211,7 +211,7 @@ class SegmentationPointCloud(PointCloud):
labels (NDArrayU8): Label matrix.
"""

labels: NDArrayU8 = field(converter=lambda x: np.asarray(x, dtype=np.uint8))
labels: NDArrayU8 = field(converter=lambda x: np.array(x, dtype=np.uint8))

@staticmethod
def num_dims() -> int:
Expand Down
10 changes: 6 additions & 4 deletions t4_devkit/dataclass/roi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ class Roi:

roi: RoiType = field(converter=tuple)

def __attrs_post_init__(self) -> None:
assert len(self.roi) == 4, (
"Expected roi is (x, y, width, height), " f"but got length with {len(self.roi)}."
)
@roi.validator
def _check_dims(self, attribute, value) -> None:
if len(value) != 4:
raise ValueError(
f"Expected {attribute.name} is (x, y, width, height), but got length with {value}."
)

@property
def offset(self) -> tuple[int, int]:
Expand Down
2 changes: 1 addition & 1 deletion t4_devkit/dataclass/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class Shape:
"""

shape_type: ShapeType
size: SizeType = field(converter=np.asarray)
size: SizeType = field(converter=np.array)
footprint: Polygon = field(default=None)

def __attrs_post_init__(self) -> None:
Expand Down
9 changes: 5 additions & 4 deletions t4_devkit/dataclass/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ class Trajectory:
[2. 2. 2.]
"""

waypoints: TrajectoryType = field(converter=np.asarray)
waypoints: TrajectoryType = field(converter=np.array)
confidence: float = field(default=1.0)

def __attrs_post_init__(self) -> None:
if self.waypoints.shape[1] != 3:
raise ValueError("Trajectory dimension must be 3.")
@waypoints.validator
def _check_dims(self, attribute, value) -> None:
if value.shape[1] != 3:
raise ValueError(f"{attribute.name} dimension must be 3.")

def __len__(self) -> int:
return len(self.waypoints)
Expand Down
8 changes: 4 additions & 4 deletions t4_devkit/dataclass/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pyquaternion import Quaternion
from typing_extensions import Self

from t4_devkit.common.converter import as_quaternion
from t4_devkit.common.converter import to_quaternion
from t4_devkit.typing import NDArray, RotationType

if TYPE_CHECKING:
Expand Down Expand Up @@ -108,8 +108,8 @@ def do_transform(self, src: str, dst: str, *args, **kwargs) -> TransformItemLike

@define
class HomogeneousMatrix:
position: TranslationType = field(converter=np.asarray)
rotation: Quaternion = field(converter=as_quaternion)
position: TranslationType = field(converter=np.array)
rotation: Quaternion = field(converter=to_quaternion)
src: str
dst: str
matrix: NDArray = field(init=False)
Expand Down Expand Up @@ -498,7 +498,7 @@ def _generate_homogeneous_matrix(
A 4x4 homogeneous matrix.
"""
position = np.asarray(position)
rotation = as_quaternion(rotation)
rotation = to_quaternion(rotation)

matrix = np.eye(4)
matrix[:3, 3] = position
Expand Down
10 changes: 5 additions & 5 deletions t4_devkit/schema/tables/calibrated_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from attrs import define, field

from t4_devkit.common.converter import as_quaternion
from t4_devkit.common.converter import to_quaternion

from ..name import SchemaName
from .base import SchemaBase
Expand All @@ -32,7 +32,7 @@ class CalibratedSensor(SchemaBase):
"""

sensor_token: str
translation: TranslationType = field(converter=np.asarray)
rotation: RotationType = field(converter=as_quaternion)
camera_intrinsic: CamIntrinsicType = field(converter=np.asarray)
camera_distortion: CamDistortionType = field(converter=np.asarray)
translation: TranslationType = field(converter=np.array)
rotation: RotationType = field(converter=to_quaternion)
camera_intrinsic: CamIntrinsicType = field(converter=np.array)
camera_distortion: CamDistortionType = field(converter=np.array)
19 changes: 7 additions & 12 deletions t4_devkit/schema/tables/ego_pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

import numpy as np
from attrs import define, field
from attrs.converters import optional

from t4_devkit.common.converter import as_quaternion
from t4_devkit.common.converter import to_quaternion

from ..name import SchemaName
from .base import SchemaBase
Expand Down Expand Up @@ -42,15 +43,9 @@ class EgoPose(SchemaBase):
(latitude, longitude, altitude) in degrees and meters.
"""

translation: TranslationType = field(converter=np.asarray)
rotation: RotationType = field(converter=as_quaternion)
translation: TranslationType = field(converter=np.array)
rotation: RotationType = field(converter=to_quaternion)
timestamp: int
twist: TwistType | None = field(
default=None, converter=lambda x: np.asarray(x) if x is not None else x
)
acceleration: AccelerationType | None = field(
default=None, converter=lambda x: np.asarray(x) if x is not None else x
)
geocoordinate: GeoCoordinateType | None = field(
default=None, converter=lambda x: np.asarray(x) if x is not None else x
)
twist: TwistType | None = field(default=None, converter=optional(np.array))
acceleration: AccelerationType | None = field(default=None, converter=optional(np.array))
geocoordinate: GeoCoordinateType | None = field(default=None, converter=optional(np.array))
2 changes: 1 addition & 1 deletion t4_devkit/schema/tables/keypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ class Keypoint(SchemaBase):
sample_data_token: str
instance_token: str
category_tokens: list[str]
keypoints: KeypointType = field(converter=np.asarray)
keypoints: KeypointType = field(converter=np.array)
num_keypoints: int
12 changes: 6 additions & 6 deletions t4_devkit/schema/tables/sample_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from attrs import define, field
from attrs.converters import optional

from t4_devkit.common.converter import as_quaternion
from t4_devkit.common.converter import to_quaternion

from ..name import SchemaName
from .base import SchemaBase
Expand Down Expand Up @@ -58,15 +58,15 @@ class SampleAnnotation(SchemaBase):
instance_token: str
attribute_tokens: list[str]
visibility_token: str
translation: TranslationType = field(converter=np.asarray)
size: SizeType = field(converter=np.asarray)
rotation: RotationType = field(converter=as_quaternion)
translation: TranslationType = field(converter=np.array)
size: SizeType = field(converter=np.array)
rotation: RotationType = field(converter=to_quaternion)
num_lidar_pts: int
num_radar_pts: int
next: str # noqa: A003
prev: str
velocity: VelocityType | None = field(default=None, converter=optional(np.asarray))
acceleration: AccelerationType | None = field(default=None, converter=optional(np.asarray))
velocity: VelocityType | None = field(default=None, converter=optional(np.array))
acceleration: AccelerationType | None = field(default=None, converter=optional(np.array))

# shortcuts
category_name: str = field(init=False, factory=str)
5 changes: 2 additions & 3 deletions t4_devkit/schema/tables/vehicle_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from enum import Enum, unique

from attrs import define, field
from attrs.converters import optional

from ..name import SchemaName
from .base import SchemaBase
Expand Down Expand Up @@ -83,9 +84,7 @@ class VehicleState(SchemaBase):
steer_pedal: float | None = field(default=None)
steering_tire_angle: float | None = field(default=None)
steering_wheel_angle: float | None = field(default=None)
shift_state: ShiftState | None = field(
default=None, converter=lambda x: None if x is None else ShiftState(x)
)
shift_state: ShiftState | None = field(default=None, converter=optional(ShiftState))
indicators: Indicators | None = field(
default=None, converter=lambda x: Indicators(**x) if isinstance(x, dict) else x
)
Expand Down

0 comments on commit 642d119

Please sign in to comment.