Skip to content

Commit

Permalink
Improve resolving template vars
Browse files Browse the repository at this point in the history
  • Loading branch information
njbbaer committed Dec 24, 2024
1 parent ecebb9c commit d1a4fc0
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 47 deletions.
43 changes: 32 additions & 11 deletions src/lm_executors/chat_executor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os

import jinja2
import yaml

from ..api_client import OpenRouterAPIClient
from ..resolve_vars import resolve_vars


class ChatExecutor:
Expand All @@ -28,23 +29,43 @@ async def execute(self):
return completion

def _build_messages(self):
vars = resolve_vars(
{
**self.context.vars,
**self._extra_template_vars(),
},
self.context.dir,
)
vars["messages"] = self.context.conversation_messages
resolved_vars = self._resolve_vars()
resolved_vars["messages"] = self.context.conversation_messages
with open(self.TEMPLATE_PATH) as file:
template = jinja2.Template(
file.read(), trim_blocks=True, lstrip_blocks=True
)
rendered_str = template.render(vars)
rendered_str = template.render(resolved_vars)
return yaml.safe_load(rendered_str)

def _extra_template_vars(self):
def _resolve_vars(self):
template_vars = self._template_vars()
env = jinja2.Environment(trim_blocks=True, lstrip_blocks=True, autoescape=False)

def load_file(filepath):
full_path = os.path.abspath(os.path.join(self.context.dir, filepath))
with open(full_path) as f:
content = f.read()
template = env.from_string(content)
return template.render(**template_vars)

env.globals["load"] = load_file

def resolve_recursive(obj):
if isinstance(obj, dict):
return {k: resolve_recursive(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [resolve_recursive(i) for i in obj]
elif isinstance(obj, str):
template = env.from_string(obj)
return template.render(**template_vars)
return obj

return resolve_recursive(template_vars)

def _template_vars(self):
return {
**self.context.vars,
"facts": self.context.conversation_facts,
"model": self.context.model,
}
36 changes: 0 additions & 36 deletions src/resolve_vars.py

This file was deleted.

0 comments on commit d1a4fc0

Please sign in to comment.