-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor resolve_vars out of chat_executor
- Loading branch information
Showing
6 changed files
with
116 additions
and
98 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import os | ||
|
||
import jinja2 | ||
|
||
|
||
def resolve_vars(vars, base_path): | ||
vars = _dereference_vars(vars, base_path) | ||
return _render_vars(vars) | ||
|
||
|
||
def _dereference_vars(vars, base_path): | ||
dereferenced_vars = vars.copy() | ||
for key, value in vars.items(): | ||
if isinstance(value, str) and value.startswith("file:"): | ||
file_path = value.split("file:", 1)[1] | ||
full_path = os.path.join(base_path, file_path) | ||
with open(full_path) as file: | ||
dereferenced_vars[key] = file.read() | ||
return dereferenced_vars | ||
|
||
|
||
def _render_vars(vars): | ||
MAX_ITERATIONS = 10 | ||
|
||
for _ in range(MAX_ITERATIONS): | ||
rendered_vars = _render_vars_recursive(vars) | ||
if rendered_vars == vars: | ||
break | ||
vars = rendered_vars | ||
else: | ||
raise RuntimeError("Too many iterations resolving vars. Circular reference?") | ||
return rendered_vars | ||
|
||
|
||
def _render_vars_recursive(vars, obj=None): | ||
if obj is None: | ||
obj = vars | ||
|
||
if isinstance(obj, dict): | ||
return {key: _render_vars_recursive(vars, value) for key, value in obj.items()} | ||
elif isinstance(obj, list): | ||
return [_render_vars_recursive(vars, item) for item in obj] | ||
elif isinstance(obj, str): | ||
return jinja2.Template(obj, trim_blocks=True, lstrip_blocks=True).render(vars) | ||
else: | ||
return obj |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from textwrap import dedent | ||
|
||
import pytest | ||
|
||
from src.resolve_vars import resolve_vars | ||
|
||
|
||
@pytest.fixture | ||
def vars(): | ||
return { | ||
"list_of_things": ["thing1", "{{ jinja_thing }}"], | ||
"chat_prompt": "file:chat_prompt.j2", | ||
"reinforcement_chat_prompt": "Remember, you are Alice.", | ||
"jinja_thing": "thing2", | ||
} | ||
|
||
|
||
@pytest.fixture | ||
def chat_prompt_file_content(): | ||
return dedent( | ||
""" | ||
You are Alice, speaking to Bob. | ||
{% for thing in list_of_things %} | ||
- {{ thing }} | ||
{% endfor %} | ||
""" | ||
).strip() | ||
|
||
|
||
@pytest.fixture | ||
def mock_read_chat_prompt(mocker, chat_prompt_file_content): | ||
mock_open = mocker.mock_open(read_data=chat_prompt_file_content) | ||
return mocker.patch("builtins.open", mock_open) | ||
|
||
|
||
def test_resolve(mock_read_chat_prompt, vars): | ||
resolved_vars = resolve_vars(vars, "base_path") | ||
|
||
mock_read_chat_prompt.assert_called_once_with("base_path/chat_prompt.j2") | ||
assert resolved_vars["list_of_things"] == ["thing1", "thing2"] | ||
assert resolved_vars["reinforcement_chat_prompt"] == "Remember, you are Alice." | ||
assert ( | ||
resolved_vars["chat_prompt"] | ||
== "You are Alice, speaking to Bob.\n - thing1\n - thing2" | ||
) |