Skip to content

Commit

Permalink
add support of multiview in preview
Browse files Browse the repository at this point in the history
  • Loading branch information
fitzjalen committed Feb 20, 2025
1 parent 9a4506a commit 698e4be
Showing 1 changed file with 152 additions and 88 deletions.
240 changes: 152 additions & 88 deletions optimed/processes/preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,18 @@
import numpy as np
import vtk
from vtk.util import numpy_support
from typing import Union, Dict, Tuple, Any
from typing import Union, Dict, Tuple, Any, List
from optimed.wrappers.nifti import load_nifti, as_closest_canonical
from optimed.wrappers.operations import exists
import matplotlib.pyplot as plt
import math

vtk.vtkObject.GlobalWarningDisplayOff() # Disable VTK warnings

def _numpy_to_vtk_image_data(np_data: np.ndarray, spacing=(1, 1, 1)) -> vtk.vtkImageData:
"""
Converts a NumPy array (in RAS orientation, axis order (X, Y, Z))
Converts a NumPy array (in RAS orientation, axes (X, Y, Z))
to a vtkImageData object.
Parameters:
np_data (np.ndarray): The input NumPy array.
spacing (Tuple[float, float, float]): The voxel spacing.
Returns:
vtk.vtkImageData: The converted vtkImageData object.
"""
x, y, z = np_data.shape
vtk_image = vtk.vtkImageData()
Expand All @@ -42,27 +38,28 @@ def _render_segmentation_to_image(
view_direction: str = 'A',
background_color: Tuple[int, int, int] = (0, 0, 0),
window_size: Tuple[int, int] = (800, 800),
output_filename: str = "segmentation_preview.png"
) -> None:
output_filename: Union[str, None] = "segmentation_preview.png",
return_image: bool = False
) -> Union[None, np.ndarray]:
"""
Renders a 3D segmentation and saves it as a PNG file.
If a label contains "text": {"label": ..., "color": ..., "size": ...},
a caption is created whose bounding box is correctly centered relative
to the base position determined by view_direction.
Renders a 3D segmentation and either saves the result to a PNG file or returns
the image as a numpy array. If a text is provided for a label, a caption is created
with proper positioning relative to view_direction.
Parameters:
vtk_image (vtk.vtkImageData): The input vtkImageData object.
segmentation_dict (Dict[int, Dict[str, Any]]): A dictionary where keys are label values
and values are dictionaries with properties for each label.
smoothing (int): The number of smoothing iterations.
shading (int): The shading factor.
view_direction (str): The view direction (R, L, A, P, S, I).
background_color (Tuple[int, int, int]): The background color (RGB).
window_size (Tuple[int, int]): The window size.
output_filename (str): The output PNG file name.
vtk_image (vtk.vtkImageData): Input vtkImageData object.
segmentation_dict (Dict[int, Dict[str, Any]]): A dictionary with label values as keys
and corresponding properties as values.
smoothing (int): Number of smoothing iterations.
shading (int): Shading factor.
view_direction (str): View direction (R, L, A, P, S, I).
background_color (Tuple[int, int, int]): Background color (RGB).
window_size (Tuple[int, int]): Window size.
output_filename (str or None): Output PNG filename.
return_image (bool): If True, returns the image as a numpy array; otherwise saves the file.
Returns:
None
None or np.ndarray: Returns the image as a numpy array if return_image is True.
"""
renderer = vtk.vtkRenderer()
renderWindow = vtk.vtkRenderWindow()
Expand All @@ -72,7 +69,7 @@ def _render_segmentation_to_image(

camera = renderer.GetActiveCamera()

# Create actors for each label
# Creating actors for each label
for label_value, props in segmentation_dict.items():
# Create the label surface using marching cubes
contour = vtk.vtkDiscreteMarchingCubes()
Expand Down Expand Up @@ -118,14 +115,14 @@ def _render_segmentation_to_image(

renderer.AddActor(actor)

# If text is specified for the label, create a caption
# If a text is provided for the label, create a caption
text_info = props.get("text", None)
if text_info is not None:
text_label = text_info.get("label", "")
text_color = text_info.get("color", (0, 0, 0))
text_size = text_info.get("size", 12) # font size

# 1) Determine the "base" position for the text based on object bounds
# 1) Determine the base position for the text based on the actor bounds
a_bounds = actor.GetBounds() # (xmin, xmax, ymin, ymax, zmin, zmax)
a_center = (
(a_bounds[0] + a_bounds[1]) / 2.0,
Expand Down Expand Up @@ -156,19 +153,19 @@ def _render_segmentation_to_image(
else:
base_pos = a_center

# 2) Create a text source and update its geometry
# 2) Create the text source and update its geometry
text_source = vtk.vtkVectorText()
text_source.SetText(text_label)
text_source.Update()

# 3) Obtain the bounding box of the text (in local coordinates)
# 3) Get the bounding box of the text (in local coordinates)
t_bounds = text_source.GetOutput().GetBounds()
# Compute the center of the text
# Calculate the text center
tcx = (t_bounds[0] + t_bounds[1]) / 2.0
tcy = (t_bounds[2] + t_bounds[3]) / 2.0
tcz = (t_bounds[4] + t_bounds[5]) / 2.0

# 4) Create a vtkFollower, set its origin, position, scale, and color
# 4) Create a vtkFollower and set its properties
text_mapper = vtk.vtkPolyDataMapper()
text_mapper.SetInputConnection(text_source.GetOutputPort())

Expand All @@ -178,10 +175,7 @@ def _render_segmentation_to_image(
text_color = tuple(c / 255.0 for c in text_color)
text_actor.GetProperty().SetColor(*text_color)

# Scale the text (this will multiply the text coordinates)
text_actor.SetScale(text_size, text_size, text_size)
# Set the text origin equal to its center so that when setting the position the center
# matches the base position
text_actor.SetOrigin(tcx, tcy, tcz)
text_actor.SetPosition(base_pos)
text_actor.SetCamera(camera)
Expand All @@ -196,7 +190,7 @@ def _render_segmentation_to_image(
# Render the scene
renderWindow.Render()

# Adjust the camera based on the overall bounding box of all objects
# Adjust the camera based on the overall object bounds
renderer.ResetCamera()
prop_bounds = [0, 0, 0, 0, 0, 0]
renderer.ComputeVisiblePropBounds(prop_bounds)
Expand All @@ -211,7 +205,7 @@ def _render_segmentation_to_image(
dz = zmax - zmin
L = max(dx, dy, dz)

fov = camera.GetViewAngle() # Typically ~30°
fov = camera.GetViewAngle()
margin_factor = 1.2
required_distance = L / (2 * np.tan(np.deg2rad(fov / 2))) * margin_factor

Expand Down Expand Up @@ -239,106 +233,163 @@ def _render_segmentation_to_image(

renderWindow.Render()

# Save the result as a PNG file
# Capture the image from the render window
w2i = vtk.vtkWindowToImageFilter()
w2i.SetInput(renderWindow)
w2i.Update()
writer = vtk.vtkPNGWriter()
writer.SetFileName(output_filename)
writer.SetInputConnection(w2i.GetOutputPort())
writer.Write()
print(f"Image saved: {output_filename}")

if return_image:
image_data = w2i.GetOutput()
dims = image_data.GetDimensions() # (width, height, 1)
num_components = image_data.GetNumberOfScalarComponents()
vtk_array = numpy_support.vtk_to_numpy(image_data.GetPointData().GetScalars())
# Reshape array to (height, width, num_components)
image = vtk_array.reshape(dims[1], dims[0], num_components)
return image
else:
writer = vtk.vtkPNGWriter()
writer.SetFileName(output_filename)
writer.SetInputConnection(w2i.GetOutputPort())
writer.Write()
return None


def preview_3d_image(
input_path: Union[str, nib.Nifti1Image],
output: str,
segmentation_dict: Dict[int, Dict[str, Any]],
view_direction: str = 'A',
view_direction: Union[str, List[str]] = 'A',
smoothing: int = 20,
shading: int = 20,
background_color: Tuple[int, int, int] = (0, 0, 0),
window_size: Tuple[int, int] = (800, 800)
) -> None:
"""
Loads a NIfTI file, converts it to RAS, creates a 3D segmentation, and saves a PNG
with the specified orientation. If a label contains text: { 'label': str, 'color': (r, g, b), 'size': int },
a caption is generated with its center calculated without intermediate rendering.
Loads a NIfTI file, converts it to a canonical orientation, creates a 3D segmentation
and saves a PNG file for the specified view direction. If a text is provided for a label
(e.g. { 'label': str, 'color': (r, g, b), 'size': int }), a caption is generated.
If view_direction is passed as a list (e.g. ['A', 'I']), the final image is saved as subplots
(maximum 2 images per row).
Parameters:
input_path (Union[str, nib.Nifti1Image]): The input NIfTI file path or Nifti1Image object.
output (str): The output PNG file name.
segmentation_dict (Dict[int, Dict[str, Any]]): A dictionary where keys are label values
and values are dictionaries with properties for each label.
view_direction (str): The view direction (R, L, A, P, S, I).
smoothing (int): The number of smoothing iterations.
shading (int): The shading factor.
background_color (Tuple[int, int, int]): The background color (RGB).
window_size (Tuple[int, int]): The window size.
input_path (Union[str, nib.Nifti1Image]): Path to the NIfTI file or a Nifti1Image object.
output (str): Output PNG filename.
segmentation_dict (Dict[int, Dict[str, Any]]): Dictionary with parameters for each label.
view_direction (str or List[str]): View direction (R, L, A, P, S, I) or list of directions.
smoothing (int): Number of smoothing iterations.
shading (int): Shading factor.
background_color (Tuple[int, int, int]): Background color (RGB).
window_size (Tuple[int, int]): Window size.
Returns:
None
"""
# Validate input parameters
# Input validation
assert exists(input_path), "File not found."
assert segmentation_dict and len(segmentation_dict) > 0, (
"A non-empty segmentation dictionary is required.\n"
"Example: {1: {'color': (255, 255, 0), 'opacity': 0.4, 'text': {'label': 'aorta', 'color': (0, 0, 0), 'size': 5}}, "
"2: {'color': (255, 0, 255), 'opacity': 1.0}}"
)
for lbl, props in segmentation_dict.items():
assert 'color' in props, f"'color' key is missing for label {lbl}."
assert 'opacity' in props, f"'opacity' key is missing for label {lbl}."
assert 'color' in props, f"'color' is missing for label {lbl}."
assert 'opacity' in props, f"'opacity' is missing for label {lbl}."
if 'text' in props:
assert 'label' in props['text'], f"'label' key is missing in 'text' for label {lbl}."
assert 'color' in props['text'], f"'color' key is missing in 'text' for label {lbl}."
assert 'size' in props['text'], f"'size' key is missing in 'text' for label {lbl}."
assert 'label' in props['text'], f"'label' is missing in 'text' for label {lbl}."
assert 'color' in props['text'], f"'color' is missing in 'text' for label {lbl}."
assert 'size' in props['text'], f"'size' is missing in 'text' for label {lbl}."
if smoothing:
assert isinstance(smoothing, int), "Smoothing iterations must be an integer."
assert smoothing >= 0, "Smoothing iterations must be >= 0."
if shading:
assert isinstance(shading, int), "Shading value must be an integer."
assert shading >= 0, "Shading value must be >= 0."
assert view_direction in ['R', 'L', 'A', 'P', 'S', 'I'], (
"Invalid view_direction. Allowed values: R, L, A, P, S, I."
)
assert isinstance(shading, int), "Shading factor must be an integer."
assert shading >= 0, "Shading factor must be >= 0."
if isinstance(view_direction, list):
for vd in view_direction:
assert vd in ['R', 'L', 'A', 'P', 'S', 'I'], (
f"Invalid view_direction value: {vd}. Allowed values: R, L, A, P, S, I."
)
else:
assert view_direction in ['R', 'L', 'A', 'P', 'S', 'I'], (
"Invalid view_direction value. Allowed values: R, L, A, P, S, I."
)
assert background_color and len(background_color) == 3, "background_color must be an RGB tuple."
for color_channel in background_color:
assert 0 <= color_channel <= 255, "Each color channel must be between 0 and 255."
for window_dim in window_size:
assert window_dim > 0, "Window dimensions must be positive numbers."
assert output.endswith('.png'), "Output filename must end with .png."

# Load the NIfTI image (use canonical orientation)
# Load the NIfTI image (convert to canonical orientation)
if isinstance(input_path, str):
img = load_nifti(input_path, canonical=True, engine='nibabel')
elif isinstance(input_path, nib.Nifti1Image):
img = as_closest_canonical(input_path)
else:
raise ValueError("input_path must be a string (file path) or Nifti1Image.")
raise ValueError("input_path must be a file path string or a Nifti1Image.")

data = img.get_fdata()
spacing = img.header.get_zooms()
vtk_image = _numpy_to_vtk_image_data(data, spacing=spacing)

_render_segmentation_to_image(
vtk_image=vtk_image,
segmentation_dict=segmentation_dict,
smoothing=smoothing,
shading=shading,
view_direction=view_direction,
background_color=background_color,
window_size=window_size,
output_filename=output
)
vtk_img = _numpy_to_vtk_image_data(data, spacing=spacing)

# If view_direction is a list, generate an image for each direction and combine them in subplots
if isinstance(view_direction, list):
images = []
for vd in view_direction:
img_arr = _render_segmentation_to_image(
vtk_image=vtk_img,
segmentation_dict=segmentation_dict,
smoothing=smoothing,
shading=shading,
view_direction=vd,
background_color=background_color,
window_size=window_size,
output_filename=None,
return_image=True
)
images.append(img_arr)

n_images = len(images)
ncols = min(2, n_images)
nrows = math.ceil(n_images / ncols)

# Choose figure size relative to window_size
fig, axes = plt.subplots(nrows, ncols, figsize=(window_size[0]/100, window_size[1]/100))
# Flatten axes array if necessary
if n_images > 1:
axes = np.atleast_1d(axes).flatten()
else:
axes = [axes]
for ax, img in zip(axes, images):
ax.imshow(np.flipud(img))
ax.axis('off')
# Hide unused subplots
for ax in axes[len(images):]:
ax.axis('off')
plt.tight_layout()
plt.savefig(output)
plt.close()
else:
_render_segmentation_to_image(
vtk_image=vtk_img,
segmentation_dict=segmentation_dict,
smoothing=smoothing,
shading=shading,
view_direction=view_direction,
background_color=background_color,
window_size=window_size,
output_filename=output,
return_image=False
)


# ------------------ USAGE EXAMPLE -------------------
if __name__ == "__main__":
nifti_path = "/Users/eolika/Downloads/dataset_aorta/botkin_0127/aorta.nii.gz"
segmentation_dict_example = {
1: {'color': (255, 255, 0), 'opacity': 0.4, 'text': {'label': 'Aorta', 'color': (0, 0, 0), 'size': 5}},
2: {'color': (255, 0, 255), 'opacity': 1.0, 'text': {'label': 'Label2','color': (0, 0, 0), 'size': 5}},
2: {'color': (255, 0, 255), 'opacity': 1.0, 'text': {'label': 'Label2', 'color': (0, 0, 0), 'size': 5}},
3: {'color': (0, 255, 0), 'opacity': 1.0},
4: {'color': (0, 255, 255), 'opacity': 1.0},
5: {'color': (255, 0, 0), 'opacity': 1.0},
Expand All @@ -349,11 +400,24 @@ def preview_3d_image(
10: {'color': (255, 0, 255), 'opacity': 1.0},
11: {'color': (0, 255, 0), 'opacity': 1.0},
}
# Example with a single view direction (normal mode)
preview_3d_image(
input_path=nifti_path,
output="segmentation_preview_single.png",
segmentation_dict=segmentation_dict_example,
view_direction='A', # Allowed values: R, L, A, P, S, I
background_color=(255, 255, 255),
smoothing=20,
shading=20,
window_size=(1200, 1200)
)

# Example with multiple view directions (subplots)
preview_3d_image(
input_path=nifti_path,
output="segmentation_preview.png",
output="segmentation_preview_multi.png",
segmentation_dict=segmentation_dict_example,
view_direction='A', # Possible values: R, L, A, P, S, I
view_direction=['A', 'I', 'R'], # List of directions
background_color=(255, 255, 255),
smoothing=20,
shading=20,
Expand Down

0 comments on commit 698e4be

Please sign in to comment.