From a9f0bc289e4d61561efa52395729b2b753359f5b Mon Sep 17 00:00:00 2001 From: Marlon Steiner Date: Mon, 16 Sep 2024 14:01:18 +0200 Subject: [PATCH] plot either all, only the top or a defined mode(s) --- src/data/plot_3d.py | 127 ++++++++++++++++++++++++++++++++------------ 1 file changed, 94 insertions(+), 33 deletions(-) diff --git a/src/data/plot_3d.py b/src/data/plot_3d.py index 5dfceda..8778584 100644 --- a/src/data/plot_3d.py +++ b/src/data/plot_3d.py @@ -6,40 +6,100 @@ from mpl_toolkits.mplot3d.art3d import Poly3DCollection -def plot_motion_forecasts(batch, pred_dict, n_step_future=60, idx_t_now=50, idx_batch=1, ax_dist=5, idx_focal=None, save_path=''): +def plot_motion_forecasts( + batch, + pred_dict, + n_step_future=60, + idx_t_now=50, + idx_batch=1, + ax_dist=5, + idx_focal=None, + mode_setting="top", # "top" or "all" or "custom" + mode_idx=None, + save_path="", +): fig = plt.figure(figsize=(15, 15), dpi=80) ax = fig.add_subplot(projection="3d", computed_zorder=False) - ax.view_init(elev=50., azim=-75) + ax.view_init(elev=50.0, azim=-75) ax.dist = ax_dist trajs = pred_dict["waymo_trajs"].movedim(1, -2) # Plot all map polylines: - for map_polyline, map_valid, map_type in zip(batch['map/pos'][idx_batch], batch['map/valid'][idx_batch], batch["map/type"][idx_batch]): + for map_polyline, map_valid, map_type in zip( + batch["map/pos"][idx_batch], + batch["map/valid"][idx_batch], + batch["map/type"][idx_batch], + ): map_polyline = map_polyline[map_valid] # lanes black, else white - if map_type[4] or map_type[5] or map_type[6] or map_type[7] or map_type[8] or map_type[9] or map_type[10]: + if ( + map_type[4] + or map_type[5] + or map_type[6] + or map_type[7] + or map_type[8] + or map_type[9] + or map_type[10] + ): plt.plot(map_polyline[:, 0], map_polyline[:, 1], "-", c="white", zorder=-10) else: plt.plot(map_polyline[:, 0], map_polyline[:, 1], "-", c="black", zorder=-10) - # idx_mode = 4 - # TODO: plot all modes - idx_top_mode = pred_dict["waymo_scores"].argmax(dim=-1, keepdim=True) - - if idx_focal is not None: - agent = trajs[idx_batch, idx_focal] - idx_mode = int(idx_top_mode[idx_batch, idx_focal]) - plt.scatter(agent[idx_mode, :, 0], agent[idx_mode, :, 1], marker=".", s=100, c=plt.cm.viridis(np.linspace(0, 1, n_step_future)), lw=10, zorder=1) - else: - for idx_agent in range(trajs.shape[1]): - agent = trajs[idx_batch, idx_agent] - idx_mode = int(idx_top_mode[idx_batch, idx_agent]) - plt.scatter(agent[idx_mode, :, 0], agent[idx_mode, :, 1], marker=".", s=100, c=plt.cm.viridis(np.linspace(0, 1, n_step_future)), lw=10, zorder=1) # 80 for waymo + n_modes = pred_dict["waymo_scores"].shape[-1] + idx_mode_plot = None + if mode_setting == "top": + idx_mode_plot = pred_dict["waymo_scores"].argmax(dim=-1, keepdim=True) + elif mode_setting == "custom": + idx_mode_plot = pred_dict["waymo_scores"].argmax(dim=-1, keepdim=True) + idx_mode_plot = torch.full_like(idx_mode_plot, mode_idx) + + for idx in range(n_modes): + if mode_setting in ["top", "custom"] and idx > 0: + continue + if idx_focal is not None: + agent = trajs[idx_batch, idx_focal] + idx_mode = ( + idx + if mode_setting == "all" + else int(idx_mode_plot[idx_batch, idx_focal]) + ) + plt.scatter( + agent[idx_mode, :, 0], + agent[idx_mode, :, 1], + marker=".", + s=100, + c=plt.cm.viridis(np.linspace(0, 1, n_step_future)), + lw=10, + zorder=1, + ) + else: + for idx_agent in range(trajs.shape[1]): + agent = trajs[idx_batch, idx_agent] + idx_mode = ( + idx + if mode_setting == "all" + else int(idx_mode_plot[idx_batch, idx_agent]) + ) + plt.scatter( + agent[idx_mode, :, 0], + agent[idx_mode, :, 1], + marker=".", + s=100, + c=plt.cm.viridis(np.linspace(0, 1, n_step_future)), + lw=10, + zorder=1, + ) # 80 for waymo # Plot agents: - for idx, (agent_pos, agent_type, agent_yaw, agent_role, agent_spd) in enumerate(zip( - batch["agent/pos"][idx_batch, idx_t_now], batch["agent/type"][idx_batch], batch["agent/yaw_bbox"][idx_batch, idx_t_now], batch["agent/role"][idx_batch], batch["agent/spd"][idx_batch, idx_t_now] - )): + for idx, (agent_pos, agent_type, agent_yaw, agent_role, agent_spd) in enumerate( + zip( + batch["agent/pos"][idx_batch, idx_t_now], + batch["agent/type"][idx_batch], + batch["agent/yaw_bbox"][idx_batch, idx_t_now], + batch["agent/role"][idx_batch], + batch["agent/spd"][idx_batch, idx_t_now], + ) + ): if agent_type[0]: bbox = rotate_bbox_zaxis(car, float(agent_yaw)) bbox = shift_cuboid(float(agent_pos[0]), float(agent_pos[1]), bbox) @@ -78,7 +138,7 @@ def plot_motion_forecasts(batch, pred_dict, n_step_future=60, idx_t_now=50, idx_ if save_path: plt.savefig(save_path, dpi=150, pad_inches=0, bbox_inches="tight") - + return fig @@ -86,7 +146,7 @@ def tensor_dict_to_cpu(obj): buffer = io.BytesIO() torch.save(obj, buffer) buffer.seek(0) - obj_cpu = torch.load(buffer, map_location='cpu') + obj_cpu = torch.load(buffer, map_location="cpu") return obj_cpu @@ -98,18 +158,19 @@ def mplfig_to_npimage(fig): """ # only the Agg backend now supports the tostring_rgb function from matplotlib.backends.backend_agg import FigureCanvasAgg + canvas = FigureCanvasAgg(fig) - canvas.draw() # update/draw the elements + canvas.draw() # update/draw the elements # get the width and the height to resize the matrix - l,b,w,h = canvas.figure.bbox.bounds + l, b, w, h = canvas.figure.bbox.bounds w, h = int(w), int(h) # exports the canvas to a string buffer and then to a numpy nd.array buf = canvas.tostring_rgb() image = np.frombuffer(buf, dtype=np.uint8) - return image.reshape(h,w,3) + return image.reshape(h, w, 3) def shift_cuboid(x_shift, y_shift, cuboid): @@ -187,13 +248,13 @@ def add_cube(cube_definition, ax, color="b", edgecolor="k", alpha=0.2): ) pedestrian = np.array( - [ - (-0.3, -0.3, 0), # left bottom front - (-0.3, 0.3, 0), # left bottom back - (0.3, -0.3, 0), # right bottom front - (-0.3, -0.3, 2), # left top front -> height - ] - ) + [ + (-0.3, -0.3, 0), # left bottom front + (-0.3, 0.3, 0), # left bottom back + (0.3, -0.3, 0), # right bottom front + (-0.3, -0.3, 2), # left top front -> height + ] +) cyclist = np.array( [ @@ -202,4 +263,4 @@ def add_cube(cube_definition, ax, color="b", edgecolor="k", alpha=0.2): (1, -0.3, 0), # right bottom front (-1, -0.3, 2), # left top front -> height ] -) \ No newline at end of file +)