Skip to content

Commit

Permalink
demo images
Browse files Browse the repository at this point in the history
  • Loading branch information
Nishad Gothoskar committed Oct 18, 2023
1 parent 6916f9e commit 10bea99
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 45 deletions.
89 changes: 47 additions & 42 deletions bayes3d/viz/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,53 @@ def preprocess_for_viz(img):
return depth_np

cmap = copy.copy(plt.get_cmap('turbo'))
cmap.set_bad(alpha=0)
cmap.set_bad(color=(1.0, 1.0, 1.0, 1.0))

def get_depth_image(image):
"""Convert a depth image to a PIL image.
Args:
image (np.ndarray): Depth image. Shape (H, W).
min (float): Minimum depth value for colormap.
max (float): Maximum depth value for colormap.
cmap (matplotlib.colors.Colormap): Colormap to use.
Returns:
PIL.Image: Depth image visualized as a PIL image.
"""
depth = np.array(image)
mask = depth < depth.max()
depth[np.logical_not(mask)] = np.nan
vmin = depth[mask].min()
vmax = depth[mask].max()
depth = (depth - vmin) / (vmax - vmin)

img = Image.fromarray(
np.rint(cmap(depth) * 255.0).astype(np.int8), mode="RGBA"
)
return img

def get_rgb_image(image, max=255.0):
"""Convert an RGB image to a PIL image.
Args:
image (np.ndarray): RGB image. Shape (H, W, 3).
max (float): Maximum value for colormap.
Returns:
PIL.Image: RGB image visualized as a PIL image.
"""
image = np.clip(image, 0.0, max)
if image.shape[-1] == 3:
image_type = "RGB"
else:
image_type = "RGBA"

img = Image.fromarray(
np.rint(
image / max * 255.0
).astype(np.int8),
mode=image_type,
).convert("RGB")
return img

saveargs = dict(bbox_inches='tight', pad_inches=0)

Expand Down Expand Up @@ -89,47 +135,6 @@ def viz_rgb_image(image):
def pil_image_from_matplotlib(fig):
return Image.frombytes('RGBA', fig.canvas.get_width_height(),fig.canvas.buffer_rgba())

def get_depth_image(image, min=None, max=None):
"""Convert a depth image to a PIL image.
Args:
image (np.ndarray): Depth image. Shape (H, W).
min (float): Minimum depth value for colormap.
max (float): Maximum depth value for colormap.
cmap (matplotlib.colors.Colormap): Colormap to use.
Returns:
PIL.Image: Depth image visualized as a PIL image.
"""
depth = (image - min) / (max - min + 1e-10)
depth = np.clip(depth, 0, 1)

img = Image.fromarray(
np.rint(cmap(depth) * 255.0).astype(np.int8), mode="RGBA"
)
return img.convert("RGB")

def get_rgb_image(image, max=255.0):
"""Convert an RGB image to a PIL image.
Args:
image (np.ndarray): RGB image. Shape (H, W, 3).
max (float): Maximum value for colormap.
Returns:
PIL.Image: RGB image visualized as a PIL image.
"""
image = np.clip(image, 0.0, max)
if image.shape[-1] == 3:
image_type = "RGB"
else:
image_type = "RGBA"

img = Image.fromarray(
np.rint(
image / max * 255.0
).astype(np.int8),
mode=image_type,
).convert("RGB")
return img

def add_rgba_dimension(image):
"""Add an alpha channel to a particle image if it doesn't already have one.
Expand Down
12 changes: 9 additions & 3 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,21 @@ def update_pose_estimate(pose_estimate, gt_image):
print ("Time elapsed:", end - start)
print ("FPS:", poses.shape[0] / (end - start))


max_depth = 10.0
rerendered_images = b.RENDERER.render_many(pose_estimates_over_time[:, None, ...], jnp.array([0]))

viz_images = [
b.viz.multi_panel(
[b.viz.get_depth_image(d[:,:,2]), b.viz.get_depth_image(r[:,:,2])],
[
b.viz.scale_image(b.viz.get_depth_image(d[:,:,2]), 3),
b.viz.scale_image(b.viz.get_depth_image(r[:,:,2]), 3)
],
labels=["Observed", "Rerendered"],
label_fontsize=20
)
for (r, d) in zip(rerendered_images, observed_images)
]
b.make_gif_from_pil_images(viz_images, "assets/demo.gif")



from IPython import embed; embed()

0 comments on commit 10bea99

Please sign in to comment.