Skip to content

Commit

Permalink
[ObjectTracker] use low pass filter.
Browse files Browse the repository at this point in the history
  • Loading branch information
AbdelrhmanBassiouny committed Dec 28, 2024
1 parent 3d5c0a0 commit c6af34d
Showing 1 changed file with 119 additions and 21 deletions.
140 changes: 119 additions & 21 deletions src/episode_segmenter/event_detectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import rospy
from matplotlib import pyplot as plt
from tf.transformations import euler_from_quaternion
from typing_extensions import Optional, List, Union, Type, Tuple

Expand All @@ -29,6 +30,18 @@
check_if_object_is_supported, check_if_object_is_supported_using_contact_points, \
check_if_object_is_supported_by_another_object, check_if_in_contact_with_support, calculate_translation

from scipy.signal import butter, lfilter, sosfilt


def butter_lowpass(cutoff, fs, order=5):
return butter(order, cutoff, fs=fs, btype='low', output='sos', analog=False)


def butter_lowpass_filter(data, cutoff, fs, order=5):
sos = butter_lowpass(cutoff, fs, order=order)
y = sosfilt(sos, data, axis=0)
return y


class PrimitiveEventDetector(threading.Thread, ABC):
"""
Expand Down Expand Up @@ -374,33 +387,95 @@ class MotionDetector(PrimitiveEventDetector, ABC):
A string that is used as a prefix for the thread ID.
"""

def __init__(self, logger: EventLogger, starter_event: NewObjectEvent, velocity_threshold: float = 0.25,
def __init__(self, logger: EventLogger, starter_event: NewObjectEvent, distance_threshold: float = 0.04,
wait_time: Optional[float] = 0.1,
time_between_frames: Optional[timedelta] = timedelta(milliseconds=50),
*args, **kwargs):
"""
:param logger: An instance of the EventLogger class that is used to log the events.
:param starter_event: An instance of the NewObjectEvent class that represents the event to start the event.
:param velocity_threshold: An optional float value that represents the velocity threshold for the object to be
considered as moving.
:param distance_threshold: An optional float value that represents the distance threshold to consider the object
as moving.
:param wait_time: An optional float value that introduces a delay between calls to the event detector.
:param time_between_frames: An optional timedelta value that represents the time between frames.
"""
super().__init__(logger, wait_time, *args, **kwargs)
self.tracked_object = starter_event.tracked_object
self.latest_pose = self.tracked_object.pose
self.latest_time = time.time()
self.velocity_threshold = velocity_threshold
self.distance_threshold = distance_threshold
self.time_between_frames: Optional[timedelta] = time_between_frames

self.measure_timestep: timedelta = timedelta(milliseconds=300)
self.measure_timestep: timedelta = timedelta(seconds=max(timedelta(milliseconds=50).total_seconds(),
self.wait_time))
# frames per measure timestep
self.measure_frame_rate: float = ceil(self.measure_timestep.total_seconds() /
time_between_frames.total_seconds()) + 0.5
time_between_frames.total_seconds())
self.measure_timestep = time_between_frames * self.measure_frame_rate

self.distance_threshold: float = self.velocity_threshold * self.measure_timestep.total_seconds()
self.velocity_threshold: float = self.distance_threshold * self.measure_timestep.total_seconds()
self.was_moving: bool = False
self.gamma: float = 1

self.window_size: int = ceil(timedelta(milliseconds=300).total_seconds() / time_between_frames.total_seconds())

self.latest_distances: List[float] = []
self.latest_times: List[float] = []

self.original_distances: List[List[List[float]]] = []
self.all_filtered_distances: List[np.ndarray] = []
self.all_times: List[List[float]] = []

self.avg_distances: List[float] = []
self.times: List[float] = []

self.plot: bool = True

def stop(self):
"""
Stop the event detector.
"""
# plot the distances
plot_parts: bool = False
plot_freq: bool = False

if self.plot:
self._plot_avg_distances()

if plot_parts:
self._plot_parts(plot_freq)

super().stop()

def _plot_avg_distances(self):
plt.plot([t - self.times[0] for t in self.times], self.avg_distances[:len(self.times)])
plt.title(f"Results of {self.__class__.__name__} for {self.tracked_object.name}")
plt.show()

def _plot_parts(self, plot_freq=False):
plot_cols: int = 2 if plot_freq else 1
ax_labels: List[str] = ["x", "y", "z"]
for i, window_time in enumerate(self.all_times):
orig_distances = np.array(self.original_distances[i])
if (np.mean(orig_distances) <= 1e-3 and self.__class__ == TranslationDetector) \
or (np.mean(orig_distances) <= 1e-4 and self.__class__ == RotationDetector):
continue
filtered_distances = self.all_filtered_distances[i]
times = [t - window_time[0] for t in window_time]
fig, axes = plt.subplots(3, plot_cols, figsize=(10, 10))
for j, ax in enumerate(axes[:, 0] if plot_freq else axes):
original = orig_distances[:, j]
filtered = filtered_distances[:, j]
ax.plot(times, original, label=f"original_{ax_labels[j]}")
ax.plot(times[:len(filtered)], filtered, label=f"filtered_{ax_labels[j]}")
if plot_freq:
for j, ax in enumerate(axes[:, 1]):
xmag = np.fft.fft(orig_distances[:, j])
freqs = np.fft.fftfreq(len(xmag), d=self.measure_timestep.total_seconds())
ax.bar(freqs[:len(xmag) // 2], np.abs(xmag)[:len(xmag) // 2], width=0.1)
for ax in axes.flatten():
ax.legend()
plt.show()

def update_latest_pose_and_time(self):
"""
Expand All @@ -420,15 +495,16 @@ def detect_events(self) -> List[Union[TranslationEvent, StopTranslationEvent]]:
:return: An instance of the TranslationEvent class that represents the event if the object is moving, else None.
"""
if self.time_since_last_event < self.measure_timestep.total_seconds():
time.sleep(self.measure_timestep.total_seconds() - self.time_since_last_event)
events = []
# if self.time_since_last_event < self.measure_timestep.total_seconds():
# time.sleep(self.measure_timestep.total_seconds() - self.time_since_last_event)
is_moving = self.is_moving()
if is_moving is not None and is_moving != self.was_moving:
self.update_object_motion_state(is_moving)
self.was_moving = not self.was_moving
events.append(self.create_event())
self.update_latest_pose_and_time()
if is_moving is not None:
if is_moving != self.was_moving:
self.update_object_motion_state(is_moving)
self.was_moving = not self.was_moving
events.append(self.create_event())
self.update_latest_pose_and_time()
return events

@property
Expand All @@ -450,16 +526,36 @@ def is_moving(self) -> Optional[bool]:
"""
distance = self.calculate_distance(self.tracked_object.pose)
self.latest_distances.append(distance)
self.latest_distances = list(map(lambda x: [x_i*0.9 for x_i in x], self.latest_distances))
if len(self.latest_distances) < self.measure_frame_rate:
self.latest_times.append(self.latest_time)
if self.gamma < 1:
self.latest_distances = list(map(lambda x: [x_i*self.gamma for x_i in x], self.latest_distances))
if (len(self.latest_distances) < self.window_size
or self.time_since_last_event < self.measure_timestep.total_seconds()):
return None
else:
self.latest_distances = self.latest_distances[-int(self.measure_frame_rate):]
self.latest_distances = self.latest_distances[-self.window_size:]
self.latest_times = self.latest_times[-self.window_size:]
return self._is_motion_condition_met

@property
def _is_motion_condition_met(self):
return np.linalg.norm(np.sum(self.latest_distances)) > self.distance_threshold
assert len(self.latest_distances) == self.window_size
assert len(self.latest_times) == self.window_size
latest_distances_arr = np.array(self.latest_distances)
assert latest_distances_arr.shape == (self.window_size, 3), \
f"latest_distances_arr.shape: {latest_distances_arr.shape}"
filtered_distances = butter_lowpass_filter(latest_distances_arr, 2,
1/self.measure_timestep.total_seconds())
assert filtered_distances.shape == (self.window_size, 3), f"filtered_distances.shape: {filtered_distances.shape}"
self.original_distances.append(self.latest_distances)
self.all_filtered_distances.append(filtered_distances)
self.all_times.append(self.latest_times)
avg_distance = np.linalg.norm(np.sum(filtered_distances))
self.avg_distances.append(avg_distance)
self.times.append(self.latest_time)
self.latest_distances = []
self.latest_times = []
return avg_distance > self.distance_threshold

def create_event(self) -> Union[TranslationEvent, StopTranslationEvent]:
"""
Expand Down Expand Up @@ -507,9 +603,9 @@ class RotationDetector(MotionDetector):
thread_prefix = "rotation_"

def __init__(self, logger: EventLogger, starter_event: NewObjectEvent,
angular_velocity_threshold: float = 30 * np.pi / 180,
angular_threshold: float = 30 * np.pi / 180,
*args, **kwargs):
super().__init__(logger, starter_event, velocity_threshold=angular_velocity_threshold, *args, **kwargs)
super().__init__(logger, starter_event, distance_threshold=angular_threshold, *args, **kwargs)

def update_object_motion_state(self, moving: bool) -> None:
"""
Expand All @@ -526,7 +622,9 @@ def calculate_distance(self, current_pose: Pose):
quat_diff = calculate_quaternion_difference(self.latest_pose.orientation_as_list(),
current_pose.orientation_as_list())
# angle = 2 * np.arccos(quat_diff[0])
return euler_from_quaternion(quat_diff)[:2]
euler_diff = list(euler_from_quaternion(quat_diff))
euler_diff[2] = 0
return euler_diff

def get_event_type(self):
return RotationEvent if self.was_moving else StopRotationEvent
Expand Down

0 comments on commit c6af34d

Please sign in to comment.