From 98f70c36f17fb495e98cc2bbaf3bcd850fe3e6cb Mon Sep 17 00:00:00 2001 From: Philippe Karan Date: Fri, 1 Mar 2024 08:02:21 -0500 Subject: [PATCH] Adding msmt support and test --- scripts/scil_frf_mean.py | 18 ++++++++++++++---- scripts/tests/test_frf_mean.py | 25 ++++++++++++++++++++----- 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/scripts/scil_frf_mean.py b/scripts/scil_frf_mean.py index 2b62f62d7..86ced4d2d 100755 --- a/scripts/scil_frf_mean.py +++ b/scripts/scil_frf_mean.py @@ -5,6 +5,10 @@ Compute the mean Fiber Response Function from a set of individually computed Response Functions. +The FRF files are obtained from scil_frf_ssst.py, scil_frf_msmt.py in the +case of multi-shell data or scil_frf_memsmt.py in the case of multi-encoding +multi-shell data. + Formerly: scil_compute_mean_frf.py """ @@ -43,14 +47,20 @@ def main(): assert_inputs_exist(parser, args.frf_files) assert_outputs_exist(parser, args, args.mean_frf) - all_frfs = np.zeros((len(args.frf_files), 4)) + frf_shape = np.loadtxt(args.frf_files[0]).shape + all_frfs = np.zeros((len(args.frf_files),) + frf_shape) for idx, frf_file in enumerate(args.frf_files): frf = np.loadtxt(frf_file) - if not frf.shape[0] == 4: - raise ValueError('FRF file {} did not contain 4 elements. Invalid ' - 'or deprecated FRF format'.format(frf_file)) + if not frf.shape[-1] == 4: + raise ValueError('FRF file {} did not contain 4 elements per ' + 'line. Invalid or deprecated FRF format.' + .format(frf_file)) + + if not frf.shape == frf_shape: + raise ValueError('FRF file {} did not match the format of ' + 'previous files.'.format(frf_file)) all_frfs[idx] = frf diff --git a/scripts/tests/test_frf_mean.py b/scripts/tests/test_frf_mean.py index 25eaaf566..38fb77923 100644 --- a/scripts/tests/test_frf_mean.py +++ b/scripts/tests/test_frf_mean.py @@ -7,7 +7,8 @@ from scilpy.io.fetcher import fetch_data, get_home, get_testing_files_dict # If they already exist, this only takes 5 seconds (check md5sum) -fetch_data(get_testing_files_dict(), keys=['processing.zip']) +fetch_data(get_testing_files_dict(), keys=['processing.zip', + 'commit_amico.zip']) tmp_dir = tempfile.TemporaryDirectory() @@ -16,9 +17,23 @@ def test_help_option(script_runner): assert ret.success -def test_execution_processing(script_runner): +def test_execution_processing_ssst(script_runner): os.chdir(os.path.expanduser(tmp_dir.name)) - in_frf = os.path.join(get_home(), 'processing', - 'frf.txt') - ret = script_runner.run('scil_frf_mean.py', in_frf, 'mfrf.txt') + in_frf = os.path.join(get_home(), 'processing', 'frf.txt') + ret = script_runner.run('scil_frf_mean.py', in_frf, in_frf, 'mfrf1.txt') assert ret.success + + +def test_execution_processing_msmt(script_runner): + os.chdir(os.path.expanduser(tmp_dir.name)) + in_frf = os.path.join(get_home(), 'commit_amico', 'wm_frf.txt') + ret = script_runner.run('scil_frf_mean.py', in_frf, in_frf, 'mfrf2.txt') + assert ret.success + + +def test_execution_processing_bad_input(script_runner): + os.chdir(os.path.expanduser(tmp_dir.name)) + in_wm_frf = os.path.join(get_home(), 'commit_amico', 'wm_frf.txt') + in_frf = os.path.join(get_home(), 'processing', 'frf.txt') + ret = script_runner.run('scil_frf_mean.py', in_wm_frf, in_frf, 'mfrf3.txt') + assert not ret.success