Skip to content

Commit

Permalink
TLDR-716 columns orientation classifier update (#477)
Browse files Browse the repository at this point in the history
Co-authored-by: Alexander Golodkov <[email protected]>
Co-authored-by: Nasty <[email protected]>
  • Loading branch information
3 people authored Aug 5, 2024
1 parent df0c03d commit 5c597a0
Show file tree
Hide file tree
Showing 11 changed files with 63 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_labeling.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ jobs:
python-version: '3.9'
- name: Run tests for labeling
run: |
test="true" docker-compose -f labeling/docker-compose.yml up --build --exit-code-from test
test="true" docker compose -f labeling/docker-compose.yml up --build --exit-code-from test
2 changes: 1 addition & 1 deletion .github/workflows/test_on_push.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ jobs:
flake8 .
- name: Run tests
run: |
test="true" docker-compose up --build --exit-code-from test
test="true" docker compose up --build --exit-code-from test
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,12 @@ cd dedoc

### 3. Build the image and run the application
```shell
docker-compose up --build
docker compose up --build
```

### 4. Run container with tests
```shell
test="true" docker-compose up --build
test="true" docker compose up --build
```

If you need to change some application settings, you may update `config.py` according to your needs and re-build the image.
Expand Down
2 changes: 1 addition & 1 deletion dedoc/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""
model_hash_dict = dict(
txtlayer_classifier="9ca1de749d8d37147b00a3a228e03ee1776c695f",
scan_orientation_efficient_net_b0="9ea283f3d346ae4fdd82463a9f60b5369a3ffb58",
scan_orientation_efficient_net_b0="c60812552a1be624476c1e5b58599867b36f8d4e",
font_classifier="db4481ad60ab050cbb42079b64f97f9e431feb07",
paragraph_classifier="c26a10193499d3cbc77ffec9842bece24fa8950b",
line_type_classifiers="0568c6e1f49612c0c351f10b80a26dc05f796683",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
import warnings
from os import path
from typing import Optional, Tuple
Expand Down Expand Up @@ -30,11 +31,9 @@ def __init__(self, on_gpu: bool, checkpoint_path: Optional[str], *, config: dict
@property
def net(self) -> ClassificationModelTorch:
if self._net is None:
net = ClassificationModelTorch(self.checkpoint_path)
if self.checkpoint_path is not None:
net = ClassificationModelTorch(path.join(self.checkpoint_path, "scan_orientation_efficient_net_b0.pth"))
self._load_weights(net)
else:
net = ClassificationModelTorch(None)
self._net = net
self._net.to(self.device)
return self._net
Expand All @@ -61,17 +60,18 @@ def _set_device(self, on_gpu: bool) -> None:
self.logger.warning(f"Classifier is set to device {self.device}")

def _load_weights(self, net: ClassificationModelTorch) -> None:
path_checkpoint = path.join(self.checkpoint_path, "scan_orientation_efficient_net_b0.pth")
if not path.isfile(path_checkpoint):
download_from_hub(out_dir=self.checkpoint_path,
if not path.isfile(self.checkpoint_path):
from dedoc.config import get_config
self.checkpoint_path = os.path.join(get_config()["resources_path"], "scan_orientation_efficient_net_b0.pth")
download_from_hub(out_dir=os.path.dirname(os.path.abspath(self.checkpoint_path)),
out_name="scan_orientation_efficient_net_b0.pth",
repo_name="scan_orientation_efficient_net_b0",
hub_name="model.pth")

with warnings.catch_warnings():
warnings.simplefilter("ignore")
net.load_state_dict(torch.load(path_checkpoint, map_location=self.location))
self.logger.info(f"Weights were loaded from {path_checkpoint}")
net.load_state_dict(torch.load(self.checkpoint_path, map_location=self.location))
self.logger.info(f"Weights were loaded from {self.checkpoint_path}")

def save_weights(self, path_checkpoint: str) -> None:
torch.save(self.net.state_dict(), path_checkpoint)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import List, Optional, Tuple

from numpy import ndarray
Expand Down Expand Up @@ -46,7 +47,9 @@ def __init__(self, *, config: Optional[dict] = None) -> None:
)
self.skew_corrector = SkewCorrector()
self.column_orientation_classifier = ColumnsOrientationClassifier(on_gpu=self.config.get("on_gpu", False),
checkpoint_path=get_config()["resources_path"], config=self.config)
checkpoint_path=os.path.join(get_config()["resources_path"],
"scan_orientation_efficient_net_b0.pth"),
config=self.config)
self.binarizer = AdaptiveBinarizer()
self.ocr = OCRLineExtractor(config=self.config)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/getting_started/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ still, the docker application should be installed and configured properly.

.. code-block:: bash
docker-compose up --build
docker compose up --build
If you need to change some application settings, you may update ``config.py`` according to your needs and re-build the image.

Expand Down
12 changes: 6 additions & 6 deletions resources/benchmarks/orient_classifier_scores.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@ Orientation predictions:
+-------+-----------+--------+-------+-------+
| Class | Precision | Recall | F1 | Count |
+=======+===========+========+=======+=======+
| 0 | 0.998 | 1 | 0.999 | 537 |
| 0 | 0.998 | 1 | 0.999 | 825 |
+-------+-----------+--------+-------+-------+
| 90 | 1 | 0.998 | 0.999 | 537 |
| 90 | 1 | 0.999 | 0.999 | 825 |
+-------+-----------+--------+-------+-------+
| 180 | 1 | 0.998 | 0.999 | 537 |
| 180 | 1 | 0.998 | 0.999 | 825 |
+-------+-----------+--------+-------+-------+
| 270 | 0.998 | 1 | 0.999 | 537 |
| 270 | 0.999 | 1 | 0.999 | 825 |
+-------+-----------+--------+-------+-------+
| AVG | 0.999 | 0.999 | 0.999 | None |
+-------+-----------+--------+-------+-------+
Column predictions:
+-------+-----------+--------+-------+-------+
| Class | Precision | Recall | F1 | Count |
+=======+===========+========+=======+=======+
| 1 | 1 | 0.999 | 0.999 | 1692 |
| 1 | 0.999 | 1 | 0.999 | 1944 |
+-------+-----------+--------+-------+-------+
| 2 | 0.996 | 1 | 0.998 | 456 |
| 2 | 1 | 0.999 | 0.999 | 1356 |
+-------+-----------+--------+-------+-------+
| AVG | 0.999 | 0.999 | 0.999 | None |
+-------+-----------+--------+-------+-------+
42 changes: 37 additions & 5 deletions scripts/train/train_eval_orientation_classifier.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import argparse
import os
import shutil
import zipfile
from time import time
from typing import List

import numpy as np
import torch
from huggingface_hub import hf_hub_download
from sklearn.metrics import precision_recall_fscore_support
from texttable import Texttable
from torch import nn
Expand All @@ -19,17 +22,18 @@
parser = argparse.ArgumentParser()
checkpoint_path_save = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "resources", "efficient_net_b0_fixed.pth"))
checkpoint_path_load = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "resources", "efficient_net_b0_fixed.pth"))
checkpoint_path = "../../resources"
output_dir = os.path.abspath(os.path.join(checkpoint_path, "benchmarks"))
output_dir = os.path.abspath(os.path.join("..", "..", "resources", "benchmarks"))

parser.add_argument("-t", "--train", type=bool, help="run for train model", default=False)
parser.add_argument("-s", "--checkpoint_save", help="Path to checkpoint for save or load", default=checkpoint_path_save)
parser.add_argument("-l", "--checkpoint_load", help="Path to checkpoint for load", default=checkpoint_path_load)
parser.add_argument("-f", "--from_checkpoint", type=bool, help="run for train model", default=True)
parser.add_argument("-d", "--input_data_folder", help="Path to data with folders train or test")
parser.add_argument("-d", "--input_data_folder", help="Path to data with folders train or test",
default=os.path.join(get_config()["intermediate_data_path"], "orientation_columns_dataset"))
parser.add_argument("-b", "--batch_size", type=int, help="Batch size", default=1)

args = parser.parse_args()
BATCH_SIZE = 1
BATCH_SIZE = args.batch_size
ON_GPU = True

"""
Expand Down Expand Up @@ -191,10 +195,38 @@ def train_step(data_executor: DataLoaderImageOrient, classifier: ColumnsOrientat
train_model(trainloader, args.checkpoint_save, classifier)


def create_dataset() -> None:
if os.path.isdir(args.input_data_folder):
return
# download source files
datasets_path = os.path.join(get_config()["resources_path"], "datasets")
os.makedirs(datasets_path, exist_ok=True)
intermediate_path = os.path.realpath(hf_hub_download(repo_id="dedoc/orientation_columns_dataset",
filename="generate_dataset_orient_classifier.zip",
repo_type="dataset",
revision="902cc77dbd28e63dbb74dfc14a7a7b198e9d6f9d"))
source_dataset_archive = os.path.join(datasets_path, "generate_dataset_orient_classifier.zip")
shutil.move(intermediate_path, source_dataset_archive)

with zipfile.ZipFile(source_dataset_archive, "r") as zip_ref:
zip_ref.extractall(datasets_path)
os.remove(source_dataset_archive)

# rotate source files
src_pics_path = os.path.join(datasets_path, "generate_dataset_orient_classifier", "src")
scripts_path = os.path.join(datasets_path, "generate_dataset_orient_classifier", "scripts")
final_dataset_folder = os.path.join(get_config()["resources_path"], "datasets", "columns_orientation_dataset")
os.makedirs(final_dataset_folder, exist_ok=True)

os.system(f"python3 {os.path.join(scripts_path, 'gen_dataset.py')} -i {src_pics_path} -o {final_dataset_folder}")
setattr(args, "input_data_folder", final_dataset_folder) # noqa: B010


if __name__ == "__main__":
config = get_config()
data_executor = DataLoaderImageOrient()
net = ColumnsOrientationClassifier(on_gpu=ON_GPU, checkpoint_path=checkpoint_path if not args.train else "", config=config)
create_dataset()
net = ColumnsOrientationClassifier(on_gpu=ON_GPU, checkpoint_path=args.checkpoint_load if args.from_checkpoint else "", config=config)
if args.train:
train_step(data_executor, net)
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/test_format_pdf_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


class TestPDFReader(unittest.TestCase):
checkpoint_path = get_test_config()["resources_path"]
checkpoint_path = os.path.join(get_test_config()["resources_path"], "scan_orientation_efficient_net_b0.pth")
config = get_test_config()
orientation_classifier = ColumnsOrientationClassifier(on_gpu=False, checkpoint_path=checkpoint_path, config=config)

Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/test_misc_on_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_line_type_classifier(self) -> None:
self.assertListEqual(predictions, ["header", "header", "cellar"])

def test_orientation_classifier(self) -> None:
checkpoint_path = get_test_config()["resources_path"]
checkpoint_path = os.path.join(get_test_config()["resources_path"], "scan_orientation_efficient_net_b0.pth")
orientation_classifier = ColumnsOrientationClassifier(on_gpu=self.config.get("on_gpu", False), checkpoint_path=checkpoint_path, config=self.config)
imgs_path = [f"../data/skew_corrector/rotated_{i}.jpg" for i in range(1, 5)]

Expand Down

0 comments on commit 5c597a0

Please sign in to comment.