Skip to content

Commit

Permalink
Merge pull request #5 from VemundFredriksen/refactor-pipeline-general…
Browse files Browse the repository at this point in the history
…ization

Refactor pipeline generalization
  • Loading branch information
sosevle authored Oct 29, 2023
2 parents 429a73e + 0b0a9e7 commit a759560
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 20 deletions.
2 changes: 1 addition & 1 deletion assets/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
2. Download the MSD Lung Tumor dataset from [here](https://drive.google.com/drive/folders/1HqEgzS8BV2c7xYNrZdEAnrHk7osJJ--2).
3. Extract the zip file into `/assets/`.
4. Run `synthlung format --dataset msd` to adjust dataset format
5. Run `synthlung seed --dataset msd` to extract tumor seeds from the dataset
5. Run `synthlung seed --dataset` to extract tumor seeds from the dataset
6. Run `synthlung host --dataset msd` to extract lung masks from the images
9 changes: 4 additions & 5 deletions synthlung/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from lungmask import LMInferer
import json

def seed_msd():
json_file_path = "./assets/source/msd/dataset.json"
def seed():
json_file_path = "./assets/source/dataset.json"

with open(json_file_path, 'r') as json_file:
image_dict = json.load(json_file)
crop_pipeline = TumorCropPipeline()
crop_pipeline(image_dict)
formatter = MSDGenerateJSONFormatter("./assets/seeds/msd/")
formatter = MSDGenerateJSONFormatter("./assets/seeds/")
formatter.generate_json()

def format_msd():
Expand Down Expand Up @@ -55,8 +55,7 @@ def main():
if(args.dataset == "msd"):
format_msd()
elif args.action == "seed":
if(args.dataset == "msd"):
seed_msd()
seed()
elif args.action == "generate":
if(args.dataset == "msd"):
generate_randomized_tumors()
Expand Down
2 changes: 1 addition & 1 deletion synthlung/utils/dataset_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
LABEL_NII_GZ = 'label.nii.gz'

class MSDImageSourceFormatter(ImageSourceFormatter, JSONGenerator):
def __init__(self, source_directory: str = "./assets/Task06_Lung/", target_directory: str = "./assets/source/msd/") -> None:
def __init__(self, source_directory: str = "./assets/Task06_Lung/", target_directory: str = "./assets/source/") -> None:
self.target_directory = target_directory
self.source_directory = source_directory

Expand Down
45 changes: 32 additions & 13 deletions synthlung/utils/tumor_isolation_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@
import numpy as np
import tqdm

class CutOutTumor(object):
class TumorSeedIsolationd(object):
def __init__(self, image_key='image', label_key='label', image_output_key='seed_image', label_output_key='seed_label') -> None:
self.image_key = image_key
self.label_key = label_key
self.image_output_key = image_output_key
self.label_output_key = label_output_key

def __call__(self, sample: dict) -> dict:
image, label = sample['image'], sample['label']
image, label = sample[self.image_key], sample[self.label_key]
bolean_mask = (label > 0).astype(np.uint8)

indeces = np.argwhere(bolean_mask)
Expand All @@ -17,15 +23,13 @@ def __call__(self, sample: dict) -> dict:
clipped_label = label[y_min:y_max+1, x_min:x_max+1, z_min:z_max+1]
clipped_image = np.where(clipped_label == 1, clipped_image, -1024)

sample['seed_image'] = clipped_image
sample['seed_image_meta_dict'] = sample['image_meta_dict']
self.__update_image_dims__(sample['seed_image_meta_dict'], clipped_image.shape)
self.__update_image_filename__(sample['seed_image_meta_dict'])
sample[self.image_output_key] = clipped_image
sample[f'{self.image_output_key}_meta_dict'] = sample[f'{self.image_key}_meta_dict']
self.__update_image_dims__(sample[f'{self.image_output_key}_meta_dict'], clipped_image.shape)

sample['seed_label'] = clipped_label
sample['seed_label_meta_dict'] = sample['label_meta_dict']
self.__update_image_dims__(sample['seed_label_meta_dict'], clipped_label.shape)
self.__update_label_filename__(sample['seed_label'])
sample[self.label_output_key] = clipped_label
sample[f'{self.label_output_key}_meta_dict'] = sample[f'{self.label_key}_meta_dict']
self.__update_image_dims__(sample[f'{self.label_output_key}_meta_dict'], clipped_label.shape)

return sample

Expand All @@ -44,15 +48,30 @@ def __update_image_filename__(self, image):
def __update_label_filename__(self, label):
label.meta['filename_or_obj'] = label.meta['filename_or_obj'].replace('source_', 'seed_')

class RenameSourceToSeed(object):
def __init__(self, meta_dict_keys=['seed_image_meta_dict'], image_object_keys=['seed_label']) -> Any:
self.meta_dict_keys= meta_dict_keys
self.image_object_keys = image_object_keys

def __call__(self, sample:dict) -> Any:
for meta_dict in self.meta_dict_keys:
sample[meta_dict]['filename_or_obj'] = sample[meta_dict]['filename_or_obj'].replace('source_', 'seed_')

for image_object in self.image_object_keys:
sample[image_object].meta['filename_or_obj'] = sample[image_object].meta['filename_or_obj'].replace('source_', 'seed_')

return sample


class TumorCropPipeline(object):
monai.config.BACKEND = "Nibabel"
def __init__(self) -> None:
self.compose = Compose([
LoadImaged(keys=['image', 'label'], image_only = False),
CutOutTumor(),
SaveImaged(keys=['seed_image'], output_dir='./assets/seeds/msd/', output_postfix='', separate_folder=False),
SaveImaged(keys=['seed_label'], output_dir='./assets/seeds/msd/', output_postfix='', separate_folder=False)
TumorSeedIsolationd(image_key='image', label_key='label', image_output_key='seed_image', label_output_key='seed_label'),
RenameSourceToSeed(meta_dict_keys=['seed_image_meta_dict', 'seed_label_meta_dict']),
SaveImaged(keys=['seed_image'], output_dir='./assets/seeds/', output_postfix='', separate_folder=False),
SaveImaged(keys=['seed_label'], output_dir='./assets/seeds/', output_postfix='', separate_folder=False)
])

def __call__(self, image_dict) -> None:
Expand Down

0 comments on commit a759560

Please sign in to comment.