Skip to content

Commit

Permalink
ENH: Add gradient plot method
Browse files Browse the repository at this point in the history
Add gradient plot method.
  • Loading branch information
jhlegarreta committed Apr 30, 2024
1 parent 0d893ff commit b8038ba
Show file tree
Hide file tree
Showing 2 changed files with 250 additions and 0 deletions.
205 changes: 205 additions & 0 deletions nireports/reportlets/modality/dwi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
"""Visualizations for diffusion MRI data."""
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.pyplot import cm
from mpl_toolkits.mplot3d import art3d


def plot_heatmap(
Expand Down Expand Up @@ -124,3 +126,206 @@ def plot_heatmap(
fig.tight_layout(rect=[0.02, 0, 1, 1])

return fig



def rotation_matrix(u, v):
r"""Calculate the rotation matrix *R* such that :math:`R \cdot \mathbf{u} = \mathbf{v}`.
Extracted from `Emmanuel Caruyer's code
<https://github.com/ecaruyer/qspace/blob/master/qspace/visu/visu_points.py>`__,
which is distributed under the revised BSD License:
Copyright (c) 2013-2015, Emmanuel Caruyer
All rights reserved.
.. admonition :: List of changes
Only minimal updates to leverage Numpy.
Parameters
----------
u : :obj:`numpy.ndarray`
A vector.
v : :obj:`numpy.ndarray`
A vector.
Returns
-------
R : :obj:`numpy.ndarray`
The rotation matrix.
"""

# the axis is given by the product u x v
u = u / np.linalg.norm(u)
v = v / np.linalg.norm(v)
w = np.asarray(
[
u[1] * v[2] - u[2] * v[1],
u[2] * v[0] - u[0] * v[2],
u[0] * v[1] - u[1] * v[0],
]
)
if (w ** 2).sum() < (np.finfo(w.dtype).eps * 10):
# The vectors u and v are collinear
return np.eye(3)

# Compute sine and cosine
c = u @ v
s = np.linalg.norm(w)

w = w / s
P = np.outer(w, w)
Q = np.asarray([[0, -w[2], w[1]], [w[2], 0, -w[0]], [-w[1], w[0], 0]])
R = P + c * (np.eye(3) - P) + s * Q
return R


def draw_circles(positions, radius, n_samples=20):
r"""Draw circular patches (lying on a sphere) at given positions.
Adapted from `Emmanuel Caruyer's code
<https://github.com/ecaruyer/qspace/blob/master/qspace/visu/visu_points.py>`__,
which is distributed under the revised BSD License:
Copyright (c) 2013-2015, Emmanuel Caruyer
All rights reserved.
.. admonition :: List of changes
Modified to take the full list of normalized bvecs and corresponding circle
radii instead of taking the list of bvecs and radii for a specific shell
(*b*-value).
Parameters
----------
positions : :obj:`numpy.ndarray`
An array :math:`N \times 3` of 3D cartesian positions.
radius : :obj:`float`
The reference radius (or, the radius in single-shell plots)
n_samples : :obj:`int`
The number of samples on the sphere.
Returns
-------
circles : :obj:`numpy.ndarray`
Circular patches.
"""

# A circle centered at [1, 0, 0] with radius r
t = np.linspace(0, 2 * np.pi, n_samples)

nb_points = positions.shape[0]
circles = np.zeros((nb_points, n_samples, 3))
for i in range(positions.shape[0]):
circle_x = np.zeros((n_samples, 3))
dots_radius = np.sqrt(radius[i]) * 0.04
circle_x[:, 1] = dots_radius * np.cos(t)
circle_x[:, 2] = dots_radius * np.sin(t)
norm = np.linalg.norm(positions[i])
point = positions[i] / norm
r1 = rotation_matrix(np.asarray([1, 0, 0]), point)
circles[i] = positions[i] + np.dot(r1, circle_x.T).T
return circles


def draw_points(gradients, ax, rad_min=0.3, rad_max=0.7, colormap="viridis"):
"""Draw the vectors on a shell.
Adapted from `Emmanuel Caruyer's code
<https://github.com/ecaruyer/qspace/blob/master/qspace/visu/visu_points.py>`__,
which is distributed under the revised BSD License:
Copyright (c) 2013-2015, Emmanuel Caruyer
All rights reserved.
.. admonition :: List of changes
* The input is a single 2D numpy array of the gradient table in RAS+B format
* The scaling of the circle radius for each bvec proportional to the inverse of
the bvals. A minimum/maximal value for the radii can be specified.
* Circles for each bvec are drawn at once instead of looping over the shells.
* Some variables have been renamed (like vects to bvecs)
Parameters
----------
gradients : :obj:`numpy.ndarray`
An (N, 4) shaped array of the gradient table in RAS+B format.
ax : :obj:`matplotlib.axes.Axis`
The matplotlib axes instance to plot in.
rad_min : :obj:`float` between 0 and 1
Minimum radius of the circle that renders a gradient direction.
rad_max : :obj:`float` between 0 and 1
Maximum radius of the circle that renders a gradient direction.
colormap : :obj:`matplotlib.pyplot.cm.ColorMap`
matplotlib colormap name.
"""

# Initialize 3D view
elev = 90
azim = 0
ax.view_init(azim=azim, elev=elev)

# Normalize to 1 the highest bvalue
bvals = np.copy(gradients[3, :])
bvals = bvals / bvals.max()

# Colormap depending on bvalue (for visualization)
cmap = cm.get_cmap(colormap)
colors = cmap(bvals)

# Relative shell radii proportional to the inverse of bvalue (for visualization)
rs = np.reciprocal(bvals)
rs = rs / rs.max()

# Readjust radius of the circle given the minimum and maximal allowed values.
rs = rs - rs.min()
rs = rs / (rs.max() - rs.min())
rs = rs * (rad_max - rad_min) + rad_min

bvecs = np.copy(
gradients[:3, :].T,
)
bvecs[bvecs[:, 2] < 0] *= -1

# Render all gradient direction of all b-values
circles = draw_circles(bvecs, rs)
ax.add_collection(art3d.Poly3DCollection(circles, facecolors=colors, linewidth=0))

max_val = 0.6
ax.set_xlim(-max_val, max_val)
ax.set_ylim(-max_val, max_val)
ax.set_zlim(-max_val, max_val)
ax.axis("off")


def plot_gradients(
gradients,
title=None,
ax=None,
spacing=0.05,
**kwargs,
):
"""Draw the vectors on a unit sphere with color code for multiple b-value.
Parameters
----------
gradients : :obj:`numpy.ndarray`
An (N, 4) shaped array of the gradient table in RAS+B format.
title : :obj:`str`
Plot title.
ax : :obj:`matplotlib.axes.Axis`
A figure's axis to plot on.
spacing : :obj:`float`
Plot spacing.
kwargs : :obj:`dict`
Extra args given to :obj:`eddymotion.viz.draw_points()`.
Returns
-------
ax : :obj:`matplotlib.axes.Axis`
The figure's axis where the data is plot.
"""

# Initialize figure
if ax is None:
figsize = kwargs.pop("figsize", (9.0, 9.0))
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111, projection="3d")
plt.subplots_adjust(bottom=spacing, top=1 - spacing, wspace=2 * spacing)

# Draw points after re-projecting all shells to the unit sphere
draw_points(gradients, ax, **kwargs)

if title:
plt.suptitle(title)

return ax
45 changes: 45 additions & 0 deletions nireports/tests/test_dwi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright 2023 The NiPreps Developers <[email protected]>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
# https://www.nipreps.org/community/licensing/
#
from pathlib import Path

import numpy as np
from matplotlib import pyplot as plt
from dipy.io import read_bvals_bvecs
from dipy.core.gradients import gradient_table

from nireports.reportlets.modality.dwi import plot_gradients


def test_plot_gradients(tmp_path):

fbval = "./mriqc/data/testdata/hcp_bvals"
fbvec = "./mriqc/data/testdata/hcp_bvecs"
_bvals, _bvecs = read_bvals_bvecs(fbval, fbvec)
gtab = gradient_table(_bvals, _bvecs)
bvecs = gtab.bvecs[~gtab.b0s_mask]
bvals = gtab.bvals[~gtab.b0s_mask]

gradients = np.vstack([bvecs.T, bvals])
_ = plot_gradients(gradients)

plt.savefig(Path(tmp_path) / "gradients.png")

0 comments on commit b8038ba

Please sign in to comment.