Skip to content

Commit

Permalink
Allow setting export mode in config
Browse files Browse the repository at this point in the history
main.py now reads the input, runs the pipeline, and exports the result
  • Loading branch information
viklofg committed Apr 10, 2024
1 parent c6ef377 commit 03c9087
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 16 deletions.
5 changes: 5 additions & 0 deletions data/pipelines/demo.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
steps:
- step: Segmentation
- step: TextRecognition

export:
format: Json
settings:
one_file: True
10 changes: 0 additions & 10 deletions src/htrflow_core/pipeline/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,6 @@ def run(self, volume):
return volume


class ExportStep(PipelineStep):
def __init__(self, dest, serializer):
self.dest = dest
self.serializer = serializer

def run(self, volume):
volume.save(self.dest, self.serializer)
return volume


def auto_import(source) -> Volume:
"""Import volume from `source`
Expand Down
10 changes: 5 additions & 5 deletions src/htrflow_core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,11 @@ def supported_formats():
return [cls.format_name for cls in Serializer.__subclasses__()]


def _get_serializer(format_name):
def get_serializer(serializer_name, **serializer_args):
for cls in Serializer.__subclasses__():
if cls.format_name.lower() == format_name.lower():
return cls()
msg = f"Format '{format_name}' is not among the supported formats: {supported_formats()}"
if cls.format_name.lower() == serializer_name.lower():
return cls(**serializer_args)
msg = f"Format '{serializer_name}' is not among the supported formats: {supported_formats()}"
raise ValueError(msg)


Expand All @@ -216,7 +216,7 @@ def save_volume(volume: Volume, serializer: str | Serializer, dest: str) -> Iter
"""

if isinstance(serializer, str):
serializer = _get_serializer(serializer)
serializer = get_serializer(serializer)
logger.info("Using %s serializer with default settings", serializer.__class__.__name__)

for doc, filename in serializer.serialize_volume(volume):
Expand Down
6 changes: 5 additions & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import yaml

from htrflow_core.pipeline.pipeline import Pipeline
from htrflow_core.pipeline.steps import auto_import
from htrflow_core.serialization import get_serializer


logging.basicConfig(filename='htrflow.log', level=logging.INFO, filemode='w')
Expand All @@ -22,4 +24,6 @@
config = yaml.safe_load(f)

pipe = Pipeline.from_config(config)
pipe.run(sys.argv[2])
volume = auto_import(sys.argv[2])
volume = pipe.run(volume)
volume.save(serializer=get_serializer(config["export"]["format"], **config["export"].get("settings", {})))

0 comments on commit 03c9087

Please sign in to comment.