Skip to content

Commit

Permalink
support for np array added
Browse files Browse the repository at this point in the history
  • Loading branch information
hardikdava committed Dec 6, 2024
1 parent b5f6eac commit 672baf4
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 10 deletions.
26 changes: 26 additions & 0 deletions tests/test_det_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def sample_image():
# # Check device placement
# assert next(predictor.model.parameters()).device == predictor.device
#


def test_predictor_preprocess(predictor, sample_image):
"""Test image preprocessing"""
# Test single image
Expand All @@ -66,6 +68,30 @@ def test_predictor_preprocess(predictor, sample_image):
assert output_batch.shape[1:] == output.shape[1:]


def test_predictor_inference_np_array(predictor, sample_image):
"""Test end-to-end inference"""
sample_image = np.array(sample_image)
with torch.no_grad():
result = predictor.predict(sample_image)

# Check output format
assert isinstance(result, list)
assert len(result) == 1 # Single image input = single result
result = result[0] # Get first prediction

assert isinstance(result, dict)
assert all(k in result for k in ["boxes", "scores", "labels"])

# Check output shapes
assert len(result["boxes"].shape) == 2 # (num_boxes, 4)
assert len(result["scores"].shape) == 1 # (num_boxes,)
assert len(result["labels"].shape) == 1 # (num_boxes,)

# Check value ranges
assert result["scores"].min() >= 0 and result["scores"].max() <= 1
assert result["labels"].min() >= 0 and result["labels"].max() < 80 # COCO classes


def test_predictor_inference(predictor, sample_image):
"""Test end-to-end inference"""
with torch.no_grad():
Expand Down
18 changes: 9 additions & 9 deletions trolo/inference/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
from PIL import Image
import numpy as np
import torchvision.transforms as T
import supervision as sv

Expand Down Expand Up @@ -141,7 +142,7 @@ def postprocess(self, outputs: torch.Tensor, letterbox_sizes: List[Tuple[int, in

def predict(
self,
inputs: Optional[Union[str, List[str], Image.Image, List[Image.Image]]] = None,
inputs: Optional[Union[str, List[str], np.ndarray, List[np.ndarray], Image.Image, List[Image.Image]]] = None,
conf_threshold: float = 0.5,
return_inputs: bool = False,
batch_size: int = 1,
Expand Down Expand Up @@ -192,7 +193,7 @@ def predict(
if input_type == "folder":
inputs = get_images_from_folder(inputs)
# Handle image inputs
if isinstance(inputs, (str, Image.Image)):
if isinstance(inputs, (str, np.ndarray, Image.Image)):
inputs = [inputs]

size = tuple(self.config.yaml_cfg["eval_spatial_size"]) # [H, W]
Expand All @@ -203,6 +204,8 @@ def predict(
for img in inputs:
if isinstance(img, str):
img = Image.open(img).convert("RGB")
elif isinstance(img, np.ndarray):
img = sv.cv2_to_pillow(img).convert("RGB")
original_images.append(img)
original_sizes.append(img.size)
letterbox_sizes.append(size)
Expand Down Expand Up @@ -246,20 +249,17 @@ def _predict_video(
all_frames = []

for batch in video_stream:
frames = batch["frames"] # List of RGB numpy arrays

# Convert frames to PIL Images
pil_frames = [Image.fromarray(frame) for frame in frames]
frames = batch["frames"]

# Get predictions for batch
predictions = self.predict(pil_frames, conf_threshold=conf_threshold, return_inputs=False)
predictions = self.predict(frames, conf_threshold=conf_threshold, return_inputs=False)

if stream:
yield predictions, pil_frames if return_inputs else predictions
yield predictions, frames if return_inputs else predictions
else:
all_predictions.extend(predictions)
if return_inputs:
all_frames.extend(pil_frames)
all_frames.extend(frames)

if not stream:
return all_predictions, all_frames if return_inputs else all_predictions
2 changes: 1 addition & 1 deletion trolo/inference/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __iter__(self) -> Iterator[Dict[str, Any]]:
yield {"frames": frames, "frame_ids": frame_ids}
break

frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
frames.append(frame)
frame_ids.append(int(self.cap.get(cv2.CAP_PROP_POS_FRAMES)))

if len(frames) == self.batch_size:
Expand Down

0 comments on commit 672baf4

Please sign in to comment.