Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Introduce a new ell.interactive API and example. #254

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions examples/interactive_tool_diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import logging
import subprocess
from datetime import datetime
from pathlib import Path
from textwrap import dedent
from typing import List

import anthropic
import ell
from pydantic import Field


logger = logging.getLogger(__name__)


def _validate_diff(diff: str) -> subprocess.CompletedProcess:
logger.info(f"Validating diff: {diff}")
return subprocess.run(
["patch", "-p1", "--dry-run"],
input=diff.encode("utf-8"),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
check=False
)


@ell.tool()
def test_diff(
diff: str = Field(description="The unified diff to test."),
) -> str | None:
"""Applies a unified diff to a local workspace using `patch -p1` and returns a natural language result."""
logger.info(f"Tool call: test_diff")
result = _validate_diff(diff)
# TODO(kwlzn): Can we send a structured output to the LLM with e.g. tool_call_result.exit_code and stdout/stderr for it to natively interpret?
if result.returncode == 0:
logger.info("Tool call: test_diff succeeded")
return f"Patch applied successfully: {result.stdout.decode()}"
else:
logger.warning("Tool call: test_diff failed")
# Provide context to the LLM on the failure by proxying the output of `patch -p1`.
return f"That patch is invalid, can you try again with the correct diff syntax? Here's the output of `patch -p1`:\n{result.stdout.decode()}"


def diff_loop(prompts: str, glob: str, repo: str = ".", max_loops: int = 3):
repo_path = Path(repo)
code_file = next(repo_path.glob(glob)).relative_to(repo_path)
code = f"<file:{code_file}>\n{code_file.read_text()}\n</file:{code_file}>"

client = anthropic.Anthropic()

system_prompt = dedent("""\
You are a helpful, expert-level programmer that generates Python code changes to an existing codebase given a request.
Your changes will be written to the filesystem using relative paths. You are in the root directory of the repository.
Test application of the changes by calling the `test_diff` tool with a valid unified diff (like `diff` or `git diff` would generate).
This will store the patch, but won't apply it to the local filesystem - so always generate a completely new patch for every request.
Use chain-of-thought reasoning to generate the code and explain your work in your response.
""")

with ell.interactive(
model="claude-3-5-sonnet-20240620",
client=client,
tools=[test_diff],
max_tokens=1024,
temperature=0.7
) as session:
# Set the system prompt without making a request.
session.set_system_prompt(system_prompt)

for i, prompt in enumerate(prompts):
# Send the code context on the first message, but not subsequent ones.
if i == 0: prompt = f"{code}\n\n{prompt}"
session.send(prompt)


def main():
logging.basicConfig(
format='%(asctime)s %(levelname)-8s] %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S'
)

ell.init(verbose=True, store="./ell_logs")

diff_loop(
prompts=[
"Add a simple argument parsing routine to interactive_tool_diff.py that provides a --help argument.",
"Now extend the argument parsing so the user can specify a model name that will be printed when the file is invoked. Make it default to gpt4o-mini.",
"Now modify the diff_loop function in interactive_tool_diff.py to accept a model parameter that is passed via this CLI arg.",
"Now make the default argument for the model name be: claude-3-5-sonnet-20240620."
],
glob="**/interactive_tool_diff.py",
max_loops=3
)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions src/ell/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ell.lmp.simple import simple
from ell.lmp.tool import tool
from ell.lmp.complex import complex
from ell.lmp.interactive import interactive
from ell.types.message import system, user, assistant, Message, ContentBlock
from ell.__version__ import __version__

Expand Down
42 changes: 42 additions & 0 deletions src/ell/lmp/interactive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from contextlib import contextmanager

from .complex import complex as ell_complex
from ..types.message import system as ell_system, user as ell_user


@contextmanager
def interactive(*args, **kwargs):
"""A contextmanager that creates an interactive, append-mode session using an inline LMP function."""

# TODO(kwlzn): Should this be specified/impl'd a different way for better viz/tracking in ell studio?
@ell_complex(*args, **kwargs)
def interactive(messages):
return messages

class _InteractiveSession():
def __init__(self):
self._system_prompt = None
self._messages = []

def set_system_prompt(self, prompt):
self._system_prompt = ell_system(prompt)

def send(self, message = None):
if message:
self._messages.append(ell_user(message))

# Invoke the LMP function.
response = interactive([self._system_prompt] + self._messages)

# Append the role="assistant" response to the messages.
self._messages.append(response)

# If we have tool calls, invoke them, append the tool call result as a user message and send it back to the LLM.
if response.tool_calls:
tool_call_message = response.call_tools_and_collect_as_message()
self._messages.append(tool_call_message)
return self.send()

return response

yield _InteractiveSession()
Loading