Skip to content

Commit

Permalink
feat: Abstracted write_rendered_output from the main dewret CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
KamenDimitrov97 committed Sep 30, 2024
1 parent 3b9ccfe commit f519c47
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 24 deletions.
53 changes: 30 additions & 23 deletions src/dewret/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,14 @@
import click
import json

from .core import set_configuration, set_render_configuration, RawRenderModule, StructuredRenderModule
from .core import (
set_configuration,
set_render_configuration,
RawRenderModule,
StructuredRenderModule,
)
from .utils import load_module_or_package
from .render import get_render_method
from .render import get_render_method, write_rendered_output
from .tasks import Backend, construct


Expand Down Expand Up @@ -72,7 +77,15 @@
@click.argument("task")
@click.argument("arguments", nargs=-1)
def render(
workflow_py: Path, task: str, arguments: list[str], pretty: bool, backend: Backend, construct_args: str, renderer: str, renderer_args: str, output: str
workflow_py: Path,
task: str,
arguments: list[str],
pretty: bool,
backend: Backend,
construct_args: str,
renderer: str,
renderer_args: str,
output: str,
) -> None:
"""Render a workflow.
Expand All @@ -91,7 +104,7 @@ def render(
kwargs[key] = json.loads(val)

render_module: Path | ModuleType
if (mtch := re.match(r"^([a-z_0-9-.]+)$", renderer)):
if mtch := re.match(r"^([a-z_0-9-.]+)$", renderer):
render_module = importlib.import_module(f"dewret.renderers.{mtch.group(1)}")
if not isinstance(render_module, RawRenderModule) and not isinstance(render_module, StructuredRenderModule):
raise NotImplementedError("The imported render module does not seem to match the `RawRenderModule` or `StructuredRenderModule` protocols.")
Expand All @@ -118,18 +131,22 @@ def render(
renderer_kwargs = dict(pair.split(":") for pair in renderer_args.split(","))

if output == "-":

@contextmanager
def _opener(key: str, _: str) -> Generator[IO[Any], None, None]:
print(" ------ ", key, " ------ ")
yield sys.stdout
print()

opener = _opener
else:

@contextmanager
def _opener(key: str, mode: str) -> Generator[IO[Any], None, None]:
output_file = output.replace("%", key)
with Path(output_file).open(mode) as output_f:
yield output_f

opener = _opener

render = get_render_method(render_module, pretty=pretty)
Expand All @@ -138,30 +155,20 @@ def _opener(key: str, mode: str) -> Generator[IO[Any], None, None]:
task_fn = getattr(workflow, task)

try:
with set_configuration(**construct_kwargs), set_render_configuration(renderer_kwargs):
rendered = render(construct(task_fn(**kwargs), **construct_kwargs), **renderer_kwargs)
with (
set_configuration(**construct_kwargs),
set_render_configuration(renderer_kwargs),
):
rendered = render(
construct(task_fn(**kwargs), **construct_kwargs), **renderer_kwargs
)
except Exception as exc:
import traceback

print(exc, exc.__cause__, exc.__context__)
traceback.print_exc()
else:
if len(rendered) == 1:
with opener("", "w") as output_f:
output_f.write(rendered["__root__"])
elif "%" in output:
for key, value in rendered.items():
if key == "__root__":
key = "ROOT"
with opener(key, "w") as output_f:
output_f.write(value)
else:
with opener("ROOT", "w") as output_f:
output_f.write(rendered["__root__"])
del rendered["__root__"]
for key, value in rendered.items():
with opener(key, "a") as output_f:
output_f.write("\n---\n")
output_f.write(value)
write_rendered_output(rendered, output, opener)


render()
27 changes: 26 additions & 1 deletion src/dewret/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import sys
from pathlib import Path
from functools import partial
from typing import TypeVar, Callable, cast
from typing import TypeVar, Callable, ContextManager, IO, Any, cast
import yaml

from .workflow import Workflow, NestedStep
Expand Down Expand Up @@ -100,6 +100,31 @@ def _render(
)


def write_rendered_output(
rendered: dict[str, str] | dict[str, RawType],
output: str,
opener: Callable[[str, str], ContextManager[IO[Any]]],
) -> None:
"""Utility function to handle writing rendered output to file or stdout."""
if len(rendered) == 1:
with opener("", "w") as output_f:
output_f.write(rendered["__root__"])
elif "%" in output:
for key, value in rendered.items():
if key == "__root__":
key = "ROOT"
with opener(key, "w") as output_f:
output_f.write(value)
else:
with opener("ROOT", "w") as output_f:
output_f.write(rendered["__root__"])
del rendered["__root__"]
for key, value in rendered.items():
with opener(key, "a") as output_f:
output_f.write("\n---\n")
output_f.write(value)


def base_render(workflow: Workflow, build_cb: Callable[[Workflow], T]) -> dict[str, T]:
"""Render to a dict-like structure.
Expand Down

0 comments on commit f519c47

Please sign in to comment.