diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index fa3b9324..e9ef4e9f 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -76,9 +76,36 @@ jobs: - name: Pyleft Check run: pyleft luxonis_train - tests: + # Tests that the `luxonis-train.config.Config` works + # even when the dependencies of `luxonis-train` are + # not installed. + config-test: + runs-on: ubuntu-latest needs: - type-check + + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + ref: ${{ github.head_ref }} + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + cache: pip + + - name: Test config without dependencies + run: | + pip install luxonis-ml[utils] + pip install -e . --no-deps + python -c 'from luxonis_train.config import Config; \ + Config.get_config("configs/complex_model.yaml")' + + tests: + needs: + - config-test strategy: fail-fast: false matrix: diff --git a/luxonis_train/config/predefined_models/ocr_recognition_model.py b/luxonis_train/config/predefined_models/ocr_recognition_model.py index ffbf8e07..ba88d7f2 100644 --- a/luxonis_train/config/predefined_models/ocr_recognition_model.py +++ b/luxonis_train/config/predefined_models/ocr_recognition_model.py @@ -1,5 +1,6 @@ from typing import Literal, TypeAlias +from loguru import logger from pydantic import BaseModel from luxonis_train.config import ( @@ -9,10 +10,20 @@ ModelNodeConfig, Params, ) -from luxonis_train.utils.ocr import AlphabetName from .base_predefined_model import BasePredefinedModel +AlphabetName: TypeAlias = Literal[ + "english", + "english_lowercase", + "numeric", + "alphanumeric", + "alphanumeric_lowercase", + "punctuation", + "ascii", +] + + VariantLiteral: TypeAlias = Literal["light", "heavy"] @@ -72,7 +83,8 @@ def __init__( self.head_params = head_params or var_config.head_params self.backbone_params["max_text_len"] = max_text_len - self.head_params["alphabet"] = alphabet + + self.head_params["alphabet"] = self._generate_alphabet(alphabet) self.head_params["ignore_unknown"] = ignore_unknown self.loss_params = loss_params or {} self.visualizer_params = visualizer_params or {} @@ -139,3 +151,34 @@ def visualizers(self) -> list[AttachedModuleConfig]: params=self.visualizer_params, ) ] + + def _generate_alphabet( + self, alphabet: list[str] | AlphabetName + ) -> list[str]: + alphabets = { + "english": list( + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + ), + "english_lowercase": list("abcdefghijklmnopqrstuvwxyz"), + "numeric": list("0123456789"), + "alphanumeric": list( + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + ), + "alphanumeric_lowercase": list( + "abcdefghijklmnopqrstuvwxyz0123456789" + ), + "punctuation": list(" !\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"), + "ascii": list("".join(chr(i) for i in range(32, 127))), + } + + if isinstance(alphabet, str): + if alphabet not in alphabets: + raise ValueError( + f"Invalid alphabet name '{alphabet}'. " + f"Available options are: {list(alphabets.keys())}. " + f"Alternatively, you can provide a custom alphabet as a list of characters." + ) + logger.info(f"Using predefined alphabet '{alphabet}'.") + alphabet = alphabets[alphabet] + + return alphabet diff --git a/luxonis_train/nodes/heads/ocr_ctc_head.py b/luxonis_train/nodes/heads/ocr_ctc_head.py index 5b07e42c..0876b6f2 100644 --- a/luxonis_train/nodes/heads/ocr_ctc_head.py +++ b/luxonis_train/nodes/heads/ocr_ctc_head.py @@ -7,7 +7,6 @@ from luxonis_train.nodes.heads import BaseHead from luxonis_train.tasks import Tasks from luxonis_train.utils import OCRDecoder, OCREncoder -from luxonis_train.utils.ocr import AlphabetName class OCRCTCHead(BaseHead[Tensor, Tensor]): @@ -18,7 +17,7 @@ class OCRCTCHead(BaseHead[Tensor, Tensor]): def __init__( self, - alphabet: list[str] | AlphabetName = "english", + alphabet: list[str], ignore_unknown: bool = True, fc_decay: float = 0.0004, mid_channels: int | None = None, @@ -34,8 +33,8 @@ def __init__( @license: U{Apache License, Version 2.0 } - @type alphabet: list[str] | AlphabetName - @param alphabet: List of characters or a name of the alphabet. + @type alphabet: list[str] + @param alphabet: List of characters. @type ignore_unknown: bool @param ignore_unknown: Whether to ignore unknown characters. Defaults to True. diff --git a/luxonis_train/utils/ocr.py b/luxonis_train/utils/ocr.py index 8f9d18ed..5c4806fa 100644 --- a/luxonis_train/utils/ocr.py +++ b/luxonis_train/utils/ocr.py @@ -1,32 +1,7 @@ -from typing import Literal, TypeAlias - import numpy as np import torch -from loguru import logger from torch import Tensor -ALPHABETS = { - "english": list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"), - "english_lowercase": list("abcdefghijklmnopqrstuvwxyz"), - "numeric": list("0123456789"), - "alphanumeric": list( - "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - ), - "alphanumeric_lowercase": list("abcdefghijklmnopqrstuvwxyz0123456789"), - "punctuation": list(" !\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"), - "ascii": list("".join(chr(i) for i in range(32, 127))), -} - -AlphabetName: TypeAlias = Literal[ - "english", - "english_lowercase", - "numeric", - "alphanumeric", - "alphanumeric_lowercase", - "punctuation", - "ascii", -] - class OCRDecoder: """OCR decoder for converting model predictions to text.""" @@ -104,29 +79,18 @@ class OCREncoder: def __init__( self, - alphabet: list[str] | AlphabetName = "english", + alphabet: list[str], ignore_unknown: bool = True, ): """Initializes the OCR encoder. - @type alphabet: list[str] | AlphabetName - @param alphabet: A list of characters in the alphabet or a name - of a predefined alphabet. Defaults to "english". + @type alphabet: list[str] + @param alphabet: A list of characters in the alphabet. @type ignore_unknown: bool @param ignore_unknown: Whether to ignore unknown characters. Defaults to True. """ - if isinstance(alphabet, str): - if alphabet not in ALPHABETS: - raise ValueError( - f"Invalid alphabet name '{alphabet}'. " - f"Available options are: {list(ALPHABETS.keys())}. " - f"Alternatively, you can provide a custom alphabet as a list of characters." - ) - alphabet = ALPHABETS[alphabet] - logger.info(f"Using predefined alphabet '{alphabet}'.") - self._alphabet = [""] + list(np.unique(alphabet)) self.char_to_int = {char: i for i, char in enumerate(self._alphabet)}