Skip to content

Commit

Permalink
WIP Batch volumes
Browse files Browse the repository at this point in the history
  • Loading branch information
viklofg committed May 7, 2024
1 parent a873b80 commit 199ddd1
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 43 deletions.
20 changes: 3 additions & 17 deletions src/htrflow_core/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
from pathlib import Path
from typing import List

import cowsay
import typer
Expand Down Expand Up @@ -35,16 +34,6 @@ def check_file_exists(file_path_str: str):
return file_path_str


def check_folder_exists(folder_paths_str: List[str]):
"""Check each path exists and is a folder."""
for folder_path_str in folder_paths_str:
folder_path = Path(folder_path_str)
if not folder_path.exists() or not folder_path.is_dir():
typer.echo(f"The path {folder_path} does not exist or is not a folder.")
raise typer.Exit(code=1)
return folder_paths_str


def validate_logfile_extension(logfile: str):
"""Ensure the logfile string has a .log extension."""
if logfile and not logfile.endswith(".log"):
Expand All @@ -58,9 +47,7 @@ def main(
pipeline: Annotated[
str, typer.Argument(..., help="Path to the pipeline config YAML file", callback=check_file_exists)
],
input_dirs: Annotated[
List[str], typer.Argument(..., help="Input directory or directories", callback=check_folder_exists)
],
inputs: Annotated[list[str], typer.Argument(..., help="Input paths")],
logfile: Annotated[
str,
typer.Option(help="Log file path", rich_help_panel="Secondary Arguments", callback=validate_logfile_extension),
Expand All @@ -78,10 +65,9 @@ def main(
hf_utils.HF_CONFIG |= config.get("huggingface_config", {})
pipe = Pipeline.from_config(config)

volume = auto_import(input_dirs)

typer.echo("Running Pipeline")
volume = pipe.run(volume)
for volume in auto_import(inputs):
volume = pipe.run(volume)
except Exception as e:
typer.echo(f"Error: {e}", err=True)
raise typer.Exit(code=1)
Expand Down
3 changes: 1 addition & 2 deletions src/htrflow_core/pipeline/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from typing import Sequence

from htrflow_core.pipeline.steps import PipelineStep, auto_import, init_step
from htrflow_core.pipeline.steps import PipelineStep, init_step


logger = logging.getLogger(__name__)
Expand All @@ -20,7 +20,6 @@ def from_config(self, config: dict[str, str]):

def run(self, volume, start=0):
"""Run pipeline on volume"""
volume = auto_import(volume)
for i, step in enumerate(self.steps[start:]):
step_name = f"{step} (step {start+i+1} / {len(self.steps)})"
logger.info("Running step %s", step_name)
Expand Down
37 changes: 14 additions & 23 deletions src/htrflow_core/pipeline/steps.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
from typing import Generator

from htrflow_core.dummies.dummy_models import simple_word_segmentation
from htrflow_core.models.importer import all_models
Expand Down Expand Up @@ -148,7 +149,7 @@ def run(self, volume):
raise Exception


def auto_import(source: Volume | list[str] | str) -> Volume:
def auto_import(source: Volume | list[str] | str, max_size: int = 100) -> Generator[Volume, None, None]:
"""Import volume from `source`
Automatically detects import type from the input. Supported types
Expand All @@ -159,36 +160,26 @@ def auto_import(source: Volume | list[str] | str) -> Volume:
- A volume instance (returns itself)
"""
if isinstance(source, Volume):
return source
yield source

# If source is a single string, treat it as a single-item list
# and continue
if isinstance(source, str):
source = [source]

if isinstance(source, list):
# Input is a single directory
if len(source) == 1:
if os.path.isdir(source[0]):
logger.info("Loading volume from directory %s", source[0])
return Volume.from_directory(source[0])
if source[0].endswith("pickle"):
return Volume.from_pickle(source[0])

# Input is a list of (potential) file paths, check each and
# keep only the ones that refers to files
paths = []
for path in source:
if not os.path.isfile(path):
logger.info("Skipping %s, not a regular file", path)
continue
paths.append(path)
all_paths = []
for path in source:
if os.path.isfile(path):
all_paths.append(path)
continue

if paths:
logger.info("Loading volume from %d file(s)", len(paths))
return Volume(paths)
for parent, _, files in os.walk(path):
for file in files:
all_paths.append(os.path.join(parent, file))

raise ValueError(f"Could not infer import type for '{source}'")
logger.info("Found %d files in %d input directories and/or files", len(all_paths), len(source))
for i in range(0, len(all_paths), max_size):
yield Volume(all_paths[i : i + max_size])


def all_subclasses(cls):
Expand Down
6 changes: 5 additions & 1 deletion src/htrflow_core/volume/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ class Volume(BaseDocumentNode):
"""

def __init__(self, paths: Iterable[str], label: str = "untitled_volume", label_format={}):
def __init__(self, paths: Iterable[str], label: str = None, label_format={}):
"""Initialize volume
Arguments:
Expand All @@ -228,6 +228,10 @@ def __init__(self, paths: Iterable[str], label: str = "untitled_volume", label_f
continue
self.children.append(page)

if label is None:
label = os.path.basename(os.path.commonpath(paths) if len(paths) > 1 else paths[0])
logger.info("No label provided, naming volume %s", label)

self._label_format = label_format
self.add_data(label=label)
logger.info("Initialized volume '%s' with %d pages", label, len(self.children))
Expand Down

0 comments on commit 199ddd1

Please sign in to comment.