Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add converse #31

Merged
merged 6 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/chatdbg/chatdbg.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ def main():
sys.exit(2)

pdb.Pdb = ChatDBG
pdb.main()
pdb.main()
77 changes: 73 additions & 4 deletions src/chatdbg/chatdbg_lldb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
import json

import llm_utils
import subprocess
import openai


sys.path.append(os.path.abspath(pathlib.Path(__file__).parent.resolve()))
import chatdbg_utils
import conversation


# The file produced by the panic handler if the Rust program is using the chatdbg crate.
rust_panic_log_filename = "panic_log.txt"
Expand Down Expand Up @@ -295,7 +300,7 @@ def print_test(
addresses = {}
for var in frame.get_all_variables():
# Returns python dictionary for each variable, converts to JSON
variable = helper_result(
variable = _val_to_json(
debugger, command, result, internal_dict, var, recurse_max, addresses
)
js = json.dumps(variable, indent=4)
Expand All @@ -308,7 +313,7 @@ def print_test(
return


def helper_result(
def _val_to_json(
debugger: lldb.SBDebugger,
command: str,
result: str,
Expand Down Expand Up @@ -339,7 +344,7 @@ def helper_result(
value += "->"
deref_val = deref_val.Dereference()
elif len(deref_val.GetType().get_fields_array()) > 0:
value = helper_result(
value = _val_to_json(
debugger,
command,
result,
Expand Down Expand Up @@ -368,7 +373,7 @@ def helper_result(
for i in range(var.GetNumChildren()):
f = var.GetChildAtIndex(i)
fields.append(
helper_result(
_val_to_json(
debugger,
command,
result,
Expand All @@ -382,3 +387,67 @@ def helper_result(
else:
json["value"] = str(var)[str(var).find("= ") + 2 :]
return json


_DEFAULT_FALLBACK_MODELS = ["gpt-4", "gpt-3.5-turbo"]


@lldb.command("converse")
def converse(
debugger: lldb.SBDebugger,
command: str,
result: str,
internal_dict: dict,
) -> None:
# Perform typical "why" checks
# Check if there is debug info.
if not is_debug_build(debugger, command, result, internal_dict):
print(
"Your program must be compiled with debug information (`-g`) to use `converse`."
)
return
# Check if program is attached to a debugger.
if not get_target():
print("Must be attached to a program to ask `converse`.")
return
# Check if code has been run before executing the `why` command.
thread = get_thread()
if not thread:
print("Must run the code first to ask `converse`.")
return
# Check why code stopped running.
if thread.GetStopReason() == lldb.eStopReasonBreakpoint:
# Check if execution stopped at a breakpoint or an error.
print("Execution stopped at a breakpoint, not an error.")
return

args = chatdbg_utils.use_argparse(command.split())
print(args)
khlevin marked this conversation as resolved.
Show resolved Hide resolved

try:
client = openai.OpenAI(timeout=args.timeout)
except openai.OpenAIError:
print("You need an OpenAI key to use this tool.")
print("You can get a key here: https://platform.openai.com/api-keys")
print("Set the environment variable OPENAI_API_KEY to your key value.")
sys.exit(1)

the_prompt = buildPrompt(debugger)

# print("Prompt 1:", the_prompt[0])
# print("Prompt 2:", the_prompt[1])
# print("Prompt 3:", the_prompt[2])
khlevin marked this conversation as resolved.
Show resolved Hide resolved

if args.show_prompt:
print("===================== Prompt =====================")
conversation.converse(client, args, the_prompt[1])
print("==================================================")
sys.exit(0)

print("==================================================")
print("ChatDBG")
print("==================================================")
print(conversation.converse(client, args, the_prompt[1]))
print("==================================================")

sys.exit(0)
108 changes: 106 additions & 2 deletions src/chatdbg/chatdbg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,116 @@

import llm_utils

import argparse
from typing import Any, Optional
from rich.console import Console

class RichArgParser(argparse.ArgumentParser):
def __init__(self, *args: Any, **kwargs: Any):
self.console = Console(highlight=False)
super().__init__(*args, **kwargs)

def _print_message(self, message: Optional[str], file: Any = None) -> None:
if message:
self.console.print(message)

class ChatDBGArgumentFormatter(argparse.HelpFormatter):
# RawDescriptionHelpFormatter.
def _fill_text(self, text, width, indent):
return "".join(indent + line for line in text.splitlines(keepends=True))

# RawTextHelpFormatter.
def _split_lines(self, text, width):
return text.splitlines()

# ArgumentDefaultsHelpFormatter.
# Ignore if help message is multiline.
def _get_help_string(self, action):
help = action.help
if "\n" not in help and "%(default)" not in action.help:
if action.default is not argparse.SUPPRESS:
defaulting_nargs = [argparse.OPTIONAL, argparse.ZERO_OR_MORE]
if action.option_strings or action.nargs in defaulting_nargs:
help += " (default: %(default)s)"
return help

def use_argparse(full_command):
description = textwrap.dedent(
rf"""
[b]ChatDBG[/b]: A Python debugger that uses AI to tell you `why`.
[blue][link=https://github.com/plasma-umass/ChatDBG]https://github.com/plasma-umass/ChatDBG[/link][/blue]

usage:
[b]chatdbg [-c command] ... [-m module | pyfile] [arg] ...[/b]

Debug the Python program given by pyfile. Alternatively,
an executable module or package to debug can be specified using
the -m switch.

Initial commands are read from .pdbrc files in your home directory
and in the current directory, if they exist. Commands supplied with
-c are executed after commands from .pdbrc files.

To let the script run until an exception occurs, use "-c continue".
You can then type `why` to get an explanation of the root cause of
the exception, along with a suggested fix. NOTE: you must have an
OpenAI key saved as the environment variable OPENAI_API_KEY.
You can get a key here: https://openai.com/api/

To let the script run up to a given line X in the debugged file, use
"-c 'until X'".
"""
).strip()
parser = RichArgParser(
prog="chatdbg",
usage=argparse.SUPPRESS,
description=description,
formatter_class=ChatDBGArgumentFormatter
)
parser.add_argument(
"--llm",
type=str,
default="gpt-4-turbo-preview",
help=textwrap.dedent(
"""
the language model to use, e.g., 'gpt-3.5-turbo' or 'gpt-4'
the default mode tries gpt-4-turbo-preview and falls back to gpt-4
"""
).strip(),
)
parser.add_argument(
"-p",
"--show-prompt",
action="store_true",
help="when enabled, only print prompt and exit (for debugging purposes)",
)
parser.add_argument(
"--timeout",
type=int,
default=60,
help="the timeout for API calls in seconds",
)
parser.add_argument(
"--max-error-tokens",
type=int,
default=1920,
help="the maximum number of tokens from the error message to send in the prompt",
)
parser.add_argument(
"--max-code-tokens",
type=int,
default=1920,
help="the maximum number of code locations tokens to send in the prompt",
)
khlevin marked this conversation as resolved.
Show resolved Hide resolved

args = parser.parse_args(full_command)
return args

def get_model() -> str:
all_models = ["gpt-4", "gpt-3.5-turbo"]
all_models = ["gpt-4-turbo-preview", "gpt-4", "gpt-3.5-turbo"]

if not "OPENAI_API_MODEL" in os.environ:
model = "gpt-4"
model = "gpt-4-turbo-preview"
else:
model = os.environ["OPENAI_API_MODEL"]
if model not in all_models:
Expand Down
82 changes: 82 additions & 0 deletions src/chatdbg/conversation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import textwrap

import llm_utils

from . import functions


def get_truncated_error_message(args, diagnostic) -> str:
"""
Alternate taking front and back lines until the maximum number of tokens.
"""
front: list[str] = []
back: list[str] = []
diagnostic_lines = diagnostic.splitlines()
n = len(diagnostic_lines)

def build_diagnostic_string():
return "\n".join(front) + "\n\n[...]\n\n" + "\n".join(reversed(back)) + "\n"

for i in range(n):
if i % 2 == 0:
line = diagnostic_lines[i // 2]
list = front
else:
line = diagnostic_lines[n - i // 2 - 1]
list = back
list.append(line)
count = llm_utils.count_tokens(args.llm, build_diagnostic_string())
if count > args.max_error_tokens:
list.pop()
break
return build_diagnostic_string()


def converse(client, args, diagnostic):
fns = functions.Functions(args)
available_functions_names = [fn["function"]["name"] for fn in fns.as_tools()]
system_message = textwrap.dedent(
f"""
You are an assistant debugger. The user is having an issue with their code, and you are trying to help them.
A few functions exist to help with this process, namely: {", ".join(available_functions_names)}.
Don't hesitate to call as many functions as needed to give the best possible answer.
Once you have identified the problem, explain the diagnostic and provide a way to fix the issue if you can.
"""
).strip()
user_message = f"Here is my error message:\n\n```\n{get_truncated_error_message(args, diagnostic)}\n```\n\nWhat's the problem?"
conversation = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
]

if args.show_prompt:
print("System message:", system_message)
print("User message:", user_message)
return

while True:
completion = client.chat.completions.create(
model=args.llm,
messages=conversation,
tools=fns.as_tools(),
)

choice = completion.choices[0]
if choice.finish_reason == "tool_calls":
for tool_call in choice.message.tool_calls:
function_response = fns.dispatch(tool_call.function)
if function_response:
conversation.append(choice.message)
conversation.append(
{
"tool_call_id": tool_call.id,
"role": "tool",
"name": tool_call.function.name,
"content": function_response,
}
)
elif choice.finish_reason == "stop":
text = completion.choices[0].message.content
return llm_utils.word_wrap_except_code_blocks(text)
else:
print(f"Not found: {choice.finish_reason}.")
56 changes: 56 additions & 0 deletions src/chatdbg/conversation/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import json
import os
from typing import Optional

import llm_utils


class Functions:
def __init__(self, args):
self.args = args

def as_tools(self):
return [
{"type": "function", "function": schema}
for schema in [self.get_code_surrounding_schema()]
]

def dispatch(self, function_call) -> Optional[str]:
arguments = json.loads(function_call.arguments)
print(
f"Calling: {function_call.name}({', '.join([f'{k}={v}' for k, v in arguments.items()])})"
)
try:
if function_call.name == "get_code_surrounding":
return self.get_code_surrounding(
arguments["filename"], arguments["lineno"]
)
else:
raise ValueError("No such function.")
except Exception as e:
print(e)
return None

def get_code_surrounding_schema(self):
return {
"name": "get_code_surrounding",
"description": "Returns the code in the given file surrounding and including the provided line number.",
"parameters": {
"type": "object",
"properties": {
"filename": {
"type": "string",
"description": "The filename to read from.",
},
"lineno": {
"type": "integer",
"description": "The line number to focus on. Some context before and after that line will be provided.",
},
},
"required": ["filename", "lineno"],
},
}

def get_code_surrounding(self, filename: str, lineno: int) -> str:
(lines, first) = llm_utils.read_lines(filename, lineno - 7, lineno + 3)
return llm_utils.number_group_of_lines(lines, first)
Binary file added test/a.out
Binary file not shown.
Loading