Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/probcomp/bayes3d
Browse files Browse the repository at this point in the history
  • Loading branch information
mirkoklukas committed Oct 18, 2023
2 parents 5459d9a + e1a6d98 commit 6916f9e
Showing 1 changed file with 58 additions and 11 deletions.
69 changes: 58 additions & 11 deletions bayes3d/viz/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@
import graphviz
import distinctipy
import jax.numpy as jnp
import copy

RED = np.array([1.0, 0.0, 0.0])
GREEN = np.array([0.0, 1.0, 0.0])
BLUE = np.array([0.0, 0.0, 1.0])
BLACK = np.array([0.0, 0.0, 0.0])

def load_image_from_file(filename):
"""Load an image from a file."""
return Image.open(filename)

def make_gif_from_pil_images(images, filename):
"""Save a list of PIL images as a GIF.
Expand All @@ -31,11 +36,27 @@ def make_gif_from_pil_images(images, filename):
loop=0,
)

def load_image_from_file(filename):
"""Load an image from a file."""
return Image.open(filename)
def preprocess_for_viz(img):
depth_np = np.array(img)
depth_np[depth_np >= depth_np.max()] = np.inf
return depth_np

cmap = copy.copy(plt.get_cmap('turbo'))
cmap.set_bad(alpha=0)

saveargs = dict(bbox_inches='tight', pad_inches=0)

def get_depth_image(image, min=None, max=None, cmap=None):

def add_depth_image(ax, depth):
ax.imshow(preprocess_for_viz(depth),cmap=cmap)
ax.axis('off')

def add_rgb_image(ax, rgb):
ax.imshow(rgb)
ax.axis('off')


def viz_depth_image(depth):
"""Convert a depth image to a PIL image.
Args:
Expand All @@ -46,13 +67,39 @@ def get_depth_image(image, min=None, max=None, cmap=None):
Returns:
PIL.Image: Depth image visualized as a PIL image.
"""
if cmap is None:
cmap = plt.get_cmap('turbo')
if min is None:
min = np.min(image)
if max is None:
max = np.max(image)

fig = plt.figure()
ax = fig.add_subplot(1,1,1)
add_depth_image(ax, depth)
return fig

def viz_rgb_image(image):
"""Convert an RGB image to a PIL image.
Args:
image (np.ndarray): RGB image. Shape (H, W, 3).
max (float): Maximum value for colormap.
Returns:
PIL.Image: RGB image visualized as a PIL image.
"""
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
add_rgb_image(ax, image)
return fig

def pil_image_from_matplotlib(fig):
return Image.frombytes('RGBA', fig.canvas.get_width_height(),fig.canvas.buffer_rgba())

def get_depth_image(image, min=None, max=None):
"""Convert a depth image to a PIL image.
Args:
image (np.ndarray): Depth image. Shape (H, W).
min (float): Minimum depth value for colormap.
max (float): Maximum depth value for colormap.
cmap (matplotlib.colors.Colormap): Colormap to use.
Returns:
PIL.Image: Depth image visualized as a PIL image.
"""
depth = (image - min) / (max - min + 1e-10)
depth = np.clip(depth, 0, 1)

Expand Down

0 comments on commit 6916f9e

Please sign in to comment.