Skip to content

Commit

Permalink
rename ZStabCamera to ZStabilizedCamera and rename targeted_flies_id …
Browse files Browse the repository at this point in the history
…to targeted_fly_names and make it a list of str
  • Loading branch information
tkclam committed Jan 16, 2025
1 parent 10b9340 commit 67cbec9
Show file tree
Hide file tree
Showing 20 changed files with 103 additions and 77 deletions.
2 changes: 1 addition & 1 deletion flygym/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .simulation import Simulation, SingleFlySimulation
from .fly import Fly
from .camera import Camera, YawOnlyCamera, ZStabCamera, GravityAlignedCamera
from .camera import Camera, YawOnlyCamera, ZStabilizedCamera, GravityAlignedCamera
from .util import get_data_path, load_config
from dm_control.rl.control import PhysicsError

Expand Down
19 changes: 8 additions & 11 deletions flygym/camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@


import cv2
import dm_control.mujoco
import imageio
import numpy as np
from dm_control import mjcf
Expand All @@ -17,8 +16,6 @@

from typing import Tuple, List, Dict, Any, Optional

from abc import ABC, abstractmethod

# Would like it to always draw gravity in the upper right corner
# Check if contact need to be drawn (outside of the image)
# New gravity camera
Expand All @@ -33,7 +30,7 @@ def __init__(
attachment_point: mjcf.element._AttachableElement,
camera_name: str,
attachment_name: str = None,
targeted_flies_id: int = [],
targeted_fly_names: List[str] = [],
window_size: Tuple[int, int] = (640, 480),
play_speed: float = 0.2,
fps: int = 30,
Expand Down Expand Up @@ -71,7 +68,7 @@ def __init__(
Attachment point pf the camera
attachment_name : str
Name of the attachment point
targeted_flies_id: List(int)
targeted_fly_names: List[str]
Index of the flies the camera is looking at. The first index is the focused fly that is tracked if using a
complex camera. The rest of the indices are used to draw the contact forces.
camera_name : str
Expand Down Expand Up @@ -120,7 +117,7 @@ def __init__(
video will not be saved. By default None.
"""
self.attachment_point = attachment_point
self.targeted_flies_id = targeted_flies_id
self.targeted_fly_names = targeted_fly_names

config = util.load_config()

Expand Down Expand Up @@ -160,7 +157,7 @@ def __init__(
self.contact_threshold = contact_threshold
self.perspective_arrow_length = perspective_arrow_length

if self.draw_contacts and len(self.targeted_flies_id) <= 0:
if self.draw_contacts and len(self.targeted_fly_names) <= 0:
logging.warning(
"Overriding `draw_contacts` to False because no flies are targeted."
)
Expand Down Expand Up @@ -233,7 +230,7 @@ def render(
img = physics.render(width=width, height=height, camera_id=self.camera_id)
img = img.copy()
if self.draw_contacts:
for i in range(len(self.targeted_flies_id)):
for i in range(len(self.targeted_fly_names)):
img = self._draw_contacts(img, physics, last_obs[i])
if self.draw_gravity:
img = self._draw_gravity(img, physics, last_obs[0]["pos"])
Expand Down Expand Up @@ -490,13 +487,13 @@ def clip(p_in, p_out, z_clip):
return img


class ZStabCamera(Camera):
class ZStabilizedCamera(Camera):
"""Camera that stabilizes the z-axis of the camera to the floor height."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Raise error if targeted flies are empty
if len(self.targeted_flies_id) == 0:
if len(self.targeted_fly_names) == 0:
raise ValueError(
"No flies are targeted by the camera. "
"Stabilized cameras require at least one fly to target."
Expand All @@ -523,7 +520,7 @@ def update_camera(self, physics: mjcf.Physics, floor_height: float, obs: dict):
return


class YawOnlyCamera(ZStabCamera):
class YawOnlyCamera(ZStabilizedCamera):
"""Camera that stabilizes the z-axis of the camera to the floor height and
only changes the yaw of the camera to follow the fly hereby preventing unnecessary
camera rotations.
Expand Down
8 changes: 4 additions & 4 deletions flygym/examples/head_stabilization/closed_loop_deployment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from pathlib import Path
from tqdm import trange
from flygym import Camera, ZStabCamera, SingleFlySimulation
from flygym import Camera, ZStabilizedCamera, SingleFlySimulation
from flygym.vision import Retina
from flygym.arena import BaseArena, FlatTerrain, BlocksTerrain
from typing import Optional
Expand Down Expand Up @@ -56,11 +56,11 @@ def run_simulation(

birdeye_cam_params = {"pos": (0, 0, 20), "euler": (0, 0, 0), "fovy": 45}

birdeye_camera = ZStabCamera(
birdeye_camera = ZStabilizedCamera(
attachment_point=fly.model.worldbody,
attachment_name=fly.name,
camera_name="birdeye_cam",
targeted_flies_id=[0],
targeted_fly_names=[fly.name],
camera_parameters=birdeye_cam_params,
play_speed=0.2,
window_size=(600, 600),
Expand All @@ -72,7 +72,7 @@ def run_simulation(
attachment_point=fly.model.worldbody,
attachment_name=fly.name,
camera_name="camera_neck_zoomin",
targeted_flies_id=[0],
targeted_fly_names=[fly.name],
play_speed=0.2,
fps=24,
window_size=(600, 600),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def run_simulation(
attachment_point=fly.model.worldbody,
camera_name="camera_left",
attachment_name=fly.name,
targeted_flies_id=[int(fly.name)],
targeted_fly_names=[fly.name],
play_speed=0.1,
)
else:
Expand Down
2 changes: 1 addition & 1 deletion flygym/examples/locomotion/controller_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def run_all(arena: str, seed: int, pos: np.ndarray, verbose: bool = False):
attachment_point=fly.model.worldbody,
camera_name="camera_right",
attachment_name=fly.name,
targeted_flies_id=[int(fly.name)],
targeted_fly_names=[fly.name],
play_speed=0.1,
)
sim = SingleFlySimulation(
Expand Down
2 changes: 1 addition & 1 deletion flygym/examples/locomotion/cpg_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def run_cpg_simulation(nmf, cpg_network, preprogrammed_steps, run_time, pbar=Tru
attachment_point=fly.model.worldbody,
camera_name="camera_right",
attachment_name=fly.name,
targeted_flies_id=[int(fly.name)],
targeted_fly_names=[fly.name],
play_speed=0.1,
)

Expand Down
2 changes: 1 addition & 1 deletion flygym/examples/locomotion/hybrid_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def run_hybrid_simulation(sim, cpg_network, preprogrammed_steps, run_time):
attachment_point=fly.model.worldbody,
camera_name="camera_right",
attachment_name=fly.name,
targeted_flies_id=[int(fly.name)],
targeted_fly_names=[fly.name],
play_speed=0.1,
)

Expand Down
2 changes: 1 addition & 1 deletion flygym/examples/locomotion/rule_based_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def run_rule_based_simulation(sim, controller, run_time, pbar=True):
attachment_point=fly.model.worldbody,
camera_name="camera_right",
attachment_name=fly.name,
targeted_flies_id=[int(fly.name)],
targeted_fly_names=[fly.name],
play_speed=0.1,
)

Expand Down
2 changes: 1 addition & 1 deletion flygym/examples/locomotion/turning_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def __init__(self, *args, **kwargs):
attachment_point=fly.model.worldbody,
camera_name="camera_right",
attachment_name=fly.name,
targeted_flies_id=[int(fly.name)],
targeted_fly_names=[fly.name],
play_speed=0.1,
)

Expand Down
2 changes: 1 addition & 1 deletion flygym/examples/locomotion/turning_fly.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def post_step(self, sim: "Simulation"):
attachment_point=fly.model.worldbody,
camera_name="camera_right",
attachment_name=fly.name,
targeted_flies_id=[int(fly.name)],
targeted_fly_names=[fly.name],
play_speed=0.1,
)

Expand Down
6 changes: 3 additions & 3 deletions flygym/examples/path_integration/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pathlib import Path
from typing import Optional

from flygym import Fly, ZStabCamera, is_rendering_skipped
from flygym import Fly, ZStabilizedCamera, is_rendering_skipped
from flygym.util import get_data_path
from flygym.preprogrammed import get_cpg_biases
from flygym.examples.path_integration.arena import (
Expand Down Expand Up @@ -129,9 +129,9 @@ def run_simulation(

cam_params = {"mode": "track", "pos": (0, 0, 150), "euler": (0, 0, 0), "fovy": 60}

cam = ZStabCamera(
cam = ZStabilizedCamera(
attachment_point=fly.model.worldbody,
targeted_flies_id=[0],
targeted_fly_names=[fly.name],
attachment_name=fly.name,
camera_name="birdeye_cam",
timestamp_text=False,
Expand Down
2 changes: 1 addition & 1 deletion flygym/examples/vision/record_baseline_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def run_simulation(
attachment_point=fly.model.worldbody,
camera_name="camera_top",
attachment_name=fly.name,
targeted_flies_id=[int(fly.name)],
targeted_fly_names=[fly.name],
play_speed=0.1,
)

Expand Down
11 changes: 5 additions & 6 deletions flygym/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,19 +238,18 @@ def step(
)

def render(self):
all_flies_obs = []
for i, fly in enumerate(self.flies):
all_flies_obs = {}
for fly in self.flies:
fly.update_colors(self.physics)
all_flies_obs.append(fly.last_obs)
all_flies_obs = np.array(all_flies_obs)
all_flies_obs[fly.name] = fly.last_obs

return [
camera.render(
self.physics,
self._floor_height,
self.curr_time,
all_flies_obs[camera.targeted_flies_id]
if camera.targeted_flies_id
[all_flies_obs[name] for name in camera.targeted_fly_names]
if camera.targeted_fly_names
else [{}],
)
for camera in self.cameras
Expand Down
15 changes: 8 additions & 7 deletions notebooks/advanced_vision.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,13 @@
")\n",
"\n",
"flies = [observer_fly, target_fly]\n",
"cam = Camera(attachment_point=observer_fly.model.worldbody,\n",
" camera_name=\"camera_top_zoomout\",\n",
" attachment_name=observer_fly.name,\n",
" targeted_flies_id=[0],\n",
" play_speed=0.1,\n",
" )\n",
"cam = Camera(\n",
" attachment_point=observer_fly.model.worldbody,\n",
" camera_name=\"camera_top_zoomout\",\n",
" attachment_name=observer_fly.name,\n",
" targeted_fly_names=[observer_fly.name],\n",
" play_speed=0.1,\n",
")\n",
"\n",
"sim = Simulation(\n",
" flies=flies,\n",
Expand Down Expand Up @@ -178,7 +179,7 @@
"cam = Camera(attachment_point=observer_fly.model.worldbody,\n",
" camera_name=\"camera_top_zoomout\",\n",
" attachment_name=observer_fly.name,\n",
" targeted_flies_id=[0],\n",
" targeted_fly_names=[observer_fly.name],\n",
" play_speed=0.1,\n",
" )\n",
"\n",
Expand Down
20 changes: 15 additions & 5 deletions notebooks/cpg_controller.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -583,13 +583,18 @@
"metadata": {},
"outputs": [],
"source": [
"from flygym import Fly, ZStabCamera, SingleFlySimulation\n",
"from flygym import Fly, ZStabilizedCamera, SingleFlySimulation\n",
"from flygym.preprogrammed import all_leg_dofs\n",
"\n",
"run_time = 1\n",
"fly = Fly(init_pose=\"stretch\", actuated_joints=all_leg_dofs, control=\"position\")\n",
"cam = ZStabCamera(attachment_point=fly.model.worldbody, camera_name=\"camera_left\", attachment_name=fly.name,\n",
" targeted_flies_id=[0], play_speed=0.1)\n",
"cam = ZStabilizedCamera(\n",
" attachment_point=fly.model.worldbody,\n",
" camera_name=\"camera_left\",\n",
" attachment_name=fly.name,\n",
" targeted_fly_names=[fly.name],\n",
" play_speed=0.1,\n",
")\n",
"sim = SingleFlySimulation(fly=fly, cameras=[cam], timestep=1e-4)"
]
},
Expand Down Expand Up @@ -819,8 +824,13 @@
" enable_adhesion=True,\n",
" draw_adhesion=True,\n",
")\n",
"cam = ZStabCamera(attachment_point=fly.model.worldbody, camera_name=\"camera_left\", attachment_name=fly.name,\n",
" targeted_flies_id=[0], play_speed=0.1)\n",
"cam = ZStabilizedCamera(\n",
" attachment_point=fly.model.worldbody,\n",
" camera_name=\"camera_left\",\n",
" attachment_name=fly.name,\n",
" targeted_fly_names=[fly.name],\n",
" play_speed=0.1,\n",
")\n",
"sim = SingleFlySimulation(fly=fly, cameras=[cam], timestep=1e-4)\n",
"cpg_network.reset()\n",
"\n",
Expand Down
11 changes: 8 additions & 3 deletions notebooks/gym_basics_and_kinematic_replay.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
"from pathlib import Path\n",
"from tqdm import trange\n",
"\n",
"from flygym import Fly, ZStabCamera, SingleFlySimulation, get_data_path\n",
"from flygym import Fly, ZStabilizedCamera, SingleFlySimulation, get_data_path\n",
"from flygym.preprogrammed import all_leg_dofs"
]
},
Expand Down Expand Up @@ -228,8 +228,13 @@
"source": [
"fly = Fly(init_pose=\"stretch\", actuated_joints=actuated_joints, control=\"position\")\n",
"\n",
"cam = ZStabCamera(attachment_point=fly.model.worldbody, camera_name=\"camera_left\", attachment_name=fly.name,\n",
" targeted_flies_id=[0], play_speed=0.1)\n",
"cam = ZStabilizedCamera(\n",
" attachment_point=fly.model.worldbody,\n",
" camera_name=\"camera_left\",\n",
" attachment_name=fly.name,\n",
" targeted_fly_names=[fly.name],\n",
" play_speed=0.1,\n",
")\n",
"sim = SingleFlySimulation(\n",
" fly=fly,\n",
" cameras=[cam],\n",
Expand Down
12 changes: 8 additions & 4 deletions notebooks/head_stabilization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
"from dm_control.utils import transformations\n",
"from dm_control.rl.control import PhysicsError\n",
"\n",
"from flygym import Fly, ZStabCamera\n",
"from flygym import Fly, ZStabilizedCamera\n",
"from flygym.arena import FlatTerrain, BlocksTerrain\n",
"from flygym.preprogrammed import get_cpg_biases\n",
"from flygym.examples.locomotion import HybridTurningController\n",
Expand Down Expand Up @@ -126,9 +126,13 @@
" contact_sensor_placements=contact_sensor_placements,\n",
" spawn_pos=(*spawn_xy, 0.25),\n",
" )\n",
" cam = ZStabCamera(attachment_point=fly.model.worldbody,\n",
" camera_name=\"camera_left\", attachment_name=fly.name,\n",
" targeted_flies_id=[0], play_speed=0.1)\n",
" cam = ZStabilizedCamera(\n",
" attachment_point=fly.model.worldbody,\n",
" camera_name=\"camera_left\",\n",
" attachment_name=fly.name,\n",
" targeted_fly_names=[fly.name],\n",
" play_speed=0.1,\n",
" )\n",
" sim = HybridTurningController(\n",
" arena=arena,\n",
" phase_biases=get_cpg_biases(gait),\n",
Expand Down
Loading

0 comments on commit 67cbec9

Please sign in to comment.