Skip to content

Commit

Permalink
update extraction to properly filter cell id list also accounting for…
Browse files Browse the repository at this point in the history
… dict lists
  • Loading branch information
sophiamaedler committed Jan 19, 2024
1 parent 8052a8d commit 8ae0e37
Showing 1 changed file with 30 additions and 12 deletions.
42 changes: 30 additions & 12 deletions src/sparcscore/pipeline/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class HDF5CellExtraction(ProcessingStep):
DEFAULT_SEGMENTATION_DIR = "segmentation"
DEFAULT_SEGMENTATION_FILE = "segmentation.h5"
DEFAULT_CLASSES_FILE = "classes.csv"
DEFAULT_FILTERED_CLASSES_FILE = "filtered/filtered_classes.csv"
DEFAULT_FILTERED_CLASSES_FILE = "filtering/filtered_classes.csv"
DEFAULT_DATA_DIR = "data"
CLEAN_LOG = False

Expand All @@ -61,7 +61,7 @@ def __init__(self,
self.input_segmentation_path = os.path.join(base_directory, self.DEFAULT_SEGMENTATION_DIR, self.DEFAULT_SEGMENTATION_FILE)

#get path to filtered classes
if os.path.isfile(os.path.join(base_directory, self.DEFAULT_SEGMENTATION_DIR, "needs_filtering.txt")):
if os.path.isfile(os.path.join(base_directory, self.DEFAULT_SEGMENTATION_DIR, "needs_additional_filtering.txt")):
try:
self.classes_path = os.path.join(base_directory, self.DEFAULT_SEGMENTATION_DIR, self.DEFAULT_FILTERED_CLASSES_FILE)
self.log(f"Loading classes from filtered classes path: {self.classes_path}")
Expand Down Expand Up @@ -179,18 +179,19 @@ def get_classes(self, filtered_classes_path = None):

if "filtered" in path:
filtered_classes = [el[0] for el in list(cr)] #do not do int transform here as we expect a str of format "nucleus_id:cytosol_id"
filtered_classes = np.unique(filtered_classes)
else:
filtered_classes = [int(float(el[0])) for el in list(cr)]
filtered_classes = np.unique(filtered_classes) #make sure they are all unique
filtered_classes.astype(np.uint64)

self.log("Loaded {} classes".format(len(filtered_classes)))
filtered_classes = np.unique(filtered_classes) #make sure they are all unique
filtered_classes.astype(np.uint64)

self.log("After removing duplicates {} classes remain.".format(len(filtered_classes)))

class_list = list(filtered_classes)
if 0 in class_list: class_list.remove(0) #remove background if still listed
self.num_classes = len(class_list)

return(class_list)

def generate_save_index_lookup(self, class_list):
Expand Down Expand Up @@ -584,21 +585,38 @@ def process(self, input_segmentation_path, filtered_classes_path = None):
px_centers, _cell_ids = self._calculate_centers(hdf_labels)

#get classes to extract
class_list = self.get_classes(filtered_classes_path)
class_list = set(class_list)
class_list = self.get_classes(filtered_classes_path)

if type(class_list[0]) == str:
print(class_list)
lookup_dict = {int(x.split(":")[0]):int(x.split(":")[1]) for x in class_list}
nuclei_ids = list(lookup_dict.keys())
nuclei_ids = set(nuclei_ids)
print(nuclei_ids)
else:
nuclei_ids = set(class_list)

#filter cell ids found using center into those that we actually want to extract
_cell_ids = list(_cell_ids)
filter = [x in class_list for x in _cell_ids]
print("_cell_ids")
print(_cell_ids)

filter = [x in nuclei_ids for x in _cell_ids]

px_centers = np.array(list(compress(px_centers, filter)))
_cell_ids = list(compress(_cell_ids, filter))

#update number of classes
self.log(f"Number of classes found in filtered classes list {len(class_list)} vs number of classes for which centers were calculated {len(_cell_ids)}")
class_list = _cell_ids
del _cell_ids, filter
#generate new class list
if type(class_list[0]) == str:
class_list = [f"{x}:{lookup_dict[x]}" for x in _cell_ids]
del lookup_dict
else:
class_list = _cell_ids

self.log(f"Number of classes found in filtered classes list {len(nuclei_ids)} vs number of classes for which centers were calculated {len(class_list)}")
del _cell_ids, filter, nuclei_ids

#update number of classes
self.num_classes = len(class_list)

# setup cache
Expand Down

0 comments on commit 8ae0e37

Please sign in to comment.