From 8ae0e37b690cb5327d6098bcd10f52760618a731 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Fri, 19 Jan 2024 13:00:51 +0100 Subject: [PATCH] update extraction to properly filter cell id list also accounting for dict lists --- src/sparcscore/pipeline/extraction.py | 42 +++++++++++++++++++-------- 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/src/sparcscore/pipeline/extraction.py b/src/sparcscore/pipeline/extraction.py index 1784ddb..6d98e9f 100644 --- a/src/sparcscore/pipeline/extraction.py +++ b/src/sparcscore/pipeline/extraction.py @@ -39,7 +39,7 @@ class HDF5CellExtraction(ProcessingStep): DEFAULT_SEGMENTATION_DIR = "segmentation" DEFAULT_SEGMENTATION_FILE = "segmentation.h5" DEFAULT_CLASSES_FILE = "classes.csv" - DEFAULT_FILTERED_CLASSES_FILE = "filtered/filtered_classes.csv" + DEFAULT_FILTERED_CLASSES_FILE = "filtering/filtered_classes.csv" DEFAULT_DATA_DIR = "data" CLEAN_LOG = False @@ -61,7 +61,7 @@ def __init__(self, self.input_segmentation_path = os.path.join(base_directory, self.DEFAULT_SEGMENTATION_DIR, self.DEFAULT_SEGMENTATION_FILE) #get path to filtered classes - if os.path.isfile(os.path.join(base_directory, self.DEFAULT_SEGMENTATION_DIR, "needs_filtering.txt")): + if os.path.isfile(os.path.join(base_directory, self.DEFAULT_SEGMENTATION_DIR, "needs_additional_filtering.txt")): try: self.classes_path = os.path.join(base_directory, self.DEFAULT_SEGMENTATION_DIR, self.DEFAULT_FILTERED_CLASSES_FILE) self.log(f"Loading classes from filtered classes path: {self.classes_path}") @@ -179,18 +179,19 @@ def get_classes(self, filtered_classes_path = None): if "filtered" in path: filtered_classes = [el[0] for el in list(cr)] #do not do int transform here as we expect a str of format "nucleus_id:cytosol_id" + filtered_classes = np.unique(filtered_classes) else: filtered_classes = [int(float(el[0])) for el in list(cr)] + filtered_classes = np.unique(filtered_classes) #make sure they are all unique + filtered_classes.astype(np.uint64) self.log("Loaded {} classes".format(len(filtered_classes))) - filtered_classes = np.unique(filtered_classes) #make sure they are all unique - filtered_classes.astype(np.uint64) + self.log("After removing duplicates {} classes remain.".format(len(filtered_classes))) class_list = list(filtered_classes) if 0 in class_list: class_list.remove(0) #remove background if still listed self.num_classes = len(class_list) - return(class_list) def generate_save_index_lookup(self, class_list): @@ -584,21 +585,38 @@ def process(self, input_segmentation_path, filtered_classes_path = None): px_centers, _cell_ids = self._calculate_centers(hdf_labels) #get classes to extract - class_list = self.get_classes(filtered_classes_path) - class_list = set(class_list) + class_list = self.get_classes(filtered_classes_path) + + if type(class_list[0]) == str: + print(class_list) + lookup_dict = {int(x.split(":")[0]):int(x.split(":")[1]) for x in class_list} + nuclei_ids = list(lookup_dict.keys()) + nuclei_ids = set(nuclei_ids) + print(nuclei_ids) + else: + nuclei_ids = set(class_list) #filter cell ids found using center into those that we actually want to extract _cell_ids = list(_cell_ids) - filter = [x in class_list for x in _cell_ids] + print("_cell_ids") + print(_cell_ids) + + filter = [x in nuclei_ids for x in _cell_ids] px_centers = np.array(list(compress(px_centers, filter))) _cell_ids = list(compress(_cell_ids, filter)) - #update number of classes - self.log(f"Number of classes found in filtered classes list {len(class_list)} vs number of classes for which centers were calculated {len(_cell_ids)}") - class_list = _cell_ids - del _cell_ids, filter + #generate new class list + if type(class_list[0]) == str: + class_list = [f"{x}:{lookup_dict[x]}" for x in _cell_ids] + del lookup_dict + else: + class_list = _cell_ids + + self.log(f"Number of classes found in filtered classes list {len(nuclei_ids)} vs number of classes for which centers were calculated {len(class_list)}") + del _cell_ids, filter, nuclei_ids + #update number of classes self.num_classes = len(class_list) # setup cache