diff --git a/flygym/__init__.py b/flygym/__init__.py index 3ffc4f2f..6c658be7 100644 --- a/flygym/__init__.py +++ b/flygym/__init__.py @@ -1,6 +1,6 @@ from .simulation import Simulation, SingleFlySimulation from .fly import Fly -from .camera import Camera, NeckCamera +from .camera import Camera, YawOnlyCamera, ZStabilizedCamera, GravityAlignedCamera from .util import get_data_path, load_config from dm_control.rl.control import PhysicsError @@ -9,3 +9,17 @@ is_rendering_skipped = ( "SKIP_RENDERING" in environ and environ["SKIP_RENDERING"] == "true" ) + +__all__ = [ + "Simulation", + "SingleFlySimulation", + "Fly", + "Camera", + "YawOnlyCamera", + "ZStabilizedCamera", + "GravityAlignedCamera", + "PhysicsError", + "get_data_path", + "load_config", + "is_rendering_skipped", +] diff --git a/flygym/arena/base.py b/flygym/arena/base.py index 590aaa72..1c8f822d 100644 --- a/flygym/arena/base.py +++ b/flygym/arena/base.py @@ -188,6 +188,19 @@ def step(self, dt: float, physics: mjcf.Physics, *args, **kwargs) -> None: """ return + @abstractmethod + def _get_max_floor_height(self) -> float: + """Get the height of the floor of the arena. This is useful for + camera rendering. The camera should be placed at a height that is + slightly above the floor to avoid rendering artifacts. + + Returns + ------- + float + The height of the floor of the arena in mm. + """ + pass + class FlatTerrain(BaseArena): """Flat terrain with no obstacles. @@ -247,6 +260,7 @@ def __init__( material=grid, size=ground_size, friction=friction, + conaffinity=0, ) self.friction = friction if scale_bar_pos: @@ -263,3 +277,12 @@ def get_spawn_position( self, rel_pos: np.ndarray, rel_angle: np.ndarray ) -> tuple[np.ndarray, np.ndarray]: return rel_pos, rel_angle + + def _get_max_floor_height(self) -> float: + geom = self.root_element.find("geom", "ground") + try: + plane_height = geom.pos[2] + except TypeError: + plane_height = 0.0 + + return plane_height diff --git a/flygym/arena/complex_terrain.py b/flygym/arena/complex_terrain.py index bc729ce6..a30f7575 100644 --- a/flygym/arena/complex_terrain.py +++ b/flygym/arena/complex_terrain.py @@ -114,6 +114,9 @@ def get_spawn_position( adj_pos = rel_pos + np.array([0, 0, self.gap_depth / 2]) return adj_pos, rel_angle + def _get_max_floor_height(self): + return self.gap_depth / 2 + class BlocksTerrain(BaseArena): """Terrain formed by blocks at random heights. @@ -185,6 +188,8 @@ def __init__( self.height_range = height_range rand_state = np.random.RandomState(rand_seed) + self.max_height = -np.inf + x_centers = np.arange(x_range[0] + block_size / 2, x_range[1], block_size) y_centers = np.arange(y_range[0] + block_size / 2, y_range[1], block_size) for i, x_pos in enumerate(x_centers): @@ -197,6 +202,8 @@ def __init__( else: height = 0.1 + rand_state.uniform(*height_range) + self.max_height = max(self.max_height, height) + self.root_element.worldbody.add( "geom", type="box", @@ -231,6 +238,9 @@ def get_spawn_position( adj_pos = rel_pos + np.array([0, 0, 0.1]) return adj_pos, rel_angle + def _get_max_floor_height(self): + return self.max_height + class MixedTerrain(BaseArena): """A mixture of flat, blocks, and gaps terrains. @@ -289,6 +299,8 @@ def __init__( self._height_expected_value = np.mean([*height_range]) + self._max_block_height = -np.inf + # 3 repetitions, each consisting of a block part, 2 gaps, and a flat part for x_range in [(-4, 5), (5, 14), (14, 23)]: # block part @@ -316,6 +328,10 @@ def __init__( y_pos, height / 2 - block_size / 2 - self._height_expected_value - 0.1, ) + self._max_block_height = max( + self._max_block_height, + height - self._height_expected_value - 0.1, + ) self.root_element.worldbody.add( "geom", type="box", @@ -391,3 +407,7 @@ def get_spawn_position( ) -> tuple[np.ndarray, np.ndarray]: adj_pos = rel_pos + np.array([0, 0, -1 * self._height_expected_value]) return adj_pos, rel_angle + + def _get_max_floor_height(self): + # The floor and gap tops are at z=0 + return max(0, self._max_block_height) diff --git a/flygym/arena/sensory_environment.py b/flygym/arena/sensory_environment.py index 687389c7..2732f3e0 100644 --- a/flygym/arena/sensory_environment.py +++ b/flygym/arena/sensory_environment.py @@ -34,11 +34,6 @@ class OdorArena(BaseArena): The function that, given a distance from the odor source, returns the relative intensity of the odor. By default, this is an inverse square relationship. - birdeye_cam : dm_control.mujoco.Camera - MuJoCo camera that gives a birdeye view of the arena. - birdeye_cam_zoom : dm_control.mujoco.Camera - MuJoCo camera that gives a birdeye view of the arena, zoomed in - toward the fly. Parameters ---------- @@ -119,24 +114,6 @@ def __init__( ) self.diffuse_func = diffuse_func - # Add birdeye camera - self.birdeye_cam = self.root_element.worldbody.add( - "camera", - name="birdeye_cam", - mode="fixed", - pos=(self.odor_source[:, 0].max() / 2, 0, 35), - euler=(0, 0, 0), - fovy=45, - ) - self.birdeye_cam_zoom = self.root_element.worldbody.add( - "camera", - name="birdeye_cam_zoom", - mode="fixed", - pos=(11, 0, 29), - euler=(0, 0, 0), - fovy=45, - ) - # Add markers at the odor sources if marker_colors is None: color_cycle_rgb = load_config()["color_cycle_rgb"] @@ -209,3 +186,12 @@ def get_olfaction(self, antennae_pos: np.ndarray) -> np.ndarray: @property def odor_dimensions(self) -> int: return self.peak_odor_intensity.shape[1] + + def _get_max_floor_height(self) -> float: + geom = self.root_element.find("geom", "ground") + try: + plane_height = geom.pos[2] + except TypeError: + plane_height = 0.0 + + return plane_height diff --git a/flygym/arena/tethered.py b/flygym/arena/tethered.py index b45fc54c..f6be20ae 100644 --- a/flygym/arena/tethered.py +++ b/flygym/arena/tethered.py @@ -47,6 +47,9 @@ def spawn_entity( "joint", name="prismatic_support_1", limited=True, range=(0, 1e-10) ) + def _get_max_floor_height(self) -> float: + raise NotImplementedError + class Ball(Tethered): """Fly tethered on a spherical treadmill. @@ -126,3 +129,6 @@ def __init__( "joint", name="treadmill_joint", type="ball", limited="false" ) treadmill_body.add("inertial", pos=[0, 0, 0], mass=mass) + + def _get_max_floor_height(self) -> float: + raise NotImplementedError diff --git a/flygym/camera.py b/flygym/camera.py index ade19c1f..adbe3d63 100644 --- a/flygym/camera.py +++ b/flygym/camera.py @@ -1,192 +1,163 @@ import logging import sys from pathlib import Path -from typing import Optional, Union +from typing import Any, Optional, Union + +import flygym.util as util + import cv2 -import dm_control.mujoco import imageio import numpy as np from dm_control import mjcf -from dm_control.utils import transformations -from flygym.fly import Fly from scipy.spatial.transform import Rotation as R +# Would like it to always draw gravity in the upper right corner +# Check if contact need to be drawn (outside of the image) +# New gravity camera + _roll_eye = np.roll(np.eye(4, 3), -1) class Camera: - """Camera associated with a fly. - - Attributes - ---------- - fly : Fly - The fly to which the camera is associated. - window_size : tuple[int, int] - Size of the rendered images in pixels. - play_speed : float - Play speed of the rendered video. - fps: int - FPS of the rendered video when played at ``play_speed``. - timestamp_text : bool - If True, text indicating the current simulation time will be added - to the rendered video. - play_speed_text : bool - If True, text indicating the play speed will be added to the - rendered video. - dm_camera : dm_control.mujoco.Camera - The ``dm_control`` camera instance associated with the camera. - Only available after calling ``initialize_dm_camera(physics)``. - Useful for mapping the rendered location to the physical location - in the simulation. - draw_contacts : bool - If True, arrows will be drawn to indicate contact forces between - the legs and the ground. - decompose_contacts : bool - If True, the arrows visualizing contact forces will be decomposed - into x-y-z components. - force_arrow_scaling : float - Scaling factor determining the length of arrows visualizing contact - forces. - tip_length : float - Size of the arrows indicating the contact forces in pixels. - contact_threshold : float - The threshold for contact detection in mN (forces below this - magnitude will be ignored). - draw_gravity : bool - If True, an arrow will be drawn indicating the direction of - gravity. This is useful during climbing simulations. - gravity_arrow_scaling : float - Scaling factor determining the size of the arrow indicating - gravity. - align_camera_with_gravity : bool - If True, the camera will be rotated such that gravity points down. - This is - useful during climbing simulations. - camera_follows_fly_orientation : bool - If True, the camera will be rotated so that it aligns with the - fly's orientation. - decompose_colors - Colors for the x, y, and z components of the contact force arrows. - output_path : Optional[Union[str, Path]] - Path to which the rendered video should be saved. If None, the - video will not be saved. - - Parameters - ---------- - fly : Fly - The fly to which the camera is associated. - camera_id : str - The camera that will be used for rendering, by default - "Animat/camera_left". - window_size : tuple[int, int] - Size of the rendered images in pixels, by default (640, 480). - play_speed : float - Play speed of the rendered video, by default 0.2. - fps: int - FPS of the rendered video when played at ``play_speed``, by - default 30. - timestamp_text : bool - If True, text indicating the current simulation time will be - added to the rendered video. - play_speed_text : bool - If True, text indicating the play speed will be added to the - rendered video. - draw_contacts : bool - If True, arrows will be drawn to indicate contact forces - between the legs and the ground. By default False. - decompose_contacts : bool - If True, the arrows visualizing contact forces will be - decomposed into x-y-z components. By default True. - force_arrow_scaling : float, optional - Scaling factor determining the length of arrows visualizing - contact forces. By default 1.0 if perspective_arrow_length - is True and 10.0 otherwise. - tip_length : float - Size of the arrows indicating the contact forces in pixels. By - default 10. - contact_threshold : float - The threshold for contact detection in mN (forces below this - magnitude will be ignored). By default 0.1. - draw_gravity : bool - If True, an arrow will be drawn indicating the direction of - gravity. This is useful during climbing simulations. By default - False. - gravity_arrow_scaling : float - Scaling factor determining the size of the arrow indicating - gravity. By default 0.0001. - align_camera_with_gravity : bool - If True, the camera will be rotated such that gravity points - down. This is useful during climbing simulations. By default - False. - camera_follows_fly_orientation : bool - If True, the camera will be rotated so that it aligns with the - fly's orientation. By default False. - decompose_colors - Colors for the x, y, and z components of the contact force - arrows. By default ((255, 0, 0), (0, 255, 0), (0, 0, 255)). - output_path : str or Path, optional - Path to which the rendered video should be saved. If None, the - video will not be saved. By default None. - perspective_arrow_length : bool - If true, the length of the arrows indicating the contact forces - will be determined by the perspective. - """ - - dm_camera: dm_control.mujoco.Camera - def __init__( self, - fly: Fly, - camera_id: str = "Animat/camera_left", + attachment_point: mjcf.element._AttachableElement, + camera_name: str, + attachment_name: str = None, + targeted_fly_names: list[str] = [], window_size: tuple[int, int] = (640, 480), play_speed: float = 0.2, fps: int = 30, timestamp_text: bool = False, play_speed_text: bool = True, + camera_parameters: Optional[dict[str, Any]] = None, draw_contacts: bool = False, decompose_contacts: bool = True, + decompose_colors: tuple[ + tuple[int, int, int], tuple[int, int, int], tuple[int, int, int] + ] = ((0, 0, 255), (0, 255, 0), (255, 0, 0)), force_arrow_scaling: float = float("nan"), tip_length: float = 10.0, # number of pixels contact_threshold: float = 0.1, + perspective_arrow_length: bool = False, draw_gravity: bool = False, gravity_arrow_scaling: float = 1e-4, - align_camera_with_gravity: bool = False, - camera_follows_fly_orientation: bool = False, - decompose_colors: tuple[ - tuple[int, int, int], tuple[int, int, int], tuple[int, int, int] - ] = ((255, 0, 0), (0, 255, 0), (0, 0, 255)), output_path: Optional[Union[str, Path]] = None, - perspective_arrow_length=False, ): - self.fly = fly + """Initialize a Camera that can be attached to any attachable element and take any mujoco inbuilt parameters. + A set of preset configurations are available in the config file: + - Simple cameras like: "camera_top" "camera_right", "camera_left", + "camera_front", "camera_back", "camera_bottom" + - Compound rotated cameras with different zoom levels: "camera_top_right", "camera_top_zoomout" + "camera_right_front", "camera_left_top_zoomout", "camera_neck_zoomin", + "camera_head_zoomin", "camera_front_zoomin", "camera_LFTarsus1_zoomin" + - "camera_LFTarsus1_zoomin": Camera looking at the left tarsus of the first leg + - "camera_back_track": 3rd person camera following the fly + + This camera can also be set with custom parameters by providing a dictionary of parameters. + + Parameters + ---------- + attachment_point: dm_control.mjcf.element._AttachableElement + Attachment point pf the camera + attachment_name : str + Name of the attachment point + targeted_fly_names: list[str] + Index of the flies the camera is looking at. The first index is the focused fly that is tracked if using a + complex camera. The rest of the indices are used to draw the contact forces. + camera_name : str + window_size : tuple[int, int] + Size of the rendered images in pixels, by default (640, 480). + play_speed : float + Play speed of the rendered video, by default 0.2. + fps: int + FPS of the rendered video when played at ``play_speed``, by + default 30. + timestamp_text : bool + If True, text indicating the current simulation time will be + added to the rendered video. + play_speed_text : bool + If True, text indicating the play speed will be added to the + rendered video. + camera_parameters : Optional[dict[str, Any]] + Parameters of the camera to be added to the model. If None, the + draw_contacts : bool + If True, arrows will be drawn to indicate contact forces between + the legs and the ground. + decompose_contacts : bool + If True, the arrows visualizing contact forces will be decomposed + into x-y-z components. + decompose_colors + Colors for the x, y, and z components of the contact force arrows. + force_arrow_scaling : float + Scaling factor determining the length of arrows visualizing contact + forces. + tip_length : float + Size of the arrows indicating the contact forces in pixels. + contact_threshold : float + The threshold for contact detection in mN (forces below this + magnitude will be ignored). + perspective_arrow_length : bool + If true, the length of the arrows indicating the contact forces + will be determined by the perspective. + draw_gravity : bool + If True, an arrow will be drawn indicating the direction of + gravity. This is useful during climbing simulations. + gravity_arrow_scaling : float + Scaling factor determining the size of the arrow indicating + gravity. + output_path : str or Path, optional + Path to which the rendered video should be saved. If None, the + video will not be saved. By default None. + """ + self.attachment_point = attachment_point + self.targeted_fly_names = targeted_fly_names + + config = util.load_config() + + self.is_custom = True + if camera_parameters is None and camera_name in config["cameras"]: + camera_parameters = config["cameras"][camera_name] + self.is_custom = False + + assert camera_parameters is not None, ( + "Camera parameters must be provided " + "if the camera name is not a predefined camera" + ) + + camera_parameters["name"] = camera_name + + # get a first value before spawning: useful for the zstab cam + self.camera_base_offset = np.array(camera_parameters.get("pos", np.zeros(3))) + + self._cam, self.camera_id = self._add_camera( + attachment_point, camera_parameters, attachment_name + ) + self.window_size = window_size self.play_speed = play_speed self.fps = fps self.timestamp_text = timestamp_text self.play_speed_text = play_speed_text + self.draw_contacts = draw_contacts self.decompose_contacts = decompose_contacts - self.tip_length = tip_length - self.contact_threshold = contact_threshold - self.draw_gravity = draw_gravity - self.gravity_arrow_scaling = gravity_arrow_scaling - self.align_camera_with_gravity = align_camera_with_gravity - self.camera_follows_fly_orientation = camera_follows_fly_orientation self.decompose_colors = decompose_colors - self.camera_id = camera_id.replace("Animat", fly.name) - self.perspective_arrow_length = perspective_arrow_length - if not np.isfinite(force_arrow_scaling): self.force_arrow_scaling = 1.0 if perspective_arrow_length else 10.0 else: self.force_arrow_scaling = force_arrow_scaling + self.tip_length = tip_length + self.contact_threshold = contact_threshold + self.perspective_arrow_length = perspective_arrow_length - if output_path is not None: - self.output_path = Path(output_path) - else: - self.output_path = None + if self.draw_contacts and len(self.targeted_fly_names) <= 0: + logging.warning( + "Overriding `draw_contacts` to False because no flies are targeted." + ) + self.draw_contacts = False if self.draw_contacts and "cv2" not in sys.modules: logging.warning( @@ -195,220 +166,47 @@ def __init__( ) self.draw_contacts = False + self.draw_gravity = draw_gravity + self.gravity_arrow_scaling = gravity_arrow_scaling + if self.draw_gravity: - fly._last_fly_pos = self.fly.spawn_pos self._gravity_rgba = [1 - 213 / 255, 1 - 90 / 255, 1 - 255 / 255, 1.0] - self._arrow_offset = np.zeros(3) - if "bottom" in camera_id or "top" in camera_id: - self._arrow_offset[0] = -3 - self._arrow_offset[1] = 2 - elif "left" in camera_id or "right" in camera_id: - self._arrow_offset[2] = 2 - self._arrow_offset[0] = -3 - elif "front" in camera_id or "back" in camera_id: - self._arrow_offset[2] = 2 - self._arrow_offset[1] = 3 - - if self.align_camera_with_gravity: - self._camera_rot = np.eye(3) - - self._cam = self.fly.model.find("camera", camera_id.split("/")[-1]) - self._initialize_custom_camera_handling(camera_id) + self._grav_arrow_start = (self.window_size[0] - 100, 100) + + if output_path is not None: + self.output_path = Path(output_path) + else: + self.output_path = None + self._eff_render_interval = self.play_speed / self.fps self._frames: list[np.ndarray] = [] self._timestamp_per_frame: list[float] = [] - def _initialize_custom_camera_handling(self, camera_name: str): - """ - This function is called when the camera is initialized. It can be - used to customize the camera behavior. I case update_camera_pos is - True and the camera is within the animat and not a head camera, the - z position will be fixed to avoid oscillations. If - self.camera_follows_fly_orientation is True, the camera - will be rotated to follow the fly orientation (i.e. the front - camera will always be in front of the fly). - """ - - is_animat = camera_name.startswith("Animat") or camera_name.startswith( - self.fly.name - ) - is_visualization_camera = ( - "head" in camera_name - or "Tarsus" in camera_name - or "camera_front_zoomin" in camera_name - ) - - canonical_cameras = [ - "camera_front", - "camera_back", - "camera_top", - "camera_bottom", - "camera_left", - "camera_right", - "camera_neck_zoomin", - ] - if "/" not in camera_name: - is_canonical_camera = False - else: - is_canonical_camera = camera_name.split("/")[-1] in canonical_cameras - - # always add pos update if it is a head camera - if is_animat and not is_visualization_camera: - self.update_camera_pos = True - self.cam_offset = self._cam.pos - if (not is_canonical_camera) and self.camera_follows_fly_orientation: - self.camera_follows_fly_orientation = False - logging.warning( - "Overriding `camera_follows_fly_orientation` to False because" - "the rendering camera is not a simple camera from a canonical " - "angle (front, back, top, bottom, left, right, neck_zoomin)." - ) - elif self.camera_follows_fly_orientation: - # Why would that be xyz and not XYZ ? DOES NOT MAKE SENSE BUT IT WORKS - self.base_camera_rot = R.from_euler( - "xyz", self._cam.euler + self.fly.spawn_orientation - ).as_matrix() - # THIS SOMEHOW REPLICATES THE CAMERA XMAT OBTAINED BY MUJOCO WHE USING - # TRACKED CAMERA - else: - # if not camera_follows_fly_orientation need to change the camera mode - # to track - self._cam.mode = "track" - return + def _add_camera(self, attachment, camera_parameters, attachment_name): + """Add a camera to the model.""" + camera = attachment.add("camera", **camera_parameters) + if attachment_name is None: + camera_id = camera.name else: - self.update_camera_pos = False - if self.camera_follows_fly_orientation: - self.camera_follows_fly_orientation = False - logging.warning( - "Overriding `camera_follows_fly_orientation` to False because" - "it is never applied to visualization cameras (head, tarsus, ect)" - "or non Animat camera." - ) - return + camera_id = attachment_name + "/" + camera.name - def initialize_dm_camera(self, physics: mjcf.Physics): - """ - ``dm_control`` comes with its own camera class that contains a - number of useful utilities, including in particular tools for - mapping the rendered location (row-column coordinate on the - rendered image) to the physical location in the simulation. Given - the physics instance of the simulation, this method initializes a - "shadow" ``dm_control`` camera instance. + return camera, camera_id - Parameters - ---------- - physics : mjcf.Physics - Physics instance of the simulation. + def init_camera_orientation(self, physics: mjcf.Physics): + """Initialize the camera handling by storing the base camera position + and rotation. This is useful for cameras that need to be updated + during the simulation beyond the default behavior of the camera. """ - self.dm_camera = dm_control.mujoco.Camera( - physics, - camera_id=self.camera_id, - width=self.window_size[0], - height=self.window_size[1], - ) - - def set_gravity(self, gravity: np.ndarray, rot_mat: np.ndarray = None) -> None: - """Set the gravity of the environment. Changing the gravity vector - might be useful during climbing simulations. The change in the - camera point of view has been extensively tested for the simple - cameras (left right top bottom front back) but not for the composed - ones. - - Parameters - ---------- - gravity : np.ndarray - The gravity vector. - rot_mat : np.ndarray, optional - The rotation matrix to align the camera with the gravity vector - by default None. - """ - # Only change the angle of the camera if the new gravity vector and the camera - # angle are compatible - camera_is_compatible = False - if "left" in self.camera_id or "right" in self.camera_id: - if not gravity[1] > 0: - camera_is_compatible = True - # elif "top" in self.camera_name or "bottom" in - # self.camera_name: - elif "front" in self.camera_id or "back" in self.camera_id: - if not gravity[1] > 0: - camera_is_compatible = True - - if rot_mat is not None and self.align_camera_with_gravity: - self._camera_rot = rot_mat - elif camera_is_compatible: - normalised_gravity = (np.array(gravity) / np.linalg.norm(gravity)).reshape( - (1, 3) - ) - downward_ref = np.array([0.0, 0.0, -1.0]).reshape((1, 3)) - - if ( - not np.all(normalised_gravity == downward_ref) - and self.align_camera_with_gravity - ): - # Generate a bunch of vectors to help the optimisation algorithm - - random_vectors = np.tile(np.random.rand(10_000), (3, 1)).T - downward_refs = random_vectors + downward_ref - gravity_vectors = random_vectors + normalised_gravity - downward_refs = downward_refs - gravity_vectors = gravity_vectors - rot_mult = R.align_vectors(downward_refs, gravity_vectors)[0] - - rot_simple = R.align_vectors( - np.reshape(normalised_gravity, (1, 3)), - downward_ref.reshape((1, 3)), - )[0] - - diff_mult = np.linalg.norm( - np.dot(rot_mult.as_matrix(), normalised_gravity.T) - downward_ref.T - ) - diff_simple = np.linalg.norm( - np.dot(rot_simple.as_matrix(), normalised_gravity.T) - - downward_ref.T - ) - if diff_mult < diff_simple: - rot = rot_mult - else: - rot = rot_simple - - logging.info( - f"{normalised_gravity}, " - f"{rot.as_euler('xyz')}, " - f"{np.dot(rot.as_matrix(), normalised_gravity.T).T}, ", - f"{downward_ref}", - ) - - # check if rotation has effect if not remove it - euler_rot = rot.as_euler("xyz") - new_euler_rot = np.zeros(3) - last_rotated_vector = normalised_gravity - for i in range(0, 3): - new_euler_rot[: i + 1] = euler_rot[: i + 1].copy() - - rotated_vector = ( - R.from_euler("xyz", new_euler_rot).as_matrix() - @ normalised_gravity.T - ).T - logging.info( - f"{euler_rot}, " - f"{new_euler_rot}, " - f"{rotated_vector}, " - f"{last_rotated_vector}" - ) - if np.linalg.norm(rotated_vector - last_rotated_vector) < 1e-2: - logging.info("Removing component {i}") - euler_rot[i] = 0 - last_rotated_vector = rotated_vector - - logging.info(str(euler_rot)) - rot = R.from_euler("xyz", euler_rot) - rot_mat = rot.as_matrix() - - self._camera_rot = rot_mat.T + bound_cam = physics.bind(self._cam) + self.camera_base_offset = bound_cam.xpos.copy() + self.camera_base_rot = R.from_matrix(bound_cam.xmat.reshape(3, 3)) def render( - self, physics: mjcf.Physics, floor_height: float, curr_time: float + self, + physics: mjcf.Physics, + floor_height: float, + curr_time: float, + last_obs: list[dict], ) -> Union[np.ndarray, None]: """Call the ``render`` method to update the renderer. It should be called every iteration; the method will decide by itself whether @@ -422,19 +220,16 @@ def render( if curr_time < len(self._frames) * self._eff_render_interval: return None + self._update_camera(physics, floor_height, last_obs[0]) + width, height = self.window_size - if self.update_camera_pos: - self._update_cam_pos(physics, floor_height) - if self.camera_follows_fly_orientation: - self._update_cam_rot(physics) - if self.align_camera_with_gravity: - self._rotate_camera(physics) img = physics.render(width=width, height=height, camera_id=self.camera_id) img = img.copy() if self.draw_contacts: - img = self._draw_contacts(img, physics) + for i in range(len(self.targeted_fly_names)): + img = self._draw_contacts(img, physics, last_obs[i]) if self.draw_gravity: - img = self._draw_gravity(img, physics) + img = self._draw_gravity(img, physics, last_obs[0]["pos"]) render_playspeed_text = self.play_speed_text render_time_text = self.timestamp_text @@ -464,85 +259,123 @@ def render( self._timestamp_per_frame.append(curr_time) return img - def _update_cam_pos(self, physics: mjcf.Physics, floor_height: float): - cam = physics.bind(self._cam) - cam_pos = cam.xpos.copy() - cam_pos[2] = self.cam_offset[2] + floor_height - cam.xpos = cam_pos + def reset(self): + self._frames.clear() + self._timestamp_per_frame = [] - def _update_cam_rot(self, physics: mjcf.Physics): - cam = physics.bind(self._cam) - cam_name = self._cam.name - fly_z_rot_euler = ( - np.array([self.fly.last_obs["rot"][0], 0.0, 0.0]) - - self.fly.spawn_orientation[::-1] - - [np.pi / 2, 0, 0] - ) - # This compensates both for the scipy to mujoco transform (align with y is - # [0, 0, 0] in mujoco but [pi/2, 0, 0] in scipy) and the fact that the fly - # orientation is already taken into account in the base_camera_rot (see below) - # camera is always looking along its -z axis - if cam_name in ["camera_top", "camera_bottom"]: - # if camera is top or bottom always keep rotation around z only - cam_matrix = R.from_euler("zyx", fly_z_rot_euler).as_matrix() - elif cam_name in ["camera_front", "camera_back", "camera_left", "camera_right"]: - # if camera is front, back, left or right apply the rotation around y - cam_matrix = R.from_euler("yzx", fly_z_rot_euler).as_matrix() - else: - cam_matrix = np.eye(3) + def save_video(self, path: Union[str, Path], stabilization_time=0.02): + """Save rendered video since the beginning or the last ``reset()``, + whichever is the latest. Only useful if ``render_mode`` is 'saved'. + + Parameters + ---------- + path : str or Path + Path to which the video should be saved. + stabilization_time : float, optional + Time (in seconds) to wait before starting to render the video. + This might be wanted because it takes a few frames for the + position controller to move the joints to the specified angles + from the default, all-stretched position. By default 0.02s + """ + if len(self._frames) == 0: + logging.warning( + "No frames have been rendered yet; no video will be saved despite " + "`save_video()` call. Be sure to call `.render()` in your simulation " + "loop." + ) - if cam_name in ["camera_bottom"]: - cam_matrix = cam_matrix.T - # z axis is inverted + Path(path).parent.mkdir(parents=True, exist_ok=True) + logging.info(f"Saving video to {path}") + with imageio.get_writer(path, fps=self.fps) as writer: + for frame, timestamp in zip(self._frames, self._timestamp_per_frame): + if timestamp >= stabilization_time: + writer.append_data(frame) - cam_matrix = self.base_camera_rot @ cam_matrix - cam.xmat = cam_matrix.flatten() + def _update_camera(self, physics: mjcf.Physics, floor_height: float, obs: dict): + """Update the camera position and rotation based on the fly position and orientation. + Used only for the complex camera that require updating the camera position and rotation + on top of the default behavior""" + pass - def _rotate_camera(self, physics: mjcf.Physics): - # get camera - cam = physics.bind(self._cam) - # rotate the cam - cam_matrix_base = getattr(cam, "xmat").copy() - cam_matrix = self._camera_rot @ cam_matrix_base.reshape(3, 3) - setattr(cam, "xmat", cam_matrix.flatten()) + def _compute_camera_matrices(self, physics: mjcf.Physics): + """Compute the camera matrices needed to project world coordinates into + pixel space. The matrices are computed based on the camera's position + and orientation in the world. + With this there is no need for using dm_control's camera""" + + cam_bound = physics.bind(self._cam) - return 0 + width, height = self.window_size + + image = np.eye(3) + image[0, 2] = (width - 1) / 2.0 + image[1, 2] = (height - 1) / 2.0 + + focal_scaling = (1.0 / np.tan(np.deg2rad(cam_bound.fovy) / 2)) * height / 2.0 + focal = np.diag([-focal_scaling, focal_scaling, 1.0, 0])[0:3, :] + + # Rotation matrix (4x4). + rotation = np.eye(4) + rotation[0:3, 0:3] = cam_bound.xmat.reshape(3, 3).T + + # Translation matrix (4x4). + translation = np.eye(4) + translation[0:3, 3] = -cam_bound.xpos + + return image, focal, rotation, translation - def _draw_gravity(self, img: np.ndarray, physics: mjcf.Physics) -> np.ndarray: + def _draw_gravity( + self, + img: np.ndarray, + physics: mjcf.Physics, + fly_pos: list[float], + thickness: float = 5, + ) -> np.ndarray: """Draw gravity as an arrow. The arrow is drawn at the top right of the frame. """ - camera_matrix = self.dm_camera.matrix - last_fly_pos = self.fly.last_obs["pos"] - - if self.align_camera_with_gravity: - arrow_start = last_fly_pos + self._camera_rot @ self._arrow_offset - else: - arrow_start = last_fly_pos + self._arrow_offset - - arrow_end = arrow_start + physics.model.opt.gravity * self.gravity_arrow_scaling - xyz_global = np.array([arrow_start, arrow_end]).T + image, focal, rotation, translation = self._compute_camera_matrices(physics) + camera_matrix = image @ focal @ rotation @ translation # Camera matrices multiply homogenous [x, y, z, 1] vectors. - corners_homogeneous = np.ones((4, xyz_global.shape[1]), dtype=float) - corners_homogeneous[:3, :] = xyz_global + grav_homogeneous = np.ones((4, 2), dtype=float) + grav_homogeneous[:3, :] = np.hstack( + [ + np.expand_dims(fly_pos, -1), + np.expand_dims( + fly_pos + physics.model.opt.gravity * self.gravity_arrow_scaling, -1 + ), + ] + ) # Project world coordinates into pixel space. See: # https://en.wikipedia.org/wiki/3D_projection#Mathematical_formula - xs, ys, s = camera_matrix @ corners_homogeneous - + xs, ys, s = camera_matrix @ grav_homogeneous # x and y are in the pixel coordinate system. - x = np.rint(xs / s).astype(int) - y = np.rint(ys / s).astype(int) - - img = img.astype(np.uint8) - img = cv2.arrowedLine(img, (x[0], y[0]), (x[1], y[1]), self._gravity_rgba, 10) + x = xs / s + y = ys / s + grav_vector = np.array( + [ + x[1] - x[0], + y[1] - y[0], + ] + ).astype(int) + + # Draw the vector on the image + cv2.arrowedLine( + img, + self._grav_arrow_start, + self._grav_arrow_start + grav_vector, + self._gravity_rgba, + thickness, + cv2.LINE_AA, + ) return img def _draw_contacts( - self, img: np.ndarray, physics: mjcf.Physics, thickness=2 + self, img: np.ndarray, physics: mjcf.Physics, last_obs: dict, thickness=2 ) -> np.ndarray: """Draw contacts as arrow which length is proportional to the force magnitude. The arrow is drawn at the center of the body. It uses @@ -554,8 +387,8 @@ def clip(p_in, p_out, z_clip): t = (z_clip - p_out[-1]) / (p_in[-1] - p_out[-1]) return t * p_in + (1 - t) * p_out - forces = self.fly.last_obs["contact_forces"] - pos = self.fly.last_obs["contact_pos"] + forces = last_obs["contact_forces"] + pos = last_obs["contact_pos"] magnitudes = np.linalg.norm(forces, axis=1) contact_indices = np.nonzero(magnitudes > self.contact_threshold)[0] @@ -577,10 +410,10 @@ def clip(p_in, p_out, z_clip): # Convert to homogeneous coordinates Xw = np.concatenate((Xw, np.ones((1, *Xw.shape[1:])))) - mat = self.dm_camera.matrices() + im_mat, foc_mat, rot_mat, trans_mat = self._compute_camera_matrices(physics) # Project to camera space - Xc = np.tensordot(mat.rotation @ mat.translation, Xw, 1) + Xc = np.tensordot(rot_mat @ trans_mat, Xw, 1) Xc = Xc[:3, :] / Xc[-1, :] z_near = -physics.model.vis.map.znear * physics.model.stat.extent @@ -600,18 +433,26 @@ def clip(p_in, p_out, z_clip): lines[:, 0, is_in] = clip(lines[:, 1, is_in], lines[:, 0, is_in], z_near) # Project to pixel space - lines = np.tensordot((mat.image @ mat.focal)[:, :3], lines, axes=1) + lines = np.tensordot((im_mat @ foc_mat)[:, :3], lines, axes=1) lines2d = lines[:2] / lines[-1] lines2d = lines2d.T if not self.perspective_arrow_length: unit_vectors = lines2d[:, :, 1] - lines2d[:, :, 0] length = np.linalg.norm(unit_vectors, axis=-1, keepdims=True) - length[length == 0] = 1 + # avoid division by small number + length = np.clip(length, 1e-8, 1e8) unit_vectors /= length - lines2d[:, :, 1] = ( - lines2d[:, :, 0] + np.abs(contact_forces[:, :, None]) * unit_vectors - ) + if self.decompose_contacts: + lines2d[:, :, 1] = ( + lines2d[:, :, 0] + np.abs(contact_forces[:, :, None]) * unit_vectors + ) + else: + lines2d[:, :, 1] = ( + lines2d[:, :, 0] + + np.linalg.norm(contact_forces, axis=1)[:, None, None] + * unit_vectors + ) lines2d = np.rint(lines2d.reshape((-1, 2, 2))).astype(int) @@ -625,7 +466,7 @@ def clip(p_in, p_out, z_clip): continue color = self.decompose_colors[color_indices[j]] - p1, p2 = lines2d[j] + p1, p2 = lines2d[j].astype(int) arrow_length = np.linalg.norm(p2 - p1) if arrow_length > 1e-2: @@ -636,113 +477,130 @@ def clip(p_in, p_out, z_clip): if is_out.ravel()[j] and self.perspective_arrow_length: cv2.line(img, p1, p2, color, thickness, cv2.LINE_AA) else: + p1 = np.clip(p1, -1e5, 1e5).astype(int) + p2 = np.clip(p2, -1e5, 1e5).astype(int) cv2.arrowedLine(img, p1, p2, color, thickness, cv2.LINE_AA, tipLength=r) return img - def save_video(self, path: Union[str, Path], stabilization_time=0.02): - """Save rendered video since the beginning or the last ``reset()``, - whichever is the latest. Only useful if ``render_mode`` is 'saved'. - Parameters - ---------- - path : str or Path - Path to which the video should be saved. - stabilization_time : float, optional - Time (in seconds) to wait before starting to render the video. - This might be wanted because it takes a few frames for the - position controller to move the joints to the specified angles - from the default, all-stretched position. By default 0.02s - """ - if len(self._frames) == 0: - logging.warning( - "No frames have been rendered yet; no video will be saved despite " - "`save_video()` call. Be sure to call `.render()` in your simulation " - "loop." +class ZStabilizedCamera(Camera): + """Camera that stabilizes the z-axis of the camera to the floor height.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Raise error if targeted flies are empty + if len(self.targeted_fly_names) == 0: + raise ValueError( + "No flies are targeted by the camera. " + "Stabilized cameras require at least one fly to target." ) - Path(path).parent.mkdir(parents=True, exist_ok=True) - logging.info(f"Saving video to {path}") - with imageio.get_writer(path, fps=self.fps) as writer: - for frame, timestamp in zip(self._frames, self._timestamp_per_frame): - if timestamp >= stabilization_time: - writer.append_data(frame) + def init_camera_orientation(self, physics: mjcf.Physics): + """Initialize the camera handling by storing the base camera position + and rotation. This is useful for cameras that need to be updated + during the simulation beyond the default behavior of the camera. + """ + bound_cam = physics.bind(self._cam) + # only update x and y as z is already set to floor height + self.camera_base_offset[:2] = bound_cam.xpos[:2].copy() + self.camera_base_rot = R.from_matrix(bound_cam.xmat.reshape(3, 3)) - def reset(self): - self._frames.clear() - self._timestamp_per_frame = [] + def _update_cam_pos(self, physics: mjcf.Physics, floor_height: float): + cam = physics.bind(self._cam) + cam_pos = cam.xpos.copy() + cam_pos[2] = floor_height + self.camera_base_offset[2] + cam.xpos = cam_pos - def _correct_camera_orientation(self, camera_name: str): - # Correct the camera orientation by incorporating the spawn rotation - # of the arena + def _update_camera(self, physics: mjcf.Physics, floor_height: float, obs: dict): + self._update_cam_pos(physics, floor_height) + return - # Get the camera - fly = self.fly - camera = fly.model.find("camera", camera_name) - if camera is None or camera.mode in ["targetbody", "targetbodycom"]: - return 0 +class YawOnlyCamera(ZStabilizedCamera): + """Camera that stabilizes the z-axis of the camera to the floor height and + only changes the yaw of the camera to follow the fly hereby preventing unnecessary + camera rotations. + """ - if "head" in camera_name or "front_zoomin" in camera_name: - # Don't correct the head camera - return camera + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.smoothing = 0.99995 # tested empirically + self.prev_yaw = None + self.init_yaw = None + + def _update_camera(self, physics: mjcf.Physics, floor_height: float, obs: dict): + smoothed_yaw = self._smooth_yaw(obs["rot"][0]) + correction = R.from_euler("xyz", [0, 0, smoothed_yaw - self.init_yaw]) + self._update_cam_pos(physics, floor_height, correction, obs["pos"]) + self._update_cam_rot(physics, correction) + + self.prev_yaw = obs["rot"][0] + + def _smooth_yaw(self, yaw: float): + if self.prev_yaw is None: + self.prev_yaw = yaw + self.init_yaw = yaw + return np.arctan2( + self.smoothing * np.sin(self.prev_yaw) + (1 - self.smoothing) * np.sin(yaw), + self.smoothing * np.cos(self.prev_yaw) + (1 - self.smoothing) * np.cos(yaw), + ) - # Add the spawn rotation (keep horizon flat) - spawn_quat = np.array( - [ - np.cos(fly.spawn_orientation[-1] / 2), - fly.spawn_orientation[0] * np.sin(fly.spawn_orientation[-1] / 2), - fly.spawn_orientation[1] * np.sin(fly.spawn_orientation[-1] / 2), - fly.spawn_orientation[2] * np.sin(fly.spawn_orientation[-1] / 2), - ] + def _update_cam_rot(self, physics: mjcf.Physics, yaw_correction: R): + physics.bind(self._cam).xmat = ( + (yaw_correction * self.camera_base_rot).as_matrix().flatten() ) + return - # Change camera euler to quaternion - camera_quat = transformations.euler_to_quat(camera.euler) - new_camera_quat = transformations.quat_mul( - transformations.quat_inv(spawn_quat), camera_quat + def _update_cam_pos( + self, + physics: mjcf.Physics, + floor_height: float, + yaw_correction: R, + fly_pos: list[float], + ): + # position the camera some distance behind the fly, at a fixed height + physics.bind(self._cam).xpos = ( + # only add floor offset to z as camera base offset is added in the next line + np.hstack([fly_pos[:2], floor_height]) + + (yaw_correction.as_matrix() @ self.camera_base_offset).flatten() ) - camera.euler = transformations.quat_to_euler(new_camera_quat) - # Elevate the camera slightly gives a better view of the arena - if "zoomin" not in camera_name: - camera.pos = camera.pos + [0.0, 0.0, 0.5] - if "front" in camera_name: - camera.pos[2] = camera.pos[2] + 1.0 - return camera +class GravityAlignedCamera(Camera): + """Camera that keeps the camera aligned with the original direction of the gravity + while following the fly.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # get the yaw_pitch roll of gravity vector + self.gravity = None + self.grav_rot = None + self.cam_matrix = np.zeros((3, 3)) -class NeckCamera(Camera): - def __init__(self, **kwargs): - assert "camera_id" not in kwargs, "camera_id should not be passed to NeckCamera" - kwargs["camera_id"] = "Animat/camera_neck_zoomin" - super().__init__(**kwargs) + def _update_gravrot(self, gravity): + """ + Update the rotation matrix that aligns the gravity vector with the z-axis + """ + self.gravity = gravity + gravity_norm = gravity / np.linalg.norm(gravity) + self.grav_rot = R.align_vectors(gravity_norm, [0, 0, -1])[0] - def _update_cam_pos(self, physics: mjcf.Physics, floor_height: float): - pass - # cam = physics.bind(self._cam) - # cam_pos = cam.xpos.copy() - # cam_pos[2] += floor_height - # cam.xpos = cam_pos + def _update_camera(self, physics: mjcf.Physics, floor_height: float, obs: dict): + if self.gravity is None or np.any(physics.model.opt.gravity != self.gravity): + print("updating gravity") + self._update_gravrot(physics.model.opt.gravity.copy()) + self._update_cam_rot(physics) + self._update_cam_pos(physics, obs["pos"]) def _update_cam_rot(self, physics: mjcf.Physics): - pass - # cam = physics.bind(self._cam) - - # fly_z_rot_euler = ( - # np.array([self.fly.last_obs["rot"][0], 0.0, 0.0]) - # - self.fly.spawn_orientation[::-1] - # - [np.pi / 2, 0, 0] - # ) - # # This compensates both for the scipy to mujoco transform (align with y is - # # [0, 0, 0] in mujoco but [pi/2, 0, 0] in scipy) and the fact that the fly - # # orientation is already taken into account in the base_camera_rot (see below) - # # camera is always looking along its -z axis - # cam_matrix = R.from_euler( - # "yxz", fly_z_rot_euler - # ).as_matrix() # apply the rotation along the y axis of the cameras - # cam_matrix = self.base_camera_rot @ cam_matrix - # cam.xmat = cam_matrix.flatten() - - def render(self, physics: mjcf.Physics, floor_height: float, curr_time: float): - return super().render(physics, floor_height, curr_time) + self.cam_matrix = self.grav_rot * self.camera_base_rot + physics.bind(self._cam).xmat = self.cam_matrix.as_matrix().flatten() + return + + def _update_cam_pos(self, physics: mjcf.Physics, fly_pos: list[float]): + # position the camera some distance behind the fly, at a fixed height + fly_pos[2] = 0 + self.cam_pos = ( + fly_pos + (self.grav_rot.as_matrix() @ self.camera_base_offset).flatten() + ) + physics.bind(self._cam).xpos = self.cam_pos diff --git a/flygym/config.yaml b/flygym/config.yaml index aad8c4b9..f4a4aba1 100644 --- a/flygym/config.yaml +++ b/flygym/config.yaml @@ -192,10 +192,7 @@ olfaction: paths: mjcf: deepfly3d: mjcf/neuromechfly_deepfly3d_kinorder_ryp.xml - deepfly3d_old: mjcf/neuromechfly_deepfly3d_kinorder_ryp_old.xml seqik: mjcf/neuromechfly_seqik_kinorder_ypr.xml - seqik_old: mjcf/neuromechfly_seqik_kinorder_ypr_old.xml - seqik_simple: mjcf/neuromechfly_seqik_kinorder_ypr_capsuletarsus.xml ommatidia_id_map: vision/ommatidia_id_map.npy canonical_pale_type_mask: "vision/pale_mask.npy" @@ -210,3 +207,104 @@ color_cycle_rgb: - [127, 127, 127] - [188, 189, 34] - [23, 190, 207] + +cameras: + camera_LFTarsus1_zoomin: + class: nmf + mode: track + ipd: 0.068 + pos: [-0.5, -2, 0] + euler: [1.67, 0, -0.3] + fovy: 60 + camera_back: + class: nmf + mode: fixed + ipd: 0.068 + pos: [-8, 0, 1.0] + euler: [1.57, 0.0, -1.57] + camera_back_track: + class: nmf + mode: track + ipd: 0.068 + pos: [-12, 0, 6] + euler: [1.2, 0.0, -1.57] + camera_front: + class: nmf + mode: fixed + ipd: 0.068 + pos: [8, 0, 1.0] + euler: [1.57, 0.0, 1.57] + camera_left: + class: nmf + mode: track + ipd: 0.068 + pos: [0, 8, 1.0] + euler: [-1.57, 3.14, 0] + camera_right: + class: nmf + mode: track + ipd: 0.068 + pos: [0, -8, 1.0] + euler: [1.57, 0, 0] + camera_top_right: + class: nmf + mode: fixed + ipd: 0.068 + pos: [0, -8, 5] + euler: [1.1, 0, 0] + camera_top: + class: nmf + mode: track + ipd: 0.068 + pos: [0, 0, 8] + euler: [0, 0, 0] + camera_top_zoomout: + class: nmf + mode: track + ipd: 0.068 + pos: [0, 0, 40] + euler: [0, 0, 0] + camera_bottom: + class: nmf + mode: track + ipd: 0.068 + pos: [0, 0, -8] + euler: [0, 3.14, 0] + camera_right_front: + class: nmf + mode: track + ipd: 0.068 + pos: [4, -6, 1.0] + euler: [1.57, 0, 0.588] + camera_left_top_zoomout: + class: nmf + mode: track + ipd: 0.068 + pos: [0, 8, 5] + euler: [-2.129, 3.14, 0] + camera_right_top_zoomout: + class: nmf + mode: track + ipd: 0.068 + pos: [0, -8, 5] + euler: [1.011, 0, 0] + camera_head_zoomin: + class: nmf + mode: track + ipd: 0.068 + pos: [3, -3, 1] + euler: [1.57, 0, 0.72] + fovy: 30 + camera_front_zoomin: + class: nmf + mode: track + ipd: 0.068 + pos: [8, 0, 1] + euler: [1.57, 0, 1.57] + fovy: 15 + camera_neck_zoomin: + class: nmf + mode: fixed + ipd: 0.068 + pos: [0.5, 2, 1.2] + euler: [-1.57, 3.14, 0] diff --git a/flygym/data/mjcf/neuromechfly_deepfly3d_kinorder_ryp.xml b/flygym/data/mjcf/neuromechfly_deepfly3d_kinorder_ryp.xml index 470aa323..c84274f4 100644 --- a/flygym/data/mjcf/neuromechfly_deepfly3d_kinorder_ryp.xml +++ b/flygym/data/mjcf/neuromechfly_deepfly3d_kinorder_ryp.xml @@ -1,6 +1,6 @@ - +