Skip to content

Commit

Permalink
Add pipeline that masks the lungs and saves lung labels (#4)
Browse files Browse the repository at this point in the history
* Add pipeline that masks the lungs and saves lung labels

* Use collections.Counter() to compare lists in tests

- Solves problem with different order on hos machine and github runner
  • Loading branch information
VemundFredriksen authored Oct 27, 2023
1 parent a26c1d4 commit 429a73e
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 18 deletions.
1 change: 1 addition & 0 deletions assets/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
3. Extract the zip file into `/assets/`.
4. Run `synthlung format --dataset msd` to adjust dataset format
5. Run `synthlung seed --dataset msd` to extract tumor seeds from the dataset
6. Run `synthlung host --dataset msd` to extract lung masks from the images
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ numpy
torch
monai
tqdm
lungmask
pytest
pytest-cov
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setuptools.setup(
name="synthlung",
version="0.0.1",
version="0.1.0",
author="Vemund Fredriksen and Svein Ole Matheson Sevle",
author_email="vemund.fredriksen@hotmailcom",
description="Package for generating synthetic lung tumors",
Expand All @@ -13,7 +13,8 @@
"numpy",
"torch",
"tqdm",
"monai"
"monai",
"lungmask"
],
classifiers=[
"Programming Language :: Python :: 3",
Expand Down
18 changes: 17 additions & 1 deletion synthlung/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from synthlung.utils.tumor_isolation_pipeline import TumorCropPipeline
from synthlung.utils.dataset_formatter import MSDImageSourceFormatter, MSDGenerateJSONFormatter
from synthlung.utils.tumor_insertion_pipeline import InsertTumorPipeline
from synthlung.utils.lung_segmentation_pipeline import LungMaskPipeline, HostJsonGenerator
from lungmask import LMInferer
import json

def seed_msd():
Expand Down Expand Up @@ -31,10 +33,21 @@ def generate_randomized_tumors():

tumor_inserter(image_dict, seeds_dict)

def mask_hosts():
lung_masker = LMInferer()
host_masker = LungMaskPipeline(lung_masker)
json_file_path = "./assets/source/msd/dataset.json"
with open(json_file_path, 'r') as json_file:
image_dict = json.load(json_file)

#host_masker(image_dict)
json_generator = HostJsonGenerator('./assets/hosts/msd/')
json_generator.generate_json()

def main():
parser = argparse.ArgumentParser(description="Create your synthetic lung tumors!")

parser.add_argument("action", choices=["format", "seed", "generate"], help="Action to perform")
parser.add_argument("action", choices=["format", "seed", "host", "generate"], help="Action to perform")
parser.add_argument("--dataset", help="Dataset to format", choices=["msd"])
args = parser.parse_args()

Expand All @@ -47,5 +60,8 @@ def main():
elif args.action == "generate":
if(args.dataset == "msd"):
generate_randomized_tumors()
elif args.action == "host":
if(args.dataset == "msd"):
mask_hosts()
else:
print("Action not recognized")
15 changes: 2 additions & 13 deletions synthlung/utils/dataset_formatter.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,13 @@
import os
import shutil
import json
from abc import ABC, abstractmethod
from synthlung.utils.json_generator import JSONGenerator
from synthlung.utils.image_source_formatter import ImageSourceFormatter

NII_GZ_EXTENSION = '.nii.gz'
IMAGE_NII_GZ = 'image.nii.gz'
LABEL_NII_GZ = 'label.nii.gz'

class ImageSourceFormatter(ABC):

@abstractmethod
def format(self) -> None:
pass

class JSONGenerator(ABC):

@abstractmethod
def generate_json(self) -> None:
pass

class MSDImageSourceFormatter(ImageSourceFormatter, JSONGenerator):
def __init__(self, source_directory: str = "./assets/Task06_Lung/", target_directory: str = "./assets/source/msd/") -> None:
self.target_directory = target_directory
Expand Down
7 changes: 7 additions & 0 deletions synthlung/utils/image_source_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from abc import abstractmethod, ABC

class ImageSourceFormatter(ABC):

@abstractmethod
def format(self) -> None:
pass
7 changes: 7 additions & 0 deletions synthlung/utils/json_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from abc import abstractmethod, ABC

class JSONGenerator(ABC):

@abstractmethod
def generate_json(self) -> None:
pass
72 changes: 72 additions & 0 deletions synthlung/utils/lung_segmentation_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Any
from monai.transforms import (Compose, LoadImaged, SaveImaged, ToMetaTensord)
from lungmask import LMInferer
from synthlung.utils.json_generator import JSONGenerator
import tqdm
import os
import json

NII_GZ_EXTENSION = '.nii.gz'
IMAGE_NII_GZ = 'image.nii.gz'
LABEL_NII_GZ = 'label.nii.gz'

class MaskLungs(object):
def __call__(self, sample) -> Any:
image = sample['image']
image = image.numpy()

image = self.__transpose_for_lungmask__(image)

mask = self.lungmask_inferer.apply(image)
mask = self.__transpose_for_lungmask__(mask)

sample['mask'] = mask

sample['mask_meta_dict'] = sample['image_meta_dict']
self.__adjust_filename__(sample['mask_meta_dict'])

return sample

def __init__(self, lungmask_inferer: LMInferer) -> None:
self.lungmask_inferer = lungmask_inferer

def __transpose_for_lungmask__(self, image):
transpose_indeces = (2, 1, 0)
return image.transpose(transpose_indeces)

def __adjust_filename__(self, mask_metadict):
mask_metadict['filename_or_obj'] = mask_metadict['filename_or_obj'].replace('source_', 'host_').replace('_image', '_label')

class LungMaskPipeline(object):
def __init__(self, lungmask_inferer: LMInferer) -> None:
self.inferer = lungmask_inferer
self.compose = Compose([
LoadImaged(keys=['image'], image_only = False),
MaskLungs(lungmask_inferer=self.inferer),
SaveImaged(keys=['mask'], output_dir='./assets/hosts/msd/', output_postfix='', separate_folder=False)
])

def __call__(self, image_dict) -> Any:
if isinstance(image_dict, list):
print(f"Lung masking for {len(image_dict)} images starting...")
for sample in tqdm.tqdm(image_dict):
self.compose(sample)
else:
self.compose(image_dict)

class HostJsonGenerator(JSONGenerator):
def __init__(self, path) -> None:
self.path = path

def generate_json(self) -> None:
dataset_json = []
for filename in os.listdir(self.path):
if filename.endswith((NII_GZ_EXTENSION)):
sample_data = {
"host_image": "./assets/source/msd/" + (filename[:filename.index(LABEL_NII_GZ)] + IMAGE_NII_GZ).replace('host_', 'source_'),
"host_label": self.path + filename
}
dataset_json.append(sample_data)

with open(self.path + "/dataset.json", 'w') as json_file:
json.dump(dataset_json, json_file, indent=4)
7 changes: 5 additions & 2 deletions tests/integration_tests/test_dataset_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import shutil
from pathlib import Path
import json
import collections

def test_msd_formatter_format():
# Arrange
Expand All @@ -28,7 +29,7 @@ def test_msd_formatter_format():
expected_files = ["source_msd_lung_001_image.nii.gz", "source_msd_lung_001_label.nii.gz", "source_msd_lung_002_image.nii.gz", "source_msd_lung_002_label.nii.gz", "source_msd_lung_003_image.nii.gz", "source_msd_lung_003_label.nii.gz"]
actual_files = sorted(shutil.os.listdir(target_dir))

assert expected_files == actual_files
assert collections.Counter(expected_files) == collections.Counter(actual_files)

# Cleanup
shutil.rmtree(source_dir)
Expand Down Expand Up @@ -66,7 +67,9 @@ def test_msd_formatter_generate_json():
with open(f"{target_dir}/dataset.json", 'r') as json_file:
actual_json = json.load(json_file)

assert expected_json == actual_json
expected_json = [json.dumps(d, sort_keys=True) for d in expected_json]
actual_json = [json.dumps(d, sort_keys=True) for d in actual_json]
assert collections.Counter(actual_json) == collections.Counter(expected_json)

# Cleanup
shutil.rmtree(target_dir)

0 comments on commit 429a73e

Please sign in to comment.