diff --git a/src/episode_segmenter/episode_segmenter.py b/src/episode_segmenter/episode_segmenter.py index b1c64eb..d510019 100644 --- a/src/episode_segmenter/episode_segmenter.py +++ b/src/episode_segmenter/episode_segmenter.py @@ -1,7 +1,7 @@ import datetime from .event_detectors import EventDetectorUnion, TypeEventDetectorUnion, MotionPickUpDetector, TranslationDetector, \ - RotationDetector + RotationDetector, PlacingDetector import time from abc import ABC, abstractmethod @@ -18,7 +18,8 @@ AbstractContactDetector from .event_logger import EventLogger from .events import ContactEvent, Event, AgentContactEvent, PickUpEvent, EventUnion, StopMotionEvent, MotionEvent, \ - NewObjectEvent, RotationEvent, StopRotationEvent + NewObjectEvent, RotationEvent, StopRotationEvent, PlacingEvent +from .utils import check_if_object_is_supported, add_imaginary_support_for_object class EpisodeSegmenter(ABC): @@ -35,7 +36,7 @@ def __init__(self, episode_player: EpisodePlayer, """ self.episode_player: EpisodePlayer = episode_player self.detectors_to_start: List[Type[EventDetector]] = detectors_to_start - self.logger = EventLogger(annotate_events, [PickUpEvent, RotationEvent, StopRotationEvent]) + self.logger = EventLogger(annotate_events, [PickUpEvent, PlacingEvent]) self.objects_to_avoid = ['particle', 'floor', 'kitchen'] # TODO: Make it a function, to be more general self.tracked_objects: List[Object] = [] self.tracked_object_contacts: Dict[Object, List[Type[AbstractContactDetector]]] = {} @@ -103,8 +104,7 @@ def start_triggered_detectors(self, event: EventUnion) -> None: """ for event_detector in self.detectors_to_start: if event_detector.start_condition_checker(event): - filtered_event = event_detector.filter_event(event) - self.start_detector_thread_for_starter_event(filtered_event, event_detector) + self.start_detector_thread_for_starter_event(event, event_detector) @abstractmethod def _process_event(self, event: Event) -> None: @@ -302,7 +302,7 @@ class NoAgentEpisodeSegmenter(EpisodeSegmenter): def __init__(self, episode_player: EpisodePlayer, detectors_to_start: Optional[List[Type[EventDetector]]] = None, annotate_events: bool = False): if detectors_to_start is None: - detectors_to_start = [MotionPickUpDetector] + detectors_to_start = [MotionPickUpDetector, PlacingDetector] super().__init__(episode_player, detectors_to_start=detectors_to_start, annotate_events=annotate_events) def start_tracking_threads_for_new_object_and_event(self, new_object: Object, event: EventUnion): @@ -333,27 +333,17 @@ def detect_missing_support_for_object(self, obj: Object) -> None: :param obj: The object to check if it is supported. """ - supported = True support_name = f"imagined_support" support_obj = World.current_world.get_object_by_name(support_name) support_thickness = 0.005 - with UseProspectionWorld(): - prospection_obj = World.current_world.get_prospection_object_for_object(obj) - current_position = prospection_obj.get_position_as_list() - World.current_world.simulate(1) - new_position = prospection_obj.get_position_as_list() - if current_position[2] - new_position[2] >= 0.01: - logdebug(f"Object {obj.name} is not supported") - supported = False + supported = check_if_object_is_supported(obj) + if supported: + return obj_base_position = obj.get_base_position_as_list() - if (not supported) and (support_obj is None): - support = GenericObjectDescription(support_name, [0, 0, 0], [1, 1, support_thickness]) - support_obj = Object(support_name, ObjectType.IMAGINED_SURFACE, None, support) - support_position = obj_base_position.copy() - support_position[2] = obj_base_position[2] - support_thickness * 0.5 - support_obj.set_position(support_position) + if support_obj is None: + support_obj = add_imaginary_support_for_object(obj, support_name, support_thickness) self.start_contact_threads_for_object(support_obj) - elif (not supported) and (support_obj is not None): + else: support_position = support_obj.get_position_as_list() if obj_base_position[2] <= support_position[2]: support_position[2] = obj_base_position[2] - support_thickness * 0.5 diff --git a/src/episode_segmenter/event_detectors.py b/src/episode_segmenter/event_detectors.py index 93cf926..b55bdf6 100644 --- a/src/episode_segmenter/event_detectors.py +++ b/src/episode_segmenter/event_detectors.py @@ -1,5 +1,4 @@ import threading -import threading import time from abc import ABC, abstractmethod from datetime import timedelta @@ -16,13 +15,13 @@ from pycram.datastructures.enums import ObjectType from pycram.datastructures.pose import Pose from pycram.ros.logging import logdebug -from pycram.world_concepts.world_object import Object, Link +from pycram.world_concepts.world_object import Object from .event_logger import EventLogger from .events import Event, ContactEvent, LossOfContactEvent, PickUpEvent, AgentContactEvent, \ - AgentLossOfContactEvent, EventUnion, LossOfSurfaceEvent, MotionEvent, StopMotionEvent, NewObjectEvent, \ - RotationEvent, StopRotationEvent -from .utils import get_angle_between_vectors, calculate_euclidean_distance, calculate_angle_between_quaternions, \ - calculate_quaternion_difference + AgentLossOfContactEvent, EventUnion, LossOfSurfaceEvent, TranslationEvent, StopTranslationEvent, NewObjectEvent, \ + RotationEvent, StopRotationEvent, PlacingEvent, MotionEvent, StopMotionEvent +from .utils import get_angle_between_vectors, calculate_euclidean_distance, calculate_quaternion_difference, \ + check_if_object_is_supported class PrimitiveEventDetector(threading.Thread, ABC): @@ -358,7 +357,7 @@ def trigger_events(self, contact_points: ContactPointsList) -> List[LossOfSurfac class MotionDetector(PrimitiveEventDetector, ABC): """ - A thread that detects if the object starts or stops moving and logs the MotionEvent or StopMotionEvent. + A thread that detects if the object starts or stops moving and logs the TranslationEvent or StopTranslationEvent. """ thread_prefix = "motion_" @@ -405,11 +404,11 @@ def get_current_pose_and_time(self) -> Tuple[Pose, float]: """ return self.tracked_object.pose, time.time() - def detect_events(self) -> List[Union[MotionEvent, StopMotionEvent]]: + def detect_events(self) -> List[Union[TranslationEvent, StopTranslationEvent]]: """ Detect if the object starts or stops moving. - :return: An instance of the MotionEvent class that represents the event if the object is moving, else None. + :return: An instance of the TranslationEvent class that represents the event if the object is moving, else None. """ if self.time_since_last_event < self.measure_timestep.total_seconds(): time.sleep(self.measure_timestep.total_seconds() - self.time_since_last_event) @@ -432,11 +431,11 @@ def is_moving(self) -> bool: distance = self.calculate_distance(self.tracked_object.pose) return distance > self.distance_threshold - def create_event(self) -> Union[MotionEvent, StopMotionEvent]: + def create_event(self) -> Union[TranslationEvent, StopTranslationEvent]: """ Create a motion event. - :return: An instance of the MotionEvent class that represents the event. + :return: An instance of the TranslationEvent class that represents the event. """ current_pose, current_time = self.get_current_pose_and_time() event_type = self.get_event_type() @@ -464,7 +463,7 @@ def calculate_distance(self, current_pose: Pose): return calculate_euclidean_distance(self.latest_pose.position_as_list(), current_pose.position_as_list()) def get_event_type(self): - return MotionEvent if self.was_moving else StopMotionEvent + return TranslationEvent if self.was_moving else StopTranslationEvent class RotationDetector(MotionDetector): @@ -515,28 +514,78 @@ def start_condition_checker(cls, event: Event) -> bool: """ pass + def check_for_event_post_starter_event(self, event_detector: Type[PrimitiveEventDetector]) -> Optional[EventUnion]: + """ + Check if the tracked_object was involved in an event after the starter event. + + :param event_detector: The event detector class that is used to detect the event. + :return: The event if the tracked_object was involved in an event, else None. + """ + event = get_latest_event_of_detector_for_object(event_detector, self.tracked_object, + after_timestamp=self.start_timestamp) + if event is None: + logdebug(f"{event_detector.__name__} found no event after {self.start_timestamp} with object :" + f" {self.tracked_object.name}") + return None + + return event + + def check_for_event_near_starter_event(self, event_detector: Type[PrimitiveEventDetector], + time_tolerance: timedelta) -> Optional[EventUnion]: + """ + Check if the tracked_object was involved in an event near the starter event (i.e. could be before or after). + + :param event_detector: The event detector class that is used to detect the event. + :param time_tolerance: The time tolerance to consider the event as near the starter event. + :return: The event if the tracked_object was involved in an event, else None. + """ + event = get_nearest_event_of_detector_for_object(event_detector, self.tracked_object, + timestamp=self.start_timestamp, time_tolerance=time_tolerance) + if event is None: + logdebug(f"{event_detector.__name__} found no event after {self.start_timestamp} with object :" + f" {self.tracked_object.name}") + return None + + return event + @property def start_timestamp(self) -> float: return self.starter_event.timestamp + +class MotionPlacingDetector(EventDetector, ABC): + """ + A detector that detects if the tracked_object was placed on a surface by using motion and contact. + """ + + thread_prefix = "placing_" + + def __init__(self, logger: EventLogger, starter_event: EventUnion, *args, **kwargs): + """ + :param logger: An instance of the EventLogger class that is used to log the events. + :param starter_event: An instance of a type of Event that represents the event to + start the event detector. + """ + super().__init__(logger, starter_event, *args, **kwargs) + self.tracked_object = self.get_object_to_place_from_event(starter_event) + self.placing_event: Optional[PlacingEvent] = None + self.run_once = True + @classmethod @abstractmethod - def filter_event(cls, event: EventUnion) -> EventUnion: + def get_object_to_place_from_event(cls, event: Event) -> Object: """ - Filter the event before logging/using it. - - :param event: An object that represents the event. - :return: An object that represents the filtered event. + Get the tracked_object to place from the event. """ pass -class AbstractPickUpDetector(EventDetector, ABC): +class AbstractAgentObjectInteractionDetector(EventDetector, ABC): """ - An abstract detector that detects if the tracked_object was picked up. + An abstract detector that detects an interaction between the agent and an object. """ - thread_prefix = "pick_up_" + thread_prefix = "agent_object_interaction_" """ A string that is used as a prefix for the thread ID. """ @@ -548,105 +597,71 @@ def __init__(self, logger: EventLogger, starter_event: EventUnion, *args, **kwar start the event detector. """ super().__init__(logger, starter_event, *args, **kwargs) - self.tracked_object = self.get_object_to_pick_from_event(starter_event) - self.objects_to_track = [self.tracked_object] - self.pick_up_event: Optional[PickUpEvent] = None + self.tracked_object = self.get_object_to_track_from_starter_event(starter_event) + self.interaction_event: EventUnion = self._init_interaction_event() + self.end_timestamp: Optional[float] = None self.run_once = True - self.break_loop = False - def detect_events(self) -> List[PickUpEvent]: + @abstractmethod + def _init_interaction_event(self) -> EventUnion: + """ + Initialize the interaction event. """ - Detect if the tracked_object was picked up by the hand. - Used Features are: - 1. The hand should still be in contact with the tracked_object. - 2. While the tracked_object that is picked should lose contact with the surface. - Other features that can be used: Grasping Type, Object Type, and Object Motion. + pass - :return: An instance of the PickUpEvent class that represents the event if the tracked_object was picked up, - else None. + def detect_events(self) -> List[EventUnion]: """ + Detect if the tracked_object was interacted with by the agent. - self.pick_up_event = PickUpEvent(self.tracked_object, timestamp=self.start_timestamp) + :return: An instance of the interaction event if the tracked_object was interacted with, else None. + """ while not self.kill_event.is_set(): - if not self.extra_checks(): - if self.break_loop: - break - else: - time.sleep(0.01) - continue + if not self.interaction_checks(): + time.sleep(0.01) + continue - self.pick_up_event.end_timestamp = self.get_end_timestamp() + self.interaction_event.end_timestamp = self.end_timestamp break - rospy.loginfo(f"Object picked up: {self.tracked_object.name}") - - self.fill_pick_up_event() + rospy.loginfo(f"{self.__class__.__name__} detected an interaction with: {self.tracked_object.name}") - return [self.pick_up_event] + return [self.interaction_event] @abstractmethod - def get_end_timestamp(self) -> float: + def interaction_checks(self) -> bool: """ - Get the end timestamp of the pickup event. - """ - pass + Perform checks to determine if the object was interacted with. - @abstractmethod - def fill_pick_up_event(self): - """ - Fill the pickup event with the necessary information. + :return: A boolean value that represents if all the checks passed and the object was interacted with. """ pass + @classmethod @abstractmethod - def extra_checks(self) -> bool: + def get_object_to_track_from_starter_event(cls, starter_event: EventUnion) -> Object: """ - Perform extra checks to determine if the object was picked up. + Get the object to track for interaction from the possible starter event. - :return: A boolean value that represents if all the checks passed and the object was picked up. + :param starter_event: The possible starter event that can be used to get the object to track. """ pass - def check_object_lost_contact_with_surface(self) -> Union[Tuple[Optional[LossOfSurfaceEvent], - Optional[List[Object]]]]: - """ - Check if the tracked_object lost contact with the surface. - - :return: A list of Object instances that represent the objects that lost contact with the tracked_object. - """ - loss_of_surface_event = get_latest_event_of_detector_for_object(LossOfSurfaceDetector, - self.tracked_object, - after_timestamp=self.start_timestamp - ) - if loss_of_surface_event is None: - logdebug(f"continue, tracked_object: {self.tracked_object.name}") - return None, None - - objects_that_lost_contact = loss_of_surface_event.latest_objects_that_got_removed - return loss_of_surface_event, objects_that_lost_contact - - @classmethod - @abstractmethod - def get_object_to_pick_from_event(cls, event: Event) -> Object: - """ - Get the tracked_object to pick up from the event. - """ - pass +class AbstractPickUpDetector(AbstractAgentObjectInteractionDetector, ABC): + """ + An abstract detector that detects if the tracked_object was picked up. + """ - @staticmethod - def select_pickable_objects(objects: List[Object]) -> List[Object]: - """ - Select the objects that can be picked up. + thread_prefix = "pick_up_" + """ + A string that is used as a prefix for the thread ID. + """ - :param objects: A list of Object instances. - """ - return [obj for obj in objects - if obj.obj_type not in [ObjectType.HUMAN, ObjectType.ROBOT, ObjectType.ENVIRONMENT, - ObjectType.IMAGINED_SURFACE]] + def _init_interaction_event(self) -> EventUnion: + return PickUpEvent(self.tracked_object, timestamp=self.start_timestamp) class AgentPickUpDetector(AbstractPickUpDetector): @@ -664,46 +679,11 @@ def __init__(self, logger: EventLogger, starter_event: AgentContactEvent, *args, self.surface_detector = LossOfSurfaceDetector(logger, self.starter_event) self.surface_detector.start() self.agent = starter_event.agent - self.agent_link = starter_event.agent_link - self.object_link = self.get_object_link_from_event(starter_event) - self.end_timestamp: Optional[float] = None - - def get_end_timestamp(self) -> float: - return self.end_timestamp - - def fill_pick_up_event(self): - self.pick_up_event.agent = self.agent + self.interaction_event.agent = self.agent @classmethod - def filter_event(cls, event: AgentContactEvent) -> Event: - """ - Filter the event by removing objects that are not in the list of objects to track. - - :param event: An object that represents the event. - :return: An object that represents the filtered event. - """ - event.with_object = cls.get_object_to_pick_from_event(event) - return event - - @classmethod - def get_object_to_pick_from_event(cls, event: AgentContactEvent) -> Object: - """ - Get the tracked_object link from the event. - - :param event: The AgentContactEvent instance that represents the contact event. - """ - return cls.get_object_link_from_event(event).object - - @classmethod - def get_object_link_from_event(cls, event: AgentContactEvent) -> Link: - """ - Get the tracked_object link from the event. - - :param event: The AgentContactEvent instance that represents the contact event. - """ - pickable_objects = cls.find_pickable_objects_from_contact_event(event) - links_in_contact = event.links - return [link for link in links_in_contact if link.object in pickable_objects][0] + def get_object_to_track_from_starter_event(cls, event: AgentContactEvent) -> Object: + return select_transportable_objects_from_contact_event(event)[0] @classmethod def start_condition_checker(cls, event: Event) -> bool: @@ -712,41 +692,26 @@ def start_condition_checker(cls, event: Event) -> bool: :param event: The ContactEvent instance that represents the contact event. """ - return isinstance(event, AgentContactEvent) and any(cls.find_pickable_objects_from_contact_event(event)) - - @classmethod - def find_pickable_objects_from_contact_event(cls, event: AgentContactEvent) -> List[Object]: - """ - Find the pickable objects from the contact event. - - :param event: The AgentContactEvent instance that represents the contact event. - """ - contacted_objects = event.contact_points.get_objects_that_have_points() - return cls.select_pickable_objects(contacted_objects) + return isinstance(event, AgentContactEvent) and any(select_transportable_objects_from_contact_event(event)) - def extra_checks(self) -> bool: + def interaction_checks(self) -> bool: """ Perform extra checks to determine if the object was picked up. """ - loss_of_surface_event, objects_that_lost_contact = self.check_object_lost_contact_with_surface() + loss_of_surface_event = self.check_for_event_post_starter_event(LossOfSurfaceDetector) - if objects_that_lost_contact is None: - time.sleep(0.01) + if not loss_of_surface_event: return False - if self.agent in objects_that_lost_contact: + if self.agent in loss_of_surface_event.latest_objects_that_got_removed: rospy.logdebug(f"Agent lost contact with tracked_object: {self.tracked_object.name}") - self.break_loop = True + self.kill_event.set() return False self.end_timestamp = loss_of_surface_event.timestamp return True - @property - def start_contact_points(self) -> ContactPointsList: - return self.starter_event.contact_points - def stop(self, timeout: Optional[float] = None): self.surface_detector.stop() self.surface_detector.join(timeout) @@ -762,21 +727,10 @@ def __init__(self, logger: EventLogger, starter_event: LossOfContactEvent, *args detector. """ super().__init__(logger, starter_event, *args, **kwargs) - self.end_timestamp: Optional[float] = None - - def get_end_timestamp(self) -> float: - return self.end_timestamp @classmethod - def get_object_to_pick_from_event(cls, event: LossOfSurfaceEvent) -> Object: - return cls.find_pickable_objects_from_contact_event(event)[0] - - def fill_pick_up_event(self): - pass - - @classmethod - def filter_event(cls, event: LossOfContactEvent) -> Event: - return event + def get_object_to_track_from_starter_event(cls, event: LossOfContactEvent) -> Object: + return select_transportable_objects_from_loss_of_contact_event(event)[0] @classmethod def start_condition_checker(cls, event: Event) -> bool: @@ -785,25 +739,16 @@ def start_condition_checker(cls, event: Event) -> bool: :param event: The ContactEvent instance that represents the contact event. """ - return isinstance(event, LossOfContactEvent) and any(cls.find_pickable_objects_from_contact_event(event)) + return (isinstance(event, LossOfContactEvent) + and any(select_transportable_objects_from_loss_of_contact_event(event))) - @classmethod - def find_pickable_objects_from_contact_event(cls, event: LossOfContactEvent) -> List[Object]: - """ - Find the pickable objects from the contact event. - - :param event: The AgentContactEvent instance that represents the contact event. - """ - return cls.select_pickable_objects(event.latest_objects_that_got_removed + [event.tracked_object]) - - def extra_checks(self) -> bool: + def interaction_checks(self) -> bool: """ Check for upward motion after the object lost contact with the surface. """ - latest_event = get_nearest_event_of_detector_for_object(TranslationDetector, self.tracked_object, - self.start_timestamp, timedelta(milliseconds=1000)) + latest_event = self.check_for_event_near_starter_event(TranslationDetector, timedelta(milliseconds=1000)) - if latest_event is None: + if not latest_event: return False z_motion = latest_event.current_pose.position.z - latest_event.start_pose.position.z @@ -814,6 +759,50 @@ def extra_checks(self) -> bool: return True +class PlacingDetector(AbstractAgentObjectInteractionDetector): + """ + An abstract detector that detects if the tracked_object was placed by the agent. + """ + + thread_prefix = "placing_" + + def _init_interaction_event(self) -> EventUnion: + return PlacingEvent(self.tracked_object, timestamp=self.start_timestamp) + + def interaction_checks(self) -> bool: + return self.initial_interaction_checkers() + + @classmethod + def get_object_to_track_from_starter_event(cls, starter_event: MotionEvent) -> Object: + return starter_event.tracked_object + + @classmethod + def start_condition_checker(cls, event: Event) -> bool: + """ + Check if an agent is in contact with the tracked_object. + + :param event: The ContactEvent instance that represents the contact event. + """ + if isinstance(event, MotionEvent) and any(select_transportable_objects([event.tracked_object])): + if not check_if_object_is_supported(event.tracked_object): + print('new placing detector') + return True + return False + + def initial_interaction_checkers(self) -> bool: + """ + Perform initial checks to determine if the object was placed. + """ + contact_event = self.check_for_event_post_starter_event(ContactDetector) + print(f"contact_event: {contact_event}") + if contact_event and check_if_object_is_supported(self.tracked_object): + self.end_timestamp = contact_event.timestamp + print(f"end_timestamp: {self.end_timestamp}") + return True + + return False + + def check_for_supporting_surface(objects_that_lost_contact: List[Object], initial_contact_points: ContactPointsList) -> Optional[Object]: """ @@ -878,6 +867,37 @@ def get_nearest_event_of_detector_for_object(detector_type: Type[PrimitiveEventD return event +def select_transportable_objects_from_contact_event(event: Union[ContactEvent, AgentContactEvent]) -> List[Object]: + """ + Select the objects that can be transported from the contact event. + + :param event: The contact event + """ + contacted_objects = event.contact_points.get_objects_that_have_points() + return select_transportable_objects(contacted_objects + [event.tracked_object]) + + +def select_transportable_objects_from_loss_of_contact_event(event: Union[LossOfContactEvent, + AgentLossOfContactEvent, + LossOfSurfaceEvent]) -> List[Object]: + """ + Select the objects that can be transported from the loss of contact event. + """ + objects_that_lost_contact = event.latest_objects_that_got_removed + return select_transportable_objects(objects_that_lost_contact + [event.tracked_object]) + + +def select_transportable_objects(objects: List[Object]) -> List[Object]: + """ + Select the objects that can be transported + + :param objects: A list of Object instances. + """ + return [obj for obj in objects + if obj.obj_type not in [ObjectType.HUMAN, ObjectType.ROBOT, ObjectType.ENVIRONMENT, + ObjectType.IMAGINED_SURFACE]] + + EventDetectorUnion = Union[ContactDetector, LossOfContactDetector, LossOfSurfaceDetector, MotionDetector, TranslationDetector, RotationDetector, NewObjectDetector, AgentPickUpDetector, MotionPickUpDetector, EventDetector] TypeEventDetectorUnion = Union[Type[ContactDetector], Type[LossOfContactDetector], Type[LossOfSurfaceDetector], diff --git a/src/episode_segmenter/events.py b/src/episode_segmenter/events.py index 65ee7f0..a6d50a0 100644 --- a/src/episode_segmenter/events.py +++ b/src/episode_segmenter/events.py @@ -96,7 +96,7 @@ def __str__(self): return f"{self.__class__.__name__}: {self.tracked_object.name}" -class MotionEvent(Event): +class MotionEvent(Event, ABC): """ The MotionEvent class is used to represent an event that involves an object that was stationary and then moved or vice versa. @@ -117,24 +117,20 @@ def __eq__(self, other): and self.timestamp == other.timestamp) def __hash__(self): - return hash((self.tracked_object, self.start_pose, self.timestamp)) + return hash((self.tracked_object, self.timestamp)) def set_color(self, color: Optional[Color] = None): color = color if color is not None else self.color self.tracked_object.set_color(color) - @property - def color(self) -> Color: - return Color(0, 1, 1, 1) - def __str__(self): return f"{self.__class__.__name__}: {self.tracked_object.name} - {self.timestamp}" -class StopMotionEvent(MotionEvent): +class TranslationEvent(MotionEvent): @property def color(self) -> Color: - return Color(1, 1, 1, 1) + return Color(0, 1, 1, 1) class RotationEvent(MotionEvent): @@ -143,6 +139,16 @@ def color(self) -> Color: return Color(1, 1, 0, 1) +class StopMotionEvent(MotionEvent): + @property + def color(self) -> Color: + return Color(1, 1, 1, 1) + + +class StopTranslationEvent(StopMotionEvent): + ... + + class StopRotationEvent(StopMotionEvent): ... @@ -308,24 +314,24 @@ def __init__(self, contact_points: ContactPointsList, self.surface: Optional[Object] = surface -class PickUpEvent(Event): +class AbstractAgentObjectInteractionEvent(Event, ABC): - def __init__(self, picked_object: Object, + def __init__(self, participating_object: Object, agent: Optional[Object] = None, timestamp: Optional[float] = None): super().__init__(timestamp) self.agent: Optional[Object] = agent - self.picked_object: Object = picked_object + self.participating_object: Object = participating_object self.end_timestamp: Optional[float] = None self.text_id: Optional[int] = None def __eq__(self, other): if not isinstance(other, self.__class__): return False - return self.agent == other.agent and self.picked_object == other.picked_object + return self.agent == other.agent and self.participating_object == other.participating_object def __hash__(self): - return hash((self.agent, self.picked_object, self.__class__)) + return hash((self.agent, self.participating_object, self.__class__)) def record_end_timestamp(self): self.end_timestamp = time.time() @@ -339,20 +345,30 @@ def set_color(self, color: Optional[Color] = None): color = color if color is not None else self.color if self.agent is not None: self.agent.set_color(color) - self.picked_object.set_color(color) - - @property - def color(self) -> Color: - return Color(0, 1, 0, 1) + self.participating_object.set_color(color) def __str__(self): - return f"Pick up event: Object: {self.picked_object.name}, Timestamp: {self.timestamp}" + \ + return f"{self.__class__.__name__}: Object: {self.participating_object.name}, Timestamp: {self.timestamp}" + \ (f", Agent: {self.agent.name}" if self.agent is not None else "") def __repr__(self): return self.__str__() +class PickUpEvent(AbstractAgentObjectInteractionEvent): + + @property + def color(self) -> Color: + return Color(0, 1, 0, 1) + + +class PlacingEvent(AbstractAgentObjectInteractionEvent): + + @property + def color(self) -> Color: + return Color(1, 0, 1, 1) + + # Create a type that is the union of all event types EventUnion = Union[NewObjectEvent, MotionEvent, @@ -362,4 +378,5 @@ def __repr__(self): AgentContactEvent, AgentLossOfContactEvent, LossOfSurfaceEvent, - PickUpEvent] + PickUpEvent, + PlacingEvent] diff --git a/src/episode_segmenter/utils.py b/src/episode_segmenter/utils.py index eb4aef7..ebb6b21 100644 --- a/src/episode_segmenter/utils.py +++ b/src/episode_segmenter/utils.py @@ -1,8 +1,52 @@ import numpy as np -from tf.transformations import quaternion_inverse, quaternion_multiply, euler_from_quaternion -from typing_extensions import List +from tf.transformations import quaternion_inverse, quaternion_multiply +from typing_extensions import List, Optional from pycram.datastructures.pose import Transform +from pycram.datastructures.world import World, UseProspectionWorld +from pycram.datastructures.enums import ObjectType +from pycram.world_concepts.world_object import Object +from pycram.ros.logging import logdebug +from pycram.object_descriptors.generic import ObjectDescription as GenericObjectDescription + + +def check_if_object_is_supported(obj: Object) -> bool: + """ + Check if the object is supported by any other object. + + :param obj: The object to check if it is supported. + :return: True if the object is supported, False otherwise. + """ + supported = True + with UseProspectionWorld(): + prospection_obj = World.current_world.get_prospection_object_for_object(obj) + current_position = prospection_obj.get_position_as_list() + World.current_world.simulate(1) + new_position = prospection_obj.get_position_as_list() + if current_position[2] - new_position[2] >= 0.2: + logdebug(f"Object {obj.name} is not supported") + supported = False + return supported + + +def add_imaginary_support_for_object(obj: Object, + support_name: Optional[str] = f"imagined_support", + support_thickness: Optional[float] = 0.005) -> Object: + """ + Add an imaginary support for the object. + + :param obj: The object for which the support should be added. + :param support_name: The name of the support object. + :param support_thickness: The thickness of the support. + :return: The support object. + """ + obj_base_position = obj.get_base_position_as_list() + support = GenericObjectDescription(support_name, [0, 0, 0], [1, 1, obj_base_position[2]*0.5]) + support_obj = Object(support_name, ObjectType.IMAGINED_SURFACE, None, support) + support_position = obj_base_position.copy() + support_position[2] = obj_base_position[2] * 0.5 + support_obj.set_position(support_position) + return support_obj def get_angle_between_vectors(vector_1: List[float], vector_2: List[float]) -> float: diff --git a/test/test_neem_segmenter.py b/test/test_neem_segmenter.py index 31d8d19..145b002 100644 --- a/test/test_neem_segmenter.py +++ b/test/test_neem_segmenter.py @@ -3,7 +3,7 @@ from neem_pycram_interface import PyCRAMNEEMInterface -from episode_segmenter.event_detectors import AgentPickUpDetector +from episode_segmenter.event_detectors import AgentPickUpDetector, PlacingDetector from episode_segmenter.neem_segmenter import NEEMSegmenter from unittest import TestCase @@ -19,7 +19,7 @@ class TestNEEMSegmentor(TestCase): @classmethod def setUpClass(cls): - BulletWorld(WorldMode.DIRECT) + BulletWorld(WorldMode.GUI) pni = PyCRAMNEEMInterface(f'mysql+pymysql://{os.environ["my_maria_uri"]}') cls.ns = NEEMSegmenter(pni, detectors_to_start=[AgentPickUpDetector], annotate_events=True) cls.viz_mark_publisher = VizMarkerPublisher()