diff --git a/luxonis_train/attached_modules/base_attached_module.py b/luxonis_train/attached_modules/base_attached_module.py index 45926858..d01c82fd 100644 --- a/luxonis_train/attached_modules/base_attached_module.py +++ b/luxonis_train/attached_modules/base_attached_module.py @@ -328,7 +328,10 @@ def prepare( # NOTE: Check the logic below, if x needs to be modified withoud fulfilling the condition if labels is not None: label, label_type = self.get_label(labels) - if label_type in [LabelType.CLASSIFICATION, LabelType.SEGMENTATION]: + if label_type in [ + LabelType.CLASSIFICATION, + LabelType.SEGMENTATION, + ]: if isinstance(x, list): if len(x) == 1: x = x[0] diff --git a/luxonis_train/attached_modules/visualizers/bbox_visualizer.py b/luxonis_train/attached_modules/visualizers/bbox_visualizer.py index 6453ea83..f9c28cca 100644 --- a/luxonis_train/attached_modules/visualizers/bbox_visualizer.py +++ b/luxonis_train/attached_modules/visualizers/bbox_visualizer.py @@ -174,7 +174,8 @@ def forward( predictions: list[Tensor], targets: Tensor | None, ) -> tuple[Tensor, Tensor] | Tensor: - """Creates a visualization of the bounding box predictions and labels. + """Creates a visualization of the bounding box predictions and + labels. @type label_canvas: Tensor @param label_canvas: The canvas containing the labels. diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index b6620cf0..13d106cb 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -19,10 +19,6 @@ from luxonis_ml.nn_archive.config import CONFIG_VERSION from luxonis_ml.utils import LuxonisFileSystem, reset_logging, setup_logging from torch.utils.data import DataLoader - -from luxonis_train.attached_modules.visualizers import get_unnormalized_images -from luxonis_train.callbacks import LuxonisRichProgressBar, LuxonisTQDMProgressBar -from luxonis_train.core.utils.infer_utils import InferAugmentations, InferDataset from typeguard import typechecked from luxonis_train.attached_modules.visualizers import get_unnormalized_images @@ -30,6 +26,10 @@ LuxonisRichProgressBar, LuxonisTQDMProgressBar, ) +from luxonis_train.core.utils.infer_utils import ( + InferAugmentations, + InferDataset, +) from luxonis_train.loaders import BaseLoaderTorch, collate_fn from luxonis_train.models import LuxonisLightningModule from luxonis_train.utils import Config, DatasetMetadata, LuxonisTrackerPL @@ -142,7 +142,6 @@ def __init__( only_normalize=True, ) - self.loaders: dict[str, BaseLoaderTorch] = {} for view in ["train", "val", "test"]: loader_name = self.cfg.loader.name @@ -444,14 +443,15 @@ def infer( """Runs inference. @type view: str | None - @param view: Which split to run the inference on. Valid values are: 'train', - 'val', 'test'. Defaults to "val". + @param view: Which split to run the inference on. Valid values + are: 'train', 'val', 'test'. Defaults to "val". @type img_src: str | None - @param img_src: Path to an image or a dir with images (.pnd, .jpg, .jpeg, .bmp, - .tiff) + @param img_src: Path to an image or a dir with images (.pnd, + .jpg, .jpeg, .bmp, .tiff) @type save_dir: str | Path | None - @param save_dir: Directory where to save the visualizations. If not specified, - visualizations will be rendered on the screen. + @param save_dir: Directory where to save the visualizations. If + not specified, visualizations will be rendered on the + screen. @type batch_size: int @param batch_size: batch size to use for inference. """ @@ -459,7 +459,9 @@ def infer( if img_src_path is not None: if Path(img_src_path).is_file(): - img = cv2.cvtColor(cv2.imread(str(img_src_path)), cv2.COLOR_BGR2RGB) + img = cv2.cvtColor( + cv2.imread(str(img_src_path)), cv2.COLOR_BGR2RGB + ) img_aug = self.infer_augmentations([img]) img_aug = np.transpose(img_aug, (2, 0, 1)) # HWC to CHW img_tensor = torch.Tensor(img_aug) @@ -473,14 +475,18 @@ def infer( return elif Path(img_src_path).is_dir(): - infer_dataset = InferDataset(img_src_path, self.infer_augmentations) + infer_dataset = InferDataset( + img_src_path, self.infer_augmentations + ) infer_loader = DataLoader(infer_dataset, batch_size=batch_size) for idx, inputs in enumerate(infer_loader): images = get_unnormalized_images(self.cfg, inputs) outputs = self.lightning_module.forward( inputs, images=images, compute_visualizations=True ) - render_visualizations(outputs.visualizations, save_dir, img_idx=idx) + render_visualizations( + outputs.visualizations, save_dir, img_idx=idx + ) return else: @@ -500,7 +506,9 @@ def infer( ) render_visualizations(outputs.visualizations, save_dir) else: - raise ValueError("Either 'veiw' or 'img_src_path' has to be defined.") + raise ValueError( + "Either 'veiw' or 'img_src_path' has to be defined." + ) def tune(self) -> None: """Runs Optuna tunning of hyperparameters.""" diff --git a/luxonis_train/core/utils/infer_utils.py b/luxonis_train/core/utils/infer_utils.py index 2925e7bb..38b4e534 100644 --- a/luxonis_train/core/utils/infer_utils.py +++ b/luxonis_train/core/utils/infer_utils.py @@ -33,7 +33,8 @@ def render_visualizations( name = name.replace("/", "_") if img_idx is not None: cv2.imwrite( - str(save_dir / f"{name}_{i}_{img_idx}.png"), viz_arr + str(save_dir / f"{name}_{i}_{img_idx}.png"), + viz_arr, ) # img_idx the number of an image in a dir for the inference method i += 1 else: @@ -48,13 +49,16 @@ def render_visualizations( class InferDataset(Dataset): - def __init__(self, image_dir: str, augmentations: Optional[Callable] = None): + def __init__( + self, image_dir: str, augmentations: Optional[Callable] = None + ): """Dataset for using with the infernce method. @type image_dir: str @param image_dir: Path to the directory with images. @type augmentations: Callable | Optional - @param augmentations: Optional transform to be applied on a sample image. + @param augmentations: Optional transform to be applied on a + sample image. """ self.image_dir = image_dir self.image_filenames = [ @@ -97,7 +101,11 @@ def __init__( only_normalize: bool = True, ): super().__init__( - image_size, augmentations, train_rgb, keep_aspect_ratio, only_normalize + image_size, + augmentations, + train_rgb, + keep_aspect_ratio, + only_normalize, ) ( @@ -107,7 +115,9 @@ def __init__( self.resize_transform, ) = self._parse_cfg( image_size=image_size, - augmentations=[a for a in augmentations if a["name"] == "Normalize"] + augmentations=[ + a for a in augmentations if a["name"] == "Normalize" + ] if only_normalize else augmentations, keep_aspect_ratio=keep_aspect_ratio, @@ -119,18 +129,20 @@ def _parse_cfg( augmentations: list[dict[str, Any]], keep_aspect_ratio: bool = True, ) -> tuple[BatchCompose, A.Compose, A.Compose, A.Compose]: - """Parses provided config and returns Albumentations BatchedCompose object and - Compose object for default transforms. + """Parses provided config and returns Albumentations + BatchedCompose object and Compose object for default transforms. @type image_size: List[int] @param image_size: Desired image size [H,W] @type augmentations: List[Dict[str, Any]] - @param augmentations: List of augmentations to use and their params + @param augmentations: List of augmentations to use and their + params @type keep_aspect_ratio: bool - @param keep_aspect_ratio: Whether should use resize that keeps aspect ratio of - original image. + @param keep_aspect_ratio: Whether should use resize that keeps + aspect ratio of original image. @rtype: Tuple[BatchCompose, A.Compose, A.Compose, A.Compose] - @return: Objects for batched, spatial, pixel and resize transforms + @return: Objects for batched, spatial, pixel and resize + transforms """ # NOTE: Always perform Resize @@ -146,14 +158,18 @@ def _parse_cfg( batched_augs = [] if augmentations: for aug in augmentations: - curr_aug = AUGMENTATIONS.get(aug["name"])(**aug.get("params", {})) + curr_aug = AUGMENTATIONS.get(aug["name"])( + **aug.get("params", {}) + ) if isinstance(curr_aug, A.ImageOnlyTransform): pixel_augs.append(curr_aug) elif isinstance(curr_aug, A.DualTransform): spatial_augs.append(curr_aug) elif isinstance(curr_aug, BatchBasedTransform): self.is_batched = True - self.aug_batch_size = max(self.aug_batch_size, curr_aug.batch_size) + self.aug_batch_size = max( + self.aug_batch_size, curr_aug.batch_size + ) batched_augs.append(curr_aug) batch_transform = BatchCompose( @@ -174,7 +190,12 @@ def _parse_cfg( [resize], ) - return batch_transform, spatial_transform, pixel_transform, resize_transform + return ( + batch_transform, + spatial_transform, + pixel_transform, + resize_transform, + ) def __call__( self, @@ -183,7 +204,8 @@ def __call__( """Performs augmentations on provided data. @type data: np.ndarray - @param data: Data with list of input images and their annotations + @param data: Data with list of input images and their + annotations @rtype: np.ndarray @return: Output image """ @@ -199,7 +221,9 @@ def __call__( } transformed = self.batch_transform(force_apply=False, **transform_args) - transformed = {key: np.array(value[0]) for key, value in transformed.items()} + transformed = { + key: np.array(value[0]) for key, value in transformed.items() + } arg_names = [ "image",