diff --git a/coralnet_toolbox/AutoDistill/QtDeployModel.py b/coralnet_toolbox/AutoDistill/QtDeployModel.py index 555a89cc..defab920 100644 --- a/coralnet_toolbox/AutoDistill/QtDeployModel.py +++ b/coralnet_toolbox/AutoDistill/QtDeployModel.py @@ -382,18 +382,6 @@ def get_ontology_mapping(self): ontology_mapping[text_input.text()] = label_dropdown.currentText() return ontology_mapping - def get_uncertainty_threshold(self): - """ - Get the uncertainty threshold, limiting it to a maximum of 0.10. - - Returns: - Adjusted uncertainty threshold value. - """ - if self.main_window.get_uncertainty_thresh() < 0.10: - return self.main_window.get_uncertainty_thresh() - else: - return 0.10 # Arbitrary value to prevent too many detections - def load_new_model(self, model_name, uncertainty_thresh): """ Load a new model based on the selected model name. @@ -406,8 +394,8 @@ def load_new_model(self, model_name, uncertainty_thresh): from autodistill_grounding_dino import GroundingDINO self.model_name = model_name self.loaded_model = GroundingDINO(ontology=self.ontology, - box_threshold=uncertainty_thresh, - text_threshold=uncertainty_thresh) + box_threshold=0.025, + text_threshold=0.025) def predict(self, image_paths=None): """ @@ -417,7 +405,6 @@ def predict(self, image_paths=None): image_paths: List of image paths to process. If None, uses the current image. """ if not self.loaded_model: - QMessageBox.critical(self, "Error", "No model loaded") return if not image_paths: @@ -440,7 +427,7 @@ def predict(self, image_paths=None): # Create a results processor results_processor = ResultsProcessor(self.main_window, self.class_mapping, - uncertainty_thresh=self.get_uncertainty_threshold(), + uncertainty_thresh=self.uncertainty_thresh, iou_thresh=self.iou_thresh, min_area_thresh=self.area_thresh_min, max_area_thresh=self.area_thresh_max) diff --git a/coralnet_toolbox/MachineLearning/BatchInference/QtClassify.py b/coralnet_toolbox/MachineLearning/BatchInference/QtClassify.py index acd38eb1..f4ae2b45 100644 --- a/coralnet_toolbox/MachineLearning/BatchInference/QtClassify.py +++ b/coralnet_toolbox/MachineLearning/BatchInference/QtClassify.py @@ -101,27 +101,21 @@ def preprocess_patch_annotations(self): progress_bar.show() progress_bar.start_progress(len(self.image_paths)) - def crop(image_path, image_annotations): - # Crop the image based on the annotations - return self.annotation_window.crop_these_image_annotations(image_path, image_annotations) - # Group annotations by image path - groups = groupby(sorted(self.annotations, key=attrgetter('image_path')), key=attrgetter('image_path')) - - with ThreadPoolExecutor() as executor: - future_to_image = {} - for path, group in groups: - future = executor.submit(crop, path, list(group)) - future_to_image[future] = path - - for future in as_completed(future_to_image): - image_path = future_to_image[future] - try: - self.prepared_patches.extend(future.result()) - except Exception as exc: - print(f'{image_path} generated an exception: {exc}') - finally: - progress_bar.update_progress() + grouped_annotations = groupby(sorted(self.annotations, key=attrgetter('image_path')), + key=attrgetter('image_path')) + + for image_path, group in grouped_annotations: + try: + # Process image annotations + image_annotations = list(group) + image_annotations = self.annotation_window.crop_these_image_annotations(image_path, image_annotations) + self.prepared_patches.extend(image_annotations) + + except Exception as exc: + print(f'{image_path} generated an exception: {exc}') + finally: + progress_bar.update_progress() progress_bar.stop_progress() progress_bar.close() @@ -147,4 +141,8 @@ def batch_inference(self): progress_bar.update_progress() progress_bar.stop_progress() - progress_bar.close() \ No newline at end of file + progress_bar.close() + + # Clear the list of annotations + self.annotations = [] + self.prepared_patches = [] \ No newline at end of file diff --git a/coralnet_toolbox/MachineLearning/DeployModel/QtBase.py b/coralnet_toolbox/MachineLearning/DeployModel/QtBase.py index 20292590..c804c12b 100644 --- a/coralnet_toolbox/MachineLearning/DeployModel/QtBase.py +++ b/coralnet_toolbox/MachineLearning/DeployModel/QtBase.py @@ -339,10 +339,10 @@ def create_generic_labels(self): def get_uncertainty_threshold(self): """ - Get the confidence threshold for predictions + Get the confidence threshold for classification predictions """ threshold = self.main_window.get_uncertainty_thresh() - return threshold if threshold < 0.10 else 0.10 + return threshold if threshold > 0.10 else 0.10 def predict(self, inputs): """ diff --git a/coralnet_toolbox/MachineLearning/DeployModel/QtClassify.py b/coralnet_toolbox/MachineLearning/DeployModel/QtClassify.py index 73647207..4aa0b8b1 100644 --- a/coralnet_toolbox/MachineLearning/DeployModel/QtClassify.py +++ b/coralnet_toolbox/MachineLearning/DeployModel/QtClassify.py @@ -110,6 +110,9 @@ def predict(self, inputs=None): if not inputs: # If no annotations are selected, predict all annotations in the image inputs = self.annotation_window.get_image_review_annotations() + if not inputs: + # If no annotations are available, return + return images_np = [] for annotation in inputs: @@ -117,12 +120,13 @@ def predict(self, inputs=None): # Predict the classification results results = self.loaded_model(images_np, + conf=self.uncertainty_thresh, device=self.main_window.device, stream=True) # Create a result processor results_processor = ResultsProcessor(self.main_window, self.class_mapping, - uncertainty_thresh=self.get_uncertainty_threshold()) + uncertainty_thresh=self.uncertainty_thresh) # Process the classification results results_processor.process_classification_results(results, inputs) diff --git a/coralnet_toolbox/MachineLearning/DeployModel/QtDetect.py b/coralnet_toolbox/MachineLearning/DeployModel/QtDetect.py index f9ec4e10..d0ef1f47 100644 --- a/coralnet_toolbox/MachineLearning/DeployModel/QtDetect.py +++ b/coralnet_toolbox/MachineLearning/DeployModel/QtDetect.py @@ -143,7 +143,7 @@ def predict(self, inputs=None): # Predict the detection results results = self.loaded_model(inputs, agnostic_nms=True, - conf=self.get_uncertainty_threshold(), + conf=self.uncertainty_thresh, iou=self.iou_thresh, device=self.main_window.device, stream=True) @@ -151,7 +151,7 @@ def predict(self, inputs=None): # Create a result processor results_processor = ResultsProcessor(self.main_window, self.class_mapping, - uncertainty_thresh=self.get_uncertainty_threshold(), + uncertainty_thresh=self.uncertainty_thresh, iou_thresh=self.iou_thresh, min_area_thresh=self.area_thresh_min, max_area_thresh=self.area_thresh_max) diff --git a/coralnet_toolbox/MachineLearning/DeployModel/QtSegment.py b/coralnet_toolbox/MachineLearning/DeployModel/QtSegment.py index 6d773bd4..99207070 100644 --- a/coralnet_toolbox/MachineLearning/DeployModel/QtSegment.py +++ b/coralnet_toolbox/MachineLearning/DeployModel/QtSegment.py @@ -143,7 +143,7 @@ def predict(self, inputs=None): # Predict the segmentation results results = self.loaded_model(inputs, agnostic_nms=True, - conf=self.get_uncertainty_threshold(), + conf=self.uncertainty_thresh, iou=self.iou_thresh, device=self.main_window.device, stream=True) @@ -151,7 +151,7 @@ def predict(self, inputs=None): # Create a result processor results_processor = ResultsProcessor(self.main_window, self.class_mapping, - uncertainty_thresh=self.get_uncertainty_threshold(), + uncertainty_thresh=self.uncertainty_thresh, iou_thresh=self.iou_thresh, min_area_thresh=self.area_thresh_min, max_area_thresh=self.area_thresh_max) diff --git a/coralnet_toolbox/QtMainWindow.py b/coralnet_toolbox/QtMainWindow.py index 87e8e0bb..b82945f9 100644 --- a/coralnet_toolbox/QtMainWindow.py +++ b/coralnet_toolbox/QtMainWindow.py @@ -1145,25 +1145,6 @@ def open_segment_deploy_model_dialog(self): self.segment_deploy_model_dialog.exec_() except Exception as e: QMessageBox.critical(self, "Critical Error", f"{e}") - - def open_batch_inference_dialog(self): - if not self.image_window.image_paths: - QMessageBox.warning(self, - "Batch Inference", - "No images are present in the project.") - return - - if not any(list(self.deploy_model_dialog.loaded_models.values())): - QMessageBox.warning(self, - "Batch Inference", - "Please deploy a model before running batch inference.") - return - - try: - self.untoggle_all_tools() - self.batch_inference_dialog.exec_() - except Exception as e: - QMessageBox.critical(self, "Critical Error", f"{e}") def open_classify_batch_inference_dialog(self): if not self.image_window.image_paths: @@ -1177,6 +1158,12 @@ def open_classify_batch_inference_dialog(self): "Batch Inference", "Please deploy a model before running batch inference.") return + + if not self.annotation_window.annotations_dict: + QMessageBox.warning(self, + "Batch Inference", + "No annotations are present in the project.") + return try: self.untoggle_all_tools() diff --git a/coralnet_toolbox/ResultsProcessor.py b/coralnet_toolbox/ResultsProcessor.py index 2d08600c..321c493a 100644 --- a/coralnet_toolbox/ResultsProcessor.py +++ b/coralnet_toolbox/ResultsProcessor.py @@ -3,6 +3,7 @@ from PyQt5.QtCore import QPointF +from torchvision.ops import nms from ultralytics.engine.results import Results from ultralytics.models.sam.amg import batched_mask_to_box from ultralytics.utils import ops @@ -23,9 +24,9 @@ def __init__(self, main_window, class_mapping, uncertainty_thresh=0.3, - iou_thresh=0.7, - min_area_thresh=0.01, - max_area_thresh=0.5): + iou_thresh=0.2, + min_area_thresh=0.00, + max_area_thresh=0.40): self.main_window = main_window self.label_window = main_window.label_window self.annotation_window = main_window.annotation_window @@ -37,39 +38,30 @@ def __init__(self, self.min_area_thresh = min_area_thresh self.max_area_thresh = max_area_thresh - def filter_by_uncertainty(self, result): + def filter_by_uncertainty(self, results): """ Filter the results based on the uncertainty threshold. """ - # Get the confidence score - conf = float(result.boxes.conf.cpu().numpy()[0]) - # Check if the confidence is within the threshold - if conf < self.uncertainty_thresh: - return False - return True + return results[results.boxes.conf > self.uncertainty_thresh] - def filter_by_area(self, result): + def filter_by_iou(self, results): + """Filter the results based on the IoU threshold.""" + return results[nms(results.boxes.xyxy, results.boxes.conf, self.iou_thresh)] + + def filter_by_area(self, results): """ Filter the results based on the area threshold. """ - # Get the normalized bounding box coordinates - x_norm, y_norm, w_norm, h_norm = map(float, result.boxes.xywhn.cpu().numpy()[0]) - # Calculate the normalized area + x_norm, y_norm, w_norm, h_norm = results.boxes.xywhn.T area_norm = w_norm * h_norm - # Check if the area is within the threshold - if area_norm < self.min_area_thresh: - return False - if area_norm > self.max_area_thresh: - return False - return True + return results[(area_norm > self.min_area_thresh) & (area_norm < self.max_area_thresh)] - def passed_filters(self, result): + def apply_filters(self, results): """Check if the results passed all filters.""" - if not self.filter_by_uncertainty(result): - return False - if not self.filter_by_area(result): - return False - return True + results = self.filter_by_uncertainty(results) + results = self.filter_by_iou(results) + results = self.filter_by_area(results) + return results def extract_classification_result(self, result): """ @@ -141,10 +133,7 @@ def extract_detection_result(self, result): def process_single_detection_result(self, result): """ Process a single detection result. - """ - if not self.passed_filters(result): - return - + """ # Get image path, class, class name, confidence, and bounding box coordinates image_path, cls, cls_name, conf, x_min, y_min, x_max, y_max = self.extract_detection_result(result) # Get the short label given the class name and confidence @@ -166,6 +155,8 @@ def process_detection_results(self, results_generator): progress_bar.show() for results in results_generator: + # Apply filtering to the results + results = self.apply_filters(results) for result in results: if result: self.process_single_detection_result(result) @@ -193,9 +184,6 @@ def process_single_segmentation_result(self, result): """ Process a single segmentation result. """ - if not self.passed_filters(result): - return - # Get image path, class, class name, confidence, and polygon points image_path, cls, cls_name, conf, points = self.extract_segmentation_result(result) # Get the short label given the class name and confidence @@ -217,6 +205,8 @@ def process_segmentation_results(self, results_generator): progress_bar.show() for results in results_generator: + # Apply filtering to the results + results = self.apply_filters(results) for result in results: if result: self.process_single_segmentation_result(result) diff --git a/coralnet_toolbox/SAM/QtDeployGenerator.py b/coralnet_toolbox/SAM/QtDeployGenerator.py index eec5705b..d08076b1 100644 --- a/coralnet_toolbox/SAM/QtDeployGenerator.py +++ b/coralnet_toolbox/SAM/QtDeployGenerator.py @@ -318,11 +318,12 @@ def load_model(self): task=self.task, mode='predict', save=False, - max_det=1000, + max_det=500, imgsz=self.get_imgsz(), - conf=self.get_uncertainty_threshold(), - iou=self.iou_thresh, + conf=0.05, + iou=1.0, device=self.main_window.device) + # Load the model self.loaded_model = FastSAMPredictor(overrides=overrides) @@ -349,18 +350,6 @@ def get_imgsz(self): self.imgsz = self.imgsz_spinbox.value() return self.imgsz - def get_uncertainty_threshold(self): - """ - Get the uncertainty threshold, limiting it to a maximum of 0.10. - - Returns: - Adjusted uncertainty threshold value. - """ - if self.main_window.get_uncertainty_thresh() < 0.10: - return self.main_window.get_uncertainty_thresh() - else: - return 0.10 # Arbitrary value to prevent too many detections - def predict(self, image_paths=None): """ Make predictions on the given image paths using the loaded model. @@ -402,7 +391,7 @@ def predict(self, image_paths=None): # Create a results processor results_processor = ResultsProcessor(self.main_window, self.class_mapping, - uncertainty_thresh=self.get_uncertainty_threshold(), + uncertainty_thresh=self.uncertainty_thresh, iou_thresh=self.iou_thresh, min_area_thresh=self.area_thresh_min, max_area_thresh=self.area_thresh_max) @@ -410,7 +399,7 @@ def predict(self, image_paths=None): # Update the progress bar progress_bar.update_progress() - if self.task == 'segment' or self.use_sam_dropdown.currentText() == "True": + if self.task.lower() == 'segment' or self.use_sam_dropdown.currentText() == "True": results_processor.process_segmentation_results(results) else: results_processor.process_detection_results(results) diff --git a/coralnet_toolbox/__init__.py b/coralnet_toolbox/__init__.py index fa37c0d5..616c4660 100644 --- a/coralnet_toolbox/__init__.py +++ b/coralnet_toolbox/__init__.py @@ -2,7 +2,7 @@ from coralnet_toolbox.main import run -__version__ = "0.0.3" +__version__ = "0.0.4" __author__ = "Jordan Pierce" __email__ = "jordan.pierce@noaa.gov" __credits__ = "National Center for Coastal and Ocean Sciences (NCCOS)" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index d746937b..cc051bd8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "coralnet-toolbox" -version = "0.0.3" +version = "0.0.4" dynamic = [ "dependencies", ] @@ -48,7 +48,7 @@ universal = true [tool.bumpversion] -current_version = "0.0.3" +current_version = "0.0.4" commit = true tag = true