Skip to content

Commit

Permalink
Merge pull request #1088 from elyz081/fix_issue_922
Browse files Browse the repository at this point in the history
Move plot residuals into viz - plot
  • Loading branch information
arnaudbore authored Dec 16, 2024
2 parents b3d22c8 + 68c11e2 commit f81231d
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 149 deletions.
77 changes: 0 additions & 77 deletions scilpy/utils/metrics_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import logging
import os

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map
Expand Down Expand Up @@ -302,78 +300,3 @@ def get_bundle_metrics_mean_std_per_point(streamlines, bundle_name,
label_stats['mean'] = float(label_mean)
label_stats['std'] = float(label_std)
return stats


def plot_metrics_stats(means, stds, title=None, xlabel=None,
ylabel=None, figlabel=None, fill_color=None,
display_means=False):
"""
Plots the mean of a metric along n points with the standard deviation.
Parameters
----------
means: Numpy 1D (or 2D) array of size n
Mean of the metric along n points.
stds: Numpy 1D (or 2D) array of size n
Standard deviation of the metric along n points.
title: string
Title of the figure.
xlabel: string
Label of the X axis.
ylabel: string
Label of the Y axis (suggestion: the metric name).
figlabel: string
Label of the figure (only metadata in the figure object returned).
fill_color: string
Hexadecimal RGB color filling the region between mean ± std. The
hexadecimal RGB color should be formatted as #RRGGBB
display_means: bool
Display the subjects means as semi-transparent line
Return
------
The figure object.
"""
matplotlib.style.use('ggplot')

fig, ax = plt.subplots()

# Set optional information to the figure, if required.
if title is not None:
ax.set_title(title)
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
if figlabel is not None:
fig.set_label(figlabel)

if means.ndim > 1:
mean = np.average(means, axis=1)
std = np.average(stds, axis=1)
alpha = 0.5
else:
mean = np.array(means).ravel()
std = np.array(stds).ravel()
alpha = 0.9

dim = np.arange(1, len(mean)+1, 1)

if len(mean) <= 20:
ax.xaxis.set_ticks(dim)

ax.set_xlim(0, len(mean)+1)

if means.ndim > 1 and display_means:
for i in range(means.shape[-1]):
ax.plot(dim, means[:, i], color="k", linewidth=1,
solid_capstyle='round', alpha=0.1)

# Plot the mean line.
ax.plot(dim, mean, color="k", linewidth=5, solid_capstyle='round')

# Plot the std
plt.fill_between(dim, mean - std, mean + std,
facecolor=fill_color, alpha=alpha)

plt.close(fig)
return fig
173 changes: 173 additions & 0 deletions scilpy/viz/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@

import matplotlib
import matplotlib.pyplot as plt
import numpy as np


def plot_metrics_stats(means, stds, title=None, xlabel=None,
ylabel=None, figlabel=None, fill_color=None,
display_means=False):
"""
Plots the mean of a metric along n points with the standard deviation.
Parameters
----------
means: Numpy 1D (or 2D) array of size n
Mean of the metric along n points.
stds: Numpy 1D (or 2D) array of size n
Standard deviation of the metric along n points.
title: string
Title of the figure.
xlabel: string
Label of the X axis.
ylabel: string
Label of the Y axis (suggestion: the metric name).
figlabel: string
Label of the figure (only metadata in the figure object returned).
fill_color: string
Hexadecimal RGB color filling the region between mean ± std. The
hexadecimal RGB color should be formatted as #RRGGBB
display_means: bool
Display the subjects means as semi-transparent line
Return
------
The figure object.
"""
matplotlib.style.use('ggplot')

fig, ax = plt.subplots()

# Set optional information to the figure, if required.
if title is not None:
ax.set_title(title)
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
if figlabel is not None:
fig.set_label(figlabel)

if means.ndim > 1:
mean = np.average(means, axis=1)
std = np.average(stds, axis=1)
alpha = 0.5
else:
mean = np.array(means).ravel()
std = np.array(stds).ravel()
alpha = 0.9

dim = np.arange(1, len(mean)+1, 1)

if len(mean) <= 20:
ax.xaxis.set_ticks(dim)

ax.set_xlim(0, len(mean)+1)

if means.ndim > 1 and display_means:
for i in range(means.shape[-1]):
ax.plot(dim, means[:, i], color="k", linewidth=1,
solid_capstyle='round', alpha=0.1)

# Plot the mean line.
ax.plot(dim, mean, color="k", linewidth=5, solid_capstyle='round')

# Plot the std
plt.fill_between(dim, mean - std, mean + std,
facecolor=fill_color, alpha=alpha)

plt.close(fig)
return fig


def plot_residuals(data_diff, mask, R_k, q1, q3, iqr, residual_basename):
"""
Plots residual statistics for DWI.
Parameters
----------
data_diff: np.ndarray
The 4D residuals between the DWI and predicted data.
mask : Numpy 3D array or None
Mask array indicating the region of interest for computing residuals.
If None, residuals are computed for the entire dataset.
R_k : Numpy 1D array
Mean residual values for each DWI volume.
q1 : Numpy 1D array
First quartile values for each DWI volume.
q3 : Numpy 1D array
Third quartile values for each DWI volume.
iqr : Numpy 1D array
Interquartile range (Q3 - Q1) for each DWI volume.
residual_basename : string
Basename for saving the output plot file. The file will be saved as
'<residual_basename>_residuals_stats.png'.
Returns
-------
None
The function generates a plot and saves it as a PNG file.
"""
# Showing results in graph
# Note that stats will be computed manually and plotted using bxp
# but could be computed using stats = cbook.boxplot_stats
# or pyplot.boxplot(x)

# Initializing stats as a List[dict]
stats = [dict.fromkeys(['label', 'mean', 'iqr', 'cilo', 'cihi',
'whishi', 'whislo', 'fliers', 'q1',
'med', 'q3'], [])
for _ in range(data_diff.shape[-1])]

nb_voxels = np.count_nonzero(mask)
percent_outliers = np.zeros(data_diff.shape[-1], dtype=np.float32)
for k in range(data_diff.shape[-1]):
stats[k]['med'] = (q1[k] + q3[k]) / 2
stats[k]['mean'] = R_k[k]
stats[k]['q1'] = q1[k]
stats[k]['q3'] = q3[k]
stats[k]['whislo'] = q1[k] - 1.5 * iqr[k]
stats[k]['whishi'] = q3[k] + 1.5 * iqr[k]
stats[k]['label'] = k

# Outliers are observations that fall below Q1 - 1.5(IQR) or
# above Q3 + 1.5(IQR) We check if a voxel is an outlier only if
# we have a mask, else we are biased.
if mask is not None:
x = data_diff[..., k]
outliers = (x < stats[k]['whislo']) | (x > stats[k]['whishi'])
percent_outliers[k] = np.sum(outliers) / nb_voxels * 100
# What would be our definition of too many outliers?
# Maybe mean(all_means)+-3SD?
# Or we let people choose based on the figure.
# if percent_outliers[k] > ???? :
# logger.warning(' Careful! Diffusion-Weighted Image'
# ' i=%s has %s %% outlier voxels',
# k, percent_outliers[k])

if mask is None:
fig, axe = plt.subplots(nrows=1, ncols=1, squeeze=False)
else:
fig, axe = plt.subplots(nrows=1, ncols=2, squeeze=False,
figsize=[10, 4.8])
# Default is [6.4, 4.8]. Increasing width to see better.

medianprops = dict(linestyle='-', linewidth=2.5, color='firebrick')
meanprops = dict(linestyle='-', linewidth=2.5, color='green')
axe[0, 0].bxp(stats, showmeans=True, meanline=True, showfliers=False,
medianprops=medianprops, meanprops=meanprops)
axe[0, 0].set_xlabel('DW image')
axe[0, 0].set_ylabel('Residuals per DWI volume. Red is median,\n'
'green is mean. Whiskers are 1.5*interquartile')
axe[0, 0].set_title('Residuals')
axe[0, 0].set_xticks(range(0, q1.shape[0], 5))
axe[0, 0].set_xticklabels(range(0, q1.shape[0], 5))

if mask is not None:
axe[0, 1].plot(range(data_diff.shape[-1]), percent_outliers)
axe[0, 1].set_xticks(range(0, q1.shape[0], 5))
axe[0, 1].set_xticklabels(range(0, q1.shape[0], 5))
axe[0, 1].set_xlabel('DW image')
axe[0, 1].set_ylabel('Percentage of outlier voxels')
axe[0, 1].set_title('Outliers')
plt.savefig(residual_basename + '_residuals_stats.png')
76 changes: 5 additions & 71 deletions scripts/scil_dti_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import argparse
import logging

import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np

Expand All @@ -54,6 +53,7 @@
is_normalized_bvecs,
normalize_bvecs)
from scilpy.utils.filenames import add_filename_suffix, split_name_with_nii
from scilpy.viz.plot import plot_residuals

logger = logging.getLogger("DTI_Metrics")
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -149,75 +149,6 @@ def _build_arg_parser():
return p


def _plot_residuals(args, data_diff, mask, R_k, q1, q3, iqr, residual_basename):
# Showing results in graph
# Note that stats will be computed manually and plotted using bxp
# but could be computed using stats = cbook.boxplot_stats
# or pyplot.boxplot(x)
if mask is None:
logging.info("Outlier detection will not be performed, since no "
"mask was provided.")

# Initializing stats as a List[dict]
stats = [dict.fromkeys(['label', 'mean', 'iqr', 'cilo', 'cihi',
'whishi', 'whislo', 'fliers', 'q1',
'med', 'q3'], [])
for _ in range(data_diff.shape[-1])]

nb_voxels = np.count_nonzero(mask)
percent_outliers = np.zeros(data_diff.shape[-1], dtype=np.float32)
for k in range(data_diff.shape[-1]):
stats[k]['med'] = (q1[k] + q3[k]) / 2
stats[k]['mean'] = R_k[k]
stats[k]['q1'] = q1[k]
stats[k]['q3'] = q3[k]
stats[k]['whislo'] = q1[k] - 1.5 * iqr[k]
stats[k]['whishi'] = q3[k] + 1.5 * iqr[k]
stats[k]['label'] = k

# Outliers are observations that fall below Q1 - 1.5(IQR) or
# above Q3 + 1.5(IQR) We check if a voxel is an outlier only if
# we have a mask, else we are biased.
if args.mask is not None:
x = data_diff[..., k]
outliers = (x < stats[k]['whislo']) | (x > stats[k]['whishi'])
percent_outliers[k] = np.sum(outliers) / nb_voxels * 100
# What would be our definition of too many outliers?
# Maybe mean(all_means)+-3SD?
# Or we let people choose based on the figure.
# if percent_outliers[k] > ???? :
# logger.warning(' Careful! Diffusion-Weighted Image'
# ' i=%s has %s %% outlier voxels',
# k, percent_outliers[k])

if args.mask is None:
fig, axe = plt.subplots(nrows=1, ncols=1, squeeze=False)
else:
fig, axe = plt.subplots(nrows=1, ncols=2, squeeze=False,
figsize=[10, 4.8])
# Default is [6.4, 4.8]. Increasing width to see better.

medianprops = dict(linestyle='-', linewidth=2.5, color='firebrick')
meanprops = dict(linestyle='-', linewidth=2.5, color='green')
axe[0, 0].bxp(stats, showmeans=True, meanline=True, showfliers=False,
medianprops=medianprops, meanprops=meanprops)
axe[0, 0].set_xlabel('DW image')
axe[0, 0].set_ylabel('Residuals per DWI volume. Red is median,\n'
'green is mean. Whiskers are 1.5*interquartile')
axe[0, 0].set_title('Residuals')
axe[0, 0].set_xticks(range(0, q1.shape[0], 5))
axe[0, 0].set_xticklabels(range(0, q1.shape[0], 5))

if args.mask is not None:
axe[0, 1].plot(range(data_diff.shape[-1]), percent_outliers)
axe[0, 1].set_xticks(range(0, q1.shape[0], 5))
axe[0, 1].set_xticklabels(range(0, q1.shape[0], 5))
axe[0, 1].set_xlabel('DW image')
axe[0, 1].set_ylabel('Percentage of outlier voxels')
axe[0, 1].set_title('Outliers')
plt.savefig(residual_basename + '_residuals_stats.png')


def main():
parser = _build_arg_parser()
args = parser.parse_args()
Expand Down Expand Up @@ -404,6 +335,9 @@ def main():
add_filename_suffix(args.pulsation, '_std_b0'))

if args.residual:
if mask is None:
logging.info("Outlier detection will not be performed, since no "
"mask was provided.")
# Mean residual image
S0 = np.mean(data[..., gtab.b0s_mask], axis=-1)
tenfit2_predict = np.zeros(data.shape, dtype=np.float32)
Expand Down Expand Up @@ -435,7 +369,7 @@ def main():
np.save(add_filename_suffix(res_stats_basename, "_std_residuals"), std)

# Plotting and saving figure
_plot_residuals(args, data_diff, mask, R_k, q1, q3, iqr,
plot_residuals(data_diff, mask, R_k, q1, q3, iqr,
residual_basename)


Expand Down
2 changes: 1 addition & 1 deletion scripts/scil_plot_stats_per_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist,
assert_output_dirs_exist_and_empty,
add_verbose_arg)
from scilpy.utils.metrics_tools import plot_metrics_stats
from scilpy.viz.plot import plot_metrics_stats


def _build_arg_parser():
Expand Down

0 comments on commit f81231d

Please sign in to comment.