Skip to content

Commit

Permalink
Merged PR 8: fixes gh issue #26
Browse files Browse the repository at this point in the history
Related work items: #1303
  • Loading branch information
Gabriel Borg committed Feb 5, 2025
1 parent 11ab576 commit 14b64fe
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 20 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "htrflow"
version = "0.2.1"
version = "0.2.2"
description = "htrflow is developed at Riksarkivet's AI-lab as an open-source package to simplify HTR"
readme = "README.md"
license = {file = "LICENSE"}
Expand Down
33 changes: 14 additions & 19 deletions src/htrflow/models/teklia/pylaia.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
super().__init__(**kwargs)

model_info_dict: PyLaiaModelInfo = get_pylaia_model(model, revision=revision, use_binary_lm=use_binary_lm)
self.model = model_info_dict
self.model_dir = model_info_dict.model_dir
model_version = model_info_dict.model_version
self.use_language_model = model_info_dict.use_language_model
Expand All @@ -99,6 +100,9 @@ def _predict(self, images: list[np.ndarray], **decode_kwargs) -> list[Result]:
Batch size for decoding. Defaults to 1.
reading_order (str, optional):
Reading order for text recognition. Defaults to "LTR".
resize_input_height (int, optional):
If set, resizes input images to the specified height,
while maintaining aspect ratio. If `-1`, resizing is skipped. Defaults to 128.
num_workers (int, optional):
Number of workers for parallel processing. Defaults to `multiprocessing.cpu_count()`.
Expand All @@ -111,6 +115,7 @@ def _predict(self, images: list[np.ndarray], **decode_kwargs) -> list[Result]:
temperature = decode_kwargs.get("temperature", 1.0)
batch_size = decode_kwargs.get("batch_size", 1)
reading_order = decode_kwargs.get("reading_order", "LTR")
resize_input_height = decode_kwargs.get("resize_input_height", 128)
num_workers = decode_kwargs.get("num_workers", multiprocessing.cpu_count())

common_args = CommonArgs(
Expand Down Expand Up @@ -146,8 +151,8 @@ def _predict(self, images: list[np.ndarray], **decode_kwargs) -> list[Result]:
image_ids = [str(uuid4()) for _ in images]

for img_id, np_img in zip(image_ids, images):
padded_img = _ensure_min_height(np_img, 128) # Just to fix the min pixel height (defaults to 128)
cv2.imwrite(str(tmp_images_dir / f"{img_id}.jpg"), padded_img)
rezied_img = _ensure_fixed_height(np_img, resize_input_height)
cv2.imwrite(str(tmp_images_dir / f"{img_id}.jpg"), rezied_img)

with NamedTemporaryFile() as pred_stdout, NamedTemporaryFile() as img_list:
Path(img_list.name).write_text("\n".join(image_ids))
Expand Down Expand Up @@ -340,26 +345,16 @@ def _detect_language_model(model_dir: Path, use_binary_lm: bool) -> tuple[bool,
return use_language_model, language_model_params


def _ensure_min_height(img: np.ndarray, min_height: int) -> np.ndarray:
"""
Ensures an image meets a minimum height by resizing it if necessary.
This function is specifically designed to ensure compatibility with PyLaia models,
which require images to have a height of at least 128 pixels.
Args:
img (np.ndarray): Input image as a NumPy array.
min_height (int): Minimum height in pixels required for the image.
def _ensure_fixed_height(img: np.ndarray, target_height: int = 128) -> np.ndarray:
"""Ensures an image is always resized to a fixed height, maintaining aspect ratio.
Returns:
np.ndarray: The resized image if the original height is less than `min_height`.
Otherwise, the original image is returned unchanged.
If target_height is -1, the function returns the original image without resizing.
"""
if img.shape[0] < min_height:
if target_height > 0:
aspect_ratio = img.shape[1] / img.shape[0]
new_heigt = int(min_height * aspect_ratio)
new_shape = (min_height, new_heigt)

new_width = int(target_height * aspect_ratio)
new_shape = (target_height, new_width)
return imgproc.resize(img, new_shape)

return img

0 comments on commit 14b64fe

Please sign in to comment.