Skip to content

Commit

Permalink
Refactor resolve_vars out of chat_executor
Browse files Browse the repository at this point in the history
  • Loading branch information
njbbaer committed Jan 27, 2024
1 parent fecdd4d commit ff0da8f
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 98 deletions.
2 changes: 1 addition & 1 deletion src/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def image_prompts(self):
return self.vars.get("image_prompts", [])

@property
def file_dir(self):
def dir(self):
return os.path.dirname(self.context_filepath)

def _initialize_conversation_data(self):
Expand Down
68 changes: 13 additions & 55 deletions src/executors/chat_executor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os

import jinja2
import yaml

from ..resolve_vars import resolve_vars
from .executor import Executor


class ChatExecutor(Executor):
TEMPLATE_PATH = "src/executors/chat_executor_template.yml"

async def execute(self):
return await self._generate_chat_completion(
self._build_chat_messages(),
Expand All @@ -16,61 +17,18 @@ async def execute(self):
},
)

def _load_chat_template(self):
with open("src/executors/chat_template.yml") as file:
return jinja2.Template(file.read(), trim_blocks=True, lstrip_blocks=True)

def _build_chat_messages(self):
template = self._load_chat_template()
dereferenced_vars = self._dereference_vars(self.context.vars)
rendered_vars = self._render_vars(
vars = resolve_vars(
{
**dereferenced_vars,
"messages": self.context.current_messages,
**self.context.vars,
"facts": self.context.current_conversation_facts,
}
},
self.context.dir,
)
rendered_str = template.render(rendered_vars)
return yaml.safe_load(rendered_str)

def _dereference_vars(self, vars):
dereferenced_vars = vars.copy()
for key, value in vars.items():
if isinstance(value, str) and value.startswith("file:"):
file_path = value.split("file:", 1)[1]
full_path = os.path.join(self.context.file_dir, file_path)
with open(full_path) as file:
dereferenced_vars[key] = file.read()
return dereferenced_vars

def _render_vars(self, vars):
MAX_ITERATIONS = 10

for _ in range(MAX_ITERATIONS):
rendered_vars = self._render_vars_recursive(vars)
if rendered_vars == vars:
break
vars = rendered_vars
else:
raise RuntimeError(
"Too many iterations resolving vars. Circular reference?"
)
return rendered_vars

def _render_vars_recursive(self, vars, obj=None):
if obj is None:
obj = vars

if isinstance(obj, dict):
return {
key: self._render_vars_recursive(vars, value)
for key, value in obj.items()
}
elif isinstance(obj, list):
return [self._render_vars_recursive(vars, item) for item in obj]
elif isinstance(obj, str):
return jinja2.Template(obj, trim_blocks=True, lstrip_blocks=True).render(
vars
vars["messages"] = self.context.current_messages
with open(self.TEMPLATE_PATH) as file:
template = jinja2.Template(
file.read(), trim_blocks=True, lstrip_blocks=True
)
else:
return obj
rendered_str = template.render(vars)
return yaml.safe_load(rendered_str)
File renamed without changes.
46 changes: 46 additions & 0 deletions src/resolve_vars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os

import jinja2


def resolve_vars(vars, base_path):
vars = _dereference_vars(vars, base_path)
return _render_vars(vars)


def _dereference_vars(vars, base_path):
dereferenced_vars = vars.copy()
for key, value in vars.items():
if isinstance(value, str) and value.startswith("file:"):
file_path = value.split("file:", 1)[1]
full_path = os.path.join(base_path, file_path)
with open(full_path) as file:
dereferenced_vars[key] = file.read()
return dereferenced_vars


def _render_vars(vars):
MAX_ITERATIONS = 10

for _ in range(MAX_ITERATIONS):
rendered_vars = _render_vars_recursive(vars)
if rendered_vars == vars:
break
vars = rendered_vars
else:
raise RuntimeError("Too many iterations resolving vars. Circular reference?")
return rendered_vars


def _render_vars_recursive(vars, obj=None):
if obj is None:
obj = vars

if isinstance(obj, dict):
return {key: _render_vars_recursive(vars, value) for key, value in obj.items()}
elif isinstance(obj, list):
return [_render_vars_recursive(vars, item) for item in obj]
elif isinstance(obj, str):
return jinja2.Template(obj, trim_blocks=True, lstrip_blocks=True).render(vars)
else:
return obj
53 changes: 11 additions & 42 deletions tests/test_chat_executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from textwrap import dedent

import pytest

from src.executors.chat_executor import ChatExecutor
Expand All @@ -20,24 +18,11 @@ def current_messages():
@pytest.fixture
def vars():
return {
"list_of_things": ["thing1", "thing2"],
"chat_prompt": "file:chat_prompt.j2",
"chat_prompt": "You are Alice, speaking to Bob.",
"reinforcement_chat_prompt": "Remember, you are Alice.",
}


@pytest.fixture
def chat_prompt_file_content():
return dedent(
"""
You are Alice, speaking to Bob.
{% for thing in list_of_things %}
- {{ thing }}
{% endfor %}
"""
).strip()


@pytest.fixture
def mock_context(mocker, current_messages, vars):
mock_instance = mocker.patch("src.context.Context").return_value
Expand All @@ -47,19 +32,6 @@ def mock_context(mocker, current_messages, vars):
return mock_instance


@pytest.fixture
def mock_read_chat_prompt(mocker, chat_prompt_file_content):
original_open = open

def mock_open(file, mode="r", *args, **kwargs):
if file.endswith("chat_prompt.j2"):
return mocker.mock_open(read_data=chat_prompt_file_content).return_value
else:
return original_open(file, mode, *args, **kwargs)

mocker.patch("builtins.open", mock_open)


@pytest.fixture
def mock_generate_chat_completion(mocker):
completion_mock = mocker.Mock()
Expand All @@ -69,23 +41,20 @@ def mock_generate_chat_completion(mocker):
return async_mock


@pytest.fixture
def mock_resolve_vars(mocker):
return mocker.patch(
"src.executors.chat_executor.resolve_vars", side_effect=lambda vars, _: vars
)


@pytest.mark.asyncio
async def test_execute(
mock_context, mock_generate_chat_completion, mock_read_chat_prompt
):
async def test_execute(mock_context, mock_generate_chat_completion, mock_resolve_vars):
await ChatExecutor(mock_context).execute()
mock_resolve_vars.assert_called_once()
mock_generate_chat_completion.assert_called_once_with(
[
{
"role": "system",
"content": dedent(
"""
You are Alice, speaking to Bob.
- thing1
- thing2
"""
).strip(),
},
{"role": "system", "content": "You are Alice, speaking to Bob."},
{"role": "assistant", "content": "Hello, Bob."},
{
"role": "user",
Expand Down
45 changes: 45 additions & 0 deletions tests/test_resolve_vars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from textwrap import dedent

import pytest

from src.resolve_vars import resolve_vars


@pytest.fixture
def vars():
return {
"list_of_things": ["thing1", "{{ jinja_thing }}"],
"chat_prompt": "file:chat_prompt.j2",
"reinforcement_chat_prompt": "Remember, you are Alice.",
"jinja_thing": "thing2",
}


@pytest.fixture
def chat_prompt_file_content():
return dedent(
"""
You are Alice, speaking to Bob.
{% for thing in list_of_things %}
- {{ thing }}
{% endfor %}
"""
).strip()


@pytest.fixture
def mock_read_chat_prompt(mocker, chat_prompt_file_content):
mock_open = mocker.mock_open(read_data=chat_prompt_file_content)
return mocker.patch("builtins.open", mock_open)


def test_resolve(mock_read_chat_prompt, vars):
resolved_vars = resolve_vars(vars, "base_path")

mock_read_chat_prompt.assert_called_once_with("base_path/chat_prompt.j2")
assert resolved_vars["list_of_things"] == ["thing1", "thing2"]
assert resolved_vars["reinforcement_chat_prompt"] == "Remember, you are Alice."
assert (
resolved_vars["chat_prompt"]
== "You are Alice, speaking to Bob.\n - thing1\n - thing2"
)

0 comments on commit ff0da8f

Please sign in to comment.