Skip to content

Commit

Permalink
added url functionality to improc.read
Browse files Browse the repository at this point in the history
  • Loading branch information
Borg93 committed Apr 10, 2024
1 parent 7f03d7c commit c87b878
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 47 deletions.
51 changes: 25 additions & 26 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ accelerate = {version = "^0.28.0", optional = true}
bitsandbytes = {version = "^0.43.0", optional = true}
# ultralytics
ultralytics = {version = "^8.0.225", optional = true}
requests = "^2.31.0"
pillow = "^10.3.0"

[tool.poetry.extras]
huggingface = ["transformers", "huggingface-hub", "datasets", "torch"]
Expand Down
4 changes: 3 additions & 1 deletion src/htrflow_core/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from tqdm import tqdm

from htrflow_core.results import Result
from htrflow_core.utils import imgproc


class BaseModel(ABC):
Expand Down Expand Up @@ -60,4 +61,5 @@ def __call__(
**kwargs,
) -> Iterable[Result]:
"""Alias for BaseModel.predict(...)"""
return self.predict(images, batch_size, *args, **kwargs)
img_array = [imgproc.read(img) for img in images]
return self.predict(img_array, batch_size, *args, **kwargs)
4 changes: 2 additions & 2 deletions src/htrflow_core/models/huggingface/trocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def _predict(self, images: list[np.ndarray], **generation_kwargs) -> list[Result
scores = model_outputs.sequences_scores.tolist()
step = generation_kwargs["num_return_sequences"]

return self._create_text_result(images, texts, scores, metadata, step)
return self._create_text_results(images, texts, scores, metadata, step)

def _create_text_result(
def _create_text_results(
self, images: list[np.ndarray], texts: list[str], scores: list[float], metadata: dict, step: int
) -> list[Result]:
"""Assemble and return a list of Result objects from the prediction outputs.
Expand Down
8 changes: 1 addition & 7 deletions src/htrflow_core/models/ultralytics/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,10 @@ def _create_segmentation_result(self, image: np.ndarray, output: UltralyticsResu


if __name__ == "__main__":
import requests
from PIL import Image

from htrflow_core.utils.imgproc import pillow2opencv

url = "https://github.com/Swedish-National-Archives-AI-lab/htrflow_core/blob/a1b4b31f9a8b7c658a26e0e665eb536a0d757c45/data/demo_image.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")

model = YOLO(model="/home/[email protected]/repo/htrflow_core/.cache/yolov8n-seg.pt")

results = model([pillow2opencv(image)])
results = model([url])

print(results[0].segments[0])
2 changes: 1 addition & 1 deletion src/htrflow_core/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class Result:
image: The original imaage
metadata: Metadata associated with the result
segments: Segments (may be empty)
texts: Texts (may be empty)
data: Dict (may be empty)
"""

image: np.ndarray
Expand Down
75 changes: 65 additions & 10 deletions src/htrflow_core/utils/imgproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@
Image processing utilities
"""

import re
from pathlib import Path

import cv2
import numpy as np
import requests
from PIL import Image

from htrflow_core.utils.geometry import Bbox, Mask

Expand Down Expand Up @@ -52,32 +57,82 @@ def binarize(image: np.ndarray) -> np.ndarray:
return img_binarized


def url2pillow(url: str) -> Image:
"""Url to PIL"""
if _is_valid_url:
return Image.open(requests.get(url, stream=True).raw).convert("RGB")
else:
raise ValueError("Input is not a valid URL")


def pillow2opencv(image: np.ndarray) -> np.ndarray:
"""PIL to OpenCV"""
from PIL import Image

if isinstance(image, Image.Image):
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
else:
raise ValueError("Input must be a PIL Image")


def opencv2pillow(image: np.ndarray) -> "Image":
def opencv2pillow(image: np.ndarray) -> Image:
"""OpenCV to PIL"""
from PIL import Image

if isinstance(image, np.ndarray):
return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
else:
raise ValueError("Input must be an OpenCV image")


def read(source: str) -> np.ndarray:
img = cv2.imread(source)
if img is None:
raise RuntimeError(f"Could not load {source}")
return img
def is_http_url(string: str) -> bool:
"""Check if the string is a valid HTTP URL."""
return re.match(r"^https?://", string, re.IGNORECASE) is not None


def _is_valid_url(url: str) -> bool:
try:
response = requests.head(url, timeout=5, allow_redirects=True)
response.raise_for_status()
return True
except requests.RequestException:
return False


def read(source: str | np.ndarray | Path) -> np.ndarray:
"""Read an image from a URL, a local path, or directly use a numpy array as an OpenCV image.
Args:
source (Union[str, np.ndarray, Path]): The source can be a URL, a local filesystem path,
or a numpy array representing an image.
Returns:
np.ndarray: Image in OpenCV format.
Raises:
RuntimeError: If the image cannot be loaded from the given source.
ValueError: If the source type is unsupported.
"""
if isinstance(source, (np.ndarray)):
return source
elif isinstance(source, str):
if isinstance(source, str) and is_http_url(source):
pil_img = url2pillow(source)
return pillow2opencv(pil_img)
else:
img = cv2.imread(str(source))
if img is not None:
return img
else:
raise RuntimeError(f"Could not load the image from {source}")

else:
raise ValueError("Source must be a string URL, np.ndarray or a filesystem path")


def write(dest: str, image: np.ndarray) -> None:
cv2.imwrite(dest, image)


if __name__ == "__main__":
url = "http://www.example.com"
if _is_valid_url(url):
print("This URL is reachable and valid.")
else:
print("This URL is not valid or reachable.")

0 comments on commit c87b878

Please sign in to comment.