Skip to content

Commit

Permalink
Adding msmt support and test
Browse files Browse the repository at this point in the history
  • Loading branch information
karanphil committed Mar 1, 2024
1 parent 4c2848c commit 98f70c3
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 9 deletions.
18 changes: 14 additions & 4 deletions scripts/scil_frf_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
Compute the mean Fiber Response Function from a set of individually
computed Response Functions.
The FRF files are obtained from scil_frf_ssst.py, scil_frf_msmt.py in the
case of multi-shell data or scil_frf_memsmt.py in the case of multi-encoding
multi-shell data.
Formerly: scil_compute_mean_frf.py
"""

Expand Down Expand Up @@ -43,14 +47,20 @@ def main():
assert_inputs_exist(parser, args.frf_files)
assert_outputs_exist(parser, args, args.mean_frf)

all_frfs = np.zeros((len(args.frf_files), 4))
frf_shape = np.loadtxt(args.frf_files[0]).shape
all_frfs = np.zeros((len(args.frf_files),) + frf_shape)

for idx, frf_file in enumerate(args.frf_files):
frf = np.loadtxt(frf_file)

if not frf.shape[0] == 4:
raise ValueError('FRF file {} did not contain 4 elements. Invalid '
'or deprecated FRF format'.format(frf_file))
if not frf.shape[-1] == 4:
raise ValueError('FRF file {} did not contain 4 elements per '
'line. Invalid or deprecated FRF format.'
.format(frf_file))

if not frf.shape == frf_shape:
raise ValueError('FRF file {} did not match the format of '
'previous files.'.format(frf_file))

all_frfs[idx] = frf

Expand Down
25 changes: 20 additions & 5 deletions scripts/tests/test_frf_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from scilpy.io.fetcher import fetch_data, get_home, get_testing_files_dict

# If they already exist, this only takes 5 seconds (check md5sum)
fetch_data(get_testing_files_dict(), keys=['processing.zip'])
fetch_data(get_testing_files_dict(), keys=['processing.zip',
'commit_amico.zip'])
tmp_dir = tempfile.TemporaryDirectory()


Expand All @@ -16,9 +17,23 @@ def test_help_option(script_runner):
assert ret.success


def test_execution_processing(script_runner):
def test_execution_processing_ssst(script_runner):
os.chdir(os.path.expanduser(tmp_dir.name))
in_frf = os.path.join(get_home(), 'processing',
'frf.txt')
ret = script_runner.run('scil_frf_mean.py', in_frf, 'mfrf.txt')
in_frf = os.path.join(get_home(), 'processing', 'frf.txt')
ret = script_runner.run('scil_frf_mean.py', in_frf, in_frf, 'mfrf1.txt')
assert ret.success


def test_execution_processing_msmt(script_runner):
os.chdir(os.path.expanduser(tmp_dir.name))
in_frf = os.path.join(get_home(), 'commit_amico', 'wm_frf.txt')
ret = script_runner.run('scil_frf_mean.py', in_frf, in_frf, 'mfrf2.txt')
assert ret.success


def test_execution_processing_bad_input(script_runner):
os.chdir(os.path.expanduser(tmp_dir.name))
in_wm_frf = os.path.join(get_home(), 'commit_amico', 'wm_frf.txt')
in_frf = os.path.join(get_home(), 'processing', 'frf.txt')
ret = script_runner.run('scil_frf_mean.py', in_wm_frf, in_frf, 'mfrf3.txt')
assert not ret.success

0 comments on commit 98f70c3

Please sign in to comment.