Skip to content

Commit

Permalink
added dataset inspect command
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov721 committed Jan 16, 2024
1 parent 8d63bff commit e8c34f8
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 138 deletions.
3 changes: 1 addition & 2 deletions luxonis_train/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .attached_modules import *
from .models import *
from .tools import *
from .utils import *

__version__ = "0.1.0"
__version__ = "0.0.1"
120 changes: 119 additions & 1 deletion luxonis_train/__main__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
from enum import Enum
from importlib.metadata import version
from pathlib import Path
from typing import Annotated, Optional

import cv2
import torch
import typer

app = typer.Typer(help="Luxonis Train CLI", add_completion=False)
Expand Down Expand Up @@ -86,6 +89,117 @@ def infer(
Inferer(str(config), opts, view=view.name, save_dir=save_dir).infer()


@app.command()
def inspect(
config: ConfigType = None,
view: ViewType = View.val,
save_dir: SaveDirType = None,
opts: OptsType = None,
):
"""Inspect dataset."""
from luxonis_ml.data import (
LuxonisDataset,
TrainAugmentations,
ValAugmentations,
)

from luxonis_train.attached_modules.visualizers.utils import (
draw_bounding_box_labels,
draw_keypoint_labels,
draw_segmentation_labels,
get_unnormalized_images,
)
from luxonis_train.utils.config import Config
from luxonis_train.utils.loaders import LuxonisLoaderTorch, collate_fn
from luxonis_train.utils.types import LabelType

overrides = {}
if opts:
if len(opts) % 2 != 0:
raise ValueError("Override options should be a list of key-value pairs")

for i in range(0, len(opts), 2):
overrides[opts[i]] = opts[i + 1]

cfg = Config.get_config(str(config), overrides)

image_size = cfg.trainer.preprocessing.train_image_size

dataset = LuxonisDataset(
dataset_name=cfg.dataset.dataset_name,
team_id=cfg.dataset.team_id,
dataset_id=cfg.dataset.dataset_id,
bucket_type=cfg.dataset.bucket_type,
bucket_storage=cfg.dataset.bucket_storage,
)
augmentations = (
TrainAugmentations(
image_size=image_size,
augmentations=[
i.model_dump() for i in cfg.trainer.preprocessing.augmentations
],
train_rgb=cfg.trainer.preprocessing.train_rgb,
keep_aspect_ratio=cfg.trainer.preprocessing.keep_aspect_ratio,
)
if view == "train"
else ValAugmentations(
image_size=image_size,
augmentations=[
i.model_dump() for i in cfg.trainer.preprocessing.augmentations
],
train_rgb=cfg.trainer.preprocessing.train_rgb,
keep_aspect_ratio=cfg.trainer.preprocessing.keep_aspect_ratio,
)
)

loader_train = LuxonisLoaderTorch(
dataset,
view=view,
augmentations=augmentations,
)

pytorch_loader_train = torch.utils.data.DataLoader(
loader_train,
batch_size=4,
num_workers=1,
collate_fn=collate_fn,
)

if save_dir is not None:
os.makedirs(save_dir, exist_ok=True)

counter = 0
for data in pytorch_loader_train:
imgs, label_dict = data
images = get_unnormalized_images(cfg, imgs)
for i, img in enumerate(images):
for label_type, labels in label_dict.items():
if label_type == LabelType.CLASSIFICATION:
continue
elif label_type == LabelType.BOUNDINGBOX:
img = draw_bounding_box_labels(
img, labels[labels[:, 0] == i][:, 2:], colors="yellow", width=1
)
elif label_type == LabelType.KEYPOINT:
img = draw_keypoint_labels(
img, labels[labels[:, 0] == i][:, 1:], colors="red"
)
elif label_type == LabelType.SEGMENTATION:
img = draw_segmentation_labels(
img, labels[i], alpha=0.8, colors="#5050FF"
)

img_arr = img.permute(1, 2, 0).numpy()
img_arr = cv2.cvtColor(img_arr, cv2.COLOR_RGB2BGR)
if save_dir is not None:
counter += 1
cv2.imwrite(os.path.join(save_dir, f"{counter}.png"), img_arr)
else:
cv2.imshow("img", img_arr)
if cv2.waitKey() == ord("q"):
exit()


def version_callback(value: bool):
if value:
typer.echo(f"LuxonisTrain Version: {version(__package__)}")
Expand All @@ -104,5 +218,9 @@ def common(
...


if __name__ == "__main__":
def main():
app()


if __name__ == "__main__":
main()
Empty file removed luxonis_train/tools/__init__.py
Empty file.
135 changes: 0 additions & 135 deletions luxonis_train/tools/test_dataset.py

This file was deleted.

Loading

0 comments on commit e8c34f8

Please sign in to comment.