diff --git a/src/episode_segmenter/episode_player.py b/src/episode_segmenter/episode_player.py index 4613575..37829c2 100644 --- a/src/episode_segmenter/episode_player.py +++ b/src/episode_segmenter/episode_player.py @@ -12,7 +12,7 @@ from tf.transformations import (quaternion_from_matrix, euler_matrix, quaternion_matrix, quaternion_multiply, euler_from_quaternion, quaternion_inverse, quaternion_from_euler, euler_from_matrix) from trimesh import Geometry -from typing_extensions import List, Tuple, Dict, Optional, Union +from typing_extensions import List, Tuple, Dict, Optional, Union, Type import pycrap from episode_segmenter.utils import calculate_quaternion_difference @@ -79,7 +79,9 @@ class FileEpisodePlayer(EpisodePlayer): def __init__(self, json_file: str, scene_id: int = 1, world: Optional[World] = None, mesh_scale: float = 0.001, time_between_frames: datetime.timedelta = datetime.timedelta(milliseconds=50), - objects_to_ignore: Optional[List[int]] = None): + objects_to_ignore: Optional[List[int]] = None, + obj_id_to_name: Optional[Dict[int, str]] = None, + obj_id_to_type: Optional[Dict[int, Type[pycrap.PhysicalObject]]] = None): """ Initializes the FAMEEpisodePlayer with the specified json file and scene id. @@ -97,6 +99,8 @@ def __init__(self, json_file: str, scene_id: int = 1, world: Optional[World] = N self.data_frames = {int(frame_id): objects_data for frame_id, objects_data in self.data_frames.items()} if objects_to_ignore is not None: self._remove_ignored_objects(objects_to_ignore) + self.obj_id_to_name: Optional[Dict[int, str]] = obj_id_to_name + self.obj_id_to_type: Optional[Dict[int, Type[pycrap.PhysicalObject]]] = obj_id_to_type self.data_frames = dict(sorted(self.data_frames.items(), key=lambda x: x[0])) self.world = world if world is not None else World.current_world self.mesh_scale = mesh_scale @@ -142,12 +146,13 @@ def process_objects_data(self, objects_data: dict): pose = self.get_pose_and_transform_to_map_frame(object_poses_data[0]) # Get the object and mesh names - obj_name = self.get_object_name(object_id) + obj_name = self.get_object_name(int(object_id)) + obj_type = self.get_object_type(int(object_id)) mesh_name = self.get_mesh_name(object_id) # Create the object if it does not exist in the world and set its pose if obj_name not in self.world.get_object_names(): - obj = Object(obj_name, pycrap.PhysicalObject, mesh_name, + obj = Object(obj_name, obj_type, mesh_name, pose=Pose(pose.position_as_list()), scale_mesh=self.mesh_scale) quat_diff = calculate_quaternion_difference(pose.orientation_as_list(), [0, 0, 0, 1]) euler_diff = euler_from_quaternion(quat_diff) @@ -352,9 +357,17 @@ def transform_pose_to_map_frame(self, position: List[float], quaternion: List[fl def camera_frame_name(self) -> str: return "episode_camera_frame" - @staticmethod - def get_object_name(object_id: str) -> str: - return f"episode_object_{object_id}" + def get_object_name(self, object_id: int) -> str: + if self.obj_id_to_name is not None and object_id in self.obj_id_to_name: + return self.obj_id_to_name[object_id] + else: + return f"object_{object_id}" + + def get_object_type(self, object_id: int) -> Type[pycrap.PhysicalObject]: + if self.obj_id_to_type is not None and object_id in self.obj_id_to_type: + return self.obj_id_to_type[int(object_id)] + else: + return pycrap.PhysicalObject @staticmethod def get_mesh_name(object_id: str) -> str: diff --git a/src/episode_segmenter/episode_segmenter.py b/src/episode_segmenter/episode_segmenter.py index 091d6df..25c20c4 100644 --- a/src/episode_segmenter/episode_segmenter.py +++ b/src/episode_segmenter/episode_segmenter.py @@ -86,6 +86,8 @@ def run_event_detectors(self) -> None: self.process_event(next_event) + self.logger.plot_events() + self.join() def process_event(self, event: EventUnion) -> None: diff --git a/src/episode_segmenter/event_detectors.py b/src/episode_segmenter/event_detectors.py index 93b16ae..d808a88 100644 --- a/src/episode_segmenter/event_detectors.py +++ b/src/episode_segmenter/event_detectors.py @@ -216,7 +216,7 @@ def join(self, timeout=None): """ World.current_world.remove_callback_on_add_object(self.on_add_object) self.new_object_queue.join() - super().join(timeout) + # super().join(timeout) class AbstractContactDetector(PrimitiveEventDetector, ABC): diff --git a/src/episode_segmenter/event_logger.py b/src/episode_segmenter/event_logger.py index 4eb510c..bbce349 100644 --- a/src/episode_segmenter/event_logger.py +++ b/src/episode_segmenter/event_logger.py @@ -47,6 +47,57 @@ def log_event(self, event: Event): self.timeline_per_thread[thread_id].append(event) self.timeline.append(event) + def plot_events(self): + """ + Plot all events that have been logged in a timeline. + """ + loginfo("Plotting events:") + # construct a dataframe with the events + import pandas as pd + import plotly.express as px + import plotly.graph_objects as go + import numpy as np + + data_dict = {'start': [], 'end': [], 'event': [], 'object': [], 'obj_type': []} + for tracker in ObjectTrackerFactory.get_all_trackers(): + for event in tracker.get_event_history(): + if hasattr(event, 'end_timestamp') and event.end_timestamp is not None: + data_dict['end'].append(event.end_timestamp) + else: + data_dict['end'].append(event.timestamp + timedelta(seconds=0.1).total_seconds()) + data_dict['start'].append(event.timestamp) + data_dict['event'].append(event.__class__.__name__) + data_dict['object'].append(tracker.obj.name) + data_dict['obj_type'].append(tracker.obj.obj_type.name) + # subtract the start time from all timestamps + min_start = min(data_dict['start']) + data_dict['start'] = [x - min_start for x in data_dict['start']] + data_dict['end'] = [x - min_start for x in data_dict['end']] + df = pd.DataFrame(data_dict) + + fig = go.Figure() + + fig = px.timeline(df, x_start=pd.to_datetime(df[f'start'], unit='s'), + x_end=pd.to_datetime(df[f'end'], unit='s'), + y=f'event', + color=f'event', + hover_data={'object': True, 'obj_type': True}, + # text=f'object', + title=f"Events Timeline") + fig.update_xaxes(tickvals=pd.to_datetime(df[f'start'], unit='s'), tickformat='%S') + fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightPink') + fig.update_layout( + font_family="Courier New", + font_color="black", + font_size=20, + title_font_family="Times New Roman", + title_font_color="black", + title_font_size=30, + legend_title_font_color="black", + legend_title_font_size=24, + ) + fig.show() + def print_events(self): """ Print all events that have been logged. diff --git a/src/episode_segmenter/object_tracker.py b/src/episode_segmenter/object_tracker.py index 47908cb..e725097 100644 --- a/src/episode_segmenter/object_tracker.py +++ b/src/episode_segmenter/object_tracker.py @@ -147,6 +147,11 @@ class ObjectTrackerFactory: _trackers: Dict[Object, ObjectTracker] = {} _lock: RLock = RLock() + @classmethod + def get_all_trackers(cls) -> List[ObjectTracker]: + with cls._lock: + return list(cls._trackers.values()) + @classmethod def get_tracker(cls, obj: Object) -> ObjectTracker: with cls._lock: diff --git a/test/test_file_episode_segmenter.py b/test/test_file_episode_segmenter.py index 3e73472..8961412 100644 --- a/test/test_file_episode_segmenter.py +++ b/test/test_file_episode_segmenter.py @@ -1,6 +1,7 @@ import datetime from unittest import TestCase +import pycrap from pycram.datastructures.world import World from pycram.datastructures.enums import WorldMode from episode_segmenter.episode_player import FileEpisodePlayer @@ -26,9 +27,13 @@ def setUpClass(cls): simulator = BulletWorld annotate_events = True if simulator == BulletWorld else False cls.world = simulator(WorldMode.GUI) + obj_id_to_name = {1: "chips", 3: "bowl", 4: "cup"} + obj_id_to_type = {1: pycrap.Container, 3: pycrap.Bowl, 4: pycrap.Cup} cls.file_player = FileEpisodePlayer(json_file, world=cls.world, time_between_frames=datetime.timedelta(milliseconds=50), - objects_to_ignore=[5]) + objects_to_ignore=[5], + obj_id_to_name=obj_id_to_name, + obj_id_to_type=obj_id_to_type) cls.episode_segmenter = NoAgentEpisodeSegmenter(cls.file_player, annotate_events=annotate_events) @classmethod