Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add video inference #84

Merged
merged 8 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions luxonis_train/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ class _ViewType(str, Enum):
typer.Option(help="Where to save the inference results."),
]

ImgPathType = Annotated[
SourcePathType = Annotated[
str | None,
typer.Option(
help="Path to an image file or a directory containing images for inference."
help="Path to an image file, a directory containing images or a video file for inference.",
),
]

Expand Down Expand Up @@ -106,14 +106,14 @@ def infer(
config: ConfigType = None,
view: ViewType = _ViewType.VAL,
save_dir: SaveDirType = None,
img_path: ImgPathType = None,
source_path: SourcePathType = None,
opts: OptsType = None,
):
"""Run inference."""
from luxonis_train.core import LuxonisModel

LuxonisModel(config, opts).infer(
view=view.value, save_dir=save_dir, img_path=img_path
view=view.value, save_dir=save_dir, source_path=source_path
)


Expand Down
29 changes: 19 additions & 10 deletions luxonis_train/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@
try_onnx_simplify,
)
from .utils.infer_utils import (
IMAGE_FORMATS,
VIDEO_FORMATS,
process_dataset_images,
process_images,
process_video,
)
from .utils.train_utils import create_trainer

Expand Down Expand Up @@ -421,7 +424,7 @@
self,
view: Literal["train", "val", "test"] = "val",
save_dir: str | Path | None = None,
img_path: str | None = None,
source_path: str | None = None,
) -> None:
"""Runs inference.

Expand All @@ -432,23 +435,29 @@
@param save_dir: Directory where to save the visualizations. If
not specified, visualizations will be rendered on the
screen.
@type img_path: str | None
@param img_path: Path to the image file or directory for inference.
@type source_path: str | None
@param source_path: Path to the image file, video file or directory.
If None, defaults to using dataset images.
"""
self.lightning_module.eval()

if img_path:
img_path_obj = Path(img_path)
if img_path_obj.is_file():
process_images(self, [img_path_obj], view, save_dir)
elif img_path_obj.is_dir():
if source_path:
source_path_obj = Path(source_path)
if source_path_obj.suffix.lower() in VIDEO_FORMATS:
process_video(self, source_path_obj, view, save_dir)
elif source_path_obj.is_file():
process_images(self, [source_path_obj], view, save_dir)
elif source_path_obj.is_dir():

Check warning on line 450 in luxonis_train/core/core.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/core.py#L445-L450

Added lines #L445 - L450 were not covered by tests
image_files = [
f
for f in img_path_obj.iterdir()
if f.suffix.lower() in {".png", ".jpg", ".jpeg"}
for f in source_path_obj.iterdir()
if f.suffix.lower() in IMAGE_FORMATS
]
process_images(self, image_files, view, save_dir)
else:
raise ValueError(

Check warning on line 458 in luxonis_train/core/core.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/core.py#L458

Added line #L458 was not covered by tests
f"Source path {source_path} is not a valid file or directory."
)
else:
process_dataset_images(self, view, save_dir)

Expand Down
127 changes: 107 additions & 20 deletions luxonis_train/core/utils/infer_utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,41 @@
from collections import defaultdict
from pathlib import Path

import cv2
import numpy as np
import torch
import tqdm
from torch import Tensor

from luxonis_train.attached_modules.visualizers import get_unnormalized_images
from luxonis_train.enums import TaskType

IMAGE_FORMATS = {
".bmp",
".jpg",
".jpeg",
".png",
".tif",
".tiff",
".dng",
".webp",
".mpo",
".pfm",
}
VIDEO_FORMATS = {".mp4", ".mov", ".avi", ".mkv"}


def render_visualizations(
visualizations: dict[str, dict[str, Tensor]], save_dir: str | Path | None
) -> None:
visualizations: dict[str, dict[str, Tensor]],
save_dir: str | Path | None,
show: bool = True,
) -> dict[str, list[np.ndarray]]:
"""Render or save visualizations."""
save_dir = Path(save_dir) if save_dir is not None else None
if save_dir is not None:
save_dir.mkdir(exist_ok=True, parents=True)

rendered_visualizations = defaultdict(list)
i = 0
for node_name, vzs in visualizations.items():
for viz_name, viz_batch in vzs.items():
Expand All @@ -27,13 +47,93 @@
name = name.replace("/", "_")
cv2.imwrite(str(save_dir / f"{name}_{i}.png"), viz_arr)
i += 1
else:
elif show:

Check warning on line 50 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L50

Added line #L50 was not covered by tests
cv2.imshow(name, viz_arr)
else:
rendered_visualizations[name].append(viz_arr)

Check warning on line 53 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L53

Added line #L53 was not covered by tests

if save_dir is None:
if save_dir is None and show:
if cv2.waitKey(0) == ord("q"):
exit()

return rendered_visualizations


def prepare_and_infer_image(model, img: np.ndarray, labels: dict, view: str):
"""Prepares the image for inference and runs the model."""
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
klemen1999 marked this conversation as resolved.
Show resolved Hide resolved
img, _ = (

Check warning on line 65 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L64-L65

Added lines #L64 - L65 were not covered by tests
model.train_augmentations([(img, {})])
if view == "train"
else model.val_augmentations([(img, {})])
)

inputs = {

Check warning on line 71 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L71

Added line #L71 was not covered by tests
"image": torch.tensor(img).unsqueeze(0).permute(0, 3, 1, 2).float()
}
images = get_unnormalized_images(model.cfg, inputs)

Check warning on line 74 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L74

Added line #L74 was not covered by tests

outputs = model.lightning_module.forward(

Check warning on line 76 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L76

Added line #L76 was not covered by tests
inputs, labels, images=images, compute_visualizations=True
)
return outputs

Check warning on line 79 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L79

Added line #L79 was not covered by tests


def process_video(
model,
video_path: str | Path,
view: str,
save_dir: str | Path | None,
show: bool = False,
) -> None:
"""Handles inference on a video."""
cap = cv2.VideoCapture(filename=str(video_path)) # type: ignore
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
progress_bar = tqdm.tqdm(

Check warning on line 92 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L90-L92

Added lines #L90 - L92 were not covered by tests
total=total_frames, position=0, leave=True, desc="Processing video"
)

if save_dir is not None:
out_writers = {}
save_dir = Path(save_dir)
save_dir.mkdir(exist_ok=True, parents=True)

Check warning on line 99 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L96-L99

Added lines #L96 - L99 were not covered by tests

labels = create_dummy_labels(

Check warning on line 101 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L101

Added line #L101 was not covered by tests
model, view, (int(cap.get(4)), int(cap.get(3)), 3)
)

while cap.isOpened():
ret, frame = cap.read()
if not ret:
break

Check warning on line 108 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L105-L108

Added lines #L105 - L108 were not covered by tests

outputs = prepare_and_infer_image(model, frame, labels, view)
rendered_visualizations = render_visualizations(

Check warning on line 111 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L110-L111

Added lines #L110 - L111 were not covered by tests
outputs.visualizations, None, show
)
if save_dir is not None:
for name, viz_arrs in rendered_visualizations.items():
if name not in out_writers:
out_writers[name] = cv2.VideoWriter(

Check warning on line 117 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L114-L117

Added lines #L114 - L117 were not covered by tests
filename=str( # type: ignore
save_dir / f"{name.replace('/', '-')}.mp4"
),
fourcc=cv2.VideoWriter_fourcc(*"mp4v"), # type: ignore
fps=cap.get(cv2.CAP_PROP_FPS), # type: ignore
frameSize=(viz_arrs[0].shape[1], viz_arrs[0].shape[0]), # type: ignore
) # type: ignore
for viz_arr in viz_arrs:
out_writers[name].write(viz_arr)

Check warning on line 126 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L125-L126

Added lines #L125 - L126 were not covered by tests

progress_bar.update(1)

Check warning on line 128 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L128

Added line #L128 was not covered by tests

if save_dir is not None:
for writer in out_writers.values():
writer.release()

Check warning on line 132 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L130-L132

Added lines #L130 - L132 were not covered by tests

cap.release()
progress_bar.close()

Check warning on line 135 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L134-L135

Added lines #L134 - L135 were not covered by tests


def process_images(
model, img_paths: list[Path], view: str, save_dir: str | Path | None
Expand All @@ -44,21 +144,8 @@
)
labels = create_dummy_labels(model, view, first_image.shape)
for img_path in img_paths:
img = cv2.cvtColor(cv2.imread(str(img_path)), cv2.COLOR_BGR2RGB)
img, _ = (
model.train_augmentations([(img, {})])
if view == "train"
else model.val_augmentations([(img, {})])
)

inputs = {
"image": torch.tensor(img).unsqueeze(0).permute(0, 3, 1, 2).float()
}
images = get_unnormalized_images(model.cfg, inputs)

outputs = model.lightning_module.forward(
inputs, labels, images=images, compute_visualizations=True
)
img = cv2.imread(str(img_path))
outputs = prepare_and_infer_image(model, img, labels, view)

Check warning on line 148 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L147-L148

Added lines #L147 - L148 were not covered by tests
render_visualizations(outputs.visualizations, save_dir)


Expand All @@ -80,12 +167,12 @@
tasks = list(model.loaders["train"].get_classes().keys())
h, w, _ = img_shape
labels = {}
nk = model.loaders[view].get_n_keypoints()["keypoints"]

for task in tasks:
if task == "classification":
labels[task] = [-1, TaskType.CLASSIFICATION]
elif task == "keypoints":
nk = model.loaders[view].get_n_keypoints()["keypoints"]

Check warning on line 175 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L175

Added line #L175 was not covered by tests
labels[task] = [torch.zeros((1, nk * 3 + 2)), TaskType.KEYPOINTS]
elif task == "segmentation":
labels[task] = [torch.zeros((1, h, w)), TaskType.SEGMENTATION]
Expand Down