From f519c47ae4b6bb6d78e5610da83f570f89bf838f Mon Sep 17 00:00:00 2001 From: KamenDimitrov97 Date: Mon, 30 Sep 2024 17:49:19 +0300 Subject: [PATCH] feat: Abstracted write_rendered_output from the main dewret CLI --- src/dewret/__main__.py | 53 ++++++++++++++++++++++++------------------ src/dewret/render.py | 27 ++++++++++++++++++++- 2 files changed, 56 insertions(+), 24 deletions(-) diff --git a/src/dewret/__main__.py b/src/dewret/__main__.py index ced66e51..92d947d2 100644 --- a/src/dewret/__main__.py +++ b/src/dewret/__main__.py @@ -31,9 +31,14 @@ import click import json -from .core import set_configuration, set_render_configuration, RawRenderModule, StructuredRenderModule +from .core import ( + set_configuration, + set_render_configuration, + RawRenderModule, + StructuredRenderModule, +) from .utils import load_module_or_package -from .render import get_render_method +from .render import get_render_method, write_rendered_output from .tasks import Backend, construct @@ -72,7 +77,15 @@ @click.argument("task") @click.argument("arguments", nargs=-1) def render( - workflow_py: Path, task: str, arguments: list[str], pretty: bool, backend: Backend, construct_args: str, renderer: str, renderer_args: str, output: str + workflow_py: Path, + task: str, + arguments: list[str], + pretty: bool, + backend: Backend, + construct_args: str, + renderer: str, + renderer_args: str, + output: str, ) -> None: """Render a workflow. @@ -91,7 +104,7 @@ def render( kwargs[key] = json.loads(val) render_module: Path | ModuleType - if (mtch := re.match(r"^([a-z_0-9-.]+)$", renderer)): + if mtch := re.match(r"^([a-z_0-9-.]+)$", renderer): render_module = importlib.import_module(f"dewret.renderers.{mtch.group(1)}") if not isinstance(render_module, RawRenderModule) and not isinstance(render_module, StructuredRenderModule): raise NotImplementedError("The imported render module does not seem to match the `RawRenderModule` or `StructuredRenderModule` protocols.") @@ -118,18 +131,22 @@ def render( renderer_kwargs = dict(pair.split(":") for pair in renderer_args.split(",")) if output == "-": + @contextmanager def _opener(key: str, _: str) -> Generator[IO[Any], None, None]: print(" ------ ", key, " ------ ") yield sys.stdout print() + opener = _opener else: + @contextmanager def _opener(key: str, mode: str) -> Generator[IO[Any], None, None]: output_file = output.replace("%", key) with Path(output_file).open(mode) as output_f: yield output_f + opener = _opener render = get_render_method(render_module, pretty=pretty) @@ -138,30 +155,20 @@ def _opener(key: str, mode: str) -> Generator[IO[Any], None, None]: task_fn = getattr(workflow, task) try: - with set_configuration(**construct_kwargs), set_render_configuration(renderer_kwargs): - rendered = render(construct(task_fn(**kwargs), **construct_kwargs), **renderer_kwargs) + with ( + set_configuration(**construct_kwargs), + set_render_configuration(renderer_kwargs), + ): + rendered = render( + construct(task_fn(**kwargs), **construct_kwargs), **renderer_kwargs + ) except Exception as exc: import traceback print(exc, exc.__cause__, exc.__context__) traceback.print_exc() else: - if len(rendered) == 1: - with opener("", "w") as output_f: - output_f.write(rendered["__root__"]) - elif "%" in output: - for key, value in rendered.items(): - if key == "__root__": - key = "ROOT" - with opener(key, "w") as output_f: - output_f.write(value) - else: - with opener("ROOT", "w") as output_f: - output_f.write(rendered["__root__"]) - del rendered["__root__"] - for key, value in rendered.items(): - with opener(key, "a") as output_f: - output_f.write("\n---\n") - output_f.write(value) + write_rendered_output(rendered, output, opener) + render() diff --git a/src/dewret/render.py b/src/dewret/render.py index cf2bc95d..d702dcad 100644 --- a/src/dewret/render.py +++ b/src/dewret/render.py @@ -21,7 +21,7 @@ import sys from pathlib import Path from functools import partial -from typing import TypeVar, Callable, cast +from typing import TypeVar, Callable, ContextManager, IO, Any, cast import yaml from .workflow import Workflow, NestedStep @@ -100,6 +100,31 @@ def _render( ) +def write_rendered_output( + rendered: dict[str, str] | dict[str, RawType], + output: str, + opener: Callable[[str, str], ContextManager[IO[Any]]], +) -> None: + """Utility function to handle writing rendered output to file or stdout.""" + if len(rendered) == 1: + with opener("", "w") as output_f: + output_f.write(rendered["__root__"]) + elif "%" in output: + for key, value in rendered.items(): + if key == "__root__": + key = "ROOT" + with opener(key, "w") as output_f: + output_f.write(value) + else: + with opener("ROOT", "w") as output_f: + output_f.write(rendered["__root__"]) + del rendered["__root__"] + for key, value in rendered.items(): + with opener(key, "a") as output_f: + output_f.write("\n---\n") + output_f.write(value) + + def base_render(workflow: Workflow, build_cb: Callable[[Workflow], T]) -> dict[str, T]: """Render to a dict-like structure.