diff --git a/tests/test_det_infer.py b/tests/test_det_infer.py index c068e2a..d4f616e 100644 --- a/tests/test_det_infer.py +++ b/tests/test_det_infer.py @@ -48,6 +48,8 @@ def sample_image(): # # Check device placement # assert next(predictor.model.parameters()).device == predictor.device # + + def test_predictor_preprocess(predictor, sample_image): """Test image preprocessing""" # Test single image @@ -66,6 +68,30 @@ def test_predictor_preprocess(predictor, sample_image): assert output_batch.shape[1:] == output.shape[1:] +def test_predictor_inference_np_array(predictor, sample_image): + """Test end-to-end inference""" + sample_image = np.array(sample_image) + with torch.no_grad(): + result = predictor.predict(sample_image) + + # Check output format + assert isinstance(result, list) + assert len(result) == 1 # Single image input = single result + result = result[0] # Get first prediction + + assert isinstance(result, dict) + assert all(k in result for k in ["boxes", "scores", "labels"]) + + # Check output shapes + assert len(result["boxes"].shape) == 2 # (num_boxes, 4) + assert len(result["scores"].shape) == 1 # (num_boxes,) + assert len(result["labels"].shape) == 1 # (num_boxes,) + + # Check value ranges + assert result["scores"].min() >= 0 and result["scores"].max() <= 1 + assert result["labels"].min() >= 0 and result["labels"].max() < 80 # COCO classes + + def test_predictor_inference(predictor, sample_image): """Test end-to-end inference""" with torch.no_grad(): diff --git a/trolo/inference/detection.py b/trolo/inference/detection.py index 3dc2bb7..580597f 100644 --- a/trolo/inference/detection.py +++ b/trolo/inference/detection.py @@ -3,6 +3,7 @@ import torch from PIL import Image +import numpy as np import torchvision.transforms as T import supervision as sv @@ -141,7 +142,7 @@ def postprocess(self, outputs: torch.Tensor, letterbox_sizes: List[Tuple[int, in def predict( self, - inputs: Optional[Union[str, List[str], Image.Image, List[Image.Image]]] = None, + inputs: Optional[Union[str, List[str], np.ndarray, List[np.ndarray], Image.Image, List[Image.Image]]] = None, conf_threshold: float = 0.5, return_inputs: bool = False, batch_size: int = 1, @@ -192,7 +193,7 @@ def predict( if input_type == "folder": inputs = get_images_from_folder(inputs) # Handle image inputs - if isinstance(inputs, (str, Image.Image)): + if isinstance(inputs, (str, np.ndarray, Image.Image)): inputs = [inputs] size = tuple(self.config.yaml_cfg["eval_spatial_size"]) # [H, W] @@ -203,6 +204,8 @@ def predict( for img in inputs: if isinstance(img, str): img = Image.open(img).convert("RGB") + elif isinstance(img, np.ndarray): + img = sv.cv2_to_pillow(img).convert("RGB") original_images.append(img) original_sizes.append(img.size) letterbox_sizes.append(size) @@ -246,20 +249,17 @@ def _predict_video( all_frames = [] for batch in video_stream: - frames = batch["frames"] # List of RGB numpy arrays - - # Convert frames to PIL Images - pil_frames = [Image.fromarray(frame) for frame in frames] + frames = batch["frames"] # Get predictions for batch - predictions = self.predict(pil_frames, conf_threshold=conf_threshold, return_inputs=False) + predictions = self.predict(frames, conf_threshold=conf_threshold, return_inputs=False) if stream: - yield predictions, pil_frames if return_inputs else predictions + yield predictions, frames if return_inputs else predictions else: all_predictions.extend(predictions) if return_inputs: - all_frames.extend(pil_frames) + all_frames.extend(frames) if not stream: return all_predictions, all_frames if return_inputs else all_predictions diff --git a/trolo/inference/video.py b/trolo/inference/video.py index 51018e6..43f6a8a 100644 --- a/trolo/inference/video.py +++ b/trolo/inference/video.py @@ -27,7 +27,7 @@ def __iter__(self) -> Iterator[Dict[str, Any]]: yield {"frames": frames, "frame_ids": frame_ids} break - frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + frames.append(frame) frame_ids.append(int(self.cap.get(cv2.CAP_PROP_POS_FRAMES))) if len(frames) == self.batch_size: