Skip to content

Commit

Permalink
plot either all, only the top or a defined mode(s)
Browse files Browse the repository at this point in the history
  • Loading branch information
marlon31415 committed Sep 16, 2024
1 parent efd8219 commit a9f0bc2
Showing 1 changed file with 94 additions and 33 deletions.
127 changes: 94 additions & 33 deletions src/data/plot_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,100 @@
from mpl_toolkits.mplot3d.art3d import Poly3DCollection


def plot_motion_forecasts(batch, pred_dict, n_step_future=60, idx_t_now=50, idx_batch=1, ax_dist=5, idx_focal=None, save_path=''):
def plot_motion_forecasts(
batch,
pred_dict,
n_step_future=60,
idx_t_now=50,
idx_batch=1,
ax_dist=5,
idx_focal=None,
mode_setting="top", # "top" or "all" or "custom"
mode_idx=None,
save_path="",
):
fig = plt.figure(figsize=(15, 15), dpi=80)
ax = fig.add_subplot(projection="3d", computed_zorder=False)
ax.view_init(elev=50., azim=-75)
ax.view_init(elev=50.0, azim=-75)
ax.dist = ax_dist
trajs = pred_dict["waymo_trajs"].movedim(1, -2)

# Plot all map polylines:
for map_polyline, map_valid, map_type in zip(batch['map/pos'][idx_batch], batch['map/valid'][idx_batch], batch["map/type"][idx_batch]):
for map_polyline, map_valid, map_type in zip(
batch["map/pos"][idx_batch],
batch["map/valid"][idx_batch],
batch["map/type"][idx_batch],
):
map_polyline = map_polyline[map_valid]
# lanes black, else white
if map_type[4] or map_type[5] or map_type[6] or map_type[7] or map_type[8] or map_type[9] or map_type[10]:
if (
map_type[4]
or map_type[5]
or map_type[6]
or map_type[7]
or map_type[8]
or map_type[9]
or map_type[10]
):
plt.plot(map_polyline[:, 0], map_polyline[:, 1], "-", c="white", zorder=-10)
else:
plt.plot(map_polyline[:, 0], map_polyline[:, 1], "-", c="black", zorder=-10)

# idx_mode = 4
# TODO: plot all modes
idx_top_mode = pred_dict["waymo_scores"].argmax(dim=-1, keepdim=True)

if idx_focal is not None:
agent = trajs[idx_batch, idx_focal]
idx_mode = int(idx_top_mode[idx_batch, idx_focal])
plt.scatter(agent[idx_mode, :, 0], agent[idx_mode, :, 1], marker=".", s=100, c=plt.cm.viridis(np.linspace(0, 1, n_step_future)), lw=10, zorder=1)
else:
for idx_agent in range(trajs.shape[1]):
agent = trajs[idx_batch, idx_agent]
idx_mode = int(idx_top_mode[idx_batch, idx_agent])
plt.scatter(agent[idx_mode, :, 0], agent[idx_mode, :, 1], marker=".", s=100, c=plt.cm.viridis(np.linspace(0, 1, n_step_future)), lw=10, zorder=1) # 80 for waymo
n_modes = pred_dict["waymo_scores"].shape[-1]
idx_mode_plot = None
if mode_setting == "top":
idx_mode_plot = pred_dict["waymo_scores"].argmax(dim=-1, keepdim=True)
elif mode_setting == "custom":
idx_mode_plot = pred_dict["waymo_scores"].argmax(dim=-1, keepdim=True)
idx_mode_plot = torch.full_like(idx_mode_plot, mode_idx)

for idx in range(n_modes):
if mode_setting in ["top", "custom"] and idx > 0:
continue
if idx_focal is not None:
agent = trajs[idx_batch, idx_focal]
idx_mode = (
idx
if mode_setting == "all"
else int(idx_mode_plot[idx_batch, idx_focal])
)
plt.scatter(
agent[idx_mode, :, 0],
agent[idx_mode, :, 1],
marker=".",
s=100,
c=plt.cm.viridis(np.linspace(0, 1, n_step_future)),
lw=10,
zorder=1,
)
else:
for idx_agent in range(trajs.shape[1]):
agent = trajs[idx_batch, idx_agent]
idx_mode = (
idx
if mode_setting == "all"
else int(idx_mode_plot[idx_batch, idx_agent])
)
plt.scatter(
agent[idx_mode, :, 0],
agent[idx_mode, :, 1],
marker=".",
s=100,
c=plt.cm.viridis(np.linspace(0, 1, n_step_future)),
lw=10,
zorder=1,
) # 80 for waymo

# Plot agents:
for idx, (agent_pos, agent_type, agent_yaw, agent_role, agent_spd) in enumerate(zip(
batch["agent/pos"][idx_batch, idx_t_now], batch["agent/type"][idx_batch], batch["agent/yaw_bbox"][idx_batch, idx_t_now], batch["agent/role"][idx_batch], batch["agent/spd"][idx_batch, idx_t_now]
)):
for idx, (agent_pos, agent_type, agent_yaw, agent_role, agent_spd) in enumerate(
zip(
batch["agent/pos"][idx_batch, idx_t_now],
batch["agent/type"][idx_batch],
batch["agent/yaw_bbox"][idx_batch, idx_t_now],
batch["agent/role"][idx_batch],
batch["agent/spd"][idx_batch, idx_t_now],
)
):
if agent_type[0]:
bbox = rotate_bbox_zaxis(car, float(agent_yaw))
bbox = shift_cuboid(float(agent_pos[0]), float(agent_pos[1]), bbox)
Expand Down Expand Up @@ -78,15 +138,15 @@ def plot_motion_forecasts(batch, pred_dict, n_step_future=60, idx_t_now=50, idx_

if save_path:
plt.savefig(save_path, dpi=150, pad_inches=0, bbox_inches="tight")

return fig


def tensor_dict_to_cpu(obj):
buffer = io.BytesIO()
torch.save(obj, buffer)
buffer.seek(0)
obj_cpu = torch.load(buffer, map_location='cpu')
obj_cpu = torch.load(buffer, map_location="cpu")

return obj_cpu

Expand All @@ -98,18 +158,19 @@ def mplfig_to_npimage(fig):
"""
# only the Agg backend now supports the tostring_rgb function
from matplotlib.backends.backend_agg import FigureCanvasAgg

canvas = FigureCanvasAgg(fig)
canvas.draw() # update/draw the elements
canvas.draw() # update/draw the elements

# get the width and the height to resize the matrix
l,b,w,h = canvas.figure.bbox.bounds
l, b, w, h = canvas.figure.bbox.bounds
w, h = int(w), int(h)

# exports the canvas to a string buffer and then to a numpy nd.array
buf = canvas.tostring_rgb()
image = np.frombuffer(buf, dtype=np.uint8)

return image.reshape(h,w,3)
return image.reshape(h, w, 3)


def shift_cuboid(x_shift, y_shift, cuboid):
Expand Down Expand Up @@ -187,13 +248,13 @@ def add_cube(cube_definition, ax, color="b", edgecolor="k", alpha=0.2):
)

pedestrian = np.array(
[
(-0.3, -0.3, 0), # left bottom front
(-0.3, 0.3, 0), # left bottom back
(0.3, -0.3, 0), # right bottom front
(-0.3, -0.3, 2), # left top front -> height
]
)
[
(-0.3, -0.3, 0), # left bottom front
(-0.3, 0.3, 0), # left bottom back
(0.3, -0.3, 0), # right bottom front
(-0.3, -0.3, 2), # left top front -> height
]
)

cyclist = np.array(
[
Expand All @@ -202,4 +263,4 @@ def add_cube(cube_definition, ax, color="b", edgecolor="k", alpha=0.2):
(1, -0.3, 0), # right bottom front
(-1, -0.3, 2), # left top front -> height
]
)
)

0 comments on commit a9f0bc2

Please sign in to comment.