Skip to content

Commit

Permalink
Feat/inference (#82)
Browse files Browse the repository at this point in the history
  • Loading branch information
JSabadin authored Sep 26, 2024
1 parent e16cde6 commit 706d3d0
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 15 deletions.
24 changes: 17 additions & 7 deletions luxonis_train/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum
from importlib.metadata import version
from pathlib import Path
from typing import Annotated, Optional
from typing import Annotated

import typer
import yaml
Expand All @@ -25,7 +25,7 @@ class _ViewType(str, Enum):


ConfigType = Annotated[
Optional[str],
str | None,
typer.Option(
help="Path to the configuration file.",
show_default=False,
Expand All @@ -34,7 +34,7 @@ class _ViewType(str, Enum):
]

OptsType = Annotated[
Optional[list[str]],
list[str] | None,
typer.Argument(
help="A list of optional CLI overrides of the config file.",
show_default=False,
Expand All @@ -46,16 +46,23 @@ class _ViewType(str, Enum):
]

SaveDirType = Annotated[
Optional[Path],
Path | None,
typer.Option(help="Where to save the inference results."),
]

ImgPathType = Annotated[
str | None,
typer.Option(
help="Path to an image file or a directory containing images for inference."
),
]


@app.command()
def train(
config: ConfigType = None,
resume: Annotated[
Optional[str],
str | None,
typer.Option(help="Resume training from this checkpoint."),
] = None,
opts: OptsType = None,
Expand Down Expand Up @@ -99,12 +106,15 @@ def infer(
config: ConfigType = None,
view: ViewType = _ViewType.VAL,
save_dir: SaveDirType = None,
img_path: ImgPathType = None,
opts: OptsType = None,
):
"""Run inference."""
from luxonis_train.core import LuxonisModel

LuxonisModel(config, opts).infer(view=view.value, save_dir=save_dir)
LuxonisModel(config, opts).infer(
view=view.value, save_dir=save_dir, img_path=img_path
)


@app.command()
Expand Down Expand Up @@ -200,7 +210,7 @@ def common(
),
] = False,
source: Annotated[
Optional[Path],
Path | None,
typer.Option(
help="Path to a python file with custom components. "
"Will be sourced before running the command.",
Expand Down
29 changes: 21 additions & 8 deletions luxonis_train/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from luxonis_ml.utils import LuxonisFileSystem, reset_logging, setup_logging
from typeguard import typechecked

from luxonis_train.attached_modules.visualizers import get_unnormalized_images
from luxonis_train.callbacks import (
LuxonisRichProgressBar,
LuxonisTQDMProgressBar,
Expand All @@ -35,7 +34,10 @@
replace_weights,
try_onnx_simplify,
)
from .utils.infer_utils import render_visualizations
from .utils.infer_utils import (
process_dataset_images,
process_images,
)
from .utils.train_utils import create_trainer

logger = getLogger(__name__)
Expand Down Expand Up @@ -419,6 +421,7 @@ def infer(
self,
view: Literal["train", "val", "test"] = "val",
save_dir: str | Path | None = None,
img_path: str | None = None,
) -> None:
"""Runs inference.
Expand All @@ -429,15 +432,25 @@ def infer(
@param save_dir: Directory where to save the visualizations. If
not specified, visualizations will be rendered on the
screen.
@type img_path: str | None
@param img_path: Path to the image file or directory for inference.
If None, defaults to using dataset images.
"""
self.lightning_module.eval()

for inputs, labels in self.pytorch_loaders[view]:
images = get_unnormalized_images(self.cfg, inputs)
outputs = self.lightning_module.forward(
inputs, labels, images=images, compute_visualizations=True
)
render_visualizations(outputs.visualizations, save_dir)
if img_path:
img_path_obj = Path(img_path)
if img_path_obj.is_file():
process_images(self, [img_path_obj], view, save_dir)
elif img_path_obj.is_dir():
image_files = [
f
for f in img_path_obj.iterdir()
if f.suffix.lower() in {".png", ".jpg", ".jpeg"}
]
process_images(self, image_files, view, save_dir)
else:
process_dataset_images(self, view, save_dir)

def tune(self) -> None:
"""Runs Optuna tunning of hyperparameters."""
Expand Down
68 changes: 68 additions & 0 deletions luxonis_train/core/utils/infer_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from pathlib import Path

import cv2
import torch
from torch import Tensor

from luxonis_train.attached_modules.visualizers import get_unnormalized_images
from luxonis_train.enums import TaskType


def render_visualizations(
visualizations: dict[str, dict[str, Tensor]], save_dir: str | Path | None
) -> None:
"""Render or save visualizations."""
save_dir = Path(save_dir) if save_dir is not None else None
if save_dir is not None:
save_dir.mkdir(exist_ok=True, parents=True)
Expand All @@ -28,3 +33,66 @@ def render_visualizations(
if save_dir is None:
if cv2.waitKey(0) == ord("q"):
exit()


def process_images(
model, img_paths: list[Path], view: str, save_dir: str | Path | None
) -> None:
"""Handles inference on one or more images."""
first_image = cv2.cvtColor(
cv2.imread(str(img_paths[0])), cv2.COLOR_BGR2RGB
)
labels = create_dummy_labels(model, view, first_image.shape)
for img_path in img_paths:
img = cv2.cvtColor(cv2.imread(str(img_path)), cv2.COLOR_BGR2RGB)
img, _ = (
model.train_augmentations([(img, {})])
if view == "train"
else model.val_augmentations([(img, {})])
)

inputs = {
"image": torch.tensor(img).unsqueeze(0).permute(0, 3, 1, 2).float()
}
images = get_unnormalized_images(model.cfg, inputs)

outputs = model.lightning_module.forward(
inputs, labels, images=images, compute_visualizations=True
)
render_visualizations(outputs.visualizations, save_dir)


def process_dataset_images(
model, view: str, save_dir: str | Path | None
) -> None:
"""Handles the inference on dataset images."""
for inputs, labels in model.pytorch_loaders[view]:
images = get_unnormalized_images(model.cfg, inputs)
outputs = model.lightning_module.forward(
inputs, labels, images=images, compute_visualizations=True
)
render_visualizations(outputs.visualizations, save_dir)


def create_dummy_labels(model, view: str, img_shape: tuple) -> dict:
"""Prepares the labels for different tasks (classification,
keypoints, etc.)."""
tasks = list(model.loaders["train"].get_classes().keys())
h, w, _ = img_shape
labels = {}
nk = model.loaders[view].get_n_keypoints()["keypoints"]

for task in tasks:
if task == "classification":
labels[task] = [-1, TaskType.CLASSIFICATION]
elif task == "keypoints":
labels[task] = [torch.zeros((1, nk * 3 + 2)), TaskType.KEYPOINTS]
elif task == "segmentation":
labels[task] = [torch.zeros((1, h, w)), TaskType.SEGMENTATION]
elif task == "boundingbox":
labels[task] = [
torch.tensor([[-1, 0, 0, 0, 0, 0]]),
TaskType.BOUNDINGBOX,
]

return labels

0 comments on commit 706d3d0

Please sign in to comment.