-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/inference #82
Feat/inference #82
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,8 +3,9 @@ | |
import threading | ||
from logging import getLogger | ||
from pathlib import Path | ||
from typing import Any, Literal, Mapping, overload | ||
|
||
from typing import Any, Literal, Mapping, overload, Optional | ||
import os | ||
import cv2 | ||
import lightning.pytorch as pl | ||
import lightning_utilities.core.rank_zero as rank_zero_module | ||
import rich.traceback | ||
|
@@ -17,7 +18,7 @@ | |
from luxonis_ml.nn_archive.config import CONFIG_VERSION | ||
from luxonis_ml.utils import LuxonisFileSystem, reset_logging, setup_logging | ||
from typeguard import typechecked | ||
|
||
from luxonis_ml.data import LabelType | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||
from luxonis_train.attached_modules.visualizers import get_unnormalized_images | ||
from luxonis_train.callbacks import ( | ||
LuxonisRichProgressBar, | ||
|
@@ -419,6 +420,7 @@ def infer( | |
self, | ||
view: Literal["train", "val", "test"] = "val", | ||
save_dir: str | Path | None = None, | ||
img_path: Optional[str] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of |
||
) -> None: | ||
"""Runs inference. | ||
|
||
|
@@ -429,16 +431,68 @@ def infer( | |
@param save_dir: Directory where to save the visualizations. If | ||
not specified, visualizations will be rendered on the | ||
screen. | ||
@type img_path: Optional[str] | ||
@param img_path: Path to the image file or directory for inference. | ||
If None, defaults to using dataset images. | ||
""" | ||
self.lightning_module.eval() | ||
|
||
if img_path: | ||
img_path_obj = Path(img_path) | ||
if img_path_obj.is_file(): | ||
self._process_single_image(img_path_obj, view, save_dir) | ||
elif img_path_obj.is_dir(): | ||
self._process_directory_images(img_path_obj, view, save_dir) | ||
else: | ||
self._process_dataset_images(view, save_dir) | ||
|
||
def _process_single_image(self, img_path: Path, view: str, save_dir: Optional[str | Path]) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be better to move these methods to |
||
"""Handles the inference on a single image.""" | ||
img = cv2.cvtColor(cv2.imread(str(img_path)), cv2.COLOR_BGR2RGB) | ||
img, _ = self.val_augmentations([(img, {})]) | ||
labels = self._prepare_labels(view, img.shape) | ||
inputs = {'image': torch.tensor(img).unsqueeze(0).permute(0, 3, 1, 2).float()} | ||
images = get_unnormalized_images(self.cfg, inputs) | ||
|
||
outputs = self.lightning_module.forward( | ||
inputs, labels, images=images, compute_visualizations=True | ||
) | ||
render_visualizations(outputs.visualizations, save_dir) | ||
|
||
def _process_directory_images(self, dir_path: Path, view: str, save_dir: Optional[str | Path]) -> None: | ||
"""Handles inference for multiple images in a directory.""" | ||
image_files = [f for f in dir_path.iterdir() if f.suffix.lower() in {'.png', '.jpg', '.jpeg'}] | ||
for image_file in image_files: | ||
self._process_single_image(image_file, view, save_dir) | ||
|
||
def _process_dataset_images(self, view: str, save_dir: Optional[str | Path]) -> None: | ||
"""Handles the inference on dataset images.""" | ||
for inputs, labels in self.pytorch_loaders[view]: | ||
images = get_unnormalized_images(self.cfg, inputs) | ||
outputs = self.lightning_module.forward( | ||
inputs, labels, images=images, compute_visualizations=True | ||
) | ||
render_visualizations(outputs.visualizations, save_dir) | ||
|
||
def _prepare_labels(self, view: str, img_shape: tuple) -> tuple: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
"""Prepares the labels for different tasks (classification, keypoints, etc.).""" | ||
tasks = list(self.loaders["train"].get_classes().keys()) | ||
h, w, _ = img_shape | ||
labels = {} | ||
nk = self.loaders[view].get_n_keypoints()['keypoints'] | ||
|
||
for task in tasks: | ||
if task == "classification": | ||
labels[task] = [-1, LabelType.CLASSIFICATION] | ||
elif task == "keypoints": | ||
labels[task] = [torch.zeros((1, nk * 3 + 2)), LabelType.KEYPOINTS] | ||
elif task == "segmentation": | ||
labels[task] = [torch.zeros((1, h, w)), LabelType.SEGMENTATION] | ||
elif task == "boundingbox": | ||
labels[task] = [torch.tensor([[-1, 0, 0, 0, 0, 0]]), LabelType.BOUNDINGBOX] | ||
|
||
return labels | ||
|
||
def tune(self) -> None: | ||
"""Runs Optuna tunning of hyperparameters.""" | ||
import optuna | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
str | None