Skip to content

Commit

Permalink
Merge pull request #81 from Jordan-Pierce/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
Jordan-Pierce authored Dec 4, 2024
2 parents 9e8ee3f + 6a4bead commit 3f60c7f
Show file tree
Hide file tree
Showing 18 changed files with 238 additions and 84 deletions.
22 changes: 13 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
<img src="https://raw.githubusercontent.com/Jordan-Pierce/CoralNet-Toolbox/refs/heads/main/figures/CoralNet_Toolbox.png" alt="CoralNet-Toolbox">
</p>

[![image](https://img.shields.io/pypi/v/CoralNet-Toolbox.svg)](https://pypi.python.org/pypi/CoralNet-Toolbox)
[![version](https://img.shields.io/pypi/v/CoralNet-Toolbox.svg)](https://pypi.python.org/pypi/CoralNet-Toolbox)
[![python-version](https://img.shields.io/pypi/pyversions/CoralNet-Toolbox.svg)](https://pypi.org/project/CoralNet-Toolbox)



## Quick Start
Expand Down Expand Up @@ -43,13 +45,13 @@ computer vision and deep learning built in `PyTorch`. For more information on th
[here](https://github.com/ultralytics/ultralytics/blob/main/LICENSE).

The `toolbox` also uses the following to create rectangle and polygon annotations:
- [`Fast-SAM`]()
- [`RepViT-SAM`]()
- [`Fast-SAM`](https://github.com/CASIA-IVA-Lab/FastSAM)
- [`RepViT-SAM`](https://github.com/THU-MIG/RepViT)
- [`EdgeSAM`](https://github.com/chongzhou96/EdgeSAM)
- [`MobileSAM`](https://github.com/ChaoningZhang/MobileSAM)
- [`SAM`](https://github.com/facebookresearch/segment-anything)
- [`AutoDistill`](https://github.com/autodistill)
- [`GroundingDino`]()
- [`GroundingDino`](https://github.com/IDEA-Research/GroundingDINO)


## Tools
Expand All @@ -61,8 +63,8 @@ Enhance your CoralNet experience with these tools:
- 🧩 Patches: Create patches (points)
- 🔳 Rectangles: Create rectangles (bounding boxes)
- 🟣 Polygons: Create polygons (instance masks)
- 🦾 SAM: Use [`FastSAM`](), [`RepViT-SAM`](), [`EdgeSAM`](), [`MobileSAM`](), and [`SAM`]() to create polygons
- 🧪 AutoDistill: Use [`AutoDistill`](https://github.com/autodistill) to access `GroundingDINO` for creating rectangles
- 🦾 SAM: Use `FastSAM`, `RepViT-SAM`, `EdgeSAM`, `MobileSAM`, and `SAM` to create polygons
- 🧪 AutoDistill: Use `AutoDistill` to access `GroundingDINO` for creating rectangles
- 🧠 Train: Build local patch-based classifiers, object detection, and instance segmentation models
- 🔮 Deploy: Use trained models for predictions
- 📊 Evaluation: Evaluate model performance
Expand Down Expand Up @@ -121,9 +123,11 @@ If `CUDA` is installed on your computer, and `torch` was built with it properly,
`toolbox` instead of a `🐢`; if you have multiple `CUDA` devices available, you should see a `🚀` icon,
and if you're using a Mac with `Metal`, you should see an `🍎` icon (click on the icon to see the device information).

See here for more details on [`cuda-nvcc`](https://anaconda.org/nvidia/cuda-nvcc),
[`cudatoolkit`](https://anaconda.org/nvidia/cuda-toolkit), and [`torch`](https://pytorch.org/get-started/locally/)
versions.
See here for more details on versions for the following:
- [`cuda-nvcc`](https://anaconda.org/nvidia/cuda-nvcc)
- [`cudatoolkit`](https://anaconda.org/nvidia/cuda-toolkit)
- [`torch`](https://pytorch.org/get-started/locally/)


### Run
Finally, you can run the `toolbox` from the command line:
Expand Down
120 changes: 120 additions & 0 deletions coralnet_toolbox/AutoDistill/Models/GroundingDINOModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import os
import urllib.request
from dataclasses import dataclass

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch

torch.use_deterministic_algorithms(False)

import supervision as sv
from autodistill.detection import CaptionOntology, DetectionBaseModel
from autodistill.helpers import load_image
from groundingdino.util.inference import Model

from autodistill_grounding_dino.helpers import (combine_detections)

HOME = os.path.expanduser("~")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ----------------------------------------------------------------------------------------------------------------------
# Functions
# ----------------------------------------------------------------------------------------------------------------------


def load_grounding_dino(model="SwinT"):
"""Load the grounding DINO model."""
# Define the paths
AUTODISTILL_CACHE_DIR = os.path.expanduser("~/.cache/autodistill")
GROUDNING_DINO_CACHE_DIR = os.path.join(AUTODISTILL_CACHE_DIR, "groundingdino")

if model == "SwinT":
GROUNDING_DINO_CONFIG_PATH = os.path.join(GROUDNING_DINO_CACHE_DIR, "GroundingDINO_SwinT_OGC.py")
GROUNDING_DINO_CHECKPOINT_PATH = os.path.join(GROUDNING_DINO_CACHE_DIR, "groundingdino_swint_ogc.pth")
else:
GROUNDING_DINO_CONFIG_PATH = os.path.join(GROUDNING_DINO_CACHE_DIR, "GroundingDINO_SwinB_OGC.py")
GROUNDING_DINO_CHECKPOINT_PATH = os.path.join(GROUDNING_DINO_CACHE_DIR, "groundingdino_swinb_cogcoor.pth")

try:
print("trying to load grounding dino directly")
grounding_dino_model = Model(
model_config_path=GROUNDING_DINO_CONFIG_PATH,
model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH,
device=DEVICE,
)
return grounding_dino_model

except Exception:
print("downloading dino model weights")
if not os.path.exists(GROUDNING_DINO_CACHE_DIR):
os.makedirs(GROUDNING_DINO_CACHE_DIR)

if model == "SwinT":
if not os.path.exists(GROUNDING_DINO_CHECKPOINT_PATH):
url = "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth"
urllib.request.urlretrieve(url, GROUNDING_DINO_CHECKPOINT_PATH)

if not os.path.exists(GROUNDING_DINO_CONFIG_PATH):
url = "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinT_OGC.py"
urllib.request.urlretrieve(url, GROUNDING_DINO_CONFIG_PATH)
else:
if not os.path.exists(GROUNDING_DINO_CHECKPOINT_PATH):
url = "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth"
urllib.request.urlretrieve(url, GROUNDING_DINO_CHECKPOINT_PATH)

if not os.path.exists(GROUNDING_DINO_CONFIG_PATH):
url = "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinB_cfg.py"
urllib.request.urlretrieve(url, GROUNDING_DINO_CONFIG_PATH)

grounding_dino_model = Model(
model_config_path=GROUNDING_DINO_CONFIG_PATH,
model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH,
device=DEVICE,
)

return grounding_dino_model


# ----------------------------------------------------------------------------------------------------------------------
# Classes
# ----------------------------------------------------------------------------------------------------------------------


@dataclass
class GroundingDINO(DetectionBaseModel):
ontology: CaptionOntology
grounding_dino_model: Model
box_threshold: float
text_threshold: float

def __init__(
self, ontology: CaptionOntology, box_threshold=0.35, text_threshold=0.25, model="SwinB",
):
self.ontology = ontology
self.grounding_dino_model = load_grounding_dino(model)
self.box_threshold = box_threshold
self.text_threshold = text_threshold

def predict(self, input: str) -> sv.Detections:
image = load_image(input, return_format="cv2")

detections_list = []

for _, description in enumerate(self.ontology.prompts()):
detections = self.grounding_dino_model.predict_with_classes(
image=image,
classes=[description],
box_threshold=self.box_threshold,
text_threshold=self.text_threshold,
)

detections_list.append(detections)

detections = combine_detections(
detections_list, overwrite_class_ids=range(len(detections_list))
)

return detections
5 changes: 5 additions & 0 deletions coralnet_toolbox/AutoDistill/Models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# coralnet_toolbox/AutoDistill/Models/__init__.py

from .GroundingDINOModel import GroundingDINO

__all__ = ["GroundingDINO"]
27 changes: 14 additions & 13 deletions coralnet_toolbox/AutoDistill/QtDeployModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from PyQt5.QtWidgets import (QApplication, QComboBox, QDialog,
QFormLayout, QHBoxLayout, QLabel, QLineEdit,
QMessageBox, QPushButton, QSlider, QVBoxLayout, QGroupBox)

from torch.cuda import empty_cache

from coralnet_toolbox.QtProgressBar import ProgressBar
Expand Down Expand Up @@ -112,7 +113,7 @@ def setup_models_layout(self):
layout = QVBoxLayout()

self.model_dropdown = QComboBox()
self.model_dropdown.addItems(["GroundingDINO"])
self.model_dropdown.addItems(["GroundingDINO-SwinT", "GroundingDINO-SwinB"])
layout.addWidget(self.model_dropdown)

group_box.setLayout(layout)
Expand Down Expand Up @@ -343,14 +344,11 @@ def load_model(self):
# Set the class mapping
self.class_mapping = {k: v for k, v in enumerate(self.ontology.classes())}

# Threshold for confidence
uncertainty_thresh = self.get_uncertainty_threshold()

# Get the name of the model to load
model_name = self.model_dropdown.currentText()

if model_name != self.model_name:
self.load_new_model(model_name, uncertainty_thresh)
self.load_new_model(model_name)
else:
# Update the model with the new ontology
self.loaded_model.ontology = self.ontology
Expand Down Expand Up @@ -382,20 +380,23 @@ def get_ontology_mapping(self):
ontology_mapping[text_input.text()] = label_dropdown.currentText()
return ontology_mapping

def load_new_model(self, model_name, uncertainty_thresh):
def load_new_model(self, model_name):
"""
Load a new model based on the selected model name.
Args:
model_name: Name of the model to load.
uncertainty_thresh: Threshold for uncertainty.
"""
if model_name == "GroundingDINO":
from autodistill_grounding_dino import GroundingDINO
if "GroundingDINO" in model_name:
from coralnet_toolbox.AutoDistill.Models.GroundingDINOModel import GroundingDINO

model = model_name.split("-")[1].strip()
self.model_name = model_name
self.loaded_model = GroundingDINO(ontology=self.ontology,
box_threshold=0.025,
text_threshold=0.025)
text_threshold=0.025,
model=model)

def predict(self, image_paths=None):
"""
Expand Down Expand Up @@ -427,10 +428,10 @@ def predict(self, image_paths=None):
# Create a results processor
results_processor = ResultsProcessor(self.main_window,
self.class_mapping,
uncertainty_thresh=self.uncertainty_thresh,
iou_thresh=self.iou_thresh,
min_area_thresh=self.area_thresh_min,
max_area_thresh=self.area_thresh_max)
uncertainty_thresh=self.main_window.get_uncertainty_thresh(),
iou_thresh=self.main_window.get_iou_thresh(),
min_area_thresh=self.main_window.get_area_thresh_min(),
max_area_thresh=self.main_window.get_area_thresh_max())

results = results_processor.from_supervision(results, image, image_path, self.class_mapping)

Expand Down
7 changes: 0 additions & 7 deletions coralnet_toolbox/MachineLearning/DeployModel/QtBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,13 +336,6 @@ def create_generic_labels(self):
)
label = self.label_window.get_label_by_short_code(class_name)
self.class_mapping[class_name] = label.to_dict()

def get_uncertainty_threshold(self):
"""
Get the confidence threshold for classification predictions
"""
threshold = self.main_window.get_uncertainty_thresh()
return threshold if threshold > 0.10 else 0.10

def predict(self, inputs):
"""
Expand Down
4 changes: 2 additions & 2 deletions coralnet_toolbox/MachineLearning/DeployModel/QtClassify.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ def predict(self, inputs=None):

# Predict the classification results
results = self.loaded_model(images_np,
conf=self.uncertainty_thresh,
conf=self.main_window.get_uncertainty_thresh(),
device=self.main_window.device,
stream=True)
# Create a result processor
results_processor = ResultsProcessor(self.main_window,
self.class_mapping,
uncertainty_thresh=self.uncertainty_thresh)
uncertainty_thresh=self.main_window.get_uncertainty_thresh())

# Process the classification results
results_processor.process_classification_results(results, inputs)
Expand Down
12 changes: 6 additions & 6 deletions coralnet_toolbox/MachineLearning/DeployModel/QtDetect.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,18 +143,18 @@ def predict(self, inputs=None):
# Predict the detection results
results = self.loaded_model(inputs,
agnostic_nms=True,
conf=self.uncertainty_thresh,
iou=self.iou_thresh,
conf=self.main_window.get_uncertainty_thresh(),
iou=self.main_window.get_iou_thresh(),
device=self.main_window.device,
stream=True)

# Create a result processor
results_processor = ResultsProcessor(self.main_window,
self.class_mapping,
uncertainty_thresh=self.uncertainty_thresh,
iou_thresh=self.iou_thresh,
min_area_thresh=self.area_thresh_min,
max_area_thresh=self.area_thresh_max)
uncertainty_thresh=self.main_window.get_uncertainty_thresh(),
iou_thresh=self.main_window.get_iou_thresh(),
min_area_thresh=self.main_window.get_area_thresh_min(),
max_area_thresh=self.main_window.get_area_thresh_max())

# Check if SAM model is deployed
if self.use_sam_dropdown.currentText() == "True":
Expand Down
13 changes: 7 additions & 6 deletions coralnet_toolbox/MachineLearning/DeployModel/QtSegment.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,18 +143,19 @@ def predict(self, inputs=None):
# Predict the segmentation results
results = self.loaded_model(inputs,
agnostic_nms=True,
conf=self.uncertainty_thresh,
iou=self.iou_thresh,
conf=self.main_window.get_uncertainty_thresh(),
iou=self.main_window.get_iou_thresh(),
device=self.main_window.device,
stream=True)

# Create a result processor
results_processor = ResultsProcessor(self.main_window,
self.class_mapping,
uncertainty_thresh=self.uncertainty_thresh,
iou_thresh=self.iou_thresh,
min_area_thresh=self.area_thresh_min,
max_area_thresh=self.area_thresh_max)
uncertainty_thresh=self.main_window.get_uncertainty_thresh(),
iou_thresh=self.main_window.get_iou_thresh(),
min_area_thresh=self.main_window.get_area_thresh_min(),
max_area_thresh=self.main_window.get_area_thresh_max())

# Check if SAM model is deployed
if self.use_sam_dropdown.currentText() == "True":
# Apply SAM to the segmentation results
Expand Down
1 change: 0 additions & 1 deletion coralnet_toolbox/QtLabelWindow.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,6 @@ def set_label_transparency(self, transparency):
# Update the active label's transparency
self.active_label.update_transparency(transparency)
# Update the transparency of all annotations with the active label
scene = self.annotation_window.scene
for annotation in self.annotation_window.annotations_dict.values():
if annotation.label.id == self.active_label.id:
annotation.update_transparency(transparency)
Expand Down
8 changes: 8 additions & 0 deletions coralnet_toolbox/QtMainWindow.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,14 @@ def get_area_thresh(self):
"""Get the current area threshold values"""
return self.area_thresh_min, self.area_thresh_max

def get_area_thresh_min(self):
"""Get the current minimum area threshold value"""
return self.area_thresh_min

def get_area_thresh_max(self):
"""Get the current maximum area threshold value"""
return self.area_thresh_max

def update_area_thresh(self, min_val, max_val):
"""Update the area threshold values"""
if self.area_thresh_min != min_val or self.area_thresh_max != max_val:
Expand Down
Loading

0 comments on commit 3f60c7f

Please sign in to comment.