From 4a206ad725d6a8e76e1f454feb94850b5c9b6279 Mon Sep 17 00:00:00 2001 From: wasserth Date: Fri, 9 Feb 2024 17:21:00 +0100 Subject: [PATCH] make nifti multilabel return value also contain postprocessing and correct header --- tests/test_end_to_end.py | 8 +++---- tests/tests_os.py | 10 +++++++- totalsegmentator/nnunet.py | 44 ++++++++++++++++++---------------- totalsegmentator/python_api.py | 11 +++++++-- 4 files changed, 46 insertions(+), 27 deletions(-) diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index f56cb7647..a04a95f47 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -52,6 +52,10 @@ def test_prediction_liver_roi_subset(self): images_equal = dice > 0.99 self.assertTrue(images_equal, f"roi subset prediction not correct (dice: {dice:.6f})") + def test_preview(self): + preview_exists = os.path.exists("tests/unittest_prediction_fast/preview_total.png") + self.assertTrue(preview_exists, "Preview was not generated") + def test_prediction_fast(self): for roi in ["liver", "vertebrae_L1"]: img_ref = nib.load(f"tests/reference_files/example_seg_fast/{roi}.nii.gz").get_fdata() @@ -60,10 +64,6 @@ def test_prediction_fast(self): images_equal = dice > 0.99 self.assertTrue(images_equal, f"{roi} fast prediction not correct (dice: {dice:.6f})") - def test_preview(self): - preview_exists = os.path.exists("tests/unittest_prediction_fast/preview_total.png") - self.assertTrue(preview_exists, "Preview was not generated") - def test_prediction_multilabel_fast(self): img_ref = nib.load("tests/reference_files/example_seg_fast.nii.gz").get_fdata() img_new = nib.load("tests/unittest_prediction_fast.nii.gz").get_fdata() diff --git a/tests/tests_os.py b/tests/tests_os.py index b691af01e..605759dd5 100755 --- a/tests/tests_os.py +++ b/tests/tests_os.py @@ -19,13 +19,21 @@ def run_tests_and_exit_on_failure(): shutil.rmtree("tests/unittest_prediction_fast") if r != 0: sys.exit("Test failed: test_prediction_fast") - # Test python api - nifti input + # Test python api 1 - nifti input, filepath output input_img = nib.load("tests/reference_files/example_ct_sm.nii.gz") totalsegmentator(input_img, "tests/unittest_prediction_fast", fast=True, device="cpu") r = pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_fast"]) shutil.rmtree("tests/unittest_prediction_fast") if r != 0: sys.exit("Test failed: test_prediction_fast with Nifti input") + # Test python api 2 - nifti input, nifti output + input_img = nib.load("tests/reference_files/example_ct_sm.nii.gz") + output_img = totalsegmentator(input_img, None, fast=True, device="cpu") + nib.save(output_img, "tests/unittest_prediction_fast.nii.gz") + r = pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_multilabel_fast"]) + os.remove("tests/unittest_prediction_fast.nii.gz") + if r != 0: sys.exit("Test failed: test_prediction_fast with Nifti input and output") + # Test terminal # Test organ predictions - fast - multilabel # makes correct path for windows and linux. Only required for terminal call. Within python diff --git a/totalsegmentator/nnunet.py b/totalsegmentator/nnunet.py index 8e138efe0..74683cbd1 100644 --- a/totalsegmentator/nnunet.py +++ b/totalsegmentator/nnunet.py @@ -515,6 +515,27 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_ if save_binary: img_data = (img_data > 0).astype(np.uint8) + # Reorder labels if needed + if v1_order and task_name == "total": + img_data = reorder_multilabel_like_v1(img_data, class_map["total"], class_map["total_v1"]) + label_map = class_map["total_v1"] + else: + label_map = class_map[task_name] + + # Keep only voxel values corresponding to the roi_subset + if roi_subset is not None: + label_map = {k: v for k, v in label_map.items() if v in roi_subset} + img_data *= np.isin(img_data, list(label_map.keys())) + + # Prepare output nifti + # Copy header to make output header exactly the same as input. But change dtype otherwise it will be + # float or int and therefore the masks will need a lot more space. + # (infos on header: https://nipy.org/nibabel/nifti_images.html) + new_header = img_in_orig.header.copy() + new_header.set_data_dtype(np.uint8) + img_out = nib.Nifti1Image(img_data, img_pred.affine, new_header) + img_out = add_label_map_to_nifti(img_out, class_map[task_name]) + if file_out is not None and skip_saving is False: if not quiet: print("Saving segmentations...") @@ -527,29 +548,13 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_ file_out.mkdir(exist_ok=True, parents=True) save_mask_as_rtstruct(img_data, selected_classes, file_in_dcm, file_out / "segmentations.dcm") else: - # Copy header to make output header exactly the same as input. But change dtype otherwise it will be - # float or int and therefore the masks will need a lot more space. - # (infos on header: https://nipy.org/nibabel/nifti_images.html) - new_header = img_in_orig.header.copy() - new_header.set_data_dtype(np.uint8) - st = time.time() if multilabel_image: file_out.parent.mkdir(exist_ok=True, parents=True) else: file_out.mkdir(exist_ok=True, parents=True) if multilabel_image: - if v1_order and task_name == "total": - img_data = reorder_multilabel_like_v1(img_data, class_map["total"], class_map["total_v1"]) - label_map = class_map["total_v1"] - else: - label_map = class_map[task_name] - # Keep only voxel values corresponding to the roi_subset - if roi_subset is not None: - label_map = {k: v for k, v in label_map.items() if v in roi_subset} - img_data *= np.isin(img_data, list(label_map.keys())) - img_out = nib.Nifti1Image(img_data, img_pred.affine, new_header) - save_multilabel_nifti(img_out, file_out, label_map) + nib.save(img_out, file_out) if nora_tag != "None": subprocess.call(f"/opt/nora/src/node/nora -p {nora_tag} --add {file_out} --addtag atlas", shell=True) else: # save each class as a separate binary image @@ -609,6 +614,5 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_ skin = extract_skin(img_in_orig, nib.load(file_out / "body.nii.gz")) nib.save(skin, file_out / "skin.nii.gz") - seg_img = nib.Nifti1Image(img_data, img_pred.affine) - seg_img = add_label_map_to_nifti(seg_img, class_map[task_name]) - return seg_img, img_in_orig + + return img_out, img_in_orig diff --git a/totalsegmentator/python_api.py b/totalsegmentator/python_api.py index f2f6f0f19..2369b25d1 100644 --- a/totalsegmentator/python_api.py +++ b/totalsegmentator/python_api.py @@ -40,7 +40,7 @@ def show_license_info(): sys.exit(1) -def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Path], ml=False, nr_thr_resamp=1, nr_thr_saving=6, +def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Path, None], ml=False, nr_thr_resamp=1, nr_thr_saving=6, fast=False, nora_tag="None", preview=False, task="total", roi_subset=None, statistics=False, radiomics=False, crop_path=None, body_seg=False, force_split=False, output_type="nifti", quiet=False, verbose=False, test=0, @@ -52,10 +52,17 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa For explanation of the arguments see description of command line arguments in bin/TotalSegmentator. + + Return: multilabel Nifti1Image """ if not isinstance(input, Nifti1Image): input = Path(input) - output = Path(output) + + if output is not None: + output = Path(output) + else: + if statistics or radiomics: + raise ValueError("Output path is required for statistics and radiomics.") nora_tag = "None" if nora_tag is None else nora_tag