From d1a4fc01c31bca99ddcf756995423eaf0bdd21f1 Mon Sep 17 00:00:00 2001 From: Nate Baer Date: Tue, 24 Dec 2024 14:31:00 -0800 Subject: [PATCH] Improve resolving template vars --- src/lm_executors/chat_executor.py | 43 +++++++++++++++++++++++-------- src/resolve_vars.py | 36 -------------------------- 2 files changed, 32 insertions(+), 47 deletions(-) delete mode 100644 src/resolve_vars.py diff --git a/src/lm_executors/chat_executor.py b/src/lm_executors/chat_executor.py index ca16759..8f9d77f 100644 --- a/src/lm_executors/chat_executor.py +++ b/src/lm_executors/chat_executor.py @@ -1,8 +1,9 @@ +import os + import jinja2 import yaml from ..api_client import OpenRouterAPIClient -from ..resolve_vars import resolve_vars class ChatExecutor: @@ -28,23 +29,43 @@ async def execute(self): return completion def _build_messages(self): - vars = resolve_vars( - { - **self.context.vars, - **self._extra_template_vars(), - }, - self.context.dir, - ) - vars["messages"] = self.context.conversation_messages + resolved_vars = self._resolve_vars() + resolved_vars["messages"] = self.context.conversation_messages with open(self.TEMPLATE_PATH) as file: template = jinja2.Template( file.read(), trim_blocks=True, lstrip_blocks=True ) - rendered_str = template.render(vars) + rendered_str = template.render(resolved_vars) return yaml.safe_load(rendered_str) - def _extra_template_vars(self): + def _resolve_vars(self): + template_vars = self._template_vars() + env = jinja2.Environment(trim_blocks=True, lstrip_blocks=True, autoescape=False) + + def load_file(filepath): + full_path = os.path.abspath(os.path.join(self.context.dir, filepath)) + with open(full_path) as f: + content = f.read() + template = env.from_string(content) + return template.render(**template_vars) + + env.globals["load"] = load_file + + def resolve_recursive(obj): + if isinstance(obj, dict): + return {k: resolve_recursive(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [resolve_recursive(i) for i in obj] + elif isinstance(obj, str): + template = env.from_string(obj) + return template.render(**template_vars) + return obj + + return resolve_recursive(template_vars) + + def _template_vars(self): return { + **self.context.vars, "facts": self.context.conversation_facts, "model": self.context.model, } diff --git a/src/resolve_vars.py b/src/resolve_vars.py deleted file mode 100644 index 2612a8e..0000000 --- a/src/resolve_vars.py +++ /dev/null @@ -1,36 +0,0 @@ -import os - -import jinja2 - - -def resolve_vars(vars, base_path): - MAX_ITERATIONS = 10 - - for _ in range(MAX_ITERATIONS): - resolved_vars = _resolve_vars_recursive(vars, base_path, vars) - if resolved_vars == vars: - return resolved_vars - vars = resolved_vars - - raise RuntimeError("Too many iterations resolving vars. Circular reference?") - - -def _resolve_vars_recursive(obj, base_path, context): - if isinstance(obj, dict): - return { - key: _resolve_vars_recursive(value, base_path, context) - for key, value in obj.items() - } - elif isinstance(obj, list): - return [_resolve_vars_recursive(item, base_path, context) for item in obj] - elif isinstance(obj, str): - if obj.startswith("file:"): - file_path = obj.split("file:", 1)[1] - full_path = os.path.join(base_path, file_path) - with open(full_path) as file: - return file.read() - return jinja2.Template(obj, trim_blocks=True, lstrip_blocks=True).render( - context - ) - else: - return obj