Skip to content

Commit

Permalink
[ObjectTracker] corrected start and end time stamps of picking and pl…
Browse files Browse the repository at this point in the history
…acing, increased window size for motion detection.
  • Loading branch information
AbdelrhmanBassiouny committed Dec 28, 2024
1 parent b2fa63a commit fc10c8c
Showing 4 changed files with 80 additions and 18 deletions.
84 changes: 73 additions & 11 deletions src/episode_segmenter/event_detectors.py
Original file line number Diff line number Diff line change
@@ -419,11 +419,18 @@ def __init__(self, logger: EventLogger, starter_event: NewObjectEvent, distance_
self.use_decay: bool = False
self.gamma: float = 0.99
self.cut_off_frequency: float = 2
self.use_low_pass_filter: bool = False
self.use_average_distance: bool = False
self.use_consistent_gradient: bool = True

self.window_size: int = ceil(timedelta(milliseconds=300).total_seconds() / time_between_frames.total_seconds())
self.window_size: int = ceil(timedelta(milliseconds=500).total_seconds() /
self.measure_timestep.total_seconds())

self.latest_distances: List[float] = []
self.latest_poses: List[Pose] = []
self.latest_times: List[float] = []
self.event_time: float = self.latest_time
self.start_pose: Pose = self.latest_pose

# for plotting purposes
self.original_distances: List[List[List[float]]] = []
@@ -433,7 +440,7 @@ def __init__(self, logger: EventLogger, starter_event: NewObjectEvent, distance_
self.avg_distances: List[float] = []
self.times: List[float] = []

self.plot: bool = True
self.plot: bool = False
self.plot_distance_windows: bool = False
self.plot_frequencies: bool = False

@@ -505,6 +512,7 @@ def update_latest_distances_and_times(self, distance: float):
:param distance: The distance to add to the list of distances.
"""
self.latest_poses.append(self.latest_pose)
self.latest_distances.append(distance)
if self.use_decay:
self._apply_decay_to_distances()
@@ -531,16 +539,58 @@ def _reset_distances_and_times(self):

@property
def _is_motion_condition_met(self):
self.times.append(self.latest_time)

distances = self.latest_distances

if self.use_low_pass_filter:
self._apply_low_pass_filter()
distances = self.all_filtered_distances[-1]

if self.use_average_distance:
return self._check_motion_using_average_distance(distances)

elif self.use_consistent_gradient:
return self._check_motion_using_consistent_gradient(distances)

def _check_motion_using_consistent_gradient(self, distances: List[List[float]]) -> bool:
"""
Check if the object is moving using the consistent gradient.
:param distances: The distances.
:return: A boolean value that represents if the object is moving.
"""
distance_arr = np.array(distances)
x, y, z = distance_arr[:, 0], distance_arr[:, 1], distance_arr[:, 2]
is_moving = any(np.all(axes > 0) or np.all(axes < 0) for axes in [x, y, z])
self.start_pose = self.latest_poses[-self.window_size]
self.event_time = self.latest_times[-self.window_size]
return is_moving

def _check_motion_using_average_distance(self, distances: List[List[float]]) -> bool:
"""
Check if the object is moving using the average distance.
:param distances: The distances.
:return: A boolean value that represents if the object is moving.
"""
avg_distance = np.linalg.norm(np.sum(distances))
self.avg_distances.append(avg_distance)
is_moving = avg_distance > self.distance_threshold
self.start_pose = self.latest_poses[-int(self.window_size / 2)]
self.event_time = self.latest_times[-int(self.window_size / 2)]
return is_moving

def _apply_low_pass_filter(self) -> None:
"""
Apply a low pass filter to the distances.
"""
latest_distances_arr = np.array(self.latest_distances)
filtered_distances = butter_lowpass_filter(latest_distances_arr, self.cut_off_frequency,
1/self.measure_timestep.total_seconds())
1 / self.measure_timestep.total_seconds())
self.original_distances.append(self.latest_distances)
self.all_filtered_distances.append(filtered_distances)
self.all_times.append(self.latest_times)
avg_distance = np.linalg.norm(np.sum(filtered_distances))
self.avg_distances.append(avg_distance)
self.times.append(self.latest_time)
return avg_distance > self.distance_threshold

def create_event(self) -> Union[TranslationEvent, StopTranslationEvent]:
"""
@@ -550,7 +600,7 @@ def create_event(self) -> Union[TranslationEvent, StopTranslationEvent]:
"""
current_pose, current_time = self.get_current_pose_and_time()
event_type = self.get_event_type()
event = event_type(self.tracked_object, self.latest_pose, current_pose, timestamp=current_time)
event = event_type(self.tracked_object, self.start_pose, current_pose, timestamp=self.event_time)
return event

@abstractmethod
@@ -578,6 +628,8 @@ def _plot_and_show_avg_distances(self) -> None:
"""
Plot the average distances.
"""
if len(self.avg_distances) == 0:
return
plt.plot([t - self.times[0] for t in self.times], self.avg_distances[:len(self.times)])
plt.title(f"Results of {self.__class__.__name__} for {self.tracked_object.name}")
plt.show()
@@ -698,6 +750,7 @@ def __init__(self, logger: EventLogger, starter_event: EventUnion, wait_time: Op
"""
super().__init__(logger, wait_time, *args, **kwargs)
self.starter_event: EventUnion = starter_event
self._start_timestamp = self.starter_event.timestamp

@classmethod
@abstractmethod
@@ -763,7 +816,11 @@ def _no_event_found_log(self, event_type: Type[Event]):

@property
def start_timestamp(self) -> float:
return self.starter_event.timestamp
return self._start_timestamp

@start_timestamp.setter
def start_timestamp(self, timestamp: float):
self._start_timestamp = timestamp


class LiftingDetector(EventDetector):
@@ -889,6 +946,7 @@ def detect_events(self) -> List[EventUnion]:
time.sleep(0.01)
continue

self.interaction_event.timestamp = self.start_timestamp
self.interaction_event.end_timestamp = self.end_timestamp
event = self.interaction_event
break
@@ -1024,12 +1082,13 @@ def interaction_checks(self) -> bool:
"""
print(f"checking if {self.tracked_object.name} was picked up")
# wait for the object to be lifted TODO: Should be replaced with a wait on a lifting event
dt = timedelta(milliseconds=400)
dt = timedelta(milliseconds=1000)
time.sleep(dt.total_seconds())
print(f"checking for translation event for {self.tracked_object.name}")
latest_event = self.check_for_event_near_starter_event(TranslationEvent, dt)

if latest_event:
self.start_timestamp = min(latest_event.timestamp, self.start_timestamp)
self.end_timestamp = max(latest_event.timestamp, self.start_timestamp)
return True

@@ -1051,7 +1110,10 @@ def interaction_checks(self) -> bool:
dt = timedelta(milliseconds=1000)
event = self.check_for_event_near_starter_event(StopMotionEvent, dt)
if event is not None:
self.end_timestamp = event.timestamp
# start_motion_event_type = TranslationEvent if isinstance(event, StopTranslationEvent) else RotationEvent
start_motion_event = self.object_tracker.get_first_event_of_type_before_event(MotionEvent, event)
self.start_timestamp = min(start_motion_event.timestamp, self.start_timestamp)
self.end_timestamp = max(event.timestamp, self.starter_event.timestamp)
return True
elif time.time() - self.start_timestamp > dt.total_seconds():
self.kill_event.set()
7 changes: 3 additions & 4 deletions src/episode_segmenter/event_logger.py
Original file line number Diff line number Diff line change
@@ -56,15 +56,14 @@ def plot_events(self):
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import numpy as np

data_dict = {'start': [], 'end': [], 'event': [], 'object': [], 'obj_type': []}
for tracker in ObjectTrackerFactory.get_all_trackers():
for event in tracker.get_event_history():
end_timestamp = event.timestamp + timedelta(seconds=0.1).total_seconds()
if hasattr(event, 'end_timestamp') and event.end_timestamp is not None:
data_dict['end'].append(event.end_timestamp)
else:
data_dict['end'].append(event.timestamp + timedelta(seconds=0.1).total_seconds())
end_timestamp = max(event.end_timestamp, end_timestamp)
data_dict['end'].append(end_timestamp)
data_dict['start'].append(event.timestamp)
data_dict['event'].append(event.__class__.__name__)
data_dict['object'].append(tracker.obj.name)
5 changes: 3 additions & 2 deletions src/episode_segmenter/events.py
Original file line number Diff line number Diff line change
@@ -363,9 +363,10 @@ class AbstractAgentObjectInteractionEvent(HasTwoTrackedObjects, ABC):

def __init__(self, participating_object: Object,
agent: Optional[Object] = None,
timestamp: Optional[float] = None):
timestamp: Optional[float] = None,
end_timestamp: Optional[float] = None):
HasTwoTrackedObjects.__init__(self, participating_object, agent, timestamp)
self.end_timestamp: Optional[float] = None
self.end_timestamp: Optional[float] = end_timestamp
self.text_id: Optional[int] = None

@property
2 changes: 1 addition & 1 deletion src/episode_segmenter/object_tracker.py
Original file line number Diff line number Diff line change
@@ -110,7 +110,7 @@ def get_first_event_of_type_before_timestamp(self, event_type: Type[Event], time
with self._lock:
start_index = self.get_index_of_first_event_before(timestamp)
if start_index is not None:
for event in reversed(self._event_history[:start_index]):
for event in reversed(self._event_history[:min(start_index+1, len(self._event_history))]):
if isinstance(event, event_type):
return event

0 comments on commit fc10c8c

Please sign in to comment.