Skip to content

Commit

Permalink
Add pipeline skeleton
Browse files Browse the repository at this point in the history
Added Pipeline and PipelineStep classes, an entrypoint (main.py) and a
demo pipeline.yaml. Works only with dummy models for now.
  • Loading branch information
viklofg committed Apr 10, 2024
1 parent c87b878 commit 44015ae
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 0 deletions.
3 changes: 3 additions & 0 deletions data/pipelines/demo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
steps:
- step: Segmentation
- step: TextRecognition
39 changes: 39 additions & 0 deletions src/htrflow_core/pipeline/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import logging
from typing import Sequence

from htrflow_core.pipeline.steps import PipelineStep, auto_import, init_step


logger = logging.getLogger(__name__)


class Pipeline:
def __init__(self, steps: Sequence[PipelineStep]):
self.steps = steps
validate(self)

@classmethod
def from_config(self, config: dict[str, str]):
"""Init pipeline from config"""
return Pipeline([init_step(step) for step in config["steps"]])

def run(self, volume):
"""Run pipeline on volume"""
volume = auto_import(volume)
for i, step in enumerate(self.steps):
logger.info("Running step %s (step %d/%d)", step, i+1, len(self.steps))
volume = step.run(volume)
return volume

def metadata(self):
return [step.metadata for step in self.steps]


def validate(pipeline: Pipeline):
steps = [step.__class__ for step in pipeline.steps]
for i, step in enumerate(steps):
for req_step in step.requires:
if req_step not in steps[:i]:
raise RuntimeError(f"Not valid pipeline: {step.__name__} must be preceded by {req_step.__name__}")
logger.info("Validating pipeline: %s is preceded by %s - OK", step.__name__, req_step.__name__)
logger.info("Pipeline passed validation")
95 changes: 95 additions & 0 deletions src/htrflow_core/pipeline/steps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import logging
import os

from htrflow_core.dummies.dummy_models import RecognitionModel, SegmentationModel, simple_word_segmentation
from htrflow_core.volume.volume import Volume


logger = logging.getLogger(__name__)


class PipelineStep:
"""Pipeline step base class
Class attributes:
requires: A list of steps that need to precede this step.
"""

requires = []

def run(self, volume: Volume) -> Volume:
"""Run step"""

def __str__(self):
return f"{self.__class__.__name__}"


class Inference(PipelineStep):

def __init__(self, model):
self.model = model

def run(self, volume):
result = self.model(volume.segments())
volume.update(result)
return volume


class Segmentation(Inference):
def __init__(self):
super().__init__(SegmentationModel())


class TextRecognition(Inference):
def __init__(self):
super().__init__(RecognitionModel())


class WordSegmentation(PipelineStep):

requires = [TextRecognition]

def run(self, volume):
results = simple_word_segmentation(volume.leaves())
volume.update(results)
return volume


class ExportStep(PipelineStep):
def __init__(self, dest, serializer):
self.dest = dest
self.serializer = serializer

def run(self, volume):
volume.save(self.dest, self.serializer)
return volume


def auto_import(source) -> Volume:
"""Import volume from `source`
Automatically detects import type from the input. Supported types
are:
- directories with images
"""
if isinstance(source, Volume):
return source
elif isinstance(source, str):
if os.path.isdir(source):
logger.info("Loading volume from directory %s", source)
return Volume.from_directory(source)
raise ValueError(f"Could not infer import type for '{source}'")


def all_subclasses(cls):
return set(cls.__subclasses__()).union([s for c in cls.__subclasses__() for s in all_subclasses(c)])

# Mapping class name -> class
# Ex. {segmentation: `steps.Segmentation`}
STEPS = {cls_.__name__.lower(): cls_ for cls_ in all_subclasses(PipelineStep)}


def init_step(step):
name = step["step"].lower()
kwargs = step.get("settings", {})
return STEPS[name](**kwargs)
25 changes: 25 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""
Entrypoint for HTRFLOW
Usage:
> python src/main.py <pipeline.yaml> <input_directory>
Example:
> python src/main.py data/pipelines/demo.yaml data/demo_images/A0068699
"""
import logging
import sys

import yaml

from htrflow_core.pipeline.pipeline import Pipeline


logging.basicConfig(filename='htrflow.log', level=logging.INFO, filemode='w')

if __name__ == '__main__':
with open(sys.argv[1], 'r') as f:
config = yaml.safe_load(f)

pipe = Pipeline.from_config(config)
pipe.run("data/demo_images/A0068699")

0 comments on commit 44015ae

Please sign in to comment.