Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor pipeline generalization #5

Merged
merged 3 commits into from
Oct 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion assets/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
2. Download the MSD Lung Tumor dataset from [here](https://drive.google.com/drive/folders/1HqEgzS8BV2c7xYNrZdEAnrHk7osJJ--2).
3. Extract the zip file into `/assets/`.
4. Run `synthlung format --dataset msd` to adjust dataset format
5. Run `synthlung seed --dataset msd` to extract tumor seeds from the dataset
5. Run `synthlung seed --dataset` to extract tumor seeds from the dataset
6. Run `synthlung host --dataset msd` to extract lung masks from the images
9 changes: 4 additions & 5 deletions synthlung/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from lungmask import LMInferer
import json

def seed_msd():
json_file_path = "./assets/source/msd/dataset.json"
def seed():
json_file_path = "./assets/source/dataset.json"

with open(json_file_path, 'r') as json_file:
image_dict = json.load(json_file)
crop_pipeline = TumorCropPipeline()
crop_pipeline(image_dict)
formatter = MSDGenerateJSONFormatter("./assets/seeds/msd/")
formatter = MSDGenerateJSONFormatter("./assets/seeds/")
formatter.generate_json()

def format_msd():
Expand Down Expand Up @@ -55,8 +55,7 @@ def main():
if(args.dataset == "msd"):
format_msd()
elif args.action == "seed":
if(args.dataset == "msd"):
seed_msd()
seed()
elif args.action == "generate":
if(args.dataset == "msd"):
generate_randomized_tumors()
Expand Down
2 changes: 1 addition & 1 deletion synthlung/utils/dataset_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
LABEL_NII_GZ = 'label.nii.gz'

class MSDImageSourceFormatter(ImageSourceFormatter, JSONGenerator):
def __init__(self, source_directory: str = "./assets/Task06_Lung/", target_directory: str = "./assets/source/msd/") -> None:
def __init__(self, source_directory: str = "./assets/Task06_Lung/", target_directory: str = "./assets/source/") -> None:
self.target_directory = target_directory
self.source_directory = source_directory

Expand Down
45 changes: 32 additions & 13 deletions synthlung/utils/tumor_isolation_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@
import numpy as np
import tqdm

class CutOutTumor(object):
class TumorSeedIsolationd(object):
def __init__(self, image_key='image', label_key='label', image_output_key='seed_image', label_output_key='seed_label') -> None:
self.image_key = image_key
self.label_key = label_key
self.image_output_key = image_output_key
self.label_output_key = label_output_key

def __call__(self, sample: dict) -> dict:
image, label = sample['image'], sample['label']
image, label = sample[self.image_key], sample[self.label_key]
bolean_mask = (label > 0).astype(np.uint8)

indeces = np.argwhere(bolean_mask)
Expand All @@ -17,15 +23,13 @@ def __call__(self, sample: dict) -> dict:
clipped_label = label[y_min:y_max+1, x_min:x_max+1, z_min:z_max+1]
clipped_image = np.where(clipped_label == 1, clipped_image, -1024)

sample['seed_image'] = clipped_image
sample['seed_image_meta_dict'] = sample['image_meta_dict']
self.__update_image_dims__(sample['seed_image_meta_dict'], clipped_image.shape)
self.__update_image_filename__(sample['seed_image_meta_dict'])
sample[self.image_output_key] = clipped_image
sample[f'{self.image_output_key}_meta_dict'] = sample[f'{self.image_key}_meta_dict']
self.__update_image_dims__(sample[f'{self.image_output_key}_meta_dict'], clipped_image.shape)

sample['seed_label'] = clipped_label
sample['seed_label_meta_dict'] = sample['label_meta_dict']
self.__update_image_dims__(sample['seed_label_meta_dict'], clipped_label.shape)
self.__update_label_filename__(sample['seed_label'])
sample[self.label_output_key] = clipped_label
sample[f'{self.label_output_key}_meta_dict'] = sample[f'{self.label_key}_meta_dict']
self.__update_image_dims__(sample[f'{self.label_output_key}_meta_dict'], clipped_label.shape)

return sample

Expand All @@ -44,15 +48,30 @@ def __update_image_filename__(self, image):
def __update_label_filename__(self, label):
label.meta['filename_or_obj'] = label.meta['filename_or_obj'].replace('source_', 'seed_')

class RenameSourceToSeed(object):
def __init__(self, meta_dict_keys=['seed_image_meta_dict'], image_object_keys=['seed_label']) -> Any:
self.meta_dict_keys= meta_dict_keys
self.image_object_keys = image_object_keys

def __call__(self, sample:dict) -> Any:
for meta_dict in self.meta_dict_keys:
sample[meta_dict]['filename_or_obj'] = sample[meta_dict]['filename_or_obj'].replace('source_', 'seed_')

for image_object in self.image_object_keys:
sample[image_object].meta['filename_or_obj'] = sample[image_object].meta['filename_or_obj'].replace('source_', 'seed_')

return sample


class TumorCropPipeline(object):
monai.config.BACKEND = "Nibabel"
def __init__(self) -> None:
self.compose = Compose([
LoadImaged(keys=['image', 'label'], image_only = False),
CutOutTumor(),
SaveImaged(keys=['seed_image'], output_dir='./assets/seeds/msd/', output_postfix='', separate_folder=False),
SaveImaged(keys=['seed_label'], output_dir='./assets/seeds/msd/', output_postfix='', separate_folder=False)
TumorSeedIsolationd(image_key='image', label_key='label', image_output_key='seed_image', label_output_key='seed_label'),
RenameSourceToSeed(meta_dict_keys=['seed_image_meta_dict', 'seed_label_meta_dict']),
SaveImaged(keys=['seed_image'], output_dir='./assets/seeds/', output_postfix='', separate_folder=False),
SaveImaged(keys=['seed_label'], output_dir='./assets/seeds/', output_postfix='', separate_folder=False)
])

def __call__(self, image_dict) -> None:
Expand Down
Loading