Skip to content

Commit

Permalink
[ObjectTracker] Added plotting, and object names and types.
Browse files Browse the repository at this point in the history
  • Loading branch information
AbdelrhmanBassiouny committed Dec 17, 2024
1 parent a2cadb4 commit 395f09a
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 9 deletions.
27 changes: 20 additions & 7 deletions src/episode_segmenter/episode_player.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tf.transformations import (quaternion_from_matrix, euler_matrix, quaternion_matrix, quaternion_multiply,
euler_from_quaternion, quaternion_inverse, quaternion_from_euler, euler_from_matrix)
from trimesh import Geometry
from typing_extensions import List, Tuple, Dict, Optional, Union
from typing_extensions import List, Tuple, Dict, Optional, Union, Type

import pycrap
from episode_segmenter.utils import calculate_quaternion_difference
Expand Down Expand Up @@ -79,7 +79,9 @@ class FileEpisodePlayer(EpisodePlayer):
def __init__(self, json_file: str, scene_id: int = 1, world: Optional[World] = None,
mesh_scale: float = 0.001,
time_between_frames: datetime.timedelta = datetime.timedelta(milliseconds=50),
objects_to_ignore: Optional[List[int]] = None):
objects_to_ignore: Optional[List[int]] = None,
obj_id_to_name: Optional[Dict[int, str]] = None,
obj_id_to_type: Optional[Dict[int, Type[pycrap.PhysicalObject]]] = None):
"""
Initializes the FAMEEpisodePlayer with the specified json file and scene id.
Expand All @@ -97,6 +99,8 @@ def __init__(self, json_file: str, scene_id: int = 1, world: Optional[World] = N
self.data_frames = {int(frame_id): objects_data for frame_id, objects_data in self.data_frames.items()}
if objects_to_ignore is not None:
self._remove_ignored_objects(objects_to_ignore)
self.obj_id_to_name: Optional[Dict[int, str]] = obj_id_to_name
self.obj_id_to_type: Optional[Dict[int, Type[pycrap.PhysicalObject]]] = obj_id_to_type
self.data_frames = dict(sorted(self.data_frames.items(), key=lambda x: x[0]))
self.world = world if world is not None else World.current_world
self.mesh_scale = mesh_scale
Expand Down Expand Up @@ -142,12 +146,13 @@ def process_objects_data(self, objects_data: dict):
pose = self.get_pose_and_transform_to_map_frame(object_poses_data[0])

# Get the object and mesh names
obj_name = self.get_object_name(object_id)
obj_name = self.get_object_name(int(object_id))
obj_type = self.get_object_type(int(object_id))
mesh_name = self.get_mesh_name(object_id)

# Create the object if it does not exist in the world and set its pose
if obj_name not in self.world.get_object_names():
obj = Object(obj_name, pycrap.PhysicalObject, mesh_name,
obj = Object(obj_name, obj_type, mesh_name,
pose=Pose(pose.position_as_list()), scale_mesh=self.mesh_scale)
quat_diff = calculate_quaternion_difference(pose.orientation_as_list(), [0, 0, 0, 1])
euler_diff = euler_from_quaternion(quat_diff)
Expand Down Expand Up @@ -352,9 +357,17 @@ def transform_pose_to_map_frame(self, position: List[float], quaternion: List[fl
def camera_frame_name(self) -> str:
return "episode_camera_frame"

@staticmethod
def get_object_name(object_id: str) -> str:
return f"episode_object_{object_id}"
def get_object_name(self, object_id: int) -> str:
if self.obj_id_to_name is not None and object_id in self.obj_id_to_name:
return self.obj_id_to_name[object_id]
else:
return f"object_{object_id}"

def get_object_type(self, object_id: int) -> Type[pycrap.PhysicalObject]:
if self.obj_id_to_type is not None and object_id in self.obj_id_to_type:
return self.obj_id_to_type[int(object_id)]
else:
return pycrap.PhysicalObject

@staticmethod
def get_mesh_name(object_id: str) -> str:
Expand Down
2 changes: 2 additions & 0 deletions src/episode_segmenter/episode_segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def run_event_detectors(self) -> None:

self.process_event(next_event)

self.logger.plot_events()

self.join()

def process_event(self, event: EventUnion) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/episode_segmenter/event_detectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def join(self, timeout=None):
"""
World.current_world.remove_callback_on_add_object(self.on_add_object)
self.new_object_queue.join()
super().join(timeout)
# super().join(timeout)


class AbstractContactDetector(PrimitiveEventDetector, ABC):
Expand Down
51 changes: 51 additions & 0 deletions src/episode_segmenter/event_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,57 @@ def log_event(self, event: Event):
self.timeline_per_thread[thread_id].append(event)
self.timeline.append(event)

def plot_events(self):
"""
Plot all events that have been logged in a timeline.
"""
loginfo("Plotting events:")
# construct a dataframe with the events
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import numpy as np

data_dict = {'start': [], 'end': [], 'event': [], 'object': [], 'obj_type': []}
for tracker in ObjectTrackerFactory.get_all_trackers():
for event in tracker.get_event_history():
if hasattr(event, 'end_timestamp') and event.end_timestamp is not None:
data_dict['end'].append(event.end_timestamp)
else:
data_dict['end'].append(event.timestamp + timedelta(seconds=0.1).total_seconds())
data_dict['start'].append(event.timestamp)
data_dict['event'].append(event.__class__.__name__)
data_dict['object'].append(tracker.obj.name)
data_dict['obj_type'].append(tracker.obj.obj_type.name)
# subtract the start time from all timestamps
min_start = min(data_dict['start'])
data_dict['start'] = [x - min_start for x in data_dict['start']]
data_dict['end'] = [x - min_start for x in data_dict['end']]
df = pd.DataFrame(data_dict)

fig = go.Figure()

fig = px.timeline(df, x_start=pd.to_datetime(df[f'start'], unit='s'),
x_end=pd.to_datetime(df[f'end'], unit='s'),
y=f'event',
color=f'event',
hover_data={'object': True, 'obj_type': True},
# text=f'object',
title=f"Events Timeline")
fig.update_xaxes(tickvals=pd.to_datetime(df[f'start'], unit='s'), tickformat='%S')
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightPink')
fig.update_layout(
font_family="Courier New",
font_color="black",
font_size=20,
title_font_family="Times New Roman",
title_font_color="black",
title_font_size=30,
legend_title_font_color="black",
legend_title_font_size=24,
)
fig.show()

def print_events(self):
"""
Print all events that have been logged.
Expand Down
5 changes: 5 additions & 0 deletions src/episode_segmenter/object_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ class ObjectTrackerFactory:
_trackers: Dict[Object, ObjectTracker] = {}
_lock: RLock = RLock()

@classmethod
def get_all_trackers(cls) -> List[ObjectTracker]:
with cls._lock:
return list(cls._trackers.values())

@classmethod
def get_tracker(cls, obj: Object) -> ObjectTracker:
with cls._lock:
Expand Down
7 changes: 6 additions & 1 deletion test/test_file_episode_segmenter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
from unittest import TestCase

import pycrap
from pycram.datastructures.world import World
from pycram.datastructures.enums import WorldMode
from episode_segmenter.episode_player import FileEpisodePlayer
Expand All @@ -26,9 +27,13 @@ def setUpClass(cls):
simulator = BulletWorld
annotate_events = True if simulator == BulletWorld else False
cls.world = simulator(WorldMode.GUI)
obj_id_to_name = {1: "chips", 3: "bowl", 4: "cup"}
obj_id_to_type = {1: pycrap.Container, 3: pycrap.Bowl, 4: pycrap.Cup}
cls.file_player = FileEpisodePlayer(json_file, world=cls.world,
time_between_frames=datetime.timedelta(milliseconds=50),
objects_to_ignore=[5])
objects_to_ignore=[5],
obj_id_to_name=obj_id_to_name,
obj_id_to_type=obj_id_to_type)
cls.episode_segmenter = NoAgentEpisodeSegmenter(cls.file_player, annotate_events=annotate_events)

@classmethod
Expand Down

0 comments on commit 395f09a

Please sign in to comment.