-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added Pipeline and PipelineStep classes, an entrypoint (main.py) and a demo pipeline.yaml. Works only with dummy models for now.
- Loading branch information
Showing
4 changed files
with
162 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
steps: | ||
- step: Segmentation | ||
- step: TextRecognition |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import logging | ||
from typing import Sequence | ||
|
||
from htrflow_core.pipeline.steps import PipelineStep, auto_import, init_step | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Pipeline: | ||
def __init__(self, steps: Sequence[PipelineStep]): | ||
self.steps = steps | ||
validate(self) | ||
|
||
@classmethod | ||
def from_config(self, config: dict[str, str]): | ||
"""Init pipeline from config""" | ||
return Pipeline([init_step(step) for step in config["steps"]]) | ||
|
||
def run(self, volume): | ||
"""Run pipeline on volume""" | ||
volume = auto_import(volume) | ||
for i, step in enumerate(self.steps): | ||
logger.info("Running step %s (step %d/%d)", step, i+1, len(self.steps)) | ||
volume = step.run(volume) | ||
return volume | ||
|
||
def metadata(self): | ||
return [step.metadata for step in self.steps] | ||
|
||
|
||
def validate(pipeline: Pipeline): | ||
steps = [step.__class__ for step in pipeline.steps] | ||
for i, step in enumerate(steps): | ||
for req_step in step.requires: | ||
if req_step not in steps[:i]: | ||
raise RuntimeError(f"Not valid pipeline: {step.__name__} must be preceded by {req_step.__name__}") | ||
logger.info("Validating pipeline: %s is preceded by %s - OK", step.__name__, req_step.__name__) | ||
logger.info("Pipeline passed validation") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import logging | ||
import os | ||
|
||
from htrflow_core.dummies.dummy_models import RecognitionModel, SegmentationModel, simple_word_segmentation | ||
from htrflow_core.volume.volume import Volume | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class PipelineStep: | ||
"""Pipeline step base class | ||
Class attributes: | ||
requires: A list of steps that need to precede this step. | ||
""" | ||
|
||
requires = [] | ||
|
||
def run(self, volume: Volume) -> Volume: | ||
"""Run step""" | ||
|
||
def __str__(self): | ||
return f"{self.__class__.__name__}" | ||
|
||
|
||
class Inference(PipelineStep): | ||
|
||
def __init__(self, model): | ||
self.model = model | ||
|
||
def run(self, volume): | ||
result = self.model(volume.segments()) | ||
volume.update(result) | ||
return volume | ||
|
||
|
||
class Segmentation(Inference): | ||
def __init__(self): | ||
super().__init__(SegmentationModel()) | ||
|
||
|
||
class TextRecognition(Inference): | ||
def __init__(self): | ||
super().__init__(RecognitionModel()) | ||
|
||
|
||
class WordSegmentation(PipelineStep): | ||
|
||
requires = [TextRecognition] | ||
|
||
def run(self, volume): | ||
results = simple_word_segmentation(volume.leaves()) | ||
volume.update(results) | ||
return volume | ||
|
||
|
||
class ExportStep(PipelineStep): | ||
def __init__(self, dest, serializer): | ||
self.dest = dest | ||
self.serializer = serializer | ||
|
||
def run(self, volume): | ||
volume.save(self.dest, self.serializer) | ||
return volume | ||
|
||
|
||
def auto_import(source) -> Volume: | ||
"""Import volume from `source` | ||
Automatically detects import type from the input. Supported types | ||
are: | ||
- directories with images | ||
""" | ||
if isinstance(source, Volume): | ||
return source | ||
elif isinstance(source, str): | ||
if os.path.isdir(source): | ||
logger.info("Loading volume from directory %s", source) | ||
return Volume.from_directory(source) | ||
raise ValueError(f"Could not infer import type for '{source}'") | ||
|
||
|
||
def all_subclasses(cls): | ||
return set(cls.__subclasses__()).union([s for c in cls.__subclasses__() for s in all_subclasses(c)]) | ||
|
||
# Mapping class name -> class | ||
# Ex. {segmentation: `steps.Segmentation`} | ||
STEPS = {cls_.__name__.lower(): cls_ for cls_ in all_subclasses(PipelineStep)} | ||
|
||
|
||
def init_step(step): | ||
name = step["step"].lower() | ||
kwargs = step.get("settings", {}) | ||
return STEPS[name](**kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
""" | ||
Entrypoint for HTRFLOW | ||
Usage: | ||
> python src/main.py <pipeline.yaml> <input_directory> | ||
Example: | ||
> python src/main.py data/pipelines/demo.yaml data/demo_images/A0068699 | ||
""" | ||
import logging | ||
import sys | ||
|
||
import yaml | ||
|
||
from htrflow_core.pipeline.pipeline import Pipeline | ||
|
||
|
||
logging.basicConfig(filename='htrflow.log', level=logging.INFO, filemode='w') | ||
|
||
if __name__ == '__main__': | ||
with open(sys.argv[1], 'r') as f: | ||
config = yaml.safe_load(f) | ||
|
||
pipe = Pipeline.from_config(config) | ||
pipe.run("data/demo_images/A0068699") |