Skip to content

Commit

Permalink
Improved example.py
Browse files Browse the repository at this point in the history
  • Loading branch information
DelinteNicolas committed Apr 5, 2023
1 parent ce74a30 commit 8ebebd2
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 46 deletions.
Binary file added dist/unravel-python-1.1.10.tar.gz
Binary file not shown.
2 changes: 1 addition & 1 deletion unravel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
microstructural properties of neural fibers in a specified tract.
"""

__version__ = "1.1.9"
__version__ = "1.1.10"
__author__ = 'Nicolas Delinte'
3 changes: 3 additions & 0 deletions unravel/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,9 @@ def plot_streamline_metrics(trk, tList: list, metric_maps: list,
if groundTruth_map is not None:
axs.plot(vList, mgtList, label='Ground truth')
axs.legend()
axs.set_ylabel('Metric')
axs.set_xlabel('Streamline segment position')
axs.set_title('Microstructure along streamline')


def plot_streamline_metrics_old(streamList: list, metric_maps: list,
Expand Down
107 changes: 75 additions & 32 deletions unravel/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,57 +14,100 @@
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from unravel.core import *
from dipy.io.streamline import load_tractogram
from unravel.core import (get_fixel_weight, get_microstructure_map,
weighted_mean_dev, main_fixel_map,
plot_streamline_metrics, total_segment_length)
from unravel.utils import (peaks_to_RGB, tract_to_ROI, peaks_to_peak,
plot_streamline_trajectory)


if __name__ == '__main__':

os.chdir('..')

trk_file = 'data/sampleSubject_cc_bundle_mid_ant.trk'
MF_dir = 'data/'
Patient = 'sampleSubject'
data_dir = 'data/'
patient = 'sampleSubject'
trk_file = data_dir+patient+'_cc_bundle_mid_ant.trk'
trk = load_tractogram(trk_file, 'same')
trk.to_vox()
trk.to_corner()

# Maps and means ----------------------------------------------------------

tList = [nib.load(data_dir+patient+'_mf_peak_f0.nii.gz').get_fdata(),
nib.load(data_dir+patient+'_mf_peak_f1.nii.gz').get_fdata()]

fixelWeights, _, _, voxelStreams, _ = get_fixel_weight_MF(
trk_file, MF_dir, Patient, streamList=[0])
fixel_weights, _, _ = get_fixel_weight(trk, tList)

metricMapList = [nib.load('data/sampleSubject_mf_fvf_f0.nii.gz').get_fdata(),
nib.load('data/sampleSubject_mf_fvf_f1.nii.gz').get_fdata()]
metric_maps = [nib.load(data_dir+patient+'_mf_fvf_f0.nii.gz').get_fdata(),
nib.load(data_dir+patient+'_mf_fvf_f1.nii.gz').get_fdata()]

microMap = get_microstructure_map(fixelWeights, metricMapList)
microMap = get_microstructure_map(fixel_weights, metric_maps)

weightedMean, weightedDev, _, [Min, Max] = weighted_mean_dev(
metricMapList, [fixelWeights[:, :, :, 0], fixelWeights[:, :, :, 1]])
metric_maps, [fixel_weights[:, :, :, 0], fixel_weights[:, :, :, 1]])

# Colors ------------------------------------------------------------------

fList = [nib.load(data_dir+patient+'_mf_fvf_f0.nii.gz').get_fdata(),
nib.load(data_dir+patient+'_mf_fvf_f1.nii.gz').get_fdata()]

mask = tract_to_ROI(trk_file)
mask = np.repeat(mask[:, :, :, np.newaxis], 3, axis=3)

p = peaks_to_peak(tList, fixel_weights)
rgb = peaks_to_RGB(peaksList=[p])*mask

# Total segment length ----------------------------------------------------

tsl = total_segment_length(fixel_weights)

# Printing means ----------------------------------------------------------

print('The fiber volume fraction estimation of '+Patient+' in the middle '
print('The fiber volume fraction estimation of '+patient+' in the middle '
+ 'anterior bundle of the corpus callosum are \n'
+ 'Weighted mean : '+str(weightedMean)+'\n'
+ 'Weighted standard deviation : '+str(weightedDev)+'\n'
+ 'Min/Max : '+str(Min), str(Max)+'\n')

# Plotting results --------------------------------------------------------

slice_num = 71

background = nib.load(
'data/sampleSubject_T1_diffusionSpace.nii.gz').get_fdata()
totalSegmentLength = np.sum(fixelWeights, axis=3)
totalSegmentLengthTransparency = totalSegmentLength / \
np.max(totalSegmentLength)
tSL = totalSegmentLength.copy()
tSL[totalSegmentLength > 0] = 1

fig, axs = plt.subplots(1, 4)
axs[0].imshow(np.rot90(main_fixel_map(
fixelWeights)[:, 71, :]), cmap='gray')
axs[0].set_title('Most aligned fixel')
axs[1].imshow(np.rot90(totalSegmentLength[:, 71, :]),
cmap='inferno', clim=[0, 80])
axs[1].set_title('Total segment length')
axs[2].imshow(np.rot90(background[:, 71, :]), cmap='gray')
axs[2].imshow(np.rot90(tSL[:, 71, :]), alpha=np.rot90(
totalSegmentLengthTransparency[:, 71, :]), cmap='Wistia')
axs[2].set_title('Total segment length')
axs[3].imshow(np.rot90(microMap[:, 71, :]), cmap='gray')
axs[3].set_title('Fiber volume fraction \n (axonal density) map')

plot_streamline_metrics(voxelStreams, metricMapList)
roi = np.where(tsl > 0, .99, 0)
non_roi = np.where(tsl == 0, .99, 0)
alpha_tsl = tsl[:, slice_num, :]/np.max(tsl)*2
alpha_tsl[alpha_tsl > 1] = 1

fig, axs = plt.subplots(2, 2)
axs[0, 0].imshow(np.rot90(background[:, slice_num, :]), cmap='gray')
axs[0, 0].imshow(np.rot90(main_fixel_map(fixel_weights)[:, slice_num, :]),
cmap='Wistia', alpha=np.rot90(roi[:, slice_num, :]))
axs[0, 0].set_title('Most aligned fixel')
axs[0, 1].imshow(np.rot90(rgb[:, slice_num, :]))
axs[0, 1].imshow(np.rot90(background[:, slice_num, :]), cmap='gray',
alpha=np.rot90(non_roi[:, slice_num, :]))
axs[0, 1].set_title('Angular weighted \n direction')
axs[1, 0].imshow(np.rot90(background[:, slice_num, :]), cmap='gray')
axs[1, 0].imshow(np.rot90(roi[:, slice_num, :]), cmap='Wistia',
alpha=np.rot90(alpha_tsl))
axs[1, 0].set_title('Total segment length')
axs[1, 1].imshow(np.rot90(background[:, slice_num, :]), cmap='gray')
fvf = axs[1, 1].imshow(np.rot90(microMap[:, slice_num, :]), cmap='autumn',
alpha=np.rot90(roi[:, slice_num, :]))
fig.colorbar(fvf, ax=axs[1, 1])
axs[1, 1].set_title('Fiber volume fraction \n (axonal density) map')

# Along streamline metric --------------------------------------------------

stream_num = 500

plot_streamline_trajectory(trk, resolution_increase=3,
streamline_number=stream_num, axis=1)

plot_streamline_metrics(trk, tList, metric_maps,
method_list=['vol', 'cfo', 'ang'],
streamline_number=stream_num, fList=fList)
35 changes: 22 additions & 13 deletions unravel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def peaks_to_RGB(peaksList: list, fracList: list = None, fvfList: list = None):
-------
rgb : 4-D array of shape (x,y,z,3)
RGB map of shape (x,y,z,3) representing the main direction of
of the peaks.
of the peaks. With type float64 [0,1].
'''

Expand All @@ -76,34 +76,43 @@ def peaks_to_RGB(peaksList: list, fracList: list = None, fvfList: list = None):
peaksList = [peaksList]

K = len(peaksList)
dim = len(peaksList[0].shape[:-1])

len_ratio = np.ones(peaksList[0].shape[:-1])

for k in range(K):
peaksList[k] = np.nan_to_num(peaksList[k])
len_ratio += np.where(np.sum(peaksList[k], axis=dim) == 0, 1, 0)

if fracList is None:
fracList = []
for k in range(K):
fracList.append(np.ones(peaksList[0].shape[:3]))
fracList.append(np.ones(peaksList[0].shape[:-1]))

if fvfList is None:
fvfList = []
for k in range(K):
fvfList.append(np.ones(peaksList[0].shape[:3]))
fvfList.append(np.ones(peaksList[0].shape[:-1]))

rgb = np.zeros(peaksList[0].shape)

for xyz in np.ndindex(peaksList[0].shape[:3]):
for xyz in np.ndindex(peaksList[0].shape[:-1]):
for k in range(K):
rgb[xyz] += abs(peaksList[k][xyz])*fracList[k][xyz]*fvfList[k][xyz]

# Normalize between [0,1] and by number of peaks per voxel
rgb *= np.repeat(1+len_ratio[(slice(None),) *
dim + (np.newaxis,)]/K, 3, axis=dim)
rgb /= np.max(rgb)

return rgb


def peaks_to_peak(peaksList: list, fixel_weights, fracList: list = None,
fvfList: list = None):
'''
Fuse peaks into a single peak based on fixel weight and fvf, intensity
is then weighted with frac Mostly used for visualization purposes.
is then weighted with frac. Mostly used for visualization purposes.
Parameters
----------
Expand All @@ -114,7 +123,7 @@ def peaks_to_peak(peaksList: list, fixel_weights, fracList: list = None,
Returns
-------
None.
peak : 3-D array of shape (x,y,z,3)
'''

Expand All @@ -128,20 +137,19 @@ def peaks_to_peak(peaksList: list, fixel_weights, fracList: list = None,
if fracList is None:
fracList = []
for k in range(K):
fracList.append(np.ones(peaksList[0].shape[:3]))/k
fracList.append(np.ones(peaksList[0].shape[:-1])/(k+1))

if fvfList is None:
fvfList = []
for k in range(K):
fvfList.append(np.ones(peaksList[0].shape[:3]))
fvfList.append(np.ones(peaksList[0].shape[:-1]))

fracTot = np.zeros(peaksList[0].shape[:3])
fracTot = np.zeros(peaksList[0].shape[:-1])

for xyz in np.ndindex(peaksList[0].shape[:3]):
for xyz in np.ndindex(peaksList[0].shape[:-1]):
for k in range(K):
peak[xyz] += abs(peaksList[k][xyz]) * \
fixel_weights[xyz+(k,)]/np.sum(fixel_weights[xyz]) * \
fvfList[k][xyz]
peak[xyz] += (abs(peaksList[k][xyz])*fixel_weights[xyz+(k,)]
/ np.sum(fixel_weights[xyz])*fvfList[k][xyz])

for k in range(K):
fracTot += fracList[k]
Expand Down Expand Up @@ -331,3 +339,4 @@ def plot_streamline_trajectory(trk, resolution_increase: int = 1,
plt.imshow(density[:, :, int(sum(z)/len(z))].T,
origin='lower', cmap='gray')
plt.plot(x, y, '.-', c='#e69402ff')
plt.title('Streamline trajectory')

0 comments on commit 8ebebd2

Please sign in to comment.