From 67cbec96f09b5609d446a6196cbec24483680f95 Mon Sep 17 00:00:00 2001 From: tlam Date: Thu, 16 Jan 2025 18:15:07 +0100 Subject: [PATCH] rename ZStabCamera to ZStabilizedCamera and rename targeted_flies_id to targeted_fly_names and make it a list of str --- flygym/__init__.py | 2 +- flygym/camera.py | 19 +++++------ .../closed_loop_deployment.py | 8 ++--- .../collect_training_data.py | 2 +- .../locomotion/controller_comparison.py | 2 +- flygym/examples/locomotion/cpg_controller.py | 2 +- .../examples/locomotion/hybrid_controller.py | 2 +- .../locomotion/rule_based_controller.py | 2 +- .../examples/locomotion/turning_controller.py | 2 +- flygym/examples/locomotion/turning_fly.py | 2 +- .../examples/path_integration/exploration.py | 6 ++-- .../vision/record_baseline_response.py | 2 +- flygym/simulation.py | 11 +++--- notebooks/advanced_vision.ipynb | 15 ++++---- notebooks/cpg_controller.ipynb | 20 ++++++++--- .../gym_basics_and_kinematic_replay.ipynb | 11 ++++-- notebooks/head_stabilization.ipynb | 12 ++++--- notebooks/hybrid_controller.ipynb | 34 +++++++++++-------- notebooks/rule_based_controller.ipynb | 14 ++++---- notebooks/turning.ipynb | 12 ++++--- 20 files changed, 103 insertions(+), 77 deletions(-) diff --git a/flygym/__init__.py b/flygym/__init__.py index a1ba8f6f..f6e3d40c 100644 --- a/flygym/__init__.py +++ b/flygym/__init__.py @@ -1,6 +1,6 @@ from .simulation import Simulation, SingleFlySimulation from .fly import Fly -from .camera import Camera, YawOnlyCamera, ZStabCamera, GravityAlignedCamera +from .camera import Camera, YawOnlyCamera, ZStabilizedCamera, GravityAlignedCamera from .util import get_data_path, load_config from dm_control.rl.control import PhysicsError diff --git a/flygym/camera.py b/flygym/camera.py index 0600b636..2135bc3a 100644 --- a/flygym/camera.py +++ b/flygym/camera.py @@ -7,7 +7,6 @@ import cv2 -import dm_control.mujoco import imageio import numpy as np from dm_control import mjcf @@ -17,8 +16,6 @@ from typing import Tuple, List, Dict, Any, Optional -from abc import ABC, abstractmethod - # Would like it to always draw gravity in the upper right corner # Check if contact need to be drawn (outside of the image) # New gravity camera @@ -33,7 +30,7 @@ def __init__( attachment_point: mjcf.element._AttachableElement, camera_name: str, attachment_name: str = None, - targeted_flies_id: int = [], + targeted_fly_names: List[str] = [], window_size: Tuple[int, int] = (640, 480), play_speed: float = 0.2, fps: int = 30, @@ -71,7 +68,7 @@ def __init__( Attachment point pf the camera attachment_name : str Name of the attachment point - targeted_flies_id: List(int) + targeted_fly_names: List[str] Index of the flies the camera is looking at. The first index is the focused fly that is tracked if using a complex camera. The rest of the indices are used to draw the contact forces. camera_name : str @@ -120,7 +117,7 @@ def __init__( video will not be saved. By default None. """ self.attachment_point = attachment_point - self.targeted_flies_id = targeted_flies_id + self.targeted_fly_names = targeted_fly_names config = util.load_config() @@ -160,7 +157,7 @@ def __init__( self.contact_threshold = contact_threshold self.perspective_arrow_length = perspective_arrow_length - if self.draw_contacts and len(self.targeted_flies_id) <= 0: + if self.draw_contacts and len(self.targeted_fly_names) <= 0: logging.warning( "Overriding `draw_contacts` to False because no flies are targeted." ) @@ -233,7 +230,7 @@ def render( img = physics.render(width=width, height=height, camera_id=self.camera_id) img = img.copy() if self.draw_contacts: - for i in range(len(self.targeted_flies_id)): + for i in range(len(self.targeted_fly_names)): img = self._draw_contacts(img, physics, last_obs[i]) if self.draw_gravity: img = self._draw_gravity(img, physics, last_obs[0]["pos"]) @@ -490,13 +487,13 @@ def clip(p_in, p_out, z_clip): return img -class ZStabCamera(Camera): +class ZStabilizedCamera(Camera): """Camera that stabilizes the z-axis of the camera to the floor height.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Raise error if targeted flies are empty - if len(self.targeted_flies_id) == 0: + if len(self.targeted_fly_names) == 0: raise ValueError( "No flies are targeted by the camera. " "Stabilized cameras require at least one fly to target." @@ -523,7 +520,7 @@ def update_camera(self, physics: mjcf.Physics, floor_height: float, obs: dict): return -class YawOnlyCamera(ZStabCamera): +class YawOnlyCamera(ZStabilizedCamera): """Camera that stabilizes the z-axis of the camera to the floor height and only changes the yaw of the camera to follow the fly hereby preventing unnecessary camera rotations. diff --git a/flygym/examples/head_stabilization/closed_loop_deployment.py b/flygym/examples/head_stabilization/closed_loop_deployment.py index c5973995..2df0c538 100644 --- a/flygym/examples/head_stabilization/closed_loop_deployment.py +++ b/flygym/examples/head_stabilization/closed_loop_deployment.py @@ -1,7 +1,7 @@ import numpy as np from pathlib import Path from tqdm import trange -from flygym import Camera, ZStabCamera, SingleFlySimulation +from flygym import Camera, ZStabilizedCamera, SingleFlySimulation from flygym.vision import Retina from flygym.arena import BaseArena, FlatTerrain, BlocksTerrain from typing import Optional @@ -56,11 +56,11 @@ def run_simulation( birdeye_cam_params = {"pos": (0, 0, 20), "euler": (0, 0, 0), "fovy": 45} - birdeye_camera = ZStabCamera( + birdeye_camera = ZStabilizedCamera( attachment_point=fly.model.worldbody, attachment_name=fly.name, camera_name="birdeye_cam", - targeted_flies_id=[0], + targeted_fly_names=[fly.name], camera_parameters=birdeye_cam_params, play_speed=0.2, window_size=(600, 600), @@ -72,7 +72,7 @@ def run_simulation( attachment_point=fly.model.worldbody, attachment_name=fly.name, camera_name="camera_neck_zoomin", - targeted_flies_id=[0], + targeted_fly_names=[fly.name], play_speed=0.2, fps=24, window_size=(600, 600), diff --git a/flygym/examples/head_stabilization/collect_training_data.py b/flygym/examples/head_stabilization/collect_training_data.py index ff1655aa..041762d7 100644 --- a/flygym/examples/head_stabilization/collect_training_data.py +++ b/flygym/examples/head_stabilization/collect_training_data.py @@ -86,7 +86,7 @@ def run_simulation( attachment_point=fly.model.worldbody, camera_name="camera_left", attachment_name=fly.name, - targeted_flies_id=[int(fly.name)], + targeted_fly_names=[fly.name], play_speed=0.1, ) else: diff --git a/flygym/examples/locomotion/controller_comparison.py b/flygym/examples/locomotion/controller_comparison.py index 7970e878..fe79e93d 100644 --- a/flygym/examples/locomotion/controller_comparison.py +++ b/flygym/examples/locomotion/controller_comparison.py @@ -398,7 +398,7 @@ def run_all(arena: str, seed: int, pos: np.ndarray, verbose: bool = False): attachment_point=fly.model.worldbody, camera_name="camera_right", attachment_name=fly.name, - targeted_flies_id=[int(fly.name)], + targeted_fly_names=[fly.name], play_speed=0.1, ) sim = SingleFlySimulation( diff --git a/flygym/examples/locomotion/cpg_controller.py b/flygym/examples/locomotion/cpg_controller.py index 72e57be8..42820bfc 100644 --- a/flygym/examples/locomotion/cpg_controller.py +++ b/flygym/examples/locomotion/cpg_controller.py @@ -176,7 +176,7 @@ def run_cpg_simulation(nmf, cpg_network, preprogrammed_steps, run_time, pbar=Tru attachment_point=fly.model.worldbody, camera_name="camera_right", attachment_name=fly.name, - targeted_flies_id=[int(fly.name)], + targeted_fly_names=[fly.name], play_speed=0.1, ) diff --git a/flygym/examples/locomotion/hybrid_controller.py b/flygym/examples/locomotion/hybrid_controller.py index 7ecb83af..dae08eae 100644 --- a/flygym/examples/locomotion/hybrid_controller.py +++ b/flygym/examples/locomotion/hybrid_controller.py @@ -219,7 +219,7 @@ def run_hybrid_simulation(sim, cpg_network, preprogrammed_steps, run_time): attachment_point=fly.model.worldbody, camera_name="camera_right", attachment_name=fly.name, - targeted_flies_id=[int(fly.name)], + targeted_fly_names=[fly.name], play_speed=0.1, ) diff --git a/flygym/examples/locomotion/rule_based_controller.py b/flygym/examples/locomotion/rule_based_controller.py index 21adfabf..1be913c0 100644 --- a/flygym/examples/locomotion/rule_based_controller.py +++ b/flygym/examples/locomotion/rule_based_controller.py @@ -296,7 +296,7 @@ def run_rule_based_simulation(sim, controller, run_time, pbar=True): attachment_point=fly.model.worldbody, camera_name="camera_right", attachment_name=fly.name, - targeted_flies_id=[int(fly.name)], + targeted_fly_names=[fly.name], play_speed=0.1, ) diff --git a/flygym/examples/locomotion/turning_controller.py b/flygym/examples/locomotion/turning_controller.py index 80f87776..20b1cee7 100644 --- a/flygym/examples/locomotion/turning_controller.py +++ b/flygym/examples/locomotion/turning_controller.py @@ -494,7 +494,7 @@ def __init__(self, *args, **kwargs): attachment_point=fly.model.worldbody, camera_name="camera_right", attachment_name=fly.name, - targeted_flies_id=[int(fly.name)], + targeted_fly_names=[fly.name], play_speed=0.1, ) diff --git a/flygym/examples/locomotion/turning_fly.py b/flygym/examples/locomotion/turning_fly.py index 7327b78d..702cb4bd 100644 --- a/flygym/examples/locomotion/turning_fly.py +++ b/flygym/examples/locomotion/turning_fly.py @@ -498,7 +498,7 @@ def post_step(self, sim: "Simulation"): attachment_point=fly.model.worldbody, camera_name="camera_right", attachment_name=fly.name, - targeted_flies_id=[int(fly.name)], + targeted_fly_names=[fly.name], play_speed=0.1, ) diff --git a/flygym/examples/path_integration/exploration.py b/flygym/examples/path_integration/exploration.py index a52cd09a..83eb2219 100644 --- a/flygym/examples/path_integration/exploration.py +++ b/flygym/examples/path_integration/exploration.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Optional -from flygym import Fly, ZStabCamera, is_rendering_skipped +from flygym import Fly, ZStabilizedCamera, is_rendering_skipped from flygym.util import get_data_path from flygym.preprogrammed import get_cpg_biases from flygym.examples.path_integration.arena import ( @@ -129,9 +129,9 @@ def run_simulation( cam_params = {"mode": "track", "pos": (0, 0, 150), "euler": (0, 0, 0), "fovy": 60} - cam = ZStabCamera( + cam = ZStabilizedCamera( attachment_point=fly.model.worldbody, - targeted_flies_id=[0], + targeted_fly_names=[fly.name], attachment_name=fly.name, camera_name="birdeye_cam", timestamp_text=False, diff --git a/flygym/examples/vision/record_baseline_response.py b/flygym/examples/vision/record_baseline_response.py index 468e7f56..5c4e52be 100644 --- a/flygym/examples/vision/record_baseline_response.py +++ b/flygym/examples/vision/record_baseline_response.py @@ -69,7 +69,7 @@ def run_simulation( attachment_point=fly.model.worldbody, camera_name="camera_top", attachment_name=fly.name, - targeted_flies_id=[int(fly.name)], + targeted_fly_names=[fly.name], play_speed=0.1, ) diff --git a/flygym/simulation.py b/flygym/simulation.py index 4909c3c3..b3a706f8 100644 --- a/flygym/simulation.py +++ b/flygym/simulation.py @@ -238,19 +238,18 @@ def step( ) def render(self): - all_flies_obs = [] - for i, fly in enumerate(self.flies): + all_flies_obs = {} + for fly in self.flies: fly.update_colors(self.physics) - all_flies_obs.append(fly.last_obs) - all_flies_obs = np.array(all_flies_obs) + all_flies_obs[fly.name] = fly.last_obs return [ camera.render( self.physics, self._floor_height, self.curr_time, - all_flies_obs[camera.targeted_flies_id] - if camera.targeted_flies_id + [all_flies_obs[name] for name in camera.targeted_fly_names] + if camera.targeted_fly_names else [{}], ) for camera in self.cameras diff --git a/notebooks/advanced_vision.ipynb b/notebooks/advanced_vision.ipynb index 65528a54..be9440a6 100644 --- a/notebooks/advanced_vision.ipynb +++ b/notebooks/advanced_vision.ipynb @@ -83,12 +83,13 @@ ")\n", "\n", "flies = [observer_fly, target_fly]\n", - "cam = Camera(attachment_point=observer_fly.model.worldbody,\n", - " camera_name=\"camera_top_zoomout\",\n", - " attachment_name=observer_fly.name,\n", - " targeted_flies_id=[0],\n", - " play_speed=0.1,\n", - " )\n", + "cam = Camera(\n", + " attachment_point=observer_fly.model.worldbody,\n", + " camera_name=\"camera_top_zoomout\",\n", + " attachment_name=observer_fly.name,\n", + " targeted_fly_names=[observer_fly.name],\n", + " play_speed=0.1,\n", + ")\n", "\n", "sim = Simulation(\n", " flies=flies,\n", @@ -178,7 +179,7 @@ "cam = Camera(attachment_point=observer_fly.model.worldbody,\n", " camera_name=\"camera_top_zoomout\",\n", " attachment_name=observer_fly.name,\n", - " targeted_flies_id=[0],\n", + " targeted_fly_names=[observer_fly.name],\n", " play_speed=0.1,\n", " )\n", "\n", diff --git a/notebooks/cpg_controller.ipynb b/notebooks/cpg_controller.ipynb index a06b1926..abe9bf64 100644 --- a/notebooks/cpg_controller.ipynb +++ b/notebooks/cpg_controller.ipynb @@ -583,13 +583,18 @@ "metadata": {}, "outputs": [], "source": [ - "from flygym import Fly, ZStabCamera, SingleFlySimulation\n", + "from flygym import Fly, ZStabilizedCamera, SingleFlySimulation\n", "from flygym.preprogrammed import all_leg_dofs\n", "\n", "run_time = 1\n", "fly = Fly(init_pose=\"stretch\", actuated_joints=all_leg_dofs, control=\"position\")\n", - "cam = ZStabCamera(attachment_point=fly.model.worldbody, camera_name=\"camera_left\", attachment_name=fly.name,\n", - " targeted_flies_id=[0], play_speed=0.1)\n", + "cam = ZStabilizedCamera(\n", + " attachment_point=fly.model.worldbody,\n", + " camera_name=\"camera_left\",\n", + " attachment_name=fly.name,\n", + " targeted_fly_names=[fly.name],\n", + " play_speed=0.1,\n", + ")\n", "sim = SingleFlySimulation(fly=fly, cameras=[cam], timestep=1e-4)" ] }, @@ -819,8 +824,13 @@ " enable_adhesion=True,\n", " draw_adhesion=True,\n", ")\n", - "cam = ZStabCamera(attachment_point=fly.model.worldbody, camera_name=\"camera_left\", attachment_name=fly.name,\n", - " targeted_flies_id=[0], play_speed=0.1)\n", + "cam = ZStabilizedCamera(\n", + " attachment_point=fly.model.worldbody,\n", + " camera_name=\"camera_left\",\n", + " attachment_name=fly.name,\n", + " targeted_fly_names=[fly.name],\n", + " play_speed=0.1,\n", + ")\n", "sim = SingleFlySimulation(fly=fly, cameras=[cam], timestep=1e-4)\n", "cpg_network.reset()\n", "\n", diff --git a/notebooks/gym_basics_and_kinematic_replay.ipynb b/notebooks/gym_basics_and_kinematic_replay.ipynb index 30cd9a0f..0461a802 100644 --- a/notebooks/gym_basics_and_kinematic_replay.ipynb +++ b/notebooks/gym_basics_and_kinematic_replay.ipynb @@ -92,7 +92,7 @@ "from pathlib import Path\n", "from tqdm import trange\n", "\n", - "from flygym import Fly, ZStabCamera, SingleFlySimulation, get_data_path\n", + "from flygym import Fly, ZStabilizedCamera, SingleFlySimulation, get_data_path\n", "from flygym.preprogrammed import all_leg_dofs" ] }, @@ -228,8 +228,13 @@ "source": [ "fly = Fly(init_pose=\"stretch\", actuated_joints=actuated_joints, control=\"position\")\n", "\n", - "cam = ZStabCamera(attachment_point=fly.model.worldbody, camera_name=\"camera_left\", attachment_name=fly.name,\n", - " targeted_flies_id=[0], play_speed=0.1)\n", + "cam = ZStabilizedCamera(\n", + " attachment_point=fly.model.worldbody,\n", + " camera_name=\"camera_left\",\n", + " attachment_name=fly.name,\n", + " targeted_fly_names=[fly.name],\n", + " play_speed=0.1,\n", + ")\n", "sim = SingleFlySimulation(\n", " fly=fly,\n", " cameras=[cam],\n", diff --git a/notebooks/head_stabilization.ipynb b/notebooks/head_stabilization.ipynb index 825d418f..65d2a0f0 100644 --- a/notebooks/head_stabilization.ipynb +++ b/notebooks/head_stabilization.ipynb @@ -52,7 +52,7 @@ "from dm_control.utils import transformations\n", "from dm_control.rl.control import PhysicsError\n", "\n", - "from flygym import Fly, ZStabCamera\n", + "from flygym import Fly, ZStabilizedCamera\n", "from flygym.arena import FlatTerrain, BlocksTerrain\n", "from flygym.preprogrammed import get_cpg_biases\n", "from flygym.examples.locomotion import HybridTurningController\n", @@ -126,9 +126,13 @@ " contact_sensor_placements=contact_sensor_placements,\n", " spawn_pos=(*spawn_xy, 0.25),\n", " )\n", - " cam = ZStabCamera(attachment_point=fly.model.worldbody,\n", - " camera_name=\"camera_left\", attachment_name=fly.name,\n", - " targeted_flies_id=[0], play_speed=0.1)\n", + " cam = ZStabilizedCamera(\n", + " attachment_point=fly.model.worldbody,\n", + " camera_name=\"camera_left\",\n", + " attachment_name=fly.name,\n", + " targeted_fly_names=[fly.name],\n", + " play_speed=0.1,\n", + " )\n", " sim = HybridTurningController(\n", " arena=arena,\n", " phase_biases=get_cpg_biases(gait),\n", diff --git a/notebooks/hybrid_controller.ipynb b/notebooks/hybrid_controller.ipynb index a0c35fed..3120cd94 100644 --- a/notebooks/hybrid_controller.ipynb +++ b/notebooks/hybrid_controller.ipynb @@ -82,7 +82,7 @@ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from tqdm import tqdm, trange\n", - "from flygym import Fly, ZStabCamera, SingleFlySimulation\n", + "from flygym import Fly, ZStabilizedCamera, SingleFlySimulation\n", "from flygym.examples.locomotion import PreprogrammedSteps\n", "from pathlib import Path\n", "\n", @@ -103,11 +103,13 @@ " init_pose=\"stretch\",\n", " control=\"position\",\n", " )\n", - " cam = ZStabCamera(\n", + " cam = ZStabilizedCamera(\n", " attachment_point=fly.model.worldbody,\n", - " camera_name=\"camera_left\", attachment_name=fly.name,\n", - " targeted_flies_id=[0], play_speed=0.1\n", - " )\n", + " camera_name=\"camera_left\",\n", + " attachment_name=fly.name,\n", + " targeted_fly_names=[fly.name],\n", + " play_speed=0.1,\n", + " )\n", " sim = SingleFlySimulation(\n", " fly=fly,\n", " cameras=[cam],\n", @@ -298,11 +300,13 @@ " init_pose=\"stretch\",\n", " control=\"position\",\n", " )\n", - " cam = ZStabCamera(\n", + " cam = ZStabilizedCamera(\n", " attachment_point=fly.model.worldbody,\n", - " camera_name=\"camera_left\", attachment_name=fly.name,\n", - " targeted_flies_id=[0], play_speed=0.1\n", - " )\n", + " camera_name=\"camera_left\",\n", + " attachment_name=fly.name,\n", + " targeted_fly_names=[fly.name],\n", + " play_speed=0.1,\n", + " )\n", " sim = SingleFlySimulation(\n", " fly=fly,\n", " cameras=[cam],\n", @@ -648,11 +652,13 @@ " control=\"position\",\n", " contact_sensor_placements=contact_sensor_placements,\n", ")\n", - "cam = ZStabCamera(\n", - " attachment_point=fly.model.worldbody,\n", - " camera_name=\"camera_left\", attachment_name=fly.name,\n", - " targeted_flies_id=[0], play_speed=0.1\n", - " )\n", + "cam = ZStabilizedCamera(\n", + " attachment_point=fly.model.worldbody,\n", + " camera_name=\"camera_left\",\n", + " attachment_name=fly.name,\n", + " targeted_fly_names=[fly.name],\n", + " play_speed=0.1,\n", + ")\n", "arena = MixedTerrain()\n", "sim = SingleFlySimulation(\n", " fly=fly,\n", diff --git a/notebooks/rule_based_controller.ipynb b/notebooks/rule_based_controller.ipynb index ef36842d..1d967ad9 100644 --- a/notebooks/rule_based_controller.ipynb +++ b/notebooks/rule_based_controller.ipynb @@ -590,7 +590,7 @@ } ], "source": [ - "from flygym import Fly, ZStabCamera, SingleFlySimulation\n", + "from flygym import Fly, ZStabilizedCamera, SingleFlySimulation\n", "from flygym.preprogrammed import all_leg_dofs\n", "from tqdm import trange\n", "\n", @@ -610,11 +610,13 @@ " draw_adhesion=True,\n", ")\n", "\n", - "cam = ZStabCamera(\n", - " attachment_point=fly.model.worldbody,\n", - " camera_name=\"camera_left\", attachment_name=fly.name,\n", - " targeted_flies_id=[0], play_speed=0.1\n", - " )\n", + "cam = ZStabilizedCamera(\n", + " attachment_point=fly.model.worldbody,\n", + " camera_name=\"camera_left\",\n", + " attachment_name=fly.name,\n", + " targeted_fly_names=[fly.name],\n", + " play_speed=0.1,\n", + ")\n", "\n", "sim = SingleFlySimulation(\n", " fly=fly,\n", diff --git a/notebooks/turning.ipynb b/notebooks/turning.ipynb index d5d30356..a52f5629 100644 --- a/notebooks/turning.ipynb +++ b/notebooks/turning.ipynb @@ -89,7 +89,7 @@ "from gymnasium import spaces\n", "from gymnasium.utils.env_checker import check_env\n", "\n", - "from flygym import Fly, ZStabCamera, SingleFlySimulation\n", + "from flygym import Fly, ZStabilizedCamera, SingleFlySimulation\n", "from flygym.examples.locomotion import PreprogrammedSteps\n", "from flygym.examples.locomotion.cpg_controller import CPGNetwork\n", "\n", @@ -583,11 +583,13 @@ " spawn_pos=(0, 0, 0.2),\n", ")\n", "\n", - "cam = ZStabCamera(\n", + "cam = ZStabilizedCamera(\n", " attachment_point=fly.model.worldbody,\n", - " camera_name=\"camera_top\", attachment_name=fly.name,\n", - " targeted_flies_id=[0], play_speed=0.1\n", - " )\n", + " camera_name=\"camera_top\",\n", + " attachment_name=fly.name,\n", + " targeted_fly_names=[fly.name],\n", + " play_speed=0.1,\n", + ")\n", "\n", "nmf = HybridTurningController(\n", " fly=fly,\n",