Skip to content

Commit

Permalink
[ObjectTracker] added factory and used in event detectors.
Browse files Browse the repository at this point in the history
  • Loading branch information
AbdelrhmanBassiouny committed Dec 4, 2024
1 parent d8e9adb commit fb18e6d
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 28 deletions.
32 changes: 17 additions & 15 deletions src/episode_segmenter/event_detectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .events import Event, ContactEvent, LossOfContactEvent, PickUpEvent, AgentContactEvent, \
AgentLossOfContactEvent, EventUnion, LossOfSurfaceEvent, TranslationEvent, StopTranslationEvent, NewObjectEvent, \
RotationEvent, StopRotationEvent, PlacingEvent, MotionEvent, StopMotionEvent
from .object_tracker import ObjectTracker
from .object_tracker import ObjectTracker, ObjectTrackerFactory
from .utils import get_angle_between_vectors, calculate_euclidean_distance, calculate_quaternion_difference, \
check_if_object_is_supported

Expand Down Expand Up @@ -537,38 +537,40 @@ def start_condition_checker(cls, event: Event) -> bool:
"""
pass

def check_for_event_post_starter_event(self, event_detector: Type[PrimitiveEventDetector]) -> Optional[EventUnion]:
def check_for_event_post_starter_event(self, event_type: Type[Event]) -> Optional[EventUnion]:
"""
Check if the tracked_object was involved in an event after the starter event.
:param event_detector: The event detector class that is used to detect the event.
:param event_type: The event type to check for.
:return: The event if the tracked_object was involved in an event, else None.
"""
event = get_latest_event_of_detector_for_object(event_detector, self.tracked_object,
after_timestamp=self.start_timestamp)
event = self.object_tracker.get_first_event_of_type_after_event(event_type, self.starter_event)
if event is None:
logdebug(f"{event_detector.__name__} found no event after {self.start_timestamp} with object :"
logdebug(f"{event_type.__name__} found no event after {self.start_timestamp} with object :"
f" {self.tracked_object.name}")
return None

return event

def check_for_event_near_starter_event(self, event_detector: Type[PrimitiveEventDetector],
@property
def object_tracker(self) -> ObjectTracker:
return ObjectTrackerFactory.get_tracker(self.tracked_object)

def check_for_event_near_starter_event(self, event_type: Type[Event],
time_tolerance: timedelta) -> Optional[EventUnion]:
"""
Check if the tracked_object was involved in an event near the starter event (i.e. could be before or after).
:param event_detector: The event detector class that is used to detect the event.
:param event_type: The event type to check for.
:param time_tolerance: The time tolerance to consider the event as near the starter event.
:return: The event if the tracked_object was involved in an event, else None.
"""
event = get_nearest_event_of_detector_for_object(event_detector, self.tracked_object,
timestamp=self.start_timestamp, time_tolerance=time_tolerance)
event = self.object_tracker.get_nearest_event_of_type_to_event(self.starter_event,
tolerance=time_tolerance,
event_type=event_type)
if event is None:
logdebug(f"{event_detector.__name__} found no event after {self.start_timestamp} with object :"
logdebug(f"{event_type.__name__} found no event after {self.start_timestamp} with object :"
f" {self.tracked_object.name}")
return None

return event

@property
Expand Down Expand Up @@ -733,7 +735,7 @@ def interaction_checks(self) -> bool:
"""
Perform extra checks to determine if the object was picked up.
"""
loss_of_surface_event = self.check_for_event_post_starter_event(LossOfSurfaceDetector)
loss_of_surface_event = self.check_for_event_post_starter_event(LossOfSurfaceEvent)

if not loss_of_surface_event:
return False
Expand Down Expand Up @@ -828,7 +830,7 @@ def initial_interaction_checkers(self) -> bool:
"""
Perform initial checks to determine if the object was placed.
"""
contact_event = self.check_for_event_post_starter_event(ContactDetector)
contact_event = self.check_for_event_post_starter_event(ContactEvent)
print(f"contact_event: {contact_event}")
if contact_event and check_if_object_is_supported(self.tracked_object):
self.end_timestamp = contact_event.timestamp
Expand Down
85 changes: 72 additions & 13 deletions src/episode_segmenter/object_tracker.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from __future__ import annotations

from datetime import timedelta
from threading import RLock
from typing_extensions import List, Type, Optional, TYPE_CHECKING, Dict

import numpy as np

from pycram.world_concepts.world_object import Object
from threading import RLock
from typing_extensions import List, Type, Optional, TYPE_CHECKING
from pycram.ros.logging import logwarn

if TYPE_CHECKING:
from .events import Event
from .events import Event, EventUnion


class ObjectTracker:
Expand Down Expand Up @@ -44,28 +47,71 @@ def get_latest_event_of_type(self, event_type: Type[Event]) -> Optional[Event]:
return None

def get_first_event_before(self, timestamp: float) -> Optional[Event]:
with self._lock:
first_event_index = self.get_index_of_first_event_before(timestamp)
return self._event_history[first_event_index] if first_event_index is not None else None

def get_first_event_after(self, timestamp: float) -> Optional[Event]:
with self._lock:
first_event_index = self.get_index_of_first_event_after(timestamp)
return self._event_history[first_event_index] if first_event_index is not None else None

def get_nearest_event_of_type_to_event(self, event: Event, event_type: Type[Event],
tolerance: Optional[timedelta] = None) -> Optional[EventUnion]:
return self.get_nearest_event_of_type_to_timestamp(event.timestamp, event_type, tolerance)

def get_nearest_event_of_type_to_timestamp(self, timestamp: float, event_type: Type[Event],
tolerance: Optional[timedelta] = None) -> Optional[Event]:
with self._lock:
time_stamps = self.time_stamps_array
try:
first_event_index = np.where(time_stamps < timestamp)[0][-1]
return self._event_history[first_event_index]
except IndexError:
type_cond = np.array([isinstance(event, event_type) for event in self._event_history])
valid_indices = np.where(type_cond)[0]
time_stamps = time_stamps[valid_indices]
nearest_event_index = self._get_nearest_index(time_stamps, timestamp, tolerance)
return self._event_history[valid_indices[nearest_event_index]]

def get_nearest_event_to(self, timestamp: float, tolerance: Optional[timedelta] = None) -> Optional[Event]:
with self._lock:
time_stamps = self.time_stamps_array
nearest_event_index = self._get_nearest_index(time_stamps, timestamp, tolerance)
return self._event_history[nearest_event_index]

def _get_nearest_index(self, time_stamps: np.ndarray,
timestamp: float, tolerance: Optional[timedelta] = None) -> Optional[int]:
with self._lock:
nearest_event_index = np.argmin(np.abs(time_stamps - timestamp))
if tolerance is not None and abs(time_stamps[nearest_event_index] - timestamp) > tolerance.total_seconds():
return None
return nearest_event_index

def get_first_event_after(self, timestamp: float) -> Optional[Event]:
def get_first_event_of_type_after_event(self, event_type: Type[Event], event: Event) -> Optional[EventUnion]:
return self.get_first_event_of_type_after_timestamp(event_type, event.timestamp)

def get_first_event_of_type_after_timestamp(self, event_type: Type[Event], timestamp: float) -> Optional[Event]:
with self._lock:
start_index = self.get_index_of_first_event_after(timestamp)
for event in self._event_history[start_index:]:
if isinstance(event, event_type):
return event
return None

def get_index_of_first_event_after(self, timestamp: float) -> Optional[int]:
with self._lock:
time_stamps = self.time_stamps_array
try:
first_event_index = np.where(time_stamps > timestamp)[0][0]
return self._event_history[first_event_index]
return np.where(time_stamps > timestamp)[0][0]
except IndexError:
logwarn(f"No events after timestamp {timestamp}")
return None

def get_nearest_event_to(self, timestamp: float) -> Optional[Event]:
def get_index_of_first_event_before(self, timestamp: float) -> Optional[int]:
with self._lock:
time_stamps = self.time_stamps_array
nearest_event_index = np.argmin(np.abs(time_stamps - timestamp))
return self._event_history[nearest_event_index]
try:
return np.where(time_stamps < timestamp)[0][-1]
except IndexError:
logwarn(f"No events before timestamp {timestamp}")
return None

@property
def time_stamps_array(self) -> np.ndarray:
Expand All @@ -77,3 +123,16 @@ def time_stamps(self) -> List[float]:
return [event.timestamp for event in self._event_history]


class ObjectTrackerFactory:

_trackers: Dict[Object, ObjectTracker] = {}
_lock: RLock = RLock()

@classmethod
def get_tracker(cls, obj: Object) -> ObjectTracker:
with cls._lock:
if obj not in cls._trackers:
cls._trackers[obj] = ObjectTracker(obj)
return cls._trackers[obj]


0 comments on commit fb18e6d

Please sign in to comment.