Skip to content

Commit

Permalink
working version, need refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
fullbat committed Jun 10, 2024
1 parent 675ba7d commit d4613ec
Show file tree
Hide file tree
Showing 6 changed files with 537 additions and 1,002 deletions.
152 changes: 77 additions & 75 deletions commit/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
cimport numpy as np
cimport cython

from concurrent.futures import ThreadPoolExecutor, as_completed
import glob
from os import makedirs, remove, listdir
from os.path import exists, join as pjoin, isfile, isdir
Expand All @@ -19,13 +20,16 @@ import pickle

from pkg_resources import get_distribution


import amico.scheme
import amico.lut

from dicelib.ui import ProgressBar, setup_logger
from dicelib import ui
from dicelib.utils import format_time
from dicelib.tractogram import filter

from commit import trk2dictionary
import commit.models
import commit.solvers
from commit.operator import operator
Expand Down Expand Up @@ -82,9 +86,12 @@ cdef class Evaluation :
cdef public x
cdef public CONFIG
cdef public temp_data
cdef public confidence_array
cdef public confidence_map_img
cdef public contribution_mask
cdef public contribution_fibs
cdef public contribution_voxels
cdef public debias
cdef public verbose

def __init__( self, study_path='.', subject='.' ) :
Expand All @@ -106,9 +113,12 @@ cdef class Evaluation :
self.A = None # set by "build_operator" method
self.regularisation_params = None # set by "set_regularisation" method
self.x = None # set by "fit" method
self.confidence_array = None # set by "fit" method
self.confidence_map_img = None # set by "fit" method
self.contribution_mask = None # set by "fit" method
self.contribution_voxels = None # set by "fit" method
self.contribution_fibs = None
self.debias = False
self.verbose = 3

# store all the parameters of an evaluation with COMMIT
Expand Down Expand Up @@ -742,52 +752,19 @@ cdef class Evaluation :
logger.info( f'[ {format_time(time.time() - tic)} ]' )


@cython.boundscheck(False)
@cython.wraparound(False)
cpdef get_y( self ):
def get_y( self ):
"""
Returns a numpy array that corresponds to the 'y' vector of the optimisation problem.
NB: this can be run only after having loaded the dictionary and the data.
"""

cdef int i = 0

if self.DICTIONARY is None :
logger.error( 'Dictionary not loaded; call "load_dictionary()" first' )
if self.niiDWI is None :
logger.error( 'Data not loaded; call "load_data()" first' )

y = self.niiDWI_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'], : ].flatten().astype(np.float64)


if self.contribution_mask is not None :

# find the voxels traversed by the tracts with zero contribution
zero_fibs = np.where(self.contribution_mask == 0)[0]

# compute number of tract passing through each voxel
n_tracts = np.bincount(self.DICTIONARY['IC']['v'], minlength=self.DICTIONARY['nV'])

# extract the voxels traversed by at most one tract
vox_single = np.where(n_tracts < zero_fibs.size)[0]

contrib_voxels = y.copy()
contrib_voxels = contrib_voxels.astype(np.uint32)
contrib_voxels[contrib_voxels > 0] = 1
contrib_voxels

# iterate over vox_single and set the y values of the voxels traverse ONLY by the tracts with zero contribution to zero
with ProgressBar(total=vox_single.size, disable=self.verbose < 3, hide_on_exit=True, subinfo=True) as pbar:
for i in vox_single:
if n_tracts[i] == 1:
if self.contribution_mask[self.DICTIONARY['IC']['fiber'][self.DICTIONARY['IC']['v'] == i]] == 0:
y[i] = 0
else:
if np.all(self.contribution_mask[self.DICTIONARY['IC']['fiber'][self.DICTIONARY['IC']['v'] == i]] == 0):
y[i] = 0
pbar.update()

self.contribution_voxels = contrib_voxels[y != 0]
# y[y < 0] = 0
print(f"place of first non-zero voxel in input data: {np.where(y==1.07405806)}")
return y


Expand Down Expand Up @@ -906,7 +883,7 @@ cdef class Evaluation :
for g in range(w_group.size):
norm_group[g] = np.sqrt(np.sum(Aty[idx_group[g]]**2)) / w_group[g]
return np.max(norm_group)


regularisation = {}

Expand Down Expand Up @@ -1351,36 +1328,80 @@ cdef class Evaluation :
nV = self.DICTIONARY['nV']
self.confidence_map_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'], : ] = np.reshape( confidence_array, (nV,-1) ).astype(np.float32)


if x0 is not None :
if x0.shape[0] != self.A.shape[1] :
logger.error( 'x0 dimension does not match the number of columns of the dictionary' )


self.CONFIG['optimization'] = {}
self.CONFIG['optimization']['tol_fun'] = tol_fun
self.CONFIG['optimization']['tol_x'] = tol_x
self.CONFIG['optimization']['max_iter'] = max_iter
self.CONFIG['optimization']['regularisation'] = self.regularisation_params

if debias:
self.debias = True
self.CONFIG['optimization']['x0'] = x0
self.confidence_array = confidence_array


# run solver
t = time.time()
with ProgressBar(disable=self.verbose!=3, hide_on_exit=True) as pb:
self.x, opt_details = commit.solvers.solve(self.get_y(), self.A, self.A.T, tol_fun=tol_fun, tol_x=tol_x, max_iter=max_iter, verbose=self.verbose, x0=x0, regularisation=self.regularisation_params, confidence_array=confidence_array)
self.CONFIG['optimization']['fit_details'] = opt_details
self.CONFIG['optimization']['fit_time'] = time.time()-t

if debias:
from commit.operator import operator
mask = np.ones(self.x.size, dtype=np.uint32)
# mask[self.x[:self.DICTIONARY['IC']['nF']]<0.000000000000001] = 0
mask[:self.DICTIONARY['IC']['nF']][self.x[:self.DICTIONARY['IC']['nF']]<0.000000000000001] = 0
self.contribution_mask = mask
if self.debias:
logger.info( 'Recomputing coefficients' )
xic, _, _ = self.get_coeffs()
weights_in = pjoin( self.get_config('TRACKING_path'), 'streamline_weights.txt' )
np.savetxt(weights_in, xic)

self.DICTIONARY["IC"]["eval"] = mask[:self.DICTIONARY['IC']['nF']]
dictionary_info = load_dictionary_info( pjoin(self.get_config('TRACKING_path'), 'dictionary_info.pickle') )
tractogram = dictionary_info['filename_tractogram']
tractogram_filtered = tractogram.replace('.tck', '_filtered.tck')

filter(dictionary_info['filename_tractogram'], tractogram_filtered, minweight=0.000000000000001, weights_in=weights_in, force=True, verbose=0)

# # RE-RUN COMMIT WITH THE FILTERED TRACTOGRAM
# path_COMMIT = os.path.join(local_path, "COMMIT_master_debias")

trk2dictionary.run(
filename_tractogram = tractogram_filtered,
TCK_ref_image = dictionary_info['TCK_ref_image'],
path_out = dictionary_info['path_out'],
filename_peaks = dictionary_info['filename_peaks'],
filename_mask = dictionary_info['filename_mask'],
do_intersect = dictionary_info['do_intersect'],
fiber_shift = dictionary_info['fiber_shift'],
min_seg_len = dictionary_info['min_seg_len'],
min_fiber_len = dictionary_info['min_fiber_len'],
max_fiber_len = dictionary_info['max_fiber_len'],
vf_THR = dictionary_info['vf_THR'],
peaks_use_affine = dictionary_info['peaks_use_affine'],
flip_peaks = dictionary_info['flip_peaks'],
blur_core_extent = dictionary_info['blur_core_extent'],
blur_gauss_extent = dictionary_info['blur_gauss_extent'],
blur_gauss_min = dictionary_info['blur_gauss_min'],
blur_spacing = dictionary_info['blur_spacing'],
ndirs = dictionary_info['ndirs'],
n_threads = dictionary_info['n_threads'],
verbose = 0
)

self.load_dictionary(dictionary_info['path_out'])

self.set_threads()
self.build_operator()

self.A = operator.LinearOperator( self.DICTIONARY, self.KERNELS, self.THREADS, nolut=True if hasattr(self.model, 'nolut') else False )
self.set_regularisation()
self.x, opt_details = commit.solvers.solve(self.get_y(), self.A, self.A.T, tol_fun=tol_fun, tol_x=tol_x, max_iter=max_iter, verbose=self.verbose, x0=x0, regularisation=self.regularisation_params, confidence_array=confidence_array)

with ProgressBar(disable=self.verbose!=3, hide_on_exit=True) as pb:
self.x, opt_details = commit.solvers.solve(self.get_y(), self.A, self.A.T, tol_fun=self.CONFIG['optimization']['tol_fun'], tol_x=self.CONFIG['optimization']['tol_x'], max_iter=self.CONFIG['optimization']['max_iter'], verbose=self.verbose, x0=self.CONFIG['optimization']['x0'], regularisation=self.regularisation_params, confidence_array=self.confidence_array)

self.CONFIG['optimization']['fit_details'] = opt_details
self.CONFIG['optimization']['fit_time'] = time.time()-t
logger.info( f'[ {format_time(self.CONFIG["optimization"]["fit_time"])} ]' )


Expand Down Expand Up @@ -1425,7 +1446,7 @@ cdef class Evaluation :
return xic, xec, xiso


def save_results( self, path_suffix=None, coeffs_format='%.5e', stat_coeffs='sum', save_est_dwi=False, do_reweighting=True ) :
def save_results( self, path_suffix=None, coeffs_format='%.5e', stat_coeffs='sum', save_est_dwi=False, do_reweighting=True, debias=False ) :
"""Save the output (coefficients, errors, maps etc).
Parameters
Expand Down Expand Up @@ -1487,19 +1508,8 @@ cdef class Evaluation :
niiMAP_hdr['descrip'] = 'Created with COMMIT %s'%self.get_config('version')
niiMAP_hdr['db_name'] = ''

if self.contribution_mask is None:
y_mea = np.reshape( self.niiDWI_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'], : ].flatten().astype(np.float32), (nV,-1) )
y_est = np.reshape( self.A.dot(self.x), (nV,-1) ).astype(np.float32)
else:
nV = self.contribution_voxels.shape[0]
y_mea = np.reshape( self.get_y()[self.contribution_voxels], (nV,-1) )
y_est = np.asarray(self.A.dot(self.x))
y_est = np.reshape( y_est[self.contribution_voxels], (nV,-1) ).astype(np.float32)
self.DICTIONARY['MASK_ix'] = self.DICTIONARY['MASK_ix'][self.contribution_voxels]
self.DICTIONARY['MASK_iy'] = self.DICTIONARY['MASK_iy'][self.contribution_voxels]
self.DICTIONARY['MASK_iz'] = self.DICTIONARY['MASK_iz'][self.contribution_voxels]


y_mea = np.reshape( self.niiDWI_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'], : ].flatten().astype(np.float32), (nV,-1) )
y_est = np.reshape( self.A.dot(self.x), (nV,-1) ).astype(np.float32)

tmp = np.sqrt( np.mean((y_mea-y_est)**2,axis=1) )
logger.subinfo(f'RMSE: {tmp.mean():.3f} +/- {tmp.std():.3f}', indent_lvl=2, indent_char='-')
Expand Down Expand Up @@ -1557,22 +1567,16 @@ cdef class Evaluation :
xv = np.bincount( self.DICTIONARY['IC']['v'], minlength=nV,
weights=tmp[ self.DICTIONARY['IC']['fiber'] ] * self.DICTIONARY['IC']['len']
).astype(np.float32)
if self.contribution_mask is not None:
niiIC_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'] ] = xv[self.contribution_voxels]
else:
niiIC_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'] ] = xv

niiIC_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'] ] = xv

logger.subinfo('Extra-axonal', indent_lvl=2, indent_char='-', with_progress=True)
with ProgressBar(disable=self.verbose < 3, hide_on_exit=True, subinfo=True) as pbar:
niiEC_img = np.zeros( self.get_config('dim'), dtype=np.float32 )
if len(self.KERNELS['wmh']) > 0 :
offset = nF * self.KERNELS['wmr'].shape[0]
tmp = x[offset:offset+nE*len(self.KERNELS['wmh'])].reshape( (-1,nE) ).sum( axis=0 )
xv = np.bincount( self.DICTIONARY['EC']['v'], weights=tmp, minlength=nV ).astype(np.float32)
if self.contribution_mask is not None:
niiEC_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'] ] = xv[self.contribution_voxels]
else:
niiEC_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'] ] = xv
niiEC_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'] ] = xv

logger.subinfo('Isotropic ', indent_lvl=2, indent_char='-', with_progress=True)
with ProgressBar(disable=self.verbose < 3, hide_on_exit=True, subinfo=True) as pbar:
Expand All @@ -1582,10 +1586,7 @@ cdef class Evaluation :
offset_iso = offset + self.DICTIONARY['ISO']['nV']
tmp = x[offset:offset_iso].reshape( (-1,self.DICTIONARY['ISO']['nV']) ).sum( axis=0 )
xv = np.bincount( self.DICTIONARY['ISO']['v'], weights=tmp, minlength=nV ).astype(np.float32)
if self.contribution_mask is not None:
niiISO_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'] ] = xv[self.contribution_voxels]
else:
niiISO_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'] ] = xv
niiISO_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'] ] = xv

if self.get_config('doNormalizeMaps') :
niiIC = nibabel.Nifti1Image( niiIC_img / ( niiIC_img + niiEC_img + niiISO_img + 1e-16), affine, header=niiMAP_hdr )
Expand Down Expand Up @@ -1681,3 +1682,4 @@ cdef class Evaluation :
self.niiDWI_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'], : ] = y_mea

logger.info( f'[ {format_time(time.time() - tic)} ]' )

Loading

0 comments on commit d4613ec

Please sign in to comment.