diff --git a/luxonis_train/__main__.py b/luxonis_train/__main__.py index 6cada3dc..80b043c9 100644 --- a/luxonis_train/__main__.py +++ b/luxonis_train/__main__.py @@ -2,7 +2,7 @@ from enum import Enum from importlib.metadata import version from pathlib import Path -from typing import Annotated, Optional +from typing import Annotated import typer import yaml @@ -25,7 +25,7 @@ class _ViewType(str, Enum): ConfigType = Annotated[ - Optional[str], + str | None, typer.Option( help="Path to the configuration file.", show_default=False, @@ -34,7 +34,7 @@ class _ViewType(str, Enum): ] OptsType = Annotated[ - Optional[list[str]], + list[str] | None, typer.Argument( help="A list of optional CLI overrides of the config file.", show_default=False, @@ -46,16 +46,23 @@ class _ViewType(str, Enum): ] SaveDirType = Annotated[ - Optional[Path], + Path | None, typer.Option(help="Where to save the inference results."), ] +ImgPathType = Annotated[ + str | None, + typer.Option( + help="Path to an image file or a directory containing images for inference." + ), +] + @app.command() def train( config: ConfigType = None, resume: Annotated[ - Optional[str], + str | None, typer.Option(help="Resume training from this checkpoint."), ] = None, opts: OptsType = None, @@ -99,12 +106,15 @@ def infer( config: ConfigType = None, view: ViewType = _ViewType.VAL, save_dir: SaveDirType = None, + img_path: ImgPathType = None, opts: OptsType = None, ): """Run inference.""" from luxonis_train.core import LuxonisModel - LuxonisModel(config, opts).infer(view=view.value, save_dir=save_dir) + LuxonisModel(config, opts).infer( + view=view.value, save_dir=save_dir, img_path=img_path + ) @app.command() @@ -200,7 +210,7 @@ def common( ), ] = False, source: Annotated[ - Optional[Path], + Path | None, typer.Option( help="Path to a python file with custom components. " "Will be sourced before running the command.", diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index 2b31cf65..c49e78fb 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -18,7 +18,6 @@ from luxonis_ml.utils import LuxonisFileSystem, reset_logging, setup_logging from typeguard import typechecked -from luxonis_train.attached_modules.visualizers import get_unnormalized_images from luxonis_train.callbacks import ( LuxonisRichProgressBar, LuxonisTQDMProgressBar, @@ -35,7 +34,10 @@ replace_weights, try_onnx_simplify, ) -from .utils.infer_utils import render_visualizations +from .utils.infer_utils import ( + process_dataset_images, + process_images, +) from .utils.train_utils import create_trainer logger = getLogger(__name__) @@ -419,6 +421,7 @@ def infer( self, view: Literal["train", "val", "test"] = "val", save_dir: str | Path | None = None, + img_path: str | None = None, ) -> None: """Runs inference. @@ -429,15 +432,25 @@ def infer( @param save_dir: Directory where to save the visualizations. If not specified, visualizations will be rendered on the screen. + @type img_path: str | None + @param img_path: Path to the image file or directory for inference. + If None, defaults to using dataset images. """ self.lightning_module.eval() - for inputs, labels in self.pytorch_loaders[view]: - images = get_unnormalized_images(self.cfg, inputs) - outputs = self.lightning_module.forward( - inputs, labels, images=images, compute_visualizations=True - ) - render_visualizations(outputs.visualizations, save_dir) + if img_path: + img_path_obj = Path(img_path) + if img_path_obj.is_file(): + process_images(self, [img_path_obj], view, save_dir) + elif img_path_obj.is_dir(): + image_files = [ + f + for f in img_path_obj.iterdir() + if f.suffix.lower() in {".png", ".jpg", ".jpeg"} + ] + process_images(self, image_files, view, save_dir) + else: + process_dataset_images(self, view, save_dir) def tune(self) -> None: """Runs Optuna tunning of hyperparameters.""" diff --git a/luxonis_train/core/utils/infer_utils.py b/luxonis_train/core/utils/infer_utils.py index 17696705..aa78743e 100644 --- a/luxonis_train/core/utils/infer_utils.py +++ b/luxonis_train/core/utils/infer_utils.py @@ -1,12 +1,17 @@ from pathlib import Path import cv2 +import torch from torch import Tensor +from luxonis_train.attached_modules.visualizers import get_unnormalized_images +from luxonis_train.enums import TaskType + def render_visualizations( visualizations: dict[str, dict[str, Tensor]], save_dir: str | Path | None ) -> None: + """Render or save visualizations.""" save_dir = Path(save_dir) if save_dir is not None else None if save_dir is not None: save_dir.mkdir(exist_ok=True, parents=True) @@ -28,3 +33,66 @@ def render_visualizations( if save_dir is None: if cv2.waitKey(0) == ord("q"): exit() + + +def process_images( + model, img_paths: list[Path], view: str, save_dir: str | Path | None +) -> None: + """Handles inference on one or more images.""" + first_image = cv2.cvtColor( + cv2.imread(str(img_paths[0])), cv2.COLOR_BGR2RGB + ) + labels = create_dummy_labels(model, view, first_image.shape) + for img_path in img_paths: + img = cv2.cvtColor(cv2.imread(str(img_path)), cv2.COLOR_BGR2RGB) + img, _ = ( + model.train_augmentations([(img, {})]) + if view == "train" + else model.val_augmentations([(img, {})]) + ) + + inputs = { + "image": torch.tensor(img).unsqueeze(0).permute(0, 3, 1, 2).float() + } + images = get_unnormalized_images(model.cfg, inputs) + + outputs = model.lightning_module.forward( + inputs, labels, images=images, compute_visualizations=True + ) + render_visualizations(outputs.visualizations, save_dir) + + +def process_dataset_images( + model, view: str, save_dir: str | Path | None +) -> None: + """Handles the inference on dataset images.""" + for inputs, labels in model.pytorch_loaders[view]: + images = get_unnormalized_images(model.cfg, inputs) + outputs = model.lightning_module.forward( + inputs, labels, images=images, compute_visualizations=True + ) + render_visualizations(outputs.visualizations, save_dir) + + +def create_dummy_labels(model, view: str, img_shape: tuple) -> dict: + """Prepares the labels for different tasks (classification, + keypoints, etc.).""" + tasks = list(model.loaders["train"].get_classes().keys()) + h, w, _ = img_shape + labels = {} + nk = model.loaders[view].get_n_keypoints()["keypoints"] + + for task in tasks: + if task == "classification": + labels[task] = [-1, TaskType.CLASSIFICATION] + elif task == "keypoints": + labels[task] = [torch.zeros((1, nk * 3 + 2)), TaskType.KEYPOINTS] + elif task == "segmentation": + labels[task] = [torch.zeros((1, h, w)), TaskType.SEGMENTATION] + elif task == "boundingbox": + labels[task] = [ + torch.tensor([[-1, 0, 0, 0, 0, 0]]), + TaskType.BOUNDINGBOX, + ] + + return labels