Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference Fix #100

Merged
merged 43 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
95fab66
version bump
kozlov721 Oct 3, 2024
4791b78
updated readme
kozlov721 Oct 3, 2024
8316ea8
fixed grammar and typos
kozlov721 Oct 3, 2024
a8dfba9
updated readme
kozlov721 Oct 3, 2024
a16b7f2
Merge branch 'dev' into docs/readme-updates
kozlov721 Oct 3, 2024
004d5bf
updated readme
kozlov721 Oct 6, 2024
3d32cfc
updated docstring
kozlov721 Oct 6, 2024
78b8a1c
Merge branch 'docs/readme-updates' of github.com:luxonis/luxonis-trai…
kozlov721 Oct 6, 2024
d77afcb
simplified credentials
kozlov721 Oct 6, 2024
737c2cb
changed types
kozlov721 Oct 6, 2024
e4f2850
removed old badges
kozlov721 Oct 6, 2024
0e6cd7f
unified readme styles
kozlov721 Oct 7, 2024
d60361b
updated pyproject.toml
kozlov721 Oct 7, 2024
8c1089c
formatted configs
kozlov721 Oct 7, 2024
70dcf6e
config examples
kozlov721 Oct 7, 2024
260e020
tutorials
kozlov721 Oct 7, 2024
f743a40
small updates
kozlov721 Oct 7, 2024
a0b2c41
updated readme
kozlov721 Oct 7, 2024
e395a66
updated complex example
kozlov721 Oct 7, 2024
64c7a8a
fixed inference command
kozlov721 Oct 7, 2024
ac1e29d
Merge branch 'dev' into fix/inference
kozlov721 Oct 7, 2024
6e30483
fixed test
kozlov721 Oct 7, 2024
852c6c6
fixed pre-commit
kozlov721 Oct 7, 2024
1d78824
Merge branch 'dev' into docs/readme-updates
kozlov721 Oct 7, 2024
725fcc7
Update README.md
kozlov721 Oct 7, 2024
1547744
updated readme
kozlov721 Oct 7, 2024
e9e84e4
changed predefined model for example tuning
kozlov721 Oct 7, 2024
c4aa467
unified command arguments
kozlov721 Oct 7, 2024
0568a3d
Update README.md
kozlov721 Oct 7, 2024
396be8a
updated study name
kozlov721 Oct 8, 2024
98a2610
updated readme
kozlov721 Oct 8, 2024
f04e9a8
Update README.md
kozlov721 Oct 8, 2024
81581ca
fixed toc
kozlov721 Oct 8, 2024
66f086c
Update README.md
kozlov721 Oct 8, 2024
dd22395
updated command
kozlov721 Oct 8, 2024
74a9205
added weights to test
kozlov721 Oct 8, 2024
bf803cc
Merge branch 'docs/readme-updates' into fix/inference
kozlov721 Oct 8, 2024
a35ce57
updated readmes
kozlov721 Oct 8, 2024
c4b2894
changed deprecated
kozlov721 Oct 8, 2024
86cc2c1
updated tests
kozlov721 Oct 8, 2024
9b8dfb7
Merge branch 'dev' into fix/inference
kozlov721 Oct 8, 2024
4c8dd3a
fixed docstrings
kozlov721 Oct 8, 2024
8295f72
fixed merge error
kozlov721 Oct 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ jobs:
with:
ref: ${{ github.head_ref }}

- name: Install pre-commit
run: python -m pip install 'pre-commit<4.0.0'

- name: Run pre-commit
uses: pre-commit/[email protected]

Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,5 @@ mlruns
wandb
tests/_data
tests/integration/save-directory
tests/integration/infer-save-directory
data
9 changes: 6 additions & 3 deletions luxonis_train/attached_modules/base_attached_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def get_input_tensors(
return inputs[self.node_tasks[self.required_labels[0]]]

def prepare(
self, inputs: Packet[Tensor], labels: Labels
self, inputs: Packet[Tensor], labels: Labels | None
) -> tuple[Unpack[Ts]]:
"""Prepares node outputs for the forward pass of the module.

Expand All @@ -287,8 +287,9 @@ def prepare(

@type inputs: L{Packet}[Tensor]
@param inputs: Output from the node, inputs to the attached module.
@type labels: L{Labels}
@param labels: Labels from the dataset.
@type labels: L{Labels} | None
@param labels: Labels from the dataset. If not provided, empty labels are used.
This is useful in visualizers for working with standalone images.

@rtype: tuple[Unpack[Ts]]
@return: Prepared inputs. Should allow the following usage with the
Expand Down Expand Up @@ -325,6 +326,8 @@ def prepare(
set(self.supported_tasks) & set(self.node_tasks)
)
x = self.get_input_tensors(inputs)
if labels is None:
return x, None # type: ignore
label, task_type = self._get_label(labels)
if task_type in [TaskType.CLASSIFICATION, TaskType.SEGMENTATION]:
if len(x) == 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class ObjectKeypointSimilarity(

def __init__(
self,
n_keypoints: int | None = None,
sigmas: list[float] | None = None,
area_factor: float | None = None,
use_cocoeval_oks: bool = True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def run(
label_canvas: Tensor,
prediction_canvas: Tensor,
inputs: Packet[Tensor],
labels: Labels,
labels: Labels | None,
) -> Tensor | tuple[Tensor, Tensor] | tuple[Tensor, list[Tensor]]:
return self(
label_canvas, prediction_canvas, *self.prepare(inputs, labels)
Expand Down
25 changes: 14 additions & 11 deletions luxonis_train/attached_modules/visualizers/bbox_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ def forward(
label_canvas: Tensor,
prediction_canvas: Tensor,
predictions: list[Tensor],
targets: Tensor,
) -> tuple[Tensor, Tensor]:
targets: Tensor | None,
) -> tuple[Tensor, Tensor] | Tensor:
"""Creates a visualization of the bounding box predictions and
labels.

Expand All @@ -189,26 +189,29 @@ def forward(
@type targets: Tensor
@param targets: The target bounding boxes.
"""
targets_viz = self.draw_targets(
label_canvas,
targets,
color_dict=self.colors,
predictions_viz = self.draw_predictions(
prediction_canvas,
predictions,
label_dict=self.bbox_labels,
color_dict=self.colors,
draw_labels=self.draw_labels,
fill=self.fill,
font=self.font,
font_size=self.font_size,
width=self.width,
)
predictions_viz = self.draw_predictions(
prediction_canvas,
predictions,
label_dict=self.bbox_labels,
if targets is None:
return predictions_viz

targets_viz = self.draw_targets(
label_canvas,
targets,
color_dict=self.colors,
label_dict=self.bbox_labels,
draw_labels=self.draw_labels,
fill=self.fill,
font=self.font,
font_size=self.font_size,
width=self.width,
)
return targets_viz, predictions_viz.to(targets_viz.device)
return targets_viz, predictions_viz
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch import Tensor

from luxonis_train.enums import TaskType
from luxonis_train.utils import Labels, Packet

from .base_visualizer import BaseVisualizer
from .utils import figure_to_torch, numpy_to_torch_img, torch_img_to_numpy
Expand Down Expand Up @@ -56,29 +57,38 @@
ax.grid(True)
return figure_to_torch(fig, width, height)

def prepare(
self, inputs: Packet[Tensor], labels: Labels | None
) -> tuple[Tensor, Tensor]:
predictions, targets = super().prepare(inputs, labels)
if isinstance(predictions, list):
predictions = predictions[0]

Check warning on line 65 in luxonis_train/attached_modules/visualizers/classification_visualizer.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/attached_modules/visualizers/classification_visualizer.py#L65

Added line #L65 was not covered by tests
return predictions, targets

def forward(
self,
label_canvas: Tensor,
prediction_canvas: Tensor,
predictions: Tensor,
labels: Tensor,
targets: Tensor | None,
) -> Tensor | tuple[Tensor, Tensor]:
overlay = torch.zeros_like(label_canvas)
plots = torch.zeros_like(prediction_canvas)
for i in range(len(overlay)):
prediction = predictions[i]
gt = self._get_class_name(labels[i])
arr = torch_img_to_numpy(label_canvas[i].clone())
curr_class = self._get_class_name(prediction)
arr = cv2.putText(
arr,
f"GT: {gt}",
(5, 10),
cv2.FONT_HERSHEY_SIMPLEX,
self.font_scale,
self.color,
self.thickness,
)
if targets is not None:
gt = self._get_class_name(targets[i])
arr = cv2.putText(
arr,
f"GT: {gt}",
(5, 10),
cv2.FONT_HERSHEY_SIMPLEX,
self.font_scale,
self.color,
self.thickness,
)
arr = cv2.putText(
arr,
f"Pred: {curr_class}",
Expand Down
21 changes: 12 additions & 9 deletions luxonis_train/attached_modules/visualizers/keypoint_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,9 @@ def forward(
label_canvas: Tensor,
prediction_canvas: Tensor,
predictions: list[Tensor],
targets: Tensor,
targets: Tensor | None,
**kwargs,
) -> tuple[Tensor, Tensor]:
target_viz = self.draw_targets(
label_canvas,
targets,
colors=self.visible_color,
connectivity=self.connectivity,
**kwargs,
)
) -> tuple[Tensor, Tensor] | Tensor:
pred_viz = self.draw_predictions(
prediction_canvas,
predictions,
Expand All @@ -113,4 +106,14 @@ def forward(
visibility_threshold=self.visibility_threshold,
**kwargs,
)
if targets is None:
return pred_viz

target_viz = self.draw_targets(
label_canvas,
targets,
colors=self.visible_color,
connectivity=self.connectivity,
**kwargs,
)
return target_viz, pred_viz
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def forward(
label_canvas: Tensor,
prediction_canvas: Tensor,
outputs: Packet[Tensor],
labels: Labels,
) -> tuple[Tensor, Tensor]:
labels: Labels | None,
) -> tuple[Tensor, Tensor] | Tensor:
for visualizer in self.visualizers:
match visualizer.run(
label_canvas, prediction_canvas, outputs, labels
Expand All @@ -57,4 +57,6 @@ def forward(
raise NotImplementedError(
"Unexpected return type from visualizer."
)
if labels is None:
return prediction_canvas
return label_canvas, prediction_canvas
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch import Tensor

from luxonis_train.enums import TaskType
from luxonis_train.utils import Labels, Packet

from .base_visualizer import BaseVisualizer
from .utils import (
Expand Down Expand Up @@ -95,14 +96,22 @@ def draw_targets(

return viz

def prepare(
self, inputs: Packet[Tensor], labels: Labels | None
) -> tuple[Tensor, Tensor]:
predictions, targets = super().prepare(inputs, labels)
if isinstance(predictions, list):
predictions = predictions[0]
return predictions, targets

def forward(
self,
label_canvas: Tensor,
prediction_canvas: Tensor,
predictions: Tensor,
targets: Tensor,
targets: Tensor | None,
**kwargs,
) -> tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor] | Tensor:
"""Creates a visualization of the segmentation predictions and
labels.

Expand All @@ -118,18 +127,21 @@ def forward(
@return: A tuple of the label and prediction visualizations.
"""

targets_vis = self.draw_targets(
label_canvas,
targets,
predictions_vis = self.draw_predictions(
prediction_canvas,
predictions,
colors=self.colors,
alpha=self.alpha,
background_class=self.background_class,
background_color=self.background_color,
**kwargs,
)
predictions_vis = self.draw_predictions(
prediction_canvas,
predictions,
if targets is None:
return predictions_vis

targets_vis = self.draw_targets(
label_canvas,
targets,
colors=self.colors,
alpha=self.alpha,
background_class=self.background_class,
Expand Down
6 changes: 1 addition & 5 deletions luxonis_train/attached_modules/visualizers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,7 @@ def get_unnormalized_images(cfg: Config, inputs: dict[str, Tensor]) -> Tensor:
if cfg.trainer.preprocessing.normalize.active:
mean = normalize_params.get("mean", [0.485, 0.456, 0.406])
std = normalize_params.get("std", [0.229, 0.224, 0.225])
return preprocess_images(
images,
mean=mean,
std=std,
)
return preprocess_images(images, mean=mean, std=std)


def number_to_hsl(seed: int) -> tuple[float, float, float]:
Expand Down
35 changes: 20 additions & 15 deletions luxonis_train/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
from .utils.infer_utils import (
IMAGE_FORMATS,
VIDEO_FORMATS,
process_dataset_images,
process_images,
process_video,
infer_from_dataset,
infer_from_directory,
infer_from_video,
)
from .utils.train_utils import create_trainer

Expand Down Expand Up @@ -466,25 +466,30 @@ def infer(
weights = weights or self.cfg.model.weights

with replace_weights(self.lightning_module, weights):
if source_path:
source_path_obj = Path(source_path)
if source_path_obj.suffix.lower() in VIDEO_FORMATS:
process_video(self, source_path_obj, view, save_dir)
elif source_path_obj.is_file():
process_images(self, [source_path_obj], view, save_dir)
elif source_path_obj.is_dir():
image_files = [
if save_dir is not None:
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
if source_path is not None:
source_path = Path(source_path)
if source_path.suffix.lower() in VIDEO_FORMATS:
infer_from_video(
self, video_path=source_path, save_dir=save_dir
)
elif source_path.is_file():
infer_from_directory(self, [source_path], save_dir)
elif source_path.is_dir():
image_files = (
f
for f in source_path_obj.iterdir()
for f in source_path.iterdir()
if f.suffix.lower() in IMAGE_FORMATS
]
process_images(self, image_files, view, save_dir)
)
infer_from_directory(self, image_files, save_dir)
else:
raise ValueError(
f"Source path {source_path} is not a valid file or directory."
)
else:
process_dataset_images(self, view, save_dir)
infer_from_dataset(self, view, save_dir)

def tune(self) -> None:
"""Runs Optuna tunning of hyperparameters."""
Expand Down
Loading