Skip to content

Commit

Permalink
Moved alphabet resolving for OCR model (#168)
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov721 authored Feb 7, 2025
1 parent 9948244 commit cb62818
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 46 deletions.
29 changes: 28 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,36 @@ jobs:
- name: Pyleft Check
run: pyleft luxonis_train

tests:
# Tests that the `luxonis-train.config.Config` works
# even when the dependencies of `luxonis-train` are
# not installed.
config-test:
runs-on: ubuntu-latest
needs:
- type-check

steps:
- name: Checkout
uses: actions/checkout@v4
with:
ref: ${{ github.head_ref }}

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: pip

- name: Test config without dependencies
run: |
pip install luxonis-ml[utils]
pip install -e . --no-deps
python -c 'from luxonis_train.config import Config; \
Config.get_config("configs/complex_model.yaml")'
tests:
needs:
- config-test
strategy:
fail-fast: false
matrix:
Expand Down
47 changes: 45 additions & 2 deletions luxonis_train/config/predefined_models/ocr_recognition_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Literal, TypeAlias

from loguru import logger
from pydantic import BaseModel

from luxonis_train.config import (
Expand All @@ -9,10 +10,20 @@
ModelNodeConfig,
Params,
)
from luxonis_train.utils.ocr import AlphabetName

from .base_predefined_model import BasePredefinedModel

AlphabetName: TypeAlias = Literal[
"english",
"english_lowercase",
"numeric",
"alphanumeric",
"alphanumeric_lowercase",
"punctuation",
"ascii",
]


VariantLiteral: TypeAlias = Literal["light", "heavy"]


Expand Down Expand Up @@ -72,7 +83,8 @@ def __init__(
self.head_params = head_params or var_config.head_params

self.backbone_params["max_text_len"] = max_text_len
self.head_params["alphabet"] = alphabet

self.head_params["alphabet"] = self._generate_alphabet(alphabet)
self.head_params["ignore_unknown"] = ignore_unknown
self.loss_params = loss_params or {}
self.visualizer_params = visualizer_params or {}
Expand Down Expand Up @@ -139,3 +151,34 @@ def visualizers(self) -> list[AttachedModuleConfig]:
params=self.visualizer_params,
)
]

def _generate_alphabet(
self, alphabet: list[str] | AlphabetName
) -> list[str]:
alphabets = {
"english": list(
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
),
"english_lowercase": list("abcdefghijklmnopqrstuvwxyz"),
"numeric": list("0123456789"),
"alphanumeric": list(
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
),
"alphanumeric_lowercase": list(
"abcdefghijklmnopqrstuvwxyz0123456789"
),
"punctuation": list(" !\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"),
"ascii": list("".join(chr(i) for i in range(32, 127))),
}

if isinstance(alphabet, str):
if alphabet not in alphabets:
raise ValueError(
f"Invalid alphabet name '{alphabet}'. "
f"Available options are: {list(alphabets.keys())}. "
f"Alternatively, you can provide a custom alphabet as a list of characters."
)
logger.info(f"Using predefined alphabet '{alphabet}'.")
alphabet = alphabets[alphabet]

return alphabet
7 changes: 3 additions & 4 deletions luxonis_train/nodes/heads/ocr_ctc_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from luxonis_train.nodes.heads import BaseHead
from luxonis_train.tasks import Tasks
from luxonis_train.utils import OCRDecoder, OCREncoder
from luxonis_train.utils.ocr import AlphabetName


class OCRCTCHead(BaseHead[Tensor, Tensor]):
Expand All @@ -18,7 +17,7 @@ class OCRCTCHead(BaseHead[Tensor, Tensor]):

def __init__(
self,
alphabet: list[str] | AlphabetName = "english",
alphabet: list[str],
ignore_unknown: bool = True,
fc_decay: float = 0.0004,
mid_channels: int | None = None,
Expand All @@ -34,8 +33,8 @@ def __init__(
@license: U{Apache License, Version 2.0
<https://github.com/PaddlePaddle/PaddleOCR/blob/main/LICENSE
>}
@type alphabet: list[str] | AlphabetName
@param alphabet: List of characters or a name of the alphabet.
@type alphabet: list[str]
@param alphabet: List of characters.
@type ignore_unknown: bool
@param ignore_unknown: Whether to ignore unknown characters.
Defaults to True.
Expand Down
42 changes: 3 additions & 39 deletions luxonis_train/utils/ocr.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,7 @@
from typing import Literal, TypeAlias

import numpy as np
import torch
from loguru import logger
from torch import Tensor

ALPHABETS = {
"english": list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"),
"english_lowercase": list("abcdefghijklmnopqrstuvwxyz"),
"numeric": list("0123456789"),
"alphanumeric": list(
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
),
"alphanumeric_lowercase": list("abcdefghijklmnopqrstuvwxyz0123456789"),
"punctuation": list(" !\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"),
"ascii": list("".join(chr(i) for i in range(32, 127))),
}

AlphabetName: TypeAlias = Literal[
"english",
"english_lowercase",
"numeric",
"alphanumeric",
"alphanumeric_lowercase",
"punctuation",
"ascii",
]


class OCRDecoder:
"""OCR decoder for converting model predictions to text."""
Expand Down Expand Up @@ -104,29 +79,18 @@ class OCREncoder:

def __init__(
self,
alphabet: list[str] | AlphabetName = "english",
alphabet: list[str],
ignore_unknown: bool = True,
):
"""Initializes the OCR encoder.
@type alphabet: list[str] | AlphabetName
@param alphabet: A list of characters in the alphabet or a name
of a predefined alphabet. Defaults to "english".
@type alphabet: list[str]
@param alphabet: A list of characters in the alphabet.
@type ignore_unknown: bool
@param ignore_unknown: Whether to ignore unknown characters.
Defaults to True.
"""

if isinstance(alphabet, str):
if alphabet not in ALPHABETS:
raise ValueError(
f"Invalid alphabet name '{alphabet}'. "
f"Available options are: {list(ALPHABETS.keys())}. "
f"Alternatively, you can provide a custom alphabet as a list of characters."
)
alphabet = ALPHABETS[alphabet]
logger.info(f"Using predefined alphabet '{alphabet}'.")

self._alphabet = [""] + list(np.unique(alphabet))
self.char_to_int = {char: i for i, char in enumerate(self._alphabet)}

Expand Down

0 comments on commit cb62818

Please sign in to comment.