Skip to content

Commit

Permalink
[EpisodeSegmenter] adjusting frequency of detector, logger, and player.
Browse files Browse the repository at this point in the history
  • Loading branch information
AbdelrhmanBassiouny committed Oct 25, 2024
1 parent 80dab7c commit 83926e1
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/episode_segmenter/episode_segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def run_event_detectors(self) -> None:

for detector_thread in self.detector_threads_list:
detector_thread.stop()
print(f"Joining {detector_thread.thread_id}, {detector_thread.name}")
logdebug(f"Joining {detector_thread.thread_id}, {detector_thread.name}")
detector_thread.join()
closed_threads = True

Expand Down Expand Up @@ -256,7 +256,7 @@ def join(self):
"""
self.logger.print_events()
self.logger.join()
print("All threads joined.")
logdebug("All threads joined.")


class AgentBasedEpisodeSegmenter(EpisodeSegmenter):
Expand Down
20 changes: 16 additions & 4 deletions src/episode_segmenter/event_detectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def join(self, timeout=None):

class AbstractContactDetector(PrimitiveEventDetector, ABC):
def __init__(self, logger: EventLogger, starter_event: EventUnion, with_object: Optional[Object] = None,
max_closeness_distance: Optional[float] = 0.05, wait_time: Optional[float] = 0.1,
max_closeness_distance: Optional[float] = 0.05, wait_time: Optional[float] = 0.01,
*args, **kwargs):
"""
:param logger: An instance of the EventLogger class that is used to log the events.
Expand Down Expand Up @@ -590,6 +590,11 @@ class AbstractAgentObjectInteractionDetector(EventDetector, ABC):
A string that is used as a prefix for the thread ID.
"""

currently_tracked_objects: List[Object] = []
"""
A list of Object instances that represent the objects that are currently being tracked.
"""

def __init__(self, logger: EventLogger, starter_event: EventUnion, *args, **kwargs):
"""
:param logger: An instance of the EventLogger class that is used to log the events.
Expand All @@ -598,6 +603,7 @@ def __init__(self, logger: EventLogger, starter_event: EventUnion, *args, **kwar
"""
super().__init__(logger, starter_event, *args, **kwargs)
self.tracked_object = self.get_object_to_track_from_starter_event(starter_event)
self.currently_tracked_objects.append(self.tracked_object)
self.interaction_event: EventUnion = self._init_interaction_event()
self.end_timestamp: Optional[float] = None
self.run_once = True
Expand Down Expand Up @@ -676,14 +682,20 @@ def __init__(self, logger: EventLogger, starter_event: AgentContactEvent, *args,
event detector, this is a contact between the agent and the tracked_object.
"""
super().__init__(logger, starter_event, *args, **kwargs)
self.surface_detector = LossOfSurfaceDetector(logger, self.starter_event)
self.surface_detector = LossOfSurfaceDetector(logger, NewObjectEvent(self.tracked_object))
self.surface_detector.start()
self.agent = starter_event.agent
self.interaction_event.agent = self.agent

@classmethod
def get_object_to_track_from_starter_event(cls, event: AgentContactEvent) -> Object:
return select_transportable_objects_from_contact_event(event)[0]
return cls.get_new_transportable_objects(event)[0]

@classmethod
def get_new_transportable_objects(cls, event: AgentContactEvent) -> List[Object]:
transportable_objects = select_transportable_objects_from_contact_event(event)
new_transportable_objects = [obj for obj in transportable_objects if obj not in cls.currently_tracked_objects]
return new_transportable_objects

@classmethod
def start_condition_checker(cls, event: Event) -> bool:
Expand All @@ -692,7 +704,7 @@ def start_condition_checker(cls, event: Event) -> bool:
:param event: The ContactEvent instance that represents the contact event.
"""
return isinstance(event, AgentContactEvent) and any(select_transportable_objects_from_contact_event(event))
return isinstance(event, AgentContactEvent) and any(cls.get_new_transportable_objects(event))

def interaction_checks(self) -> bool:
"""
Expand Down
6 changes: 3 additions & 3 deletions src/episode_segmenter/event_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def get_nearest_event_of_thread(self, thread_id: str, timestamp: float) -> Optio
if thread_id not in self.timeline_per_thread:
return None
all_event_timestamps = [(event, event.timestamp) for event in self.timeline_per_thread[thread_id]]
print(all_event_timestamps)
return min(all_event_timestamps, key=lambda x: abs(x[1] - timestamp))[0]

def get_latest_event_of_thread(self, thread_id: str) -> Optional[Event]:
Expand Down Expand Up @@ -174,8 +173,9 @@ def get_next_z_offset(self):
def run(self):
while not self.kill_event.is_set():
try:
event = self.logger.annotation_queue.get(timeout=1)
event = self.logger.annotation_queue.get(block=False)
except queue.Empty:
time.sleep(0.001)
continue
self.logger.annotation_queue.task_done()
if len(self.current_annotations) >= self.max_annotations:
Expand All @@ -192,7 +192,7 @@ def run(self):
z_offset = self.get_next_z_offset()
text_ann = event.annotate([1.5, 1, z_offset])
self.current_annotations.append(text_ann)
time.sleep(0.1)
time.sleep(0.001)

def stop(self):
self.kill_event.set()
2 changes: 1 addition & 1 deletion src/episode_segmenter/neem_segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def ready(self):

def run(self):
self.pni.replay_motions_in_query(real_time=False,
step_time=datetime.timedelta(milliseconds=1))
step_time=datetime.timedelta(milliseconds=0))


class NEEMSegmenter(AgentBasedEpisodeSegmenter):
Expand Down

0 comments on commit 83926e1

Please sign in to comment.