Skip to content

Commit

Permalink
enforce merge
Browse files Browse the repository at this point in the history
  • Loading branch information
nikitas-k committed Nov 7, 2024
1 parent 15b3807 commit 5b11ea7
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 126 deletions.
117 changes: 76 additions & 41 deletions eigenstrapping/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

from .geometry import (calc_surface_eigenmodes,
get_tkrvox2ras,
make_tetra)
make_tetra,
spin_single,
gen_eigensamples,
truncate_emodes)

import copy

Expand Down Expand Up @@ -311,7 +314,10 @@ def __init__(self, data, surface=None, evals=None, emodes=None, num_modes=200,
save_surface=False, seed=None, decomp_method='matrix',
medial=None, randomize=False, resample=False, n_jobs=1,
use_cholmod=False, permute=False, add_res=False,
save_rotations=False, parcellation=None, normalize=True):
truncate_modes=False, ret_fwhm=False, truncate_args=None,
save_rotations=False, parcellation=None,
normalize=True, use_nn=False, nn_method='original',
eigen_perms=None):

# initialization of variables
if surface is None:
Expand Down Expand Up @@ -347,6 +353,11 @@ def __init__(self, data, surface=None, evals=None, emodes=None, num_modes=200,
self.save_rotations = save_rotations
self.parcellation = parcellation
self.normalize = normalize
self.truncate_modes = truncate_modes
self.use_nn = use_nn
self.nn_method = nn_method
self.eigen_perms = eigen_perms
self.truncate_args = truncate_args

self._lm = LinearRegression(fit_intercept=True)
self.add_res = add_res
Expand Down Expand Up @@ -417,13 +428,23 @@ def __init__(self, data, surface=None, evals=None, emodes=None, num_modes=200,
if self.emodes.shape[1] != self.evals.shape[0]:
raise ValueError("There must be as many eigenmodes as there are eigenvalues")

# actually the indices that AREN'T medial wall
self.medial_wall = self.medial_wall.astype(np.bool_)

self.n_vertices = self.emodes.shape[0]

# actually ARE the indices of medial wall
self.medial_mask = np.logical_not(self.medial_wall)

self.n_vertices = self.emodes.shape[0]

self.data_no_mwall = self.data # deepcopy original data so it's not modified
self.data_no_mwall = self.data[self.medial_wall]

# truncate eigenmodes at nearest eigengroup
if self.truncate_modes:
if self.surface is None:
raise ValueError('in order to compute kernel density estimate for data re:\n'
'number of modes (groups) to use, surface geometry must be given.')
self.fwhm, self.emodes, self.evals = truncate_emodes(self.data_no_mwall, self.surface, self.emodes, self.evals, mask=self.medial_wall, seed=self._rs, ret_fwhm=self.ret_fwhm, **self.truncate_args)

self._emodes = copy.deepcopy(self.emodes)
self._emodes = self._emodes[self.medial_wall]

Expand Down Expand Up @@ -459,9 +480,11 @@ def __call__(self, n=1):
``self.medial_wall``.
"""
if self.save_rotations:
self.rotations = np.zeros((n, *self.emodes.shape))
rs = self._rs.randint(np.iinfo(np.int32).max, size=n)
if self.use_nn is True or self.eigen_perms is not None:
if self.eigen_perms is None:
self.eigen_perms = gen_eigensamples(self.emodes, self.evals, mask=self.medial_wall, num_modes=self.num_modes,
n_rotate=n, method=self.nn_method, n_jobs=self.n_jobs)
surrs = np.row_stack(
Parallel(self.n_jobs)(
delayed(self._call_method)(rs=i) for i in rs
Expand Down Expand Up @@ -497,33 +520,42 @@ def generate(self, output_modes=False):
mask = self.medial_wall
coeffs = copy.deepcopy(self.coeffs)
reconstructed_data = copy.deepcopy(self.reconstructed_data)
residuals = copy.deepcopy(self.residuals)

# initialize the new modes
new_modes = np.zeros_like(emodes)

# resample the hypersphere (except for groups 1 and 2)
for idx in range(len(groups)):
if idx > 100:
gen = True
group_modes = emodes[:, groups[idx]]
group_evals = evals[groups[idx]]
p = group_modes
# else, transform to spheroid and index the angles properly
if self.normalize:
p = transform_to_spheroid(group_evals, group_modes)

p_rot = self.rotate_modes(p, gen=True)

# transform back to ellipsoid
if self.normalize:
p_rot = transform_to_ellipsoid(group_evals, p_rot)
#residuals = copy.deepcopy(self.residuals)

if self.eigen_perms is None:
# initialize the new modes
new_modes = np.zeros_like(emodes)
# resample the hypersphere (except for groups 1 and 2)
for idx in range(len(groups)):
if idx > 100:
gen = True
group_modes = emodes[:, groups[idx]]
group_evals = evals[groups[idx]]
p = group_modes
# else, transform to spheroid and index the angles properly
if self.normalize:
p = transform_to_spheroid(group_evals, group_modes)

p_rot = self.rotate_modes(p, gen=True)

# transform back to ellipsoid
if self.normalize:
p_rot = transform_to_ellipsoid(group_evals, p_rot)

new_modes[:, groups[idx]] = p_rot

new_modes[:, groups[idx]] = p_rot

if output_modes:
return new_modes
if output_modes:
return new_modes

else:
new_modes = spin_single(
self.emodes,
evals,
mask=mask,
eigen_perms=self.eigen_perms[self._rs.randint(0, len(self.eigen_perms))],
num_modes=self.num_modes,
return_masked=True,
)
# matrix multiply the estimated coefficients by the new modes
surrogate = np.zeros_like(self.data)*np.nan # original data

Expand All @@ -538,13 +570,6 @@ def generate(self, output_modes=False):
# Mask the data and surrogate_data excluding the medial wall
surr_no_mwall = copy.deepcopy(surrogate)
surr_no_mwall = surr_no_mwall[mask]

# now add the residuals of the original data
if self.permute:
surr_no_mwall += self._rs.permutation(residuals)

if self.add_res:
surr_no_mwall += residuals

# if self.resample:
# # Get the rank ordered indices
Expand All @@ -558,9 +583,19 @@ def generate(self, output_modes=False):
sorted_map = np.sort(data)
ii = np.argsort(surr_no_mwall)
np.put(surr_no_mwall, ii, sorted_map)
else: # demean
surr_no_mwall = surr_no_mwall - np.nanmean(surr_no_mwall)
else: # force match the minimum and maximum
sf = (np.nanmax(data) - np.nanmin(data)) / (np.nanmax(surr_no_mwall) - np.nanmin(surr_no_mwall))
shift = np.nanmin(data) - sf * np.nanmin(surr_no_mwall)
surr_no_mwall = surr_no_mwall * sf + shift

# now add the residuals of the original data
residuals = data - surr_no_mwall
if self.permute:
surr_no_mwall += self._rs.permutation(residuals)

if self.add_res:
surr_no_mwall += residuals

# else: # force match the minima
# indices = np.nonzero(surr_no_mwall)[0] # Indices where s is non-zero

Expand Down
Loading

0 comments on commit 5b11ea7

Please sign in to comment.