From 5ab2d2f364b10190c8776cab3f190fb499ff0522 Mon Sep 17 00:00:00 2001 From: Jordan Pierce <115024024+Jordan-Pierce@users.noreply.github.com> Date: Tue, 3 Dec 2024 23:54:01 -0500 Subject: [PATCH 1/2] Add GroundingDINO model and refactor uncertainty threshold handling --- README.md | 22 ++-- .../AutoDistill/Models/GroundingDINOModel.py | 120 ++++++++++++++++++ .../AutoDistill/Models/__init__.py | 5 + coralnet_toolbox/AutoDistill/QtDeployModel.py | 27 ++-- .../MachineLearning/DeployModel/QtBase.py | 7 - .../MachineLearning/DeployModel/QtClassify.py | 4 +- .../MachineLearning/DeployModel/QtDetect.py | 12 +- .../MachineLearning/DeployModel/QtSegment.py | 13 +- coralnet_toolbox/QtLabelWindow.py | 1 - coralnet_toolbox/QtMainWindow.py | 8 ++ coralnet_toolbox/ResultsProcessor.py | 31 ++++- coralnet_toolbox/SAM/QtDeployGenerator.py | 17 ++- coralnet_toolbox/SAM/QtDeployPredictor.py | 30 +++-- coralnet_toolbox/Tools/QtSAMTool.py | 2 +- docs/index.md | 10 +- docs/installation.md | 7 +- 16 files changed, 235 insertions(+), 81 deletions(-) create mode 100644 coralnet_toolbox/AutoDistill/Models/GroundingDINOModel.py create mode 100644 coralnet_toolbox/AutoDistill/Models/__init__.py diff --git a/README.md b/README.md index ee4b6ae7..a20b72e8 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,9 @@ CoralNet-Toolbox

-[![image](https://img.shields.io/pypi/v/CoralNet-Toolbox.svg)](https://pypi.python.org/pypi/CoralNet-Toolbox) +[![version](https://img.shields.io/pypi/v/CoralNet-Toolbox.svg)](https://pypi.python.org/pypi/CoralNet-Toolbox) +[![python-version](https://img.shields.io/pypi/pyversions/CoralNet-Toolbox.svg)](https://pypi.org/project/CoralNet-Toolbox) + ## Quick Start @@ -43,13 +45,13 @@ computer vision and deep learning built in `PyTorch`. For more information on th [here](https://github.com/ultralytics/ultralytics/blob/main/LICENSE). The `toolbox` also uses the following to create rectangle and polygon annotations: -- [`Fast-SAM`]() -- [`RepViT-SAM`]() +- [`Fast-SAM`](https://github.com/CASIA-IVA-Lab/FastSAM) +- [`RepViT-SAM`](https://github.com/THU-MIG/RepViT) - [`EdgeSAM`](https://github.com/chongzhou96/EdgeSAM) - [`MobileSAM`](https://github.com/ChaoningZhang/MobileSAM) - [`SAM`](https://github.com/facebookresearch/segment-anything) - [`AutoDistill`](https://github.com/autodistill) - - [`GroundingDino`]() + - [`GroundingDino`](https://github.com/IDEA-Research/GroundingDINO) ## Tools @@ -61,8 +63,8 @@ Enhance your CoralNet experience with these tools: - ๐Ÿงฉ Patches: Create patches (points) - ๐Ÿ”ณ Rectangles: Create rectangles (bounding boxes) - ๐ŸŸฃ Polygons: Create polygons (instance masks) -- ๐Ÿฆพ SAM: Use [`FastSAM`](), [`RepViT-SAM`](), [`EdgeSAM`](), [`MobileSAM`](), and [`SAM`]() to create polygons -- ๐Ÿงช AutoDistill: Use [`AutoDistill`](https://github.com/autodistill) to access `GroundingDINO` for creating rectangles +- ๐Ÿฆพ SAM: Use `FastSAM`, `RepViT-SAM`, `EdgeSAM`, `MobileSAM`, and `SAM` to create polygons +- ๐Ÿงช AutoDistill: Use `AutoDistill` to access `GroundingDINO` for creating rectangles - ๐Ÿง  Train: Build local patch-based classifiers, object detection, and instance segmentation models - ๐Ÿ”ฎ Deploy: Use trained models for predictions - ๐Ÿ“Š Evaluation: Evaluate model performance @@ -121,9 +123,11 @@ If `CUDA` is installed on your computer, and `torch` was built with it properly, `toolbox` instead of a `๐Ÿข`; if you have multiple `CUDA` devices available, you should see a `๐Ÿš€` icon, and if you're using a Mac with `Metal`, you should see an `๐ŸŽ` icon (click on the icon to see the device information). -See here for more details on [`cuda-nvcc`](https://anaconda.org/nvidia/cuda-nvcc), -[`cudatoolkit`](https://anaconda.org/nvidia/cuda-toolkit), and [`torch`](https://pytorch.org/get-started/locally/) -versions. +See here for more details on versions for the following: +- [`cuda-nvcc`](https://anaconda.org/nvidia/cuda-nvcc) +- [`cudatoolkit`](https://anaconda.org/nvidia/cuda-toolkit) +- [`torch`](https://pytorch.org/get-started/locally/) + ### Run Finally, you can run the `toolbox` from the command line: diff --git a/coralnet_toolbox/AutoDistill/Models/GroundingDINOModel.py b/coralnet_toolbox/AutoDistill/Models/GroundingDINOModel.py new file mode 100644 index 00000000..03d43d22 --- /dev/null +++ b/coralnet_toolbox/AutoDistill/Models/GroundingDINOModel.py @@ -0,0 +1,120 @@ +import os +import urllib.request +from dataclasses import dataclass + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +import torch + +torch.use_deterministic_algorithms(False) + +import supervision as sv +from autodistill.detection import CaptionOntology, DetectionBaseModel +from autodistill.helpers import load_image +from groundingdino.util.inference import Model + +from autodistill_grounding_dino.helpers import (combine_detections) + +HOME = os.path.expanduser("~") +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +# ---------------------------------------------------------------------------------------------------------------------- +# Functions +# ---------------------------------------------------------------------------------------------------------------------- + + +def load_grounding_dino(model="SwinT"): + """Load the grounding DINO model.""" + # Define the paths + AUTODISTILL_CACHE_DIR = os.path.expanduser("~/.cache/autodistill") + GROUDNING_DINO_CACHE_DIR = os.path.join(AUTODISTILL_CACHE_DIR, "groundingdino") + + if model == "SwinT": + GROUNDING_DINO_CONFIG_PATH = os.path.join(GROUDNING_DINO_CACHE_DIR, "GroundingDINO_SwinT_OGC.py") + GROUNDING_DINO_CHECKPOINT_PATH = os.path.join(GROUDNING_DINO_CACHE_DIR, "groundingdino_swint_ogc.pth") + else: + GROUNDING_DINO_CONFIG_PATH = os.path.join(GROUDNING_DINO_CACHE_DIR, "GroundingDINO_SwinB_OGC.py") + GROUNDING_DINO_CHECKPOINT_PATH = os.path.join(GROUDNING_DINO_CACHE_DIR, "groundingdino_swinb_cogcoor.pth") + + try: + print("trying to load grounding dino directly") + grounding_dino_model = Model( + model_config_path=GROUNDING_DINO_CONFIG_PATH, + model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, + device=DEVICE, + ) + return grounding_dino_model + + except Exception: + print("downloading dino model weights") + if not os.path.exists(GROUDNING_DINO_CACHE_DIR): + os.makedirs(GROUDNING_DINO_CACHE_DIR) + + if model == "SwinT": + if not os.path.exists(GROUNDING_DINO_CHECKPOINT_PATH): + url = "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth" + urllib.request.urlretrieve(url, GROUNDING_DINO_CHECKPOINT_PATH) + + if not os.path.exists(GROUNDING_DINO_CONFIG_PATH): + url = "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinT_OGC.py" + urllib.request.urlretrieve(url, GROUNDING_DINO_CONFIG_PATH) + else: + if not os.path.exists(GROUNDING_DINO_CHECKPOINT_PATH): + url = "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth" + urllib.request.urlretrieve(url, GROUNDING_DINO_CHECKPOINT_PATH) + + if not os.path.exists(GROUNDING_DINO_CONFIG_PATH): + url = "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinB_cfg.py" + urllib.request.urlretrieve(url, GROUNDING_DINO_CONFIG_PATH) + + grounding_dino_model = Model( + model_config_path=GROUNDING_DINO_CONFIG_PATH, + model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, + device=DEVICE, + ) + + return grounding_dino_model + + +# ---------------------------------------------------------------------------------------------------------------------- +# Classes +# ---------------------------------------------------------------------------------------------------------------------- + + +@dataclass +class GroundingDINO(DetectionBaseModel): + ontology: CaptionOntology + grounding_dino_model: Model + box_threshold: float + text_threshold: float + + def __init__( + self, ontology: CaptionOntology, box_threshold=0.35, text_threshold=0.25, model="SwinB", + ): + self.ontology = ontology + self.grounding_dino_model = load_grounding_dino(model) + self.box_threshold = box_threshold + self.text_threshold = text_threshold + + def predict(self, input: str) -> sv.Detections: + image = load_image(input, return_format="cv2") + + detections_list = [] + + for _, description in enumerate(self.ontology.prompts()): + detections = self.grounding_dino_model.predict_with_classes( + image=image, + classes=[description], + box_threshold=self.box_threshold, + text_threshold=self.text_threshold, + ) + + detections_list.append(detections) + + detections = combine_detections( + detections_list, overwrite_class_ids=range(len(detections_list)) + ) + + return detections \ No newline at end of file diff --git a/coralnet_toolbox/AutoDistill/Models/__init__.py b/coralnet_toolbox/AutoDistill/Models/__init__.py new file mode 100644 index 00000000..1b59fda2 --- /dev/null +++ b/coralnet_toolbox/AutoDistill/Models/__init__.py @@ -0,0 +1,5 @@ +# coralnet_toolbox/AutoDistill/Models/__init__.py + +from .GroundingDINOModel import GroundingDINO + +__all__ = ["GroundingDINO"] \ No newline at end of file diff --git a/coralnet_toolbox/AutoDistill/QtDeployModel.py b/coralnet_toolbox/AutoDistill/QtDeployModel.py index defab920..9cc826de 100644 --- a/coralnet_toolbox/AutoDistill/QtDeployModel.py +++ b/coralnet_toolbox/AutoDistill/QtDeployModel.py @@ -12,6 +12,7 @@ from PyQt5.QtWidgets import (QApplication, QComboBox, QDialog, QFormLayout, QHBoxLayout, QLabel, QLineEdit, QMessageBox, QPushButton, QSlider, QVBoxLayout, QGroupBox) + from torch.cuda import empty_cache from coralnet_toolbox.QtProgressBar import ProgressBar @@ -112,7 +113,7 @@ def setup_models_layout(self): layout = QVBoxLayout() self.model_dropdown = QComboBox() - self.model_dropdown.addItems(["GroundingDINO"]) + self.model_dropdown.addItems(["GroundingDINO-SwinT", "GroundingDINO-SwinB"]) layout.addWidget(self.model_dropdown) group_box.setLayout(layout) @@ -343,14 +344,11 @@ def load_model(self): # Set the class mapping self.class_mapping = {k: v for k, v in enumerate(self.ontology.classes())} - # Threshold for confidence - uncertainty_thresh = self.get_uncertainty_threshold() - # Get the name of the model to load model_name = self.model_dropdown.currentText() if model_name != self.model_name: - self.load_new_model(model_name, uncertainty_thresh) + self.load_new_model(model_name) else: # Update the model with the new ontology self.loaded_model.ontology = self.ontology @@ -382,7 +380,7 @@ def get_ontology_mapping(self): ontology_mapping[text_input.text()] = label_dropdown.currentText() return ontology_mapping - def load_new_model(self, model_name, uncertainty_thresh): + def load_new_model(self, model_name): """ Load a new model based on the selected model name. @@ -390,12 +388,15 @@ def load_new_model(self, model_name, uncertainty_thresh): model_name: Name of the model to load. uncertainty_thresh: Threshold for uncertainty. """ - if model_name == "GroundingDINO": - from autodistill_grounding_dino import GroundingDINO + if "GroundingDINO" in model_name: + from coralnet_toolbox.AutoDistill.Models.GroundingDINOModel import GroundingDINO + + model = model_name.split("-")[1].strip() self.model_name = model_name self.loaded_model = GroundingDINO(ontology=self.ontology, box_threshold=0.025, - text_threshold=0.025) + text_threshold=0.025, + model=model) def predict(self, image_paths=None): """ @@ -427,10 +428,10 @@ def predict(self, image_paths=None): # Create a results processor results_processor = ResultsProcessor(self.main_window, self.class_mapping, - uncertainty_thresh=self.uncertainty_thresh, - iou_thresh=self.iou_thresh, - min_area_thresh=self.area_thresh_min, - max_area_thresh=self.area_thresh_max) + uncertainty_thresh=self.main_window.get_uncertainty_thresh(), + iou_thresh=self.main_window.get_iou_thresh(), + min_area_thresh=self.main_window.get_area_thresh_min(), + max_area_thresh=self.main_window.get_area_thresh_max()) results = results_processor.from_supervision(results, image, image_path, self.class_mapping) diff --git a/coralnet_toolbox/MachineLearning/DeployModel/QtBase.py b/coralnet_toolbox/MachineLearning/DeployModel/QtBase.py index c804c12b..85a99710 100644 --- a/coralnet_toolbox/MachineLearning/DeployModel/QtBase.py +++ b/coralnet_toolbox/MachineLearning/DeployModel/QtBase.py @@ -336,13 +336,6 @@ def create_generic_labels(self): ) label = self.label_window.get_label_by_short_code(class_name) self.class_mapping[class_name] = label.to_dict() - - def get_uncertainty_threshold(self): - """ - Get the confidence threshold for classification predictions - """ - threshold = self.main_window.get_uncertainty_thresh() - return threshold if threshold > 0.10 else 0.10 def predict(self, inputs): """ diff --git a/coralnet_toolbox/MachineLearning/DeployModel/QtClassify.py b/coralnet_toolbox/MachineLearning/DeployModel/QtClassify.py index 4aa0b8b1..bf7dca67 100644 --- a/coralnet_toolbox/MachineLearning/DeployModel/QtClassify.py +++ b/coralnet_toolbox/MachineLearning/DeployModel/QtClassify.py @@ -120,13 +120,13 @@ def predict(self, inputs=None): # Predict the classification results results = self.loaded_model(images_np, - conf=self.uncertainty_thresh, + conf=self.main_window.get_uncertainty_thresh(), device=self.main_window.device, stream=True) # Create a result processor results_processor = ResultsProcessor(self.main_window, self.class_mapping, - uncertainty_thresh=self.uncertainty_thresh) + uncertainty_thresh=self.main_window.get_uncertainty_thresh()) # Process the classification results results_processor.process_classification_results(results, inputs) diff --git a/coralnet_toolbox/MachineLearning/DeployModel/QtDetect.py b/coralnet_toolbox/MachineLearning/DeployModel/QtDetect.py index d0ef1f47..d0ef1733 100644 --- a/coralnet_toolbox/MachineLearning/DeployModel/QtDetect.py +++ b/coralnet_toolbox/MachineLearning/DeployModel/QtDetect.py @@ -143,18 +143,18 @@ def predict(self, inputs=None): # Predict the detection results results = self.loaded_model(inputs, agnostic_nms=True, - conf=self.uncertainty_thresh, - iou=self.iou_thresh, + conf=self.main_window.get_uncertainty_thresh(), + iou=self.main_window.get_iou_thresh(), device=self.main_window.device, stream=True) # Create a result processor results_processor = ResultsProcessor(self.main_window, self.class_mapping, - uncertainty_thresh=self.uncertainty_thresh, - iou_thresh=self.iou_thresh, - min_area_thresh=self.area_thresh_min, - max_area_thresh=self.area_thresh_max) + uncertainty_thresh=self.main_window.get_uncertainty_thresh(), + iou_thresh=self.main_window.get_iou_thresh(), + min_area_thresh=self.main_window.get_area_thresh_min(), + max_area_thresh=self.main_window.get_area_thresh_max()) # Check if SAM model is deployed if self.use_sam_dropdown.currentText() == "True": diff --git a/coralnet_toolbox/MachineLearning/DeployModel/QtSegment.py b/coralnet_toolbox/MachineLearning/DeployModel/QtSegment.py index 99207070..b643f0dc 100644 --- a/coralnet_toolbox/MachineLearning/DeployModel/QtSegment.py +++ b/coralnet_toolbox/MachineLearning/DeployModel/QtSegment.py @@ -143,18 +143,19 @@ def predict(self, inputs=None): # Predict the segmentation results results = self.loaded_model(inputs, agnostic_nms=True, - conf=self.uncertainty_thresh, - iou=self.iou_thresh, + conf=self.main_window.get_uncertainty_thresh(), + iou=self.main_window.get_iou_thresh(), device=self.main_window.device, stream=True) # Create a result processor results_processor = ResultsProcessor(self.main_window, self.class_mapping, - uncertainty_thresh=self.uncertainty_thresh, - iou_thresh=self.iou_thresh, - min_area_thresh=self.area_thresh_min, - max_area_thresh=self.area_thresh_max) + uncertainty_thresh=self.main_window.get_uncertainty_thresh(), + iou_thresh=self.main_window.get_iou_thresh(), + min_area_thresh=self.main_window.get_area_thresh_min(), + max_area_thresh=self.main_window.get_area_thresh_max()) + # Check if SAM model is deployed if self.use_sam_dropdown.currentText() == "True": # Apply SAM to the segmentation results diff --git a/coralnet_toolbox/QtLabelWindow.py b/coralnet_toolbox/QtLabelWindow.py index cbfb8de4..fefffa69 100644 --- a/coralnet_toolbox/QtLabelWindow.py +++ b/coralnet_toolbox/QtLabelWindow.py @@ -309,7 +309,6 @@ def set_label_transparency(self, transparency): # Update the active label's transparency self.active_label.update_transparency(transparency) # Update the transparency of all annotations with the active label - scene = self.annotation_window.scene for annotation in self.annotation_window.annotations_dict.values(): if annotation.label.id == self.active_label.id: annotation.update_transparency(transparency) diff --git a/coralnet_toolbox/QtMainWindow.py b/coralnet_toolbox/QtMainWindow.py index b82945f9..33dcefbe 100644 --- a/coralnet_toolbox/QtMainWindow.py +++ b/coralnet_toolbox/QtMainWindow.py @@ -947,6 +947,14 @@ def get_area_thresh(self): """Get the current area threshold values""" return self.area_thresh_min, self.area_thresh_max + def get_area_thresh_min(self): + """Get the current minimum area threshold value""" + return self.area_thresh_min + + def get_area_thresh_max(self): + """Get the current maximum area threshold value""" + return self.area_thresh_max + def update_area_thresh(self, min_val, max_val): """Update the area threshold values""" if self.area_thresh_min != min_val or self.area_thresh_max != max_val: diff --git a/coralnet_toolbox/ResultsProcessor.py b/coralnet_toolbox/ResultsProcessor.py index 321c493a..adb059fb 100644 --- a/coralnet_toolbox/ResultsProcessor.py +++ b/coralnet_toolbox/ResultsProcessor.py @@ -42,19 +42,40 @@ def filter_by_uncertainty(self, results): """ Filter the results based on the uncertainty threshold. """ - return results[results.boxes.conf > self.uncertainty_thresh] + try: + if isinstance(results, list): + results = results[0] + results = results[results.boxes.conf > self.uncertainty_thresh] + except Exception as e: + print(f"Warning: Failed to filter results by uncertainty\n{e}") + + return results def filter_by_iou(self, results): """Filter the results based on the IoU threshold.""" - return results[nms(results.boxes.xyxy, results.boxes.conf, self.iou_thresh)] + try: + if isinstance(results, list): + results = results[0] + results = results[nms(results.boxes.xyxy, results.boxes.conf, self.iou_thresh)] + except Exception as e: + print(f"Warning: Failed to filter results by IoU\n{e}") + + return results def filter_by_area(self, results): """ Filter the results based on the area threshold. """ - x_norm, y_norm, w_norm, h_norm = results.boxes.xywhn.T - area_norm = w_norm * h_norm - return results[(area_norm > self.min_area_thresh) & (area_norm < self.max_area_thresh)] + try: + if isinstance(results, list): + results = results[0] + x_norm, y_norm, w_norm, h_norm = results.boxes.xywhn.T + area_norm = w_norm * h_norm + results = results[(area_norm > self.min_area_thresh) & (area_norm < self.max_area_thresh)] + except Exception as e: + print(f"Warning: Failed to filter results by area\n{e}") + + return results def apply_filters(self, results): """Check if the results passed all filters.""" diff --git a/coralnet_toolbox/SAM/QtDeployGenerator.py b/coralnet_toolbox/SAM/QtDeployGenerator.py index d08076b1..3bb4076d 100644 --- a/coralnet_toolbox/SAM/QtDeployGenerator.py +++ b/coralnet_toolbox/SAM/QtDeployGenerator.py @@ -297,7 +297,7 @@ def update_area_label(self): self.area_thresh_min = min_val / 100.0 self.area_thresh_max = max_val / 100.0 self.main_window.update_area_thresh(self.area_thresh_min, self.area_thresh_max) - self.area_threshold_label.setText(f"{self.area_thresh_min:.2f} - {self.area_thresh_max:.2f}") + self.area_threshold_label.setText(f"{self.area_thresh_min:.2f} - {self.area_thresh_max:.2f}") def load_model(self): """ @@ -358,7 +358,6 @@ def predict(self, image_paths=None): image_paths: List of image paths to process. If None, uses the current image. """ if not self.loaded_model: - QMessageBox.critical(self, "Error", "No model loaded") return if not image_paths: @@ -387,17 +386,17 @@ def predict(self, image_paths=None): if self.use_sam_dropdown.currentText() == "True": # Apply SAM to the detection results results = self.sam_dialog.predict_from_results(results, self.class_mapping) + + # Update the progress bar + progress_bar.update_progress() # Create a results processor results_processor = ResultsProcessor(self.main_window, self.class_mapping, - uncertainty_thresh=self.uncertainty_thresh, - iou_thresh=self.iou_thresh, - min_area_thresh=self.area_thresh_min, - max_area_thresh=self.area_thresh_max) - - # Update the progress bar - progress_bar.update_progress() + uncertainty_thresh=self.main_window.get_uncertainty_thresh(), + iou_thresh=self.main_window.get_iou_thresh(), + min_area_thresh=self.main_window.get_area_thresh_min(), + max_area_thresh=self.main_window.get_area_thresh_max()) if self.task.lower() == 'segment' or self.use_sam_dropdown.currentText() == "True": results_processor.process_segmentation_results(results) diff --git a/coralnet_toolbox/SAM/QtDeployPredictor.py b/coralnet_toolbox/SAM/QtDeployPredictor.py index eaf8fa42..28ec24f2 100644 --- a/coralnet_toolbox/SAM/QtDeployPredictor.py +++ b/coralnet_toolbox/SAM/QtDeployPredictor.py @@ -516,10 +516,10 @@ def predict_from_prompts(self, bbox, points, labels): # Create a results processor results_processor = ResultsProcessor(self.main_window, class_mapping=None, - uncertainty_thresh=self.uncertainty_thresh, - iou_thresh=self.iou_thresh, - min_area_thresh=self.area_thresh_min, - max_area_thresh=self.area_thresh_max) + uncertainty_thresh=self.main_window.get_uncertainty_thresh(), + iou_thresh=self.main_window.get_iou_thresh(), + min_area_thresh=self.main_window.get_area_thresh_min(), + max_area_thresh=self.main_window.get_area_thresh_max()) # Post-process the results results = results_processor.from_sam(masks, scores, self.original_image, self.image_path) @@ -540,23 +540,25 @@ def predict_from_results(self, results_generator, class_mapping): # Create a result processor result_processor = ResultsProcessor(self.main_window, class_mapping=class_mapping, - uncertainty_thresh=self.uncertainty_thresh, - iou_thresh=self.iou_thresh, - min_area_thresh=self.area_thresh_min, - max_area_thresh=self.area_thresh_max) + uncertainty_thresh=self.main_window.get_uncertainty_thresh(), + iou_thresh=self.main_window.get_iou_thresh(), + min_area_thresh=self.main_window.get_area_thresh_min(), + max_area_thresh=self.main_window.get_area_thresh_max()) results_dict = {} for results in results_generator: + results = result_processor.apply_filters(results) for result in results: - # Extract the results - image_path, cls_id, cls_name, conf, *bbox = result_processor.extract_detection_result(result) + if result: + # Extract the results + image_path, cls_id, cls_name, conf, *bbox = result_processor.extract_detection_result(result) - if image_path not in results_dict: - results_dict[image_path] = [] + if image_path not in results_dict: + results_dict[image_path] = [] - # Add the results to the dictionary - results_dict[image_path].append(np.array(bbox)) + # Add the results to the dictionary + results_dict[image_path].append(np.array(bbox)) # Loop through each unique image path for image_path in results_dict: diff --git a/coralnet_toolbox/Tools/QtSAMTool.py b/coralnet_toolbox/Tools/QtSAMTool.py index 54ea185c..aec56f82 100644 --- a/coralnet_toolbox/Tools/QtSAMTool.py +++ b/coralnet_toolbox/Tools/QtSAMTool.py @@ -386,7 +386,7 @@ def create_annotation(self, scene_pos: QPointF, finished: bool = False): if not results: return None - if results.boxes.conf[0] < self.sam_dialog.uncertainty_thresh: + if results.boxes.conf[0] < self.annotation_window.main_window.get_uncertainty_thresh(): return None # TODO use results processor diff --git a/docs/index.md b/docs/index.md index 3dbcd4ec..049176ba 100644 --- a/docs/index.md +++ b/docs/index.md @@ -43,13 +43,13 @@ computer vision and deep learning built in `PyTorch`. For more information on th [here](https://github.com/ultralytics/ultralytics/blob/main/LICENSE). The `toolbox` also uses the following to create rectangle and polygon annotations: -- [`Fast-SAM`]() -- [`RepViT-SAM`]() +- [`Fast-SAM`](https://github.com/CASIA-IVA-Lab/FastSAM) +- [`RepViT-SAM`](https://github.com/THU-MIG/RepViT) - [`EdgeSAM`](https://github.com/chongzhou96/EdgeSAM) - [`MobileSAM`](https://github.com/ChaoningZhang/MobileSAM) - [`SAM`](https://github.com/facebookresearch/segment-anything) - [`AutoDistill`](https://github.com/autodistill) - - [`GroundingDino`]() + - [`GroundingDino`](https://github.com/IDEA-Research/GroundingDINO) ## Tools @@ -61,8 +61,8 @@ Enhance your CoralNet experience with these tools: - ๐Ÿงฉ Patches: Create patches (points) - ๐Ÿ”ณ Rectangles: Create rectangles (bounding boxes) - ๐ŸŸฃ Polygons: Create polygons (instance masks) -- ๐Ÿฆพ SAM: Use [`FastSAM`](), [`RepViT-SAM`](), [`EdgeSAM`](), [`MobileSAM`](), and [`SAM`]() to create polygons -- ๐Ÿงช AutoDistill: Use [`AutoDistill`](https://github.com/autodistill) to access `GroundingDINO` for creating rectangles +- ๐Ÿฆพ SAM: Use `FastSAM`, `RepViT-SAM`, `EdgeSAM`, `MobileSAM`, and `SAM` to create polygons +- ๐Ÿงช AutoDistill: Use `AutoDistill` to access `GroundingDINO` for creating rectangles - ๐Ÿง  Train: Build local patch-based classifiers, object detection, and instance segmentation models - ๐Ÿ”ฎ Deploy: Use trained models for predictions - ๐Ÿ“Š Evaluation: Evaluate model performance diff --git a/docs/installation.md b/docs/installation.md index 876eb9bf..ecac4e80 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -41,9 +41,10 @@ If `CUDA` is installed on your computer, and `torch` was built with it properly, `toolbox` instead of a `๐Ÿข`; if you have multiple `CUDA` devices available, you should see a `๐Ÿš€` icon, and if you're using a Mac with `Metal`, you should see an `๐ŸŽ` icon (click on the icon to see the device information). -See here for more details on [`cuda-nvcc`](https://anaconda.org/nvidia/cuda-nvcc), -[`cudatoolkit`](https://anaconda.org/nvidia/cuda-toolkit), and [`torch`](https://pytorch.org/get-started/locally/) -versions. +See here for more details on versions for the following: +- [`cuda-nvcc`](https://anaconda.org/nvidia/cuda-nvcc) +- [`cudatoolkit`](https://anaconda.org/nvidia/cuda-toolkit) +- [`torch`](https://pytorch.org/get-started/locally/) ### Run Finally, you can run the `toolbox` from the command line: From 6a4bead41372e66b0f63671cf5fc98b3c6cfd061 Mon Sep 17 00:00:00 2001 From: Jordan Pierce <115024024+Jordan-Pierce@users.noreply.github.com> Date: Tue, 3 Dec 2024 23:54:29 -0500 Subject: [PATCH 2/2] =?UTF-8?q?Bump=20version:=200.0.4=20=E2=86=92=200.0.5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- coralnet_toolbox/__init__.py | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/coralnet_toolbox/__init__.py b/coralnet_toolbox/__init__.py index 616c4660..a092dfc5 100644 --- a/coralnet_toolbox/__init__.py +++ b/coralnet_toolbox/__init__.py @@ -2,7 +2,7 @@ from coralnet_toolbox.main import run -__version__ = "0.0.4" +__version__ = "0.0.5" __author__ = "Jordan Pierce" __email__ = "jordan.pierce@noaa.gov" __credits__ = "National Center for Coastal and Ocean Sciences (NCCOS)" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index cc051bd8..dd708fc4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "coralnet-toolbox" -version = "0.0.4" +version = "0.0.5" dynamic = [ "dependencies", ] @@ -48,7 +48,7 @@ universal = true [tool.bumpversion] -current_version = "0.0.4" +current_version = "0.0.5" commit = true tag = true