Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #80

Merged
merged 2 commits into from
Dec 4, 2024
Merged

Dev #80

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 3 additions & 16 deletions coralnet_toolbox/AutoDistill/QtDeployModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,18 +382,6 @@ def get_ontology_mapping(self):
ontology_mapping[text_input.text()] = label_dropdown.currentText()
return ontology_mapping

def get_uncertainty_threshold(self):
"""
Get the uncertainty threshold, limiting it to a maximum of 0.10.

Returns:
Adjusted uncertainty threshold value.
"""
if self.main_window.get_uncertainty_thresh() < 0.10:
return self.main_window.get_uncertainty_thresh()
else:
return 0.10 # Arbitrary value to prevent too many detections

def load_new_model(self, model_name, uncertainty_thresh):
"""
Load a new model based on the selected model name.
Expand All @@ -406,8 +394,8 @@ def load_new_model(self, model_name, uncertainty_thresh):
from autodistill_grounding_dino import GroundingDINO
self.model_name = model_name
self.loaded_model = GroundingDINO(ontology=self.ontology,
box_threshold=uncertainty_thresh,
text_threshold=uncertainty_thresh)
box_threshold=0.025,
text_threshold=0.025)

def predict(self, image_paths=None):
"""
Expand All @@ -417,7 +405,6 @@ def predict(self, image_paths=None):
image_paths: List of image paths to process. If None, uses the current image.
"""
if not self.loaded_model:
QMessageBox.critical(self, "Error", "No model loaded")
return

if not image_paths:
Expand All @@ -440,7 +427,7 @@ def predict(self, image_paths=None):
# Create a results processor
results_processor = ResultsProcessor(self.main_window,
self.class_mapping,
uncertainty_thresh=self.get_uncertainty_threshold(),
uncertainty_thresh=self.uncertainty_thresh,
iou_thresh=self.iou_thresh,
min_area_thresh=self.area_thresh_min,
max_area_thresh=self.area_thresh_max)
Expand Down
40 changes: 19 additions & 21 deletions coralnet_toolbox/MachineLearning/BatchInference/QtClassify.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,27 +101,21 @@ def preprocess_patch_annotations(self):
progress_bar.show()
progress_bar.start_progress(len(self.image_paths))

def crop(image_path, image_annotations):
# Crop the image based on the annotations
return self.annotation_window.crop_these_image_annotations(image_path, image_annotations)

# Group annotations by image path
groups = groupby(sorted(self.annotations, key=attrgetter('image_path')), key=attrgetter('image_path'))

with ThreadPoolExecutor() as executor:
future_to_image = {}
for path, group in groups:
future = executor.submit(crop, path, list(group))
future_to_image[future] = path

for future in as_completed(future_to_image):
image_path = future_to_image[future]
try:
self.prepared_patches.extend(future.result())
except Exception as exc:
print(f'{image_path} generated an exception: {exc}')
finally:
progress_bar.update_progress()
grouped_annotations = groupby(sorted(self.annotations, key=attrgetter('image_path')),
key=attrgetter('image_path'))

for image_path, group in grouped_annotations:
try:
# Process image annotations
image_annotations = list(group)
image_annotations = self.annotation_window.crop_these_image_annotations(image_path, image_annotations)
self.prepared_patches.extend(image_annotations)

except Exception as exc:
print(f'{image_path} generated an exception: {exc}')
finally:
progress_bar.update_progress()

progress_bar.stop_progress()
progress_bar.close()
Expand All @@ -147,4 +141,8 @@ def batch_inference(self):
progress_bar.update_progress()

progress_bar.stop_progress()
progress_bar.close()
progress_bar.close()

# Clear the list of annotations
self.annotations = []
self.prepared_patches = []
4 changes: 2 additions & 2 deletions coralnet_toolbox/MachineLearning/DeployModel/QtBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,10 @@ def create_generic_labels(self):

def get_uncertainty_threshold(self):
"""
Get the confidence threshold for predictions
Get the confidence threshold for classification predictions
"""
threshold = self.main_window.get_uncertainty_thresh()
return threshold if threshold < 0.10 else 0.10
return threshold if threshold > 0.10 else 0.10

def predict(self, inputs):
"""
Expand Down
6 changes: 5 additions & 1 deletion coralnet_toolbox/MachineLearning/DeployModel/QtClassify.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,23 @@ def predict(self, inputs=None):
if not inputs:
# If no annotations are selected, predict all annotations in the image
inputs = self.annotation_window.get_image_review_annotations()
if not inputs:
# If no annotations are available, return
return

images_np = []
for annotation in inputs:
images_np.append(pixmap_to_numpy(annotation.cropped_image))

# Predict the classification results
results = self.loaded_model(images_np,
conf=self.uncertainty_thresh,
device=self.main_window.device,
stream=True)
# Create a result processor
results_processor = ResultsProcessor(self.main_window,
self.class_mapping,
uncertainty_thresh=self.get_uncertainty_threshold())
uncertainty_thresh=self.uncertainty_thresh)

# Process the classification results
results_processor.process_classification_results(results, inputs)
Expand Down
4 changes: 2 additions & 2 deletions coralnet_toolbox/MachineLearning/DeployModel/QtDetect.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,15 @@ def predict(self, inputs=None):
# Predict the detection results
results = self.loaded_model(inputs,
agnostic_nms=True,
conf=self.get_uncertainty_threshold(),
conf=self.uncertainty_thresh,
iou=self.iou_thresh,
device=self.main_window.device,
stream=True)

# Create a result processor
results_processor = ResultsProcessor(self.main_window,
self.class_mapping,
uncertainty_thresh=self.get_uncertainty_threshold(),
uncertainty_thresh=self.uncertainty_thresh,
iou_thresh=self.iou_thresh,
min_area_thresh=self.area_thresh_min,
max_area_thresh=self.area_thresh_max)
Expand Down
4 changes: 2 additions & 2 deletions coralnet_toolbox/MachineLearning/DeployModel/QtSegment.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,15 @@ def predict(self, inputs=None):
# Predict the segmentation results
results = self.loaded_model(inputs,
agnostic_nms=True,
conf=self.get_uncertainty_threshold(),
conf=self.uncertainty_thresh,
iou=self.iou_thresh,
device=self.main_window.device,
stream=True)

# Create a result processor
results_processor = ResultsProcessor(self.main_window,
self.class_mapping,
uncertainty_thresh=self.get_uncertainty_threshold(),
uncertainty_thresh=self.uncertainty_thresh,
iou_thresh=self.iou_thresh,
min_area_thresh=self.area_thresh_min,
max_area_thresh=self.area_thresh_max)
Expand Down
25 changes: 6 additions & 19 deletions coralnet_toolbox/QtMainWindow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,25 +1145,6 @@ def open_segment_deploy_model_dialog(self):
self.segment_deploy_model_dialog.exec_()
except Exception as e:
QMessageBox.critical(self, "Critical Error", f"{e}")

def open_batch_inference_dialog(self):
if not self.image_window.image_paths:
QMessageBox.warning(self,
"Batch Inference",
"No images are present in the project.")
return

if not any(list(self.deploy_model_dialog.loaded_models.values())):
QMessageBox.warning(self,
"Batch Inference",
"Please deploy a model before running batch inference.")
return

try:
self.untoggle_all_tools()
self.batch_inference_dialog.exec_()
except Exception as e:
QMessageBox.critical(self, "Critical Error", f"{e}")

def open_classify_batch_inference_dialog(self):
if not self.image_window.image_paths:
Expand All @@ -1177,6 +1158,12 @@ def open_classify_batch_inference_dialog(self):
"Batch Inference",
"Please deploy a model before running batch inference.")
return

if not self.annotation_window.annotations_dict:
QMessageBox.warning(self,
"Batch Inference",
"No annotations are present in the project.")
return

try:
self.untoggle_all_tools()
Expand Down
56 changes: 23 additions & 33 deletions coralnet_toolbox/ResultsProcessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from PyQt5.QtCore import QPointF

from torchvision.ops import nms
from ultralytics.engine.results import Results
from ultralytics.models.sam.amg import batched_mask_to_box
from ultralytics.utils import ops
Expand All @@ -23,9 +24,9 @@ def __init__(self,
main_window,
class_mapping,
uncertainty_thresh=0.3,
iou_thresh=0.7,
min_area_thresh=0.01,
max_area_thresh=0.5):
iou_thresh=0.2,
min_area_thresh=0.00,
max_area_thresh=0.40):
self.main_window = main_window
self.label_window = main_window.label_window
self.annotation_window = main_window.annotation_window
Expand All @@ -37,39 +38,30 @@ def __init__(self,
self.min_area_thresh = min_area_thresh
self.max_area_thresh = max_area_thresh

def filter_by_uncertainty(self, result):
def filter_by_uncertainty(self, results):
"""
Filter the results based on the uncertainty threshold.
"""
# Get the confidence score
conf = float(result.boxes.conf.cpu().numpy()[0])
# Check if the confidence is within the threshold
if conf < self.uncertainty_thresh:
return False
return True
return results[results.boxes.conf > self.uncertainty_thresh]

def filter_by_area(self, result):
def filter_by_iou(self, results):
"""Filter the results based on the IoU threshold."""
return results[nms(results.boxes.xyxy, results.boxes.conf, self.iou_thresh)]

def filter_by_area(self, results):
"""
Filter the results based on the area threshold.
"""
# Get the normalized bounding box coordinates
x_norm, y_norm, w_norm, h_norm = map(float, result.boxes.xywhn.cpu().numpy()[0])
# Calculate the normalized area
x_norm, y_norm, w_norm, h_norm = results.boxes.xywhn.T
area_norm = w_norm * h_norm
# Check if the area is within the threshold
if area_norm < self.min_area_thresh:
return False
if area_norm > self.max_area_thresh:
return False
return True
return results[(area_norm > self.min_area_thresh) & (area_norm < self.max_area_thresh)]

def passed_filters(self, result):
def apply_filters(self, results):
"""Check if the results passed all filters."""
if not self.filter_by_uncertainty(result):
return False
if not self.filter_by_area(result):
return False
return True
results = self.filter_by_uncertainty(results)
results = self.filter_by_iou(results)
results = self.filter_by_area(results)
return results

def extract_classification_result(self, result):
"""
Expand Down Expand Up @@ -141,10 +133,7 @@ def extract_detection_result(self, result):
def process_single_detection_result(self, result):
"""
Process a single detection result.
"""
if not self.passed_filters(result):
return

"""
# Get image path, class, class name, confidence, and bounding box coordinates
image_path, cls, cls_name, conf, x_min, y_min, x_max, y_max = self.extract_detection_result(result)
# Get the short label given the class name and confidence
Expand All @@ -166,6 +155,8 @@ def process_detection_results(self, results_generator):
progress_bar.show()

for results in results_generator:
# Apply filtering to the results
results = self.apply_filters(results)
for result in results:
if result:
self.process_single_detection_result(result)
Expand Down Expand Up @@ -193,9 +184,6 @@ def process_single_segmentation_result(self, result):
"""
Process a single segmentation result.
"""
if not self.passed_filters(result):
return

# Get image path, class, class name, confidence, and polygon points
image_path, cls, cls_name, conf, points = self.extract_segmentation_result(result)
# Get the short label given the class name and confidence
Expand All @@ -217,6 +205,8 @@ def process_segmentation_results(self, results_generator):
progress_bar.show()

for results in results_generator:
# Apply filtering to the results
results = self.apply_filters(results)
for result in results:
if result:
self.process_single_segmentation_result(result)
Expand Down
23 changes: 6 additions & 17 deletions coralnet_toolbox/SAM/QtDeployGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,12 @@ def load_model(self):
task=self.task,
mode='predict',
save=False,
max_det=1000,
max_det=500,
imgsz=self.get_imgsz(),
conf=self.get_uncertainty_threshold(),
iou=self.iou_thresh,
conf=0.05,
iou=1.0,
device=self.main_window.device)

# Load the model
self.loaded_model = FastSAMPredictor(overrides=overrides)

Expand All @@ -349,18 +350,6 @@ def get_imgsz(self):
self.imgsz = self.imgsz_spinbox.value()
return self.imgsz

def get_uncertainty_threshold(self):
"""
Get the uncertainty threshold, limiting it to a maximum of 0.10.

Returns:
Adjusted uncertainty threshold value.
"""
if self.main_window.get_uncertainty_thresh() < 0.10:
return self.main_window.get_uncertainty_thresh()
else:
return 0.10 # Arbitrary value to prevent too many detections

def predict(self, image_paths=None):
"""
Make predictions on the given image paths using the loaded model.
Expand Down Expand Up @@ -402,15 +391,15 @@ def predict(self, image_paths=None):
# Create a results processor
results_processor = ResultsProcessor(self.main_window,
self.class_mapping,
uncertainty_thresh=self.get_uncertainty_threshold(),
uncertainty_thresh=self.uncertainty_thresh,
iou_thresh=self.iou_thresh,
min_area_thresh=self.area_thresh_min,
max_area_thresh=self.area_thresh_max)

# Update the progress bar
progress_bar.update_progress()

if self.task == 'segment' or self.use_sam_dropdown.currentText() == "True":
if self.task.lower() == 'segment' or self.use_sam_dropdown.currentText() == "True":
results_processor.process_segmentation_results(results)
else:
results_processor.process_detection_results(results)
Expand Down
2 changes: 1 addition & 1 deletion coralnet_toolbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from coralnet_toolbox.main import run

__version__ = "0.0.3"
__version__ = "0.0.4"
__author__ = "Jordan Pierce"
__email__ = "[email protected]"
__credits__ = "National Center for Coastal and Ocean Sciences (NCCOS)"
Loading
Loading