Skip to content

Commit

Permalink
[EpisodeSegmenter] Placing in progress, need to change the design to …
Browse files Browse the repository at this point in the history
…focus on object states not events.
  • Loading branch information
AbdelrhmanBassiouny committed Oct 24, 2024
1 parent 719ce56 commit 80dab7c
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 221 deletions.
34 changes: 12 additions & 22 deletions src/episode_segmenter/episode_segmenter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime

from .event_detectors import EventDetectorUnion, TypeEventDetectorUnion, MotionPickUpDetector, TranslationDetector, \
RotationDetector
RotationDetector, PlacingDetector
import time
from abc import ABC, abstractmethod

Expand All @@ -18,7 +18,8 @@
AbstractContactDetector
from .event_logger import EventLogger
from .events import ContactEvent, Event, AgentContactEvent, PickUpEvent, EventUnion, StopMotionEvent, MotionEvent, \
NewObjectEvent, RotationEvent, StopRotationEvent
NewObjectEvent, RotationEvent, StopRotationEvent, PlacingEvent
from .utils import check_if_object_is_supported, add_imaginary_support_for_object


class EpisodeSegmenter(ABC):
Expand All @@ -35,7 +36,7 @@ def __init__(self, episode_player: EpisodePlayer,
"""
self.episode_player: EpisodePlayer = episode_player
self.detectors_to_start: List[Type[EventDetector]] = detectors_to_start
self.logger = EventLogger(annotate_events, [PickUpEvent, RotationEvent, StopRotationEvent])
self.logger = EventLogger(annotate_events, [PickUpEvent, PlacingEvent])
self.objects_to_avoid = ['particle', 'floor', 'kitchen'] # TODO: Make it a function, to be more general
self.tracked_objects: List[Object] = []
self.tracked_object_contacts: Dict[Object, List[Type[AbstractContactDetector]]] = {}
Expand Down Expand Up @@ -103,8 +104,7 @@ def start_triggered_detectors(self, event: EventUnion) -> None:
"""
for event_detector in self.detectors_to_start:
if event_detector.start_condition_checker(event):
filtered_event = event_detector.filter_event(event)
self.start_detector_thread_for_starter_event(filtered_event, event_detector)
self.start_detector_thread_for_starter_event(event, event_detector)

@abstractmethod
def _process_event(self, event: Event) -> None:
Expand Down Expand Up @@ -302,7 +302,7 @@ class NoAgentEpisodeSegmenter(EpisodeSegmenter):
def __init__(self, episode_player: EpisodePlayer, detectors_to_start: Optional[List[Type[EventDetector]]] = None,
annotate_events: bool = False):
if detectors_to_start is None:
detectors_to_start = [MotionPickUpDetector]
detectors_to_start = [MotionPickUpDetector, PlacingDetector]
super().__init__(episode_player, detectors_to_start=detectors_to_start, annotate_events=annotate_events)

def start_tracking_threads_for_new_object_and_event(self, new_object: Object, event: EventUnion):
Expand Down Expand Up @@ -333,27 +333,17 @@ def detect_missing_support_for_object(self, obj: Object) -> None:
:param obj: The object to check if it is supported.
"""
supported = True
support_name = f"imagined_support"
support_obj = World.current_world.get_object_by_name(support_name)
support_thickness = 0.005
with UseProspectionWorld():
prospection_obj = World.current_world.get_prospection_object_for_object(obj)
current_position = prospection_obj.get_position_as_list()
World.current_world.simulate(1)
new_position = prospection_obj.get_position_as_list()
if current_position[2] - new_position[2] >= 0.01:
logdebug(f"Object {obj.name} is not supported")
supported = False
supported = check_if_object_is_supported(obj)
if supported:
return
obj_base_position = obj.get_base_position_as_list()
if (not supported) and (support_obj is None):
support = GenericObjectDescription(support_name, [0, 0, 0], [1, 1, support_thickness])
support_obj = Object(support_name, ObjectType.IMAGINED_SURFACE, None, support)
support_position = obj_base_position.copy()
support_position[2] = obj_base_position[2] - support_thickness * 0.5
support_obj.set_position(support_position)
if support_obj is None:
support_obj = add_imaginary_support_for_object(obj, support_name, support_thickness)
self.start_contact_threads_for_object(support_obj)
elif (not supported) and (support_obj is not None):
else:
support_position = support_obj.get_position_as_list()
if obj_base_position[2] <= support_position[2]:
support_position[2] = obj_base_position[2] - support_thickness * 0.5
Expand Down
Loading

0 comments on commit 80dab7c

Please sign in to comment.