diff --git a/pyproject.toml b/pyproject.toml index dee6cea..c9c27ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,14 +8,14 @@ version = "0.2.2" authors = [ { name="Emery Berger", email="emery.berger@gmail.com" }, ] -dependencies = ["llm_utils==0.2.4", "openai>=1.6.1"] +dependencies = ["llm_utils>=0.2.6", "openai>=1.6.1", "rich>=13.7.0"] description = "AI-assisted debugging. Uses AI to answer 'why'." readme = "README.md" requires-python = ">=3.7" classifiers = [ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: Apache Software License", - "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", ] [project.scripts] diff --git a/src/chatdbg/chatdbg.py b/src/chatdbg/chatdbg.py index 120b295..a507867 100644 --- a/src/chatdbg/chatdbg.py +++ b/src/chatdbg/chatdbg.py @@ -56,4 +56,4 @@ def main(): sys.exit(2) pdb.Pdb = ChatDBG - pdb.main() + pdb.main() \ No newline at end of file diff --git a/src/chatdbg/chatdbg_lldb.py b/src/chatdbg/chatdbg_lldb.py index 5275a4c..d014d24 100644 --- a/src/chatdbg/chatdbg_lldb.py +++ b/src/chatdbg/chatdbg_lldb.py @@ -7,9 +7,14 @@ import json import llm_utils +import subprocess +import openai + sys.path.append(os.path.abspath(pathlib.Path(__file__).parent.resolve())) import chatdbg_utils +import conversation + # The file produced by the panic handler if the Rust program is using the chatdbg crate. rust_panic_log_filename = "panic_log.txt" @@ -295,7 +300,7 @@ def print_test( addresses = {} for var in frame.get_all_variables(): # Returns python dictionary for each variable, converts to JSON - variable = helper_result( + variable = _val_to_json( debugger, command, result, internal_dict, var, recurse_max, addresses ) js = json.dumps(variable, indent=4) @@ -308,7 +313,7 @@ def print_test( return -def helper_result( +def _val_to_json( debugger: lldb.SBDebugger, command: str, result: str, @@ -339,7 +344,7 @@ def helper_result( value += "->" deref_val = deref_val.Dereference() elif len(deref_val.GetType().get_fields_array()) > 0: - value = helper_result( + value = _val_to_json( debugger, command, result, @@ -368,7 +373,7 @@ def helper_result( for i in range(var.GetNumChildren()): f = var.GetChildAtIndex(i) fields.append( - helper_result( + _val_to_json( debugger, command, result, @@ -382,3 +387,63 @@ def helper_result( else: json["value"] = str(var)[str(var).find("= ") + 2 :] return json + + +_DEFAULT_FALLBACK_MODELS = ["gpt-4", "gpt-3.5-turbo"] + + +@lldb.command("converse") +def converse( + debugger: lldb.SBDebugger, + command: str, + result: str, + internal_dict: dict, +) -> None: + # Perform typical "why" checks + # Check if there is debug info. + if not is_debug_build(debugger, command, result, internal_dict): + print( + "Your program must be compiled with debug information (`-g`) to use `converse`." + ) + return + # Check if program is attached to a debugger. + if not get_target(): + print("Must be attached to a program to ask `converse`.") + return + # Check if code has been run before executing the `why` command. + thread = get_thread() + if not thread: + print("Must run the code first to ask `converse`.") + return + # Check why code stopped running. + if thread.GetStopReason() == lldb.eStopReasonBreakpoint: + # Check if execution stopped at a breakpoint or an error. + print("Execution stopped at a breakpoint, not an error.") + return + + args = chatdbg_utils.use_argparse(command.split()) + + try: + client = openai.OpenAI(timeout=args.timeout) + except openai.OpenAIError: + print("You need an OpenAI key to use this tool.") + print("You can get a key here: https://platform.openai.com/api-keys") + print("Set the environment variable OPENAI_API_KEY to your key value.") + sys.exit(1) + + the_prompt = buildPrompt(debugger) + + + if args.show_prompt: + print("===================== Prompt =====================") + conversation.converse(client, args, the_prompt[1]) + print("==================================================") + sys.exit(0) + + print("==================================================") + print("ChatDBG") + print("==================================================") + print(conversation.converse(client, args, the_prompt[1])) + print("==================================================") + + sys.exit(0) diff --git a/src/chatdbg/chatdbg_utils.py b/src/chatdbg/chatdbg_utils.py index 9b7a910..41d56f6 100644 --- a/src/chatdbg/chatdbg_utils.py +++ b/src/chatdbg/chatdbg_utils.py @@ -5,12 +5,104 @@ import llm_utils +import argparse +from typing import Any, Optional +from rich.console import Console + +class RichArgParser(argparse.ArgumentParser): + def __init__(self, *args: Any, **kwargs: Any): + self.console = Console(highlight=False) + super().__init__(*args, **kwargs) + + def _print_message(self, message: Optional[str], file: Any = None) -> None: + if message: + self.console.print(message) + +class ChatDBGArgumentFormatter(argparse.HelpFormatter): + # RawDescriptionHelpFormatter. + def _fill_text(self, text, width, indent): + return "".join(indent + line for line in text.splitlines(keepends=True)) + + # RawTextHelpFormatter. + def _split_lines(self, text, width): + return text.splitlines() + + # ArgumentDefaultsHelpFormatter. + # Ignore if help message is multiline. + def _get_help_string(self, action): + help = action.help + if "\n" not in help and "%(default)" not in action.help: + if action.default is not argparse.SUPPRESS: + defaulting_nargs = [argparse.OPTIONAL, argparse.ZERO_OR_MORE] + if action.option_strings or action.nargs in defaulting_nargs: + help += " (default: %(default)s)" + return help + +def use_argparse(full_command): + description = textwrap.dedent( + rf""" + [b]ChatDBG[/b]: A Python debugger that uses AI to tell you `why`. + [blue][link=https://github.com/plasma-umass/ChatDBG]https://github.com/plasma-umass/ChatDBG[/link][/blue] + + usage: + [b]chatdbg [-c command] ... [-m module | pyfile] [arg] ...[/b] + + Debug the Python program given by pyfile. Alternatively, + an executable module or package to debug can be specified using + the -m switch. + + Initial commands are read from .pdbrc files in your home directory + and in the current directory, if they exist. Commands supplied with + -c are executed after commands from .pdbrc files. + + To let the script run until an exception occurs, use "-c continue". + You can then type `why` to get an explanation of the root cause of + the exception, along with a suggested fix. NOTE: you must have an + OpenAI key saved as the environment variable OPENAI_API_KEY. + You can get a key here: https://openai.com/api/ + + To let the script run up to a given line X in the debugged file, use + "-c 'until X'". + """ + ).strip() + parser = RichArgParser( + prog="chatdbg", + usage=argparse.SUPPRESS, + description=description, + formatter_class=ChatDBGArgumentFormatter + ) + parser.add_argument( + "--llm", + type=str, + default="gpt-4-turbo-preview", + help=textwrap.dedent( + """ + the language model to use, e.g., 'gpt-3.5-turbo' or 'gpt-4' + the default mode tries gpt-4-turbo-preview and falls back to gpt-4 + """ + ).strip(), + ) + parser.add_argument( + "-p", + "--show-prompt", + action="store_true", + help="when enabled, only print prompt and exit (for debugging purposes)", + ) + parser.add_argument( + "--timeout", + type=int, + default=60, + help="the timeout for API calls in seconds", + ) + + args = parser.parse_args(full_command) + return args def get_model() -> str: - all_models = ["gpt-4", "gpt-3.5-turbo"] + all_models = ["gpt-4-turbo-preview", "gpt-4", "gpt-3.5-turbo"] if not "OPENAI_API_MODEL" in os.environ: - model = "gpt-4" + model = "gpt-4-turbo-preview" else: model = os.environ["OPENAI_API_MODEL"] if model not in all_models: diff --git a/src/chatdbg/conversation/__init__.py b/src/chatdbg/conversation/__init__.py new file mode 100644 index 0000000..de19991 --- /dev/null +++ b/src/chatdbg/conversation/__init__.py @@ -0,0 +1,82 @@ +import textwrap + +import llm_utils + +from . import functions + + +def get_truncated_error_message(args, diagnostic) -> str: + """ + Alternate taking front and back lines until the maximum number of tokens. + """ + front: list[str] = [] + back: list[str] = [] + diagnostic_lines = diagnostic.splitlines() + n = len(diagnostic_lines) + + def build_diagnostic_string(): + return "\n".join(front) + "\n\n[...]\n\n" + "\n".join(reversed(back)) + "\n" + + for i in range(n): + if i % 2 == 0: + line = diagnostic_lines[i // 2] + list = front + else: + line = diagnostic_lines[n - i // 2 - 1] + list = back + list.append(line) + count = llm_utils.count_tokens(args.llm, build_diagnostic_string()) + if count > args.max_error_tokens: + list.pop() + break + return build_diagnostic_string() + + +def converse(client, args, diagnostic): + fns = functions.Functions(args) + available_functions_names = [fn["function"]["name"] for fn in fns.as_tools()] + system_message = textwrap.dedent( + f""" + You are an assistant debugger. The user is having an issue with their code, and you are trying to help them. + A few functions exist to help with this process, namely: {", ".join(available_functions_names)}. + Don't hesitate to call as many functions as needed to give the best possible answer. + Once you have identified the problem, explain the diagnostic and provide a way to fix the issue if you can. + """ + ).strip() + user_message = f"Here is my error message:\n\n```\n{get_truncated_error_message(args, diagnostic)}\n```\n\nWhat's the problem?" + conversation = [ + {"role": "system", "content": system_message}, + {"role": "user", "content": user_message}, + ] + + if args.show_prompt: + print("System message:", system_message) + print("User message:", user_message) + return + + while True: + completion = client.chat.completions.create( + model=args.llm, + messages=conversation, + tools=fns.as_tools(), + ) + + choice = completion.choices[0] + if choice.finish_reason == "tool_calls": + for tool_call in choice.message.tool_calls: + function_response = fns.dispatch(tool_call.function) + if function_response: + conversation.append(choice.message) + conversation.append( + { + "tool_call_id": tool_call.id, + "role": "tool", + "name": tool_call.function.name, + "content": function_response, + } + ) + elif choice.finish_reason == "stop": + text = completion.choices[0].message.content + return llm_utils.word_wrap_except_code_blocks(text) + else: + print(f"Not found: {choice.finish_reason}.") diff --git a/src/chatdbg/conversation/functions.py b/src/chatdbg/conversation/functions.py new file mode 100644 index 0000000..db2100c --- /dev/null +++ b/src/chatdbg/conversation/functions.py @@ -0,0 +1,56 @@ +import json +import os +from typing import Optional + +import llm_utils + + +class Functions: + def __init__(self, args): + self.args = args + + def as_tools(self): + return [ + {"type": "function", "function": schema} + for schema in [self.get_code_surrounding_schema()] + ] + + def dispatch(self, function_call) -> Optional[str]: + arguments = json.loads(function_call.arguments) + print( + f"Calling: {function_call.name}({', '.join([f'{k}={v}' for k, v in arguments.items()])})" + ) + try: + if function_call.name == "get_code_surrounding": + return self.get_code_surrounding( + arguments["filename"], arguments["lineno"] + ) + else: + raise ValueError("No such function.") + except Exception as e: + print(e) + return None + + def get_code_surrounding_schema(self): + return { + "name": "get_code_surrounding", + "description": "Returns the code in the given file surrounding and including the provided line number.", + "parameters": { + "type": "object", + "properties": { + "filename": { + "type": "string", + "description": "The filename to read from.", + }, + "lineno": { + "type": "integer", + "description": "The line number to focus on. Some context before and after that line will be provided.", + }, + }, + "required": ["filename", "lineno"], + }, + } + + def get_code_surrounding(self, filename: str, lineno: int) -> str: + (lines, first) = llm_utils.read_lines(filename, lineno - 7, lineno + 3) + return llm_utils.number_group_of_lines(lines, first) diff --git a/test/a.out b/test/a.out new file mode 100755 index 0000000..ebeda25 Binary files /dev/null and b/test/a.out differ