From ce2f0d541904c08846c53c76134b2162bada3746 Mon Sep 17 00:00:00 2001 From: Wang Boyu Date: Thu, 31 Oct 2024 13:59:32 -0400 Subject: [PATCH] implement VideoViz to record model runs in a video --- .gitignore | 3 + mesa/examples/basic/schelling/video.py | 40 ++++++ mesa/visualization/video_viz.py | 184 +++++++++++++++++++++++++ 3 files changed, 227 insertions(+) create mode 100644 mesa/examples/basic/schelling/video.py create mode 100644 mesa/visualization/video_viz.py diff --git a/.gitignore b/.gitignore index a33dd9b7b41..3aa5fd30c2a 100644 --- a/.gitignore +++ b/.gitignore @@ -92,3 +92,6 @@ dmypy.json # JS dependencies mesa/visualization/templates/external/ mesa/visualization/templates/js/external/ + +# Video +**/*.mp4 diff --git a/mesa/examples/basic/schelling/video.py b/mesa/examples/basic/schelling/video.py new file mode 100644 index 00000000000..6bfb3dfd212 --- /dev/null +++ b/mesa/examples/basic/schelling/video.py @@ -0,0 +1,40 @@ +"""Example of using VideoViz with the Schelling model.""" + +from mesa.examples.basic.schelling.model import Schelling +from mesa.visualization.video_viz import ( + VideoViz, + make_measure_component, + make_space_component, +) + +# Create model +model = Schelling(10, 10) + + +def agent_portrayal(agent): + """Portray agents based on their type.""" + if agent is None: + return {} + + portrayal = { + "color": "red" if agent.type == 0 else "blue", + "size": 25, + "marker": "s", # square marker + } + return portrayal + + +# Create visualization with space and some metrics +viz = VideoViz( + model, + [ + make_space_component(agent_portrayal=agent_portrayal, save_format="svg"), + make_measure_component("happy", save_format="svg"), + ], + title="Schelling's Segregation Model", +) + +# Record simulation +if __name__ == "__main__": + video_path = viz.record(steps=50, filepath="schelling.mp4") + print(f"Video saved to: {video_path}") diff --git a/mesa/visualization/video_viz.py b/mesa/visualization/video_viz.py new file mode 100644 index 00000000000..a492725e2ca --- /dev/null +++ b/mesa/visualization/video_viz.py @@ -0,0 +1,184 @@ +"""Video recording components for Mesa model visualization.""" + +import shutil +from collections.abc import Callable, Sequence +from pathlib import Path + +import matplotlib.animation as animation +import matplotlib.pyplot as plt +import numpy as np + +import mesa +from mesa.visualization.matplotlib_renderer import ( + MatplotlibRenderer, + MeasureRendererMatplotlib, + SpaceRenderMatplotlib, +) + + +def make_space_component( + agent_portrayal: Callable | None = None, + propertylayer_portrayal: dict | None = None, + post_process: Callable | None = None, + **space_drawing_kwargs, +): + """Create a Matplotlib-based space visualization component. + + Args: + agent_portrayal: Function to portray agents. + propertylayer_portrayal: Dictionary of PropertyLayer portrayal specifications + post_process : a callable that will be called with the Axes instance. Allows for fine tuning plots (e.g., control ticks) + backend: The backend to use for rendering the space. Can be "matplotlib" or "altair". + space_drawing_kwargs : additional keyword arguments to be passed on to the underlying space drawer function. See + the functions for drawing the various spaces for further details. + + ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color", + "size", "marker", and "zorder". Other field are ignored and will result in a user warning. + + + Returns: + SpaceRenderMatplotlib: A component for rendering the space. + """ + if agent_portrayal is None: + + def agent_portrayal(a): + return {} + + return SpaceRenderMatplotlib( + agent_portrayal, + propertylayer_portrayal, + post_process=post_process, + **space_drawing_kwargs, + ) + + +def make_measure_component( + measure: Callable, + **kwargs, +): + """Create a plotting function for a specified measure. + + Args: + measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot. + kwargs: Additional keyword arguments to pass to the MeasureRendererMatplotlib constructor. + + Returns: + MeasureRendererMatplotlib: A component for rendering the measure. + """ + return MeasureRendererMatplotlib( + measure, + **kwargs, + ) + + +class VideoViz: + """Create high-quality video recordings of model simulations.""" + + def __init__( + self, + model: mesa.Model, + components: Sequence[MatplotlibRenderer], + *, + title: str | None = None, + figsize: tuple[float, float] | None = None, + grid: tuple[int, int] | None = None, + ): + """Initialize video visualization configuration. + + Args: + model: The model to simulate and record + components: Sequence of component objects defining what to visualize + title: Optional title for the video + figsize: Optional figure size in inches (width, height) + grid: Optional (rows, cols) for custom layout. Auto-calculated if None. + """ + # Check if FFmpeg is available + if not shutil.which("ffmpeg"): + raise RuntimeError( + "FFmpeg not found. Please install FFmpeg to save animations:\n" + " - macOS: brew install ffmpeg\n" + " - Linux: sudo apt-get install ffmpeg\n" + " - Windows: download from https://ffmpeg.org/download.html" + ) + self.model = model + self.components = components + self.title = title + self.figsize = figsize + self.grid = grid or self._calculate_grid(len(components)) + + # Setup figure and axes + self.fig, self.axes = self._setup_figure() + + def record( + self, + *, + steps: int, + filepath: str | Path, + dpi: int = 100, + fps: int = 10, + codec: str = "h264", + bitrate: int = 2000, + ) -> Path: + """Record model simulation to video file. + + Args: + steps: Number of simulation steps to record + filepath: Where to save the video file + dpi: Resolution of the output video + fps: Frames per second in the output video + codec: Video codec to use + bitrate: Video bitrate in kbps (default: 2000) + + Returns: + Path to the saved video file + + Raises: + RuntimeError: If FFmpeg is not installed + """ + filepath = Path(filepath) + + def update(frame_num): + # Update model state + self.model.step() + + # Render all visualization frames + for component, ax in zip(self.components, self.axes): + ax.clear() + component.draw(self.model, ax) + return self.axes + + # Create and save animation + anim = animation.FuncAnimation( + self.fig, update, frames=steps, interval=1000 / fps, blit=False + ) + + writer = animation.FFMpegWriter( + fps=fps, + codec=codec, + bitrate=bitrate, # Now passing as integer + ) + + anim.save(filepath, writer=writer, dpi=dpi) + return filepath + + def _calculate_grid(self, n_frames: int) -> tuple[int, int]: + """Calculate optimal grid layout for given number of frames.""" + cols = min(3, n_frames) # Max 3 columns + rows = int(np.ceil(n_frames / cols)) + return (rows, cols) + + def _setup_figure(self): + """Setup matplotlib figure and axes.""" + if not self.figsize: + self.figsize = (5 * self.grid[1], 5 * self.grid[0]) + fig = plt.figure(figsize=self.figsize) + axes = [] + + for i in range(len(self.components)): + ax = fig.add_subplot(self.grid[0], self.grid[1], i + 1) + axes.append(ax) + + if self.title: + fig.suptitle(self.title, fontsize=16) + fig.tight_layout() + return fig, axes