Skip to content

Commit

Permalink
Add models to pipeline
Browse files Browse the repository at this point in the history
This lets the user specify model type and initialization and inference
arguments from the pipeline config, see the demo pipeline for example
usage.
  • Loading branch information
viklofg committed Apr 10, 2024
1 parent 42980ac commit f135efe
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 9 deletions.
9 changes: 9 additions & 0 deletions data/pipelines/demo.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
steps:
- step: Segmentation
settings:
model: yolo
- step: Segmentation
settings:
model: segmentationmodel
- step: TextRecognition
settings:
model: trocr
generation_settings:
num_beams: 4

export:
format: Json
Expand Down
35 changes: 26 additions & 9 deletions src/htrflow_core/pipeline/steps.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import logging
import os

from htrflow_core.dummies.dummy_models import RecognitionModel, SegmentationModel, simple_word_segmentation
# Imported-but-unused models are needed here in order for
# `all_subclasses` to find them
from htrflow_core.dummies.dummy_models import RecognitionModel, SegmentationModel, simple_word_segmentation # noqa: F401
from htrflow_core.volume.volume import Volume
from htrflow_core.models.base_model import BaseModel
from htrflow_core.models.huggingface.trocr import TrOCR # noqa: F401
from htrflow_core.models.ultralytics.yolo import YOLO # noqa: F401


logger = logging.getLogger(__name__)
Expand All @@ -17,6 +22,10 @@ class PipelineStep:

requires = []

@classmethod
def from_config(cls, config):
return cls(**config)

def run(self, volume: Volume) -> Volume:
"""Run step"""

Expand All @@ -26,23 +35,30 @@ def __str__(self):

class Inference(PipelineStep):

def __init__(self, model):
def __init__(self, model, generation_kwargs):
self.model = model
self.generation_kwargs = generation_kwargs

@classmethod
def from_config(cls, config):
name = config["model"].lower()
init_kwargs = config.get("model_settings", {})
model = MODELS[name](**init_kwargs)
generation_kwargs = config.get("generation_settings", {})
return cls(model, generation_kwargs)

def run(self, volume):
result = self.model(volume.segments())
result = self.model(volume.segments(), **self.generation_kwargs)
volume.update(result)
return volume


class Segmentation(Inference):
def __init__(self):
super().__init__(SegmentationModel())
pass


class TextRecognition(Inference):
def __init__(self):
super().__init__(RecognitionModel())
pass


class WordSegmentation(PipelineStep):
Expand Down Expand Up @@ -77,9 +93,10 @@ def all_subclasses(cls):
# Mapping class name -> class
# Ex. {segmentation: `steps.Segmentation`}
STEPS = {cls_.__name__.lower(): cls_ for cls_ in all_subclasses(PipelineStep)}
MODELS = {cls_.__name__.lower(): cls_ for cls_ in all_subclasses(BaseModel)}


def init_step(step):
name = step["step"].lower()
kwargs = step.get("settings", {})
return STEPS[name](**kwargs)
config = step.get("settings", {})
return STEPS[name].from_config(config)

0 comments on commit f135efe

Please sign in to comment.