Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inference functionality on a single image or a dir with images #77

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions luxonis_train/attached_modules/base_attached_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def get_input_tensors(
return inputs[self.node_tasks[self.required_labels[0]]]

def prepare(
self, inputs: Packet[Tensor], labels: Labels
self, inputs: Packet[Tensor], labels: Labels | None
) -> tuple[Unpack[Ts]]:
"""Prepares node outputs for the forward pass of the module.

Expand Down Expand Up @@ -325,19 +325,27 @@ def prepare(
set(self.supported_labels) & set(self.node_tasks)
)
x = self.get_input_tensors(inputs)
label, label_type = self._get_label(labels)
if label_type in [LabelType.CLASSIFICATION, LabelType.SEGMENTATION]:
if len(x) == 1:
x = x[0]
else:
logger.warning(
f"Module {self.name} expects a single tensor as input, "
f"but got {len(x)} tensors. Using the last tensor. "
f"If this is not the desired behavior, please override the "
"`prepare` method of the attached module or the `wrap` "
f"method of {self.node.name}."
)
x = x[-1]
# NOTE: Check the logic below, if x needs to be modified withoud fulfilling the condition
if labels is not None:
label, label_type = self.get_label(labels)
if label_type in [
LabelType.CLASSIFICATION,
LabelType.SEGMENTATION,
]:
if isinstance(x, list):
if len(x) == 1:
x = x[0]
else:
logger.warning(
f"Module {self.name} expects a single tensor as input, "
f"but got {len(x)} tensors. Using the last tensor. "
f"If this is not the desired behavior, please override the "
"`prepare` method of the attached module or the `wrap` "
f"method of {self.node.name}."
)
x = x[-1]
else:
label = None

return x, label # type: ignore

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def run(
label_canvas: Tensor,
prediction_canvas: Tensor,
inputs: Packet[Tensor],
labels: Labels,
labels: Labels | None,
) -> Tensor | tuple[Tensor, Tensor] | tuple[Tensor, list[Tensor]]:
return self(
label_canvas, prediction_canvas, *self.prepare(inputs, labels)
Expand Down
66 changes: 41 additions & 25 deletions luxonis_train/attached_modules/visualizers/bbox_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ def forward(
label_canvas: Tensor,
prediction_canvas: Tensor,
predictions: list[Tensor],
targets: Tensor,
) -> tuple[Tensor, Tensor]:
targets: Tensor | None,
) -> tuple[Tensor, Tensor] | Tensor:
"""Creates a visualization of the bounding box predictions and
labels.

Expand All @@ -188,26 +188,42 @@ def forward(
@type targets: Tensor
@param targets: The target bounding boxes.
"""
targets_viz = self.draw_targets(
label_canvas,
targets,
color_dict=self.colors,
label_dict=self.bbox_labels,
draw_labels=self.draw_labels,
fill=self.fill,
font=self.font,
font_size=self.font_size,
width=self.width,
)
predictions_viz = self.draw_predictions(
prediction_canvas,
predictions,
label_dict=self.bbox_labels,
color_dict=self.colors,
draw_labels=self.draw_labels,
fill=self.fill,
font=self.font,
font_size=self.font_size,
width=self.width,
)
return targets_viz, predictions_viz.to(targets_viz.device)
if targets is not None:
targets_viz = self.draw_targets(
label_canvas,
targets,
color_dict=self.colors,
label_dict=self.bbox_labels,
draw_labels=self.draw_labels,
fill=self.fill,
font=self.font,
font_size=self.font_size,
width=self.width,
)

predictions_viz = self.draw_predictions(
prediction_canvas,
predictions,
label_dict=self.bbox_labels,
color_dict=self.colors,
draw_labels=self.draw_labels,
fill=self.fill,
font=self.font,
font_size=self.font_size,
width=self.width,
)
return targets_viz, predictions_viz.to(targets_viz.device)

else:
predictions_viz = self.draw_predictions(
prediction_canvas,
predictions,
label_dict=self.bbox_labels,
color_dict=self.colors,
draw_labels=self.draw_labels,
fill=self.fill,
font=self.font,
font_size=self.font_size,
width=self.width,
)
return predictions_viz.to(prediction_canvas.device)
86 changes: 79 additions & 7 deletions luxonis_train/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from pathlib import Path
from typing import Any, Literal, Mapping, overload

import cv2
import lightning.pytorch as pl
import lightning_utilities.core.rank_zero as rank_zero_module
import numpy as np
import rich.traceback
import torch
import torch.utils.data as torch_data
Expand All @@ -16,13 +18,18 @@
from luxonis_ml.nn_archive import ArchiveGenerator
from luxonis_ml.nn_archive.config import CONFIG_VERSION
from luxonis_ml.utils import LuxonisFileSystem, reset_logging, setup_logging
from torch.utils.data import DataLoader
from typeguard import typechecked

from luxonis_train.attached_modules.visualizers import get_unnormalized_images
from luxonis_train.callbacks import (
LuxonisRichProgressBar,
LuxonisTQDMProgressBar,
)
from luxonis_train.core.utils.infer_utils import (
InferAugmentations,
InferDataset,
)
from luxonis_train.loaders import BaseLoaderTorch, collate_fn
from luxonis_train.models import LuxonisLightningModule
from luxonis_train.utils import Config, DatasetMetadata, LuxonisTrackerPL
Expand Down Expand Up @@ -112,6 +119,7 @@ def __init__(
train_rgb=self.cfg.trainer.preprocessing.train_rgb,
keep_aspect_ratio=self.cfg.trainer.preprocessing.keep_aspect_ratio,
)

self.val_augmentations = Augmentations(
image_size=self.cfg.trainer.preprocessing.train_image_size,
augmentations=[
Expand All @@ -123,6 +131,17 @@ def __init__(
only_normalize=True,
)

self.infer_augmentations = InferAugmentations(
image_size=self.cfg.trainer.preprocessing.train_image_size,
augmentations=[
i.model_dump()
for i in self.cfg.trainer.preprocessing.get_active_augmentations()
],
train_rgb=self.cfg.trainer.preprocessing.train_rgb,
keep_aspect_ratio=self.cfg.trainer.preprocessing.keep_aspect_ratio,
only_normalize=True,
)

self.loaders: dict[str, BaseLoaderTorch] = {}
for view in ["train", "val", "test"]:
loader_name = self.cfg.loader.name
Expand Down Expand Up @@ -416,27 +435,80 @@ def test(
@typechecked
def infer(
self,
view: Literal["train", "val", "test"] = "val",
view: str | None = "val",
img_src_path: str | None = None,
save_dir: str | Path | None = None,
batch_size: int = 1,
) -> None:
"""Runs inference.

@type view: str
@type view: str | None
@param view: Which split to run the inference on. Valid values
are: 'train', 'val', 'test'. Defaults to "val".
@type img_src: str | None
@param img_src: Path to an image or a dir with images (.pnd,
.jpg, .jpeg, .bmp, .tiff)
@type save_dir: str | Path | None
@param save_dir: Directory where to save the visualizations. If
not specified, visualizations will be rendered on the
screen.
@type batch_size: int
@param batch_size: batch size to use for inference.
"""
self.lightning_module.eval()

for inputs, labels in self.pytorch_loaders[view]:
images = get_unnormalized_images(self.cfg, inputs)
outputs = self.lightning_module.forward(
inputs, labels, images=images, compute_visualizations=True
if img_src_path is not None:
if Path(img_src_path).is_file():
img = cv2.cvtColor(
cv2.imread(str(img_src_path)), cv2.COLOR_BGR2RGB
)
img_aug = self.infer_augmentations([img])
img_aug = np.transpose(img_aug, (2, 0, 1)) # HWC to CHW
img_tensor = torch.Tensor(img_aug)
img_tensor = img_tensor.unsqueeze(0)
img_dict = {"image": img_tensor}
image = get_unnormalized_images(self.cfg, img_dict)
outputs = self.lightning_module.forward(
img_dict, images=image, compute_visualizations=True
)
render_visualizations(outputs.visualizations, save_dir)
return

elif Path(img_src_path).is_dir():
infer_dataset = InferDataset(
img_src_path, self.infer_augmentations
)
infer_loader = DataLoader(infer_dataset, batch_size=batch_size)
for idx, inputs in enumerate(infer_loader):
images = get_unnormalized_images(self.cfg, inputs)
outputs = self.lightning_module.forward(
inputs, images=images, compute_visualizations=True
)
render_visualizations(
outputs.visualizations, save_dir, img_idx=idx
)
return

else:
raise ValueError(
f"Path {img_src_path} is not valid. It has to either point to a dir with images or an image file."
)

elif view is not None:
if view not in self.pytorch_loaders:
raise ValueError(
f"View {view} is not valid. Valid views are: 'train', 'val', 'test'."
)
for inputs, labels in self.pytorch_loaders[view]:
images = get_unnormalized_images(self.cfg, inputs)
outputs = self.lightning_module.forward(
inputs, labels, images=images, compute_visualizations=True
)
render_visualizations(outputs.visualizations, save_dir)
else:
raise ValueError(
"Either 'veiw' or 'img_src_path' has to be defined."
)
render_visualizations(outputs.visualizations, save_dir)

def tune(self) -> None:
"""Runs Optuna tunning of hyperparameters."""
Expand Down
Loading
Loading