Skip to content

Commit

Permalink
Merge pull request #80 from Jordan-Pierce/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
Jordan-Pierce authored Dec 4, 2024
2 parents b7172bd + 33082fa commit 9e8ee3f
Show file tree
Hide file tree
Showing 11 changed files with 71 additions and 116 deletions.
19 changes: 3 additions & 16 deletions coralnet_toolbox/AutoDistill/QtDeployModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,18 +382,6 @@ def get_ontology_mapping(self):
ontology_mapping[text_input.text()] = label_dropdown.currentText()
return ontology_mapping

def get_uncertainty_threshold(self):
"""
Get the uncertainty threshold, limiting it to a maximum of 0.10.
Returns:
Adjusted uncertainty threshold value.
"""
if self.main_window.get_uncertainty_thresh() < 0.10:
return self.main_window.get_uncertainty_thresh()
else:
return 0.10 # Arbitrary value to prevent too many detections

def load_new_model(self, model_name, uncertainty_thresh):
"""
Load a new model based on the selected model name.
Expand All @@ -406,8 +394,8 @@ def load_new_model(self, model_name, uncertainty_thresh):
from autodistill_grounding_dino import GroundingDINO
self.model_name = model_name
self.loaded_model = GroundingDINO(ontology=self.ontology,
box_threshold=uncertainty_thresh,
text_threshold=uncertainty_thresh)
box_threshold=0.025,
text_threshold=0.025)

def predict(self, image_paths=None):
"""
Expand All @@ -417,7 +405,6 @@ def predict(self, image_paths=None):
image_paths: List of image paths to process. If None, uses the current image.
"""
if not self.loaded_model:
QMessageBox.critical(self, "Error", "No model loaded")
return

if not image_paths:
Expand All @@ -440,7 +427,7 @@ def predict(self, image_paths=None):
# Create a results processor
results_processor = ResultsProcessor(self.main_window,
self.class_mapping,
uncertainty_thresh=self.get_uncertainty_threshold(),
uncertainty_thresh=self.uncertainty_thresh,
iou_thresh=self.iou_thresh,
min_area_thresh=self.area_thresh_min,
max_area_thresh=self.area_thresh_max)
Expand Down
40 changes: 19 additions & 21 deletions coralnet_toolbox/MachineLearning/BatchInference/QtClassify.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,27 +101,21 @@ def preprocess_patch_annotations(self):
progress_bar.show()
progress_bar.start_progress(len(self.image_paths))

def crop(image_path, image_annotations):
# Crop the image based on the annotations
return self.annotation_window.crop_these_image_annotations(image_path, image_annotations)

# Group annotations by image path
groups = groupby(sorted(self.annotations, key=attrgetter('image_path')), key=attrgetter('image_path'))

with ThreadPoolExecutor() as executor:
future_to_image = {}
for path, group in groups:
future = executor.submit(crop, path, list(group))
future_to_image[future] = path

for future in as_completed(future_to_image):
image_path = future_to_image[future]
try:
self.prepared_patches.extend(future.result())
except Exception as exc:
print(f'{image_path} generated an exception: {exc}')
finally:
progress_bar.update_progress()
grouped_annotations = groupby(sorted(self.annotations, key=attrgetter('image_path')),
key=attrgetter('image_path'))

for image_path, group in grouped_annotations:
try:
# Process image annotations
image_annotations = list(group)
image_annotations = self.annotation_window.crop_these_image_annotations(image_path, image_annotations)
self.prepared_patches.extend(image_annotations)

except Exception as exc:
print(f'{image_path} generated an exception: {exc}')
finally:
progress_bar.update_progress()

progress_bar.stop_progress()
progress_bar.close()
Expand All @@ -147,4 +141,8 @@ def batch_inference(self):
progress_bar.update_progress()

progress_bar.stop_progress()
progress_bar.close()
progress_bar.close()

# Clear the list of annotations
self.annotations = []
self.prepared_patches = []
4 changes: 2 additions & 2 deletions coralnet_toolbox/MachineLearning/DeployModel/QtBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,10 @@ def create_generic_labels(self):

def get_uncertainty_threshold(self):
"""
Get the confidence threshold for predictions
Get the confidence threshold for classification predictions
"""
threshold = self.main_window.get_uncertainty_thresh()
return threshold if threshold < 0.10 else 0.10
return threshold if threshold > 0.10 else 0.10

def predict(self, inputs):
"""
Expand Down
6 changes: 5 additions & 1 deletion coralnet_toolbox/MachineLearning/DeployModel/QtClassify.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,23 @@ def predict(self, inputs=None):
if not inputs:
# If no annotations are selected, predict all annotations in the image
inputs = self.annotation_window.get_image_review_annotations()
if not inputs:
# If no annotations are available, return
return

images_np = []
for annotation in inputs:
images_np.append(pixmap_to_numpy(annotation.cropped_image))

# Predict the classification results
results = self.loaded_model(images_np,
conf=self.uncertainty_thresh,
device=self.main_window.device,
stream=True)
# Create a result processor
results_processor = ResultsProcessor(self.main_window,
self.class_mapping,
uncertainty_thresh=self.get_uncertainty_threshold())
uncertainty_thresh=self.uncertainty_thresh)

# Process the classification results
results_processor.process_classification_results(results, inputs)
Expand Down
4 changes: 2 additions & 2 deletions coralnet_toolbox/MachineLearning/DeployModel/QtDetect.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,15 @@ def predict(self, inputs=None):
# Predict the detection results
results = self.loaded_model(inputs,
agnostic_nms=True,
conf=self.get_uncertainty_threshold(),
conf=self.uncertainty_thresh,
iou=self.iou_thresh,
device=self.main_window.device,
stream=True)

# Create a result processor
results_processor = ResultsProcessor(self.main_window,
self.class_mapping,
uncertainty_thresh=self.get_uncertainty_threshold(),
uncertainty_thresh=self.uncertainty_thresh,
iou_thresh=self.iou_thresh,
min_area_thresh=self.area_thresh_min,
max_area_thresh=self.area_thresh_max)
Expand Down
4 changes: 2 additions & 2 deletions coralnet_toolbox/MachineLearning/DeployModel/QtSegment.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,15 @@ def predict(self, inputs=None):
# Predict the segmentation results
results = self.loaded_model(inputs,
agnostic_nms=True,
conf=self.get_uncertainty_threshold(),
conf=self.uncertainty_thresh,
iou=self.iou_thresh,
device=self.main_window.device,
stream=True)

# Create a result processor
results_processor = ResultsProcessor(self.main_window,
self.class_mapping,
uncertainty_thresh=self.get_uncertainty_threshold(),
uncertainty_thresh=self.uncertainty_thresh,
iou_thresh=self.iou_thresh,
min_area_thresh=self.area_thresh_min,
max_area_thresh=self.area_thresh_max)
Expand Down
25 changes: 6 additions & 19 deletions coralnet_toolbox/QtMainWindow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,25 +1145,6 @@ def open_segment_deploy_model_dialog(self):
self.segment_deploy_model_dialog.exec_()
except Exception as e:
QMessageBox.critical(self, "Critical Error", f"{e}")

def open_batch_inference_dialog(self):
if not self.image_window.image_paths:
QMessageBox.warning(self,
"Batch Inference",
"No images are present in the project.")
return

if not any(list(self.deploy_model_dialog.loaded_models.values())):
QMessageBox.warning(self,
"Batch Inference",
"Please deploy a model before running batch inference.")
return

try:
self.untoggle_all_tools()
self.batch_inference_dialog.exec_()
except Exception as e:
QMessageBox.critical(self, "Critical Error", f"{e}")

def open_classify_batch_inference_dialog(self):
if not self.image_window.image_paths:
Expand All @@ -1177,6 +1158,12 @@ def open_classify_batch_inference_dialog(self):
"Batch Inference",
"Please deploy a model before running batch inference.")
return

if not self.annotation_window.annotations_dict:
QMessageBox.warning(self,
"Batch Inference",
"No annotations are present in the project.")
return

try:
self.untoggle_all_tools()
Expand Down
56 changes: 23 additions & 33 deletions coralnet_toolbox/ResultsProcessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from PyQt5.QtCore import QPointF

from torchvision.ops import nms
from ultralytics.engine.results import Results
from ultralytics.models.sam.amg import batched_mask_to_box
from ultralytics.utils import ops
Expand All @@ -23,9 +24,9 @@ def __init__(self,
main_window,
class_mapping,
uncertainty_thresh=0.3,
iou_thresh=0.7,
min_area_thresh=0.01,
max_area_thresh=0.5):
iou_thresh=0.2,
min_area_thresh=0.00,
max_area_thresh=0.40):
self.main_window = main_window
self.label_window = main_window.label_window
self.annotation_window = main_window.annotation_window
Expand All @@ -37,39 +38,30 @@ def __init__(self,
self.min_area_thresh = min_area_thresh
self.max_area_thresh = max_area_thresh

def filter_by_uncertainty(self, result):
def filter_by_uncertainty(self, results):
"""
Filter the results based on the uncertainty threshold.
"""
# Get the confidence score
conf = float(result.boxes.conf.cpu().numpy()[0])
# Check if the confidence is within the threshold
if conf < self.uncertainty_thresh:
return False
return True
return results[results.boxes.conf > self.uncertainty_thresh]

def filter_by_area(self, result):
def filter_by_iou(self, results):
"""Filter the results based on the IoU threshold."""
return results[nms(results.boxes.xyxy, results.boxes.conf, self.iou_thresh)]

def filter_by_area(self, results):
"""
Filter the results based on the area threshold.
"""
# Get the normalized bounding box coordinates
x_norm, y_norm, w_norm, h_norm = map(float, result.boxes.xywhn.cpu().numpy()[0])
# Calculate the normalized area
x_norm, y_norm, w_norm, h_norm = results.boxes.xywhn.T
area_norm = w_norm * h_norm
# Check if the area is within the threshold
if area_norm < self.min_area_thresh:
return False
if area_norm > self.max_area_thresh:
return False
return True
return results[(area_norm > self.min_area_thresh) & (area_norm < self.max_area_thresh)]

def passed_filters(self, result):
def apply_filters(self, results):
"""Check if the results passed all filters."""
if not self.filter_by_uncertainty(result):
return False
if not self.filter_by_area(result):
return False
return True
results = self.filter_by_uncertainty(results)
results = self.filter_by_iou(results)
results = self.filter_by_area(results)
return results

def extract_classification_result(self, result):
"""
Expand Down Expand Up @@ -141,10 +133,7 @@ def extract_detection_result(self, result):
def process_single_detection_result(self, result):
"""
Process a single detection result.
"""
if not self.passed_filters(result):
return

"""
# Get image path, class, class name, confidence, and bounding box coordinates
image_path, cls, cls_name, conf, x_min, y_min, x_max, y_max = self.extract_detection_result(result)
# Get the short label given the class name and confidence
Expand All @@ -166,6 +155,8 @@ def process_detection_results(self, results_generator):
progress_bar.show()

for results in results_generator:
# Apply filtering to the results
results = self.apply_filters(results)
for result in results:
if result:
self.process_single_detection_result(result)
Expand Down Expand Up @@ -193,9 +184,6 @@ def process_single_segmentation_result(self, result):
"""
Process a single segmentation result.
"""
if not self.passed_filters(result):
return

# Get image path, class, class name, confidence, and polygon points
image_path, cls, cls_name, conf, points = self.extract_segmentation_result(result)
# Get the short label given the class name and confidence
Expand All @@ -217,6 +205,8 @@ def process_segmentation_results(self, results_generator):
progress_bar.show()

for results in results_generator:
# Apply filtering to the results
results = self.apply_filters(results)
for result in results:
if result:
self.process_single_segmentation_result(result)
Expand Down
23 changes: 6 additions & 17 deletions coralnet_toolbox/SAM/QtDeployGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,12 @@ def load_model(self):
task=self.task,
mode='predict',
save=False,
max_det=1000,
max_det=500,
imgsz=self.get_imgsz(),
conf=self.get_uncertainty_threshold(),
iou=self.iou_thresh,
conf=0.05,
iou=1.0,
device=self.main_window.device)

# Load the model
self.loaded_model = FastSAMPredictor(overrides=overrides)

Expand All @@ -349,18 +350,6 @@ def get_imgsz(self):
self.imgsz = self.imgsz_spinbox.value()
return self.imgsz

def get_uncertainty_threshold(self):
"""
Get the uncertainty threshold, limiting it to a maximum of 0.10.
Returns:
Adjusted uncertainty threshold value.
"""
if self.main_window.get_uncertainty_thresh() < 0.10:
return self.main_window.get_uncertainty_thresh()
else:
return 0.10 # Arbitrary value to prevent too many detections

def predict(self, image_paths=None):
"""
Make predictions on the given image paths using the loaded model.
Expand Down Expand Up @@ -402,15 +391,15 @@ def predict(self, image_paths=None):
# Create a results processor
results_processor = ResultsProcessor(self.main_window,
self.class_mapping,
uncertainty_thresh=self.get_uncertainty_threshold(),
uncertainty_thresh=self.uncertainty_thresh,
iou_thresh=self.iou_thresh,
min_area_thresh=self.area_thresh_min,
max_area_thresh=self.area_thresh_max)

# Update the progress bar
progress_bar.update_progress()

if self.task == 'segment' or self.use_sam_dropdown.currentText() == "True":
if self.task.lower() == 'segment' or self.use_sam_dropdown.currentText() == "True":
results_processor.process_segmentation_results(results)
else:
results_processor.process_detection_results(results)
Expand Down
2 changes: 1 addition & 1 deletion coralnet_toolbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from coralnet_toolbox.main import run

__version__ = "0.0.3"
__version__ = "0.0.4"
__author__ = "Jordan Pierce"
__email__ = "[email protected]"
__credits__ = "National Center for Coastal and Ocean Sciences (NCCOS)"
Loading

0 comments on commit 9e8ee3f

Please sign in to comment.