diff --git a/data/pipelines/demo.yaml b/data/pipelines/demo.yaml new file mode 100644 index 0000000..67a24b5 --- /dev/null +++ b/data/pipelines/demo.yaml @@ -0,0 +1,3 @@ +steps: +- step: Segmentation +- step: TextRecognition diff --git a/src/htrflow_core/pipeline/pipeline.py b/src/htrflow_core/pipeline/pipeline.py new file mode 100644 index 0000000..10a01e4 --- /dev/null +++ b/src/htrflow_core/pipeline/pipeline.py @@ -0,0 +1,39 @@ +import logging +from typing import Sequence + +from htrflow_core.pipeline.steps import PipelineStep, auto_import, init_step + + +logger = logging.getLogger(__name__) + + +class Pipeline: + def __init__(self, steps: Sequence[PipelineStep]): + self.steps = steps + validate(self) + + @classmethod + def from_config(self, config: dict[str, str]): + """Init pipeline from config""" + return Pipeline([init_step(step) for step in config["steps"]]) + + def run(self, volume): + """Run pipeline on volume""" + volume = auto_import(volume) + for i, step in enumerate(self.steps): + logger.info("Running step %s (step %d/%d)", step, i+1, len(self.steps)) + volume = step.run(volume) + return volume + + def metadata(self): + return [step.metadata for step in self.steps] + + +def validate(pipeline: Pipeline): + steps = [step.__class__ for step in pipeline.steps] + for i, step in enumerate(steps): + for req_step in step.requires: + if req_step not in steps[:i]: + raise RuntimeError(f"Not valid pipeline: {step.__name__} must be preceded by {req_step.__name__}") + logger.info("Validating pipeline: %s is preceded by %s - OK", step.__name__, req_step.__name__) + logger.info("Pipeline passed validation") diff --git a/src/htrflow_core/pipeline/steps.py b/src/htrflow_core/pipeline/steps.py new file mode 100644 index 0000000..7c25b96 --- /dev/null +++ b/src/htrflow_core/pipeline/steps.py @@ -0,0 +1,95 @@ +import logging +import os + +from htrflow_core.dummies.dummy_models import RecognitionModel, SegmentationModel, simple_word_segmentation +from htrflow_core.volume.volume import Volume + + +logger = logging.getLogger(__name__) + + +class PipelineStep: + """Pipeline step base class + + Class attributes: + requires: A list of steps that need to precede this step. + """ + + requires = [] + + def run(self, volume: Volume) -> Volume: + """Run step""" + + def __str__(self): + return f"{self.__class__.__name__}" + + +class Inference(PipelineStep): + + def __init__(self, model): + self.model = model + + def run(self, volume): + result = self.model(volume.segments()) + volume.update(result) + return volume + + +class Segmentation(Inference): + def __init__(self): + super().__init__(SegmentationModel()) + + +class TextRecognition(Inference): + def __init__(self): + super().__init__(RecognitionModel()) + + +class WordSegmentation(PipelineStep): + + requires = [TextRecognition] + + def run(self, volume): + results = simple_word_segmentation(volume.leaves()) + volume.update(results) + return volume + + +class ExportStep(PipelineStep): + def __init__(self, dest, serializer): + self.dest = dest + self.serializer = serializer + + def run(self, volume): + volume.save(self.dest, self.serializer) + return volume + + +def auto_import(source) -> Volume: + """Import volume from `source` + + Automatically detects import type from the input. Supported types + are: + - directories with images + """ + if isinstance(source, Volume): + return source + elif isinstance(source, str): + if os.path.isdir(source): + logger.info("Loading volume from directory %s", source) + return Volume.from_directory(source) + raise ValueError(f"Could not infer import type for '{source}'") + + +def all_subclasses(cls): + return set(cls.__subclasses__()).union([s for c in cls.__subclasses__() for s in all_subclasses(c)]) + +# Mapping class name -> class +# Ex. {segmentation: `steps.Segmentation`} +STEPS = {cls_.__name__.lower(): cls_ for cls_ in all_subclasses(PipelineStep)} + + +def init_step(step): + name = step["step"].lower() + kwargs = step.get("settings", {}) + return STEPS[name](**kwargs) diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..7ec819d --- /dev/null +++ b/src/main.py @@ -0,0 +1,25 @@ +""" +Entrypoint for HTRFLOW + +Usage: +> python src/main.py + +Example: +> python src/main.py data/pipelines/demo.yaml data/demo_images/A0068699 +""" +import logging +import sys + +import yaml + +from htrflow_core.pipeline.pipeline import Pipeline + + +logging.basicConfig(filename='htrflow.log', level=logging.INFO, filemode='w') + +if __name__ == '__main__': + with open(sys.argv[1], 'r') as f: + config = yaml.safe_load(f) + + pipe = Pipeline.from_config(config) + pipe.run("data/demo_images/A0068699")