Skip to content

Commit

Permalink
make nifti multilabel return value also contain postprocessing and co…
Browse files Browse the repository at this point in the history
…rrect header
  • Loading branch information
wasserth committed Feb 9, 2024
1 parent 3eec586 commit 4a206ad
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 27 deletions.
8 changes: 4 additions & 4 deletions tests/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def test_prediction_liver_roi_subset(self):
images_equal = dice > 0.99
self.assertTrue(images_equal, f"roi subset prediction not correct (dice: {dice:.6f})")

def test_preview(self):
preview_exists = os.path.exists("tests/unittest_prediction_fast/preview_total.png")
self.assertTrue(preview_exists, "Preview was not generated")

def test_prediction_fast(self):
for roi in ["liver", "vertebrae_L1"]:
img_ref = nib.load(f"tests/reference_files/example_seg_fast/{roi}.nii.gz").get_fdata()
Expand All @@ -60,10 +64,6 @@ def test_prediction_fast(self):
images_equal = dice > 0.99
self.assertTrue(images_equal, f"{roi} fast prediction not correct (dice: {dice:.6f})")

def test_preview(self):
preview_exists = os.path.exists("tests/unittest_prediction_fast/preview_total.png")
self.assertTrue(preview_exists, "Preview was not generated")

def test_prediction_multilabel_fast(self):
img_ref = nib.load("tests/reference_files/example_seg_fast.nii.gz").get_fdata()
img_new = nib.load("tests/unittest_prediction_fast.nii.gz").get_fdata()
Expand Down
10 changes: 9 additions & 1 deletion tests/tests_os.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,21 @@ def run_tests_and_exit_on_failure():
shutil.rmtree("tests/unittest_prediction_fast")
if r != 0: sys.exit("Test failed: test_prediction_fast")

# Test python api - nifti input
# Test python api 1 - nifti input, filepath output
input_img = nib.load("tests/reference_files/example_ct_sm.nii.gz")
totalsegmentator(input_img, "tests/unittest_prediction_fast", fast=True, device="cpu")
r = pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_fast"])
shutil.rmtree("tests/unittest_prediction_fast")
if r != 0: sys.exit("Test failed: test_prediction_fast with Nifti input")

# Test python api 2 - nifti input, nifti output
input_img = nib.load("tests/reference_files/example_ct_sm.nii.gz")
output_img = totalsegmentator(input_img, None, fast=True, device="cpu")
nib.save(output_img, "tests/unittest_prediction_fast.nii.gz")
r = pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_multilabel_fast"])
os.remove("tests/unittest_prediction_fast.nii.gz")
if r != 0: sys.exit("Test failed: test_prediction_fast with Nifti input and output")

# Test terminal
# Test organ predictions - fast - multilabel
# makes correct path for windows and linux. Only required for terminal call. Within python
Expand Down
44 changes: 24 additions & 20 deletions totalsegmentator/nnunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,27 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_
if save_binary:
img_data = (img_data > 0).astype(np.uint8)

# Reorder labels if needed
if v1_order and task_name == "total":
img_data = reorder_multilabel_like_v1(img_data, class_map["total"], class_map["total_v1"])
label_map = class_map["total_v1"]
else:
label_map = class_map[task_name]

# Keep only voxel values corresponding to the roi_subset
if roi_subset is not None:
label_map = {k: v for k, v in label_map.items() if v in roi_subset}
img_data *= np.isin(img_data, list(label_map.keys()))

# Prepare output nifti
# Copy header to make output header exactly the same as input. But change dtype otherwise it will be
# float or int and therefore the masks will need a lot more space.
# (infos on header: https://nipy.org/nibabel/nifti_images.html)
new_header = img_in_orig.header.copy()
new_header.set_data_dtype(np.uint8)
img_out = nib.Nifti1Image(img_data, img_pred.affine, new_header)
img_out = add_label_map_to_nifti(img_out, class_map[task_name])

if file_out is not None and skip_saving is False:
if not quiet: print("Saving segmentations...")

Expand All @@ -527,29 +548,13 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_
file_out.mkdir(exist_ok=True, parents=True)
save_mask_as_rtstruct(img_data, selected_classes, file_in_dcm, file_out / "segmentations.dcm")
else:
# Copy header to make output header exactly the same as input. But change dtype otherwise it will be
# float or int and therefore the masks will need a lot more space.
# (infos on header: https://nipy.org/nibabel/nifti_images.html)
new_header = img_in_orig.header.copy()
new_header.set_data_dtype(np.uint8)

st = time.time()
if multilabel_image:
file_out.parent.mkdir(exist_ok=True, parents=True)
else:
file_out.mkdir(exist_ok=True, parents=True)
if multilabel_image:
if v1_order and task_name == "total":
img_data = reorder_multilabel_like_v1(img_data, class_map["total"], class_map["total_v1"])
label_map = class_map["total_v1"]
else:
label_map = class_map[task_name]
# Keep only voxel values corresponding to the roi_subset
if roi_subset is not None:
label_map = {k: v for k, v in label_map.items() if v in roi_subset}
img_data *= np.isin(img_data, list(label_map.keys()))
img_out = nib.Nifti1Image(img_data, img_pred.affine, new_header)
save_multilabel_nifti(img_out, file_out, label_map)
nib.save(img_out, file_out)
if nora_tag != "None":
subprocess.call(f"/opt/nora/src/node/nora -p {nora_tag} --add {file_out} --addtag atlas", shell=True)
else: # save each class as a separate binary image
Expand Down Expand Up @@ -609,6 +614,5 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_
skin = extract_skin(img_in_orig, nib.load(file_out / "body.nii.gz"))
nib.save(skin, file_out / "skin.nii.gz")

seg_img = nib.Nifti1Image(img_data, img_pred.affine)
seg_img = add_label_map_to_nifti(seg_img, class_map[task_name])
return seg_img, img_in_orig

return img_out, img_in_orig
11 changes: 9 additions & 2 deletions totalsegmentator/python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def show_license_info():
sys.exit(1)


def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Path], ml=False, nr_thr_resamp=1, nr_thr_saving=6,
def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Path, None], ml=False, nr_thr_resamp=1, nr_thr_saving=6,
fast=False, nora_tag="None", preview=False, task="total", roi_subset=None,
statistics=False, radiomics=False, crop_path=None, body_seg=False,
force_split=False, output_type="nifti", quiet=False, verbose=False, test=0,
Expand All @@ -52,10 +52,17 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa
For explanation of the arguments see description of command line
arguments in bin/TotalSegmentator.
Return: multilabel Nifti1Image
"""
if not isinstance(input, Nifti1Image):
input = Path(input)
output = Path(output)

if output is not None:
output = Path(output)
else:
if statistics or radiomics:
raise ValueError("Output path is required for statistics and radiomics.")

nora_tag = "None" if nora_tag is None else nora_tag

Expand Down

0 comments on commit 4a206ad

Please sign in to comment.