diff --git a/luxonis_train/loaders/luxonis_loader_torch.py b/luxonis_train/loaders/luxonis_loader_torch.py index c21d5230..4267cced 100644 --- a/luxonis_train/loaders/luxonis_loader_torch.py +++ b/luxonis_train/loaders/luxonis_loader_torch.py @@ -96,7 +96,7 @@ def __init__( height=self.height, width=self.width, keep_aspect_ratio=self.keep_aspect_ratio, - out_image_format=self.color_space, + color_space=self.color_space, ) @override diff --git a/luxonis_train/loaders/utils.py b/luxonis_train/loaders/utils.py index 10b4d17a..9c9e1d45 100644 --- a/luxonis_train/loaders/utils.py +++ b/luxonis_train/loaders/utils.py @@ -37,12 +37,15 @@ def collate_fn( if task_type in {"keypoints", "boundingbox"}: label_box: list[Tensor] = [] - for i, box in enumerate(annos): - l_box = torch.zeros((box.shape[0], box.shape[1] + 1)) - l_box[:, 0] = i # add target image index for build_targets() - l_box[:, 1:] = box - label_box.append(l_box) + for i, ann in enumerate(annos): + new_ann = torch.zeros((ann.shape[0], ann.shape[1] + 1)) + # add target image index for build_targets() + new_ann[:, 0] = i + new_ann[:, 1:] = ann + label_box.append(new_ann) out_labels[task] = torch.cat(label_box, 0) + elif task_type == "instance_segmentation": + out_labels[task] = torch.cat(annos, 0) else: out_labels[task] = torch.stack(annos, 0)