Skip to content

Commit

Permalink
Revert but add assert_headers_compatible everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed Mar 12, 2024
1 parent 2635c98 commit 36b97ca
Show file tree
Hide file tree
Showing 44 changed files with 249 additions and 222 deletions.
5 changes: 2 additions & 3 deletions scilpy/image/volume_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def compute_snr(dwi, bval, bvec, b0_thr, mask,
"""
data = dwi.get_fdata(dtype=np.float32)
affine = dwi.affine
mask = get_data_as_mask(mask, dtype=bool, ref_shape=data.shape)
mask = get_data_as_mask(mask, dtype=bool)

if split_shells:
centroids, shell_indices = identify_shells(bval, tol=40.0,
Expand Down Expand Up @@ -343,8 +343,7 @@ def compute_snr(dwi, bval, bvec, b0_thr, mask,
nib.save(nib.Nifti1Image(noise_mask, affine),
basename + '_noise_mask.nii.gz')
elif noise_mask:
noise_mask = get_data_as_mask(noise_mask, dtype=bool,
ref_shape=data.shape).squeeze()
noise_mask = get_data_as_mask(noise_mask, dtype=bool).squeeze()
elif noise_map:
data_noisemap = noise_map.get_fdata(dtype=np.float32)

Expand Down
17 changes: 1 addition & 16 deletions scilpy/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def assert_same_resolution(images):
raise Exception("Images are not of the same resolution/affine")


def get_data_as_mask(mask_img, dtype=np.uint8, ref_img=None, ref_shape=None):
def get_data_as_mask(mask_img, dtype=np.uint8):
"""
Get data as mask (force type np.uint8 or bool), check data type before
casting.
Expand All @@ -99,11 +99,6 @@ def get_data_as_mask(mask_img, dtype=np.uint8, ref_img=None, ref_shape=None):
Mask image.
dtype: type or str
Data type for the output data (default: uint8)
ref_img: nibabel.nitfi1.Nifti1Image
Reference image. If given, mask must be compatible.
ref_shape: shape
Alternative to ref_image. The shape of the associated data. If given,
verifies that the mask shape fits with the ref_shape.
Return
------
Expand All @@ -116,16 +111,6 @@ def get_data_as_mask(mask_img, dtype=np.uint8, ref_img=None, ref_shape=None):
raise IOError('Output data type must be uint8 or bool. '
'Current data type is {}.'.format(dtype))

# Verify that shape is ok
if ref_img is not None:
if not is_header_compatible(mask_img, ref_img):
raise IOError("Mask is not of the same resolution/affine as data.")
elif ref_shape is not None:
if not np.array_equal(mask_img.shape, ref_shape[0:3]):
raise IOError("Mask is not the same shape as data. Got {}, and "
"data is of shape {}"
.format(mask_img.shape, ref_shape))

# Verify that loaded datatype is ok
curr_type = mask_img.get_data_dtype().type
basename = os.path.basename(mask_img.get_filename())
Expand Down
26 changes: 20 additions & 6 deletions scilpy/io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,9 +702,8 @@ def assert_roi_radii_format(parser):
return roi_radii


def is_header_compatible_multiple_files(parser, list_files,
verbose_all_compatible=False,
reference=None):
def assert_headers_compatible(parser, required, optional=None,
verbose_all_compatible=False, reference=None):
"""
Verifies the compatibility between the first item in list_files
and the remaining files in list.
Expand All @@ -713,17 +712,32 @@ def is_header_compatible_multiple_files(parser, list_files,
---------
parser: argument parser
Will raise an error if a file is not compatible.
list_files: List[str]
required: List[str]
List of files to test
optional: List[str or None]
List of files. May contain None, they will be discarted.
verbose_all_compatible: bool
If true will print a message when everything is okay
reference: str
Reference for any .tck passed in `list_files`
"""
all_valid = True

# Gather "headers" for all files to compare against
# eachother later
# Format required and optional to lists if a single filename was sent.
if isinstance(required, str):
required = [required]
if optional is None:
optional = []
elif isinstance(optional, str):
optional = [optional]
else:
optional = [f for f in optional if f is not None]
list_files = required + optional

if len(list_files) <= 1:
return

# Gather "headers" for all files to compare against each other later
headers = []
for filepath in list_files:
_, in_extension = split_name_with_nii(filepath)
Expand Down
2 changes: 1 addition & 1 deletion scilpy/reconst/mti.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def threshold_map(computed_map, in_mask,
# Load and apply sum of T1 probability maps on myelin maps
if in_mask is not None:
mask_image = nib.load(in_mask)
mask_data = get_data_as_mask(mask_image, ref_shape=computed_map.shape)
mask_data = get_data_as_mask(mask_image)
computed_map[np.where(mask_data == 0)] = 0

# Apply threshold based on combination of specific contrast maps
Expand Down
4 changes: 2 additions & 2 deletions scilpy/segment/tractogram_from_roi.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def _extract_vb_one_bundle(
mask_1_img = nib.load(head_filename)
mask_2_img = nib.load(tail_filename)
mask_1 = get_data_as_mask(mask_1_img)
mask_2 = get_data_as_mask(mask_2_img, ref_img=mask_1_img)
mask_2 = get_data_as_mask(mask_2_img)

if dilate_endpoints:
mask_1 = binary_dilation(mask_1, iterations=dilate_endpoints)
Expand Down Expand Up @@ -501,7 +501,7 @@ def _extract_ib_one_bundle(sft, mask_1_filename, mask_2_filename,
mask_1_img = nib.load(mask_1_filename)
mask_2_img = nib.load(mask_2_filename)
mask_1 = get_data_as_mask(mask_1_img)
mask_2 = get_data_as_mask(mask_2_img, ref_img=mask_1_img)
mask_2 = get_data_as_mask(mask_2_img)

if dilate_endpoints:
mask_1 = binary_dilation(mask_1, iterations=dilate_endpoints)
Expand Down
6 changes: 4 additions & 2 deletions scripts/scil_aodf_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
add_sh_basis_args,
add_verbose_arg,
assert_inputs_exist,
assert_headers_compatible,
assert_outputs_exist,
add_overwrite_arg,
parse_sh_basis_arg)
Expand Down Expand Up @@ -142,6 +143,8 @@ def main():

assert_inputs_exist(parser, inputs)
assert_outputs_exist(parser, args, arglist)
if args.mask:
assert_headers_compatible(parser, inputs)

# Loading
sh_img = nib.load(args.in_sh)
Expand All @@ -155,8 +158,7 @@ def main():
parser.error('Invalid SH image. A full SH basis is expected.')

if args.mask:
mask = get_data_as_mask(nib.load(args.mask), dtype=bool,
ref_img=sh_img)
mask = get_data_as_mask(nib.load(args.mask), dtype=bool)
else:
mask = np.sum(np.abs(sh), axis=-1) > 0

Expand Down
10 changes: 6 additions & 4 deletions scripts/scil_bingham_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
import time
import argparse
import logging
from scilpy.io.image import get_data_as_mask

from scilpy.io.image import get_data_as_mask
from scilpy.io.utils import (add_overwrite_arg, add_processes_arg,
add_verbose_arg, assert_inputs_exist,
assert_outputs_exist, validate_nbr_processes)
assert_outputs_exist, validate_nbr_processes,
assert_headers_compatible)
from scilpy.reconst.bingham import (compute_fiber_density,
compute_fiber_spread,
compute_fiber_fraction)
Expand Down Expand Up @@ -94,11 +95,12 @@ def main():
outputs = [args.out_fd, args.out_fs, args.out_ff]
assert_inputs_exist(parser, args.in_bingham, args.mask)
assert_outputs_exist(parser, args, [], optional=outputs)
assert_headers_compatible(parser, args.in_bingham, args.mask)

bingham_im = nib.load(args.in_bingham)
bingham = bingham_im.get_fdata()
mask = get_data_as_mask(nib.load(args.mask), dtype=bool,
ref_img=bingham_im) if args.mask else None
mask = get_data_as_mask(nib.load(args.mask),
dtype=bool) if args.mask else None

nbr_processes = validate_nbr_processes(parser, args)

Expand Down
11 changes: 6 additions & 5 deletions scripts/scil_btensor_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist,
assert_outputs_exist, add_processes_arg,
add_verbose_arg, add_skip_b0_check_arg,
add_tolerance_arg)
add_tolerance_arg,
assert_headers_compatible)
from scilpy.reconst.divide import fit_gamma, gamma_fit2metrics


Expand Down Expand Up @@ -158,9 +159,10 @@ def main():
parser.error('When using --not_all, you need to specify at least '
'one file to output.')

assert_inputs_exist(parser, [],
optional=args.in_dwis + args.in_bvals + args.in_bvecs)
assert_inputs_exist(parser, args.in_dwis + args.in_bvals + args.in_bvecs,
args.mask)
assert_outputs_exist(parser, args, arglist)
assert_headers_compatible(parser, args.in_dwis, args.mask)

if args.op and not args.fa:
parser.error('Computation of the OP requires a precomputed '
Expand Down Expand Up @@ -199,8 +201,7 @@ def main():
'No mask provided. The fit might not converge due to noise. '
'Please provide a mask if it is the case.')
else:
mask = get_data_as_mask(nib.load(args.mask), dtype=bool,
ref_shape=data.shape)
mask = get_data_as_mask(nib.load(args.mask), dtype=bool)

if args.fa is not None:
vol = nib.load(args.fa)
Expand Down
6 changes: 4 additions & 2 deletions scripts/scil_bundle_generate_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
add_verbose_arg,
assert_inputs_exist,
assert_outputs_exist,
parse_sh_basis_arg)
parse_sh_basis_arg,
assert_headers_compatible)
from scilpy.reconst.utils import find_order_from_nb_coeff
from scilpy.tractanalysis.todi import TrackOrientationDensityImaging

Expand Down Expand Up @@ -78,6 +79,7 @@ def main():

required = [args.in_bundle, args.in_fodf, args.in_mask]
assert_inputs_exist(parser, required)
assert_headers_compatible(parser, required)

out_efod = os.path.join(args.out_dir,
'{0}efod.nii.gz'.format(args.out_prefix))
Expand All @@ -100,7 +102,7 @@ def main():
sh_order = find_order_from_nb_coeff(sh_shape)
sh_basis, is_legacy = parse_sh_basis_arg(args)
img_mask = nib.load(args.in_mask)
mask_data = get_data_as_mask(img_mask, ref_img=img_sh)
mask_data = get_data_as_mask(img_mask)

sft = load_tractogram_with_reference(parser, args, args.in_bundle)
sft.to_vox()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def main():
gs_binary_3d = get_data_as_mask(nib.load(args.voxels_measures[0]))
gs_binary_3d[gs_binary_3d > 0] = 1
tracking_mask_data = get_data_as_mask(
nib.load(args.voxels_measures[1]), ref_shape=gs_binary_3d.shape)
nib.load(args.voxels_measures[1]))

if nbr_cpu == 1:
voxels_dict = []
Expand Down
8 changes: 5 additions & 3 deletions scripts/scil_denoising_nlmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
add_overwrite_arg,
add_verbose_arg,
assert_inputs_exist,
assert_outputs_exist)
assert_outputs_exist,
assert_headers_compatible)


def _build_arg_parser():
Expand Down Expand Up @@ -77,8 +78,9 @@ def main():
args = parser.parse_args()
logging.getLogger().setLevel(logging.getLevelName(args.verbose))

assert_inputs_exist(parser, args.in_image)
assert_inputs_exist(parser, args.in_image, args.mask)
assert_outputs_exist(parser, args, args.out_image, args.logfile)
assert_headers_compatible(parser, args.in_image, args.mask)

if args.logfile is not None:
logging.getLogger().addHandler(logging.FileHandler(args.logfile,
Expand All @@ -93,7 +95,7 @@ def main():
else:
mask[data > 0] = 1
else:
mask = get_data_as_mask(nib.load(args.mask), dtype=bool, ref_img=vol)
mask = get_data_as_mask(nib.load(args.mask), dtype=bool)

sigma = args.sigma

Expand Down
8 changes: 5 additions & 3 deletions scripts/scil_dki_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@
from scilpy.io.image import get_data_as_mask
from scilpy.io.utils import (add_overwrite_arg, add_skip_b0_check_arg,
add_verbose_arg, assert_inputs_exist,
assert_outputs_exist, add_tolerance_arg, )
assert_outputs_exist, add_tolerance_arg,
assert_headers_compatible, )
from scilpy.gradients.bvec_bval_tools import (check_b0_threshold,
is_normalized_bvecs,
identify_shells,
Expand Down Expand Up @@ -174,13 +175,14 @@ def main():
assert_inputs_exist(
parser, [args.in_dwi, args.in_bval, args.in_bvec], args.mask)
assert_outputs_exist(parser, args, outputs)
assert_headers_compatible(parser, args.in_dwi, args.mask)

# Loading
img = nib.load(args.in_dwi)
data = img.get_fdata(dtype=np.float32)
affine = img.affine
mask = get_data_as_mask(nib.load(args.mask), dtype=bool,
ref_img=img) if args.mask else None
mask = get_data_as_mask(nib.load(args.mask),
dtype=bool) if args.mask else None

bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec)
if not is_normalized_bvecs(bvecs):
Expand Down
9 changes: 6 additions & 3 deletions scripts/scil_dti_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from dipy.core.gradients import gradient_table
import dipy.denoise.noise_estimate as ne
from dipy.io.gradients import read_bvals_bvecs
from dipy.io.utils import is_header_compatible
from dipy.reconst.dti import (TensorModel, color_fa, fractional_anisotropy,
geodesic_anisotropy, mean_diffusivity,
axial_diffusivity, norm,
Expand All @@ -46,7 +47,8 @@
from scilpy.io.image import get_data_as_mask
from scilpy.io.utils import (add_b0_thresh_arg, add_overwrite_arg,
add_skip_b0_check_arg, add_verbose_arg,
assert_inputs_exist, assert_outputs_exist)
assert_inputs_exist, assert_outputs_exist,
assert_headers_compatible)
from scilpy.io.tensor import convert_tensor_from_dipy_format, \
supported_tensor_formats, tensor_format_description
from scilpy.gradients.bvec_bval_tools import (check_b0_threshold,
Expand Down Expand Up @@ -250,13 +252,14 @@ def main():
assert_inputs_exist(
parser, [args.in_dwi, args.in_bval, args.in_bvec], args.mask)
assert_outputs_exist(parser, args, outputs)
assert_headers_compatible(parser, args.in_dwi, args.mask)

# Loading
img = nib.load(args.in_dwi)
data = img.get_fdata(dtype=np.float32)
affine = img.affine
mask = get_data_as_mask(nib.load(args.mask), dtype=bool,
ref_img=img) if args.mask else None
mask = get_data_as_mask(nib.load(args.mask),
dtype=bool) if args.mask else None

logging.info('Tensor estimation with the {} method...'.format(args.method))
bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec)
Expand Down
6 changes: 4 additions & 2 deletions scripts/scil_dwi_apply_bias_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from scilpy.io.utils import (add_overwrite_arg,
add_verbose_arg,
assert_inputs_exist,
assert_outputs_exist)
assert_outputs_exist, assert_headers_compatible)


def _build_arg_parser():
Expand Down Expand Up @@ -53,6 +53,8 @@ def main():

assert_inputs_exist(parser, [args.in_dwi, args.in_bias_field], args.mask)
assert_outputs_exist(parser, args, args.out_name)
assert_headers_compatible(parser, [args.in_dwi, args.in_bias_field],
args.mask)

dwi_img = nib.load(args.in_dwi)
dwi_data = dwi_img.get_fdata(dtype=np.float32)
Expand All @@ -61,7 +63,7 @@ def main():
bias_field_data = bias_field_img.get_fdata(dtype=np.float32)

if args.mask:
mask_data = get_data_as_mask(nib.load(args.mask), ref_img=dwi_img)
mask_data = get_data_as_mask(nib.load(args.mask))
else:
mask_data = np.average(dwi_data, axis=-1) != 0

Expand Down
Loading

0 comments on commit 36b97ca

Please sign in to comment.