Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov721 committed Sep 20, 2024
1 parent c15b010 commit 3b3a6d3
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 33 deletions.
5 changes: 4 additions & 1 deletion luxonis_train/attached_modules/base_attached_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,10 @@ def prepare(
# NOTE: Check the logic below, if x needs to be modified withoud fulfilling the condition
if labels is not None:
label, label_type = self.get_label(labels)
if label_type in [LabelType.CLASSIFICATION, LabelType.SEGMENTATION]:
if label_type in [
LabelType.CLASSIFICATION,
LabelType.SEGMENTATION,
]:
if isinstance(x, list):
if len(x) == 1:
x = x[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ def forward(
predictions: list[Tensor],
targets: Tensor | None,
) -> tuple[Tensor, Tensor] | Tensor:
"""Creates a visualization of the bounding box predictions and labels.
"""Creates a visualization of the bounding box predictions and
labels.
@type label_canvas: Tensor
@param label_canvas: The canvas containing the labels.
Expand Down
38 changes: 23 additions & 15 deletions luxonis_train/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@
from luxonis_ml.nn_archive.config import CONFIG_VERSION
from luxonis_ml.utils import LuxonisFileSystem, reset_logging, setup_logging
from torch.utils.data import DataLoader

from luxonis_train.attached_modules.visualizers import get_unnormalized_images
from luxonis_train.callbacks import LuxonisRichProgressBar, LuxonisTQDMProgressBar
from luxonis_train.core.utils.infer_utils import InferAugmentations, InferDataset
from typeguard import typechecked

from luxonis_train.attached_modules.visualizers import get_unnormalized_images
from luxonis_train.callbacks import (
LuxonisRichProgressBar,
LuxonisTQDMProgressBar,
)
from luxonis_train.core.utils.infer_utils import (
InferAugmentations,
InferDataset,
)
from luxonis_train.loaders import BaseLoaderTorch, collate_fn
from luxonis_train.models import LuxonisLightningModule
from luxonis_train.utils import Config, DatasetMetadata, LuxonisTrackerPL
Expand Down Expand Up @@ -142,7 +142,6 @@ def __init__(
only_normalize=True,
)


self.loaders: dict[str, BaseLoaderTorch] = {}
for view in ["train", "val", "test"]:
loader_name = self.cfg.loader.name
Expand Down Expand Up @@ -444,22 +443,25 @@ def infer(
"""Runs inference.
@type view: str | None
@param view: Which split to run the inference on. Valid values are: 'train',
'val', 'test'. Defaults to "val".
@param view: Which split to run the inference on. Valid values
are: 'train', 'val', 'test'. Defaults to "val".
@type img_src: str | None
@param img_src: Path to an image or a dir with images (.pnd, .jpg, .jpeg, .bmp,
.tiff)
@param img_src: Path to an image or a dir with images (.pnd,
.jpg, .jpeg, .bmp, .tiff)
@type save_dir: str | Path | None
@param save_dir: Directory where to save the visualizations. If not specified,
visualizations will be rendered on the screen.
@param save_dir: Directory where to save the visualizations. If
not specified, visualizations will be rendered on the
screen.
@type batch_size: int
@param batch_size: batch size to use for inference.
"""
self.lightning_module.eval()

if img_src_path is not None:
if Path(img_src_path).is_file():
img = cv2.cvtColor(cv2.imread(str(img_src_path)), cv2.COLOR_BGR2RGB)
img = cv2.cvtColor(
cv2.imread(str(img_src_path)), cv2.COLOR_BGR2RGB
)
img_aug = self.infer_augmentations([img])
img_aug = np.transpose(img_aug, (2, 0, 1)) # HWC to CHW
img_tensor = torch.Tensor(img_aug)
Expand All @@ -473,14 +475,18 @@ def infer(
return

elif Path(img_src_path).is_dir():
infer_dataset = InferDataset(img_src_path, self.infer_augmentations)
infer_dataset = InferDataset(
img_src_path, self.infer_augmentations
)
infer_loader = DataLoader(infer_dataset, batch_size=batch_size)
for idx, inputs in enumerate(infer_loader):
images = get_unnormalized_images(self.cfg, inputs)
outputs = self.lightning_module.forward(
inputs, images=images, compute_visualizations=True
)
render_visualizations(outputs.visualizations, save_dir, img_idx=idx)
render_visualizations(
outputs.visualizations, save_dir, img_idx=idx
)
return

else:
Expand All @@ -500,7 +506,9 @@ def infer(
)
render_visualizations(outputs.visualizations, save_dir)
else:
raise ValueError("Either 'veiw' or 'img_src_path' has to be defined.")
raise ValueError(
"Either 'veiw' or 'img_src_path' has to be defined."
)

def tune(self) -> None:
"""Runs Optuna tunning of hyperparameters."""
Expand Down
56 changes: 40 additions & 16 deletions luxonis_train/core/utils/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def render_visualizations(
name = name.replace("/", "_")
if img_idx is not None:
cv2.imwrite(
str(save_dir / f"{name}_{i}_{img_idx}.png"), viz_arr
str(save_dir / f"{name}_{i}_{img_idx}.png"),
viz_arr,
) # img_idx the number of an image in a dir for the inference method
i += 1
else:
Expand All @@ -48,13 +49,16 @@ def render_visualizations(


class InferDataset(Dataset):
def __init__(self, image_dir: str, augmentations: Optional[Callable] = None):
def __init__(
self, image_dir: str, augmentations: Optional[Callable] = None
):
"""Dataset for using with the infernce method.
@type image_dir: str
@param image_dir: Path to the directory with images.
@type augmentations: Callable | Optional
@param augmentations: Optional transform to be applied on a sample image.
@param augmentations: Optional transform to be applied on a
sample image.
"""
self.image_dir = image_dir
self.image_filenames = [
Expand Down Expand Up @@ -97,7 +101,11 @@ def __init__(
only_normalize: bool = True,
):
super().__init__(
image_size, augmentations, train_rgb, keep_aspect_ratio, only_normalize
image_size,
augmentations,
train_rgb,
keep_aspect_ratio,
only_normalize,
)

(
Expand All @@ -107,7 +115,9 @@ def __init__(
self.resize_transform,
) = self._parse_cfg(
image_size=image_size,
augmentations=[a for a in augmentations if a["name"] == "Normalize"]
augmentations=[
a for a in augmentations if a["name"] == "Normalize"
]
if only_normalize
else augmentations,
keep_aspect_ratio=keep_aspect_ratio,
Expand All @@ -119,18 +129,20 @@ def _parse_cfg(
augmentations: list[dict[str, Any]],
keep_aspect_ratio: bool = True,
) -> tuple[BatchCompose, A.Compose, A.Compose, A.Compose]:
"""Parses provided config and returns Albumentations BatchedCompose object and
Compose object for default transforms.
"""Parses provided config and returns Albumentations
BatchedCompose object and Compose object for default transforms.
@type image_size: List[int]
@param image_size: Desired image size [H,W]
@type augmentations: List[Dict[str, Any]]
@param augmentations: List of augmentations to use and their params
@param augmentations: List of augmentations to use and their
params
@type keep_aspect_ratio: bool
@param keep_aspect_ratio: Whether should use resize that keeps aspect ratio of
original image.
@param keep_aspect_ratio: Whether should use resize that keeps
aspect ratio of original image.
@rtype: Tuple[BatchCompose, A.Compose, A.Compose, A.Compose]
@return: Objects for batched, spatial, pixel and resize transforms
@return: Objects for batched, spatial, pixel and resize
transforms
"""

# NOTE: Always perform Resize
Expand All @@ -146,14 +158,18 @@ def _parse_cfg(
batched_augs = []
if augmentations:
for aug in augmentations:
curr_aug = AUGMENTATIONS.get(aug["name"])(**aug.get("params", {}))
curr_aug = AUGMENTATIONS.get(aug["name"])(
**aug.get("params", {})
)
if isinstance(curr_aug, A.ImageOnlyTransform):
pixel_augs.append(curr_aug)
elif isinstance(curr_aug, A.DualTransform):
spatial_augs.append(curr_aug)
elif isinstance(curr_aug, BatchBasedTransform):
self.is_batched = True
self.aug_batch_size = max(self.aug_batch_size, curr_aug.batch_size)
self.aug_batch_size = max(
self.aug_batch_size, curr_aug.batch_size
)
batched_augs.append(curr_aug)

batch_transform = BatchCompose(
Expand All @@ -174,7 +190,12 @@ def _parse_cfg(
[resize],
)

return batch_transform, spatial_transform, pixel_transform, resize_transform
return (
batch_transform,
spatial_transform,
pixel_transform,
resize_transform,
)

def __call__(
self,
Expand All @@ -183,7 +204,8 @@ def __call__(
"""Performs augmentations on provided data.
@type data: np.ndarray
@param data: Data with list of input images and their annotations
@param data: Data with list of input images and their
annotations
@rtype: np.ndarray
@return: Output image
"""
Expand All @@ -199,7 +221,9 @@ def __call__(
}

transformed = self.batch_transform(force_apply=False, **transform_args)
transformed = {key: np.array(value[0]) for key, value in transformed.items()}
transformed = {
key: np.array(value[0]) for key, value in transformed.items()
}

arg_names = [
"image",
Expand Down

0 comments on commit 3b3a6d3

Please sign in to comment.