Skip to content

Commit

Permalink
Add ability to stop queries.
Browse files Browse the repository at this point in the history
  • Loading branch information
jtang613 committed Sep 29, 2024
1 parent 49f9d2c commit 3ddf484
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 38 deletions.
26 changes: 16 additions & 10 deletions src/llm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ class LlmApi:
You are an expert Python and Rust developer. You are familiar with common frameworks and libraries
such as WinSock, OpenSSL, MFC, etc. You are an expert in TCP/IP network programming and packet analysis.
You always respond to queries in a structured format using Markdown styling for headings and lists.
You format code blocks using back-tick code-fencing.
You format code blocks using back-tick code-fencing.\n
'''

FUNCTION_PROMPT = '''
USE THE PROVIDED TOOLS WHEN NECESSARY. YOU ALWAYS RESPOND WITH TOOL CALLS WHEN POSSIBLE.
USE THE PROVIDED TOOLS WHEN NECESSARY. YOU ALWAYS RESPOND WITH TOOL CALLS WHEN POSSIBLE.\n
'''

FORMAT_PROMPT = '''
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
The output MUST strictly adhere to the following JSON format, do not include any other text.
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'.
```
{
Expand All @@ -47,7 +47,7 @@ class LlmApi:
]
}
```
REMEMBER, YOU MUST ALWAYS PRODUCE A JSON LIST OF TOOL_CALLS!
REMEMBER, YOU MUST ALWAYS PRODUCE A JSON LIST OF TOOL_CALLS!\n
'''

def __init__(self):
Expand All @@ -56,6 +56,7 @@ def __init__(self):
"""
self.settings = Settings()
self.threads = [] # Keep a list of active threads
self.thread = None
self.initialize_database()
self.rag = RAG(self.settings.get_string('binassist.rag_db_path'))
self.api_provider = self.get_active_provider()
Expand Down Expand Up @@ -167,7 +168,7 @@ def explain(self, bv, addr, bin_type, il_type, addr_to_text_func, signal) -> str
f"present. But only fallback to strings or log messages that are clearly function " +\
f"names for this function.\n```\n" +\
f"{addr_to_text_func(bv, addr)}\n```"
self._start_thread(client, model, max_tokens, query, self.SYSTEM_PROMPT, signal)
self.thread = self._start_thread(client, model, max_tokens, query, self.SYSTEM_PROMPT, signal)
return query

def query(self, query, signal) -> str:
Expand All @@ -187,10 +188,10 @@ def query(self, query, signal) -> str:
if self.use_rag():
context = self._get_rag_context(query)
augmented_query = f"Context:\n{context}\n\nQuery: {query}"
self._start_thread(client, model, max_tokens, augmented_query, self.SYSTEM_PROMPT, signal)
self.thread = self._start_thread(client, model, max_tokens, augmented_query, self.SYSTEM_PROMPT, signal)
return augmented_query
else:
self._start_thread(client, model, max_tokens, query, self.SYSTEM_PROMPT, signal)
self.thread = self._start_thread(client, model, max_tokens, query, self.SYSTEM_PROMPT, signal)
return query

def analyze_function(self, action: str, bv, addr, bin_type, il_type, addr_to_text_func, signal) -> str:
Expand Down Expand Up @@ -220,10 +221,13 @@ def analyze_function(self, action: str, bv, addr, bin_type, il_type, addr_to_tex
raise ValueError(f"Unknown action type: {action}")

query = f"{prompt}\n{self.FUNCTION_PROMPT}{self.FORMAT_PROMPT}"
self._start_thread(client, model, max_tokens, query, f"{self.SYSTEM_PROMPT}{self.FUNCTION_PROMPT}{self.FORMAT_PROMPT}", signal, ToolCalling.FN_TEMPLATES)
self.thread = self._start_thread(client, model, max_tokens, query, f"{self.SYSTEM_PROMPT}{self.FUNCTION_PROMPT}{self.FORMAT_PROMPT}", signal, ToolCalling.FN_TEMPLATES)

return query

def isRunning(self):
return self.thread.isRunning()

def _get_rag_context(self, query: str) -> str:
"""
Query the RAG database for query context.
Expand Down Expand Up @@ -271,6 +275,7 @@ def _start_thread(self, client, model, max_tokens, query, system, signal, tools=
thread.update_response.connect(signal)
self.threads.append(thread) # Keep track of the thread
thread.start()
return thread

def DataToText(self, bv: BinaryView, start_addr: int, end_addr: int) -> str:
"""
Expand Down Expand Up @@ -429,8 +434,9 @@ def stop_threads(self):
Stops all active threads used for handling LLM queries.
"""
for thread in self.threads:
thread.quit()
thread.wait()
thread.stop()
self.threads.clear() # Clear the list after stopping all threads


class ToolCalling:
"""
Expand Down
136 changes: 108 additions & 28 deletions src/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def _init_ui(self) -> None:
self.tabs.addTab(actions_tab, "Actions")
self.tabs.addTab(rag_management_tab, "RAG Management")

self.submit_button = None
self.submit_label = None

layout = QtWidgets.QVBoxLayout()
layout.addWidget(self.tabs)
self.setLayout(layout)
Expand Down Expand Up @@ -337,22 +340,55 @@ def refreshRAGDocumentList(self):
def onExplainILClicked(self) -> None:
"""
Handles the event when the 'Explain Function' button is clicked.
Toggles the button between 'Explain' and 'Stop'.
"""
datatype = self.datatype.split(':')[1]
il_type = self.il_type.name
func = self.get_func_text()
self.text_box.clear()
self.request = self.LlmApi.explain(self.bv, self.offset_addr, datatype, il_type, func, self.display_response)
self.submit_button = self.sender()

if self.submit_button.text() == "Explain Function":
self.submit_label = self.submit_button.text()
# Start explanation
datatype = self.datatype.split(':')[1]
il_type = self.il_type.name
func = self.get_func_text()
self.text_box.clear()

# Trigger LLM query and store request
self.request = self.LlmApi.explain(self.bv, self.offset_addr, datatype, il_type, func, self.display_response)

# Change the button text to "Stop"
self.submit_button.setText("Stop")
else:
# Stop the running query
self.LlmApi.stop_threads()

# Revert the button back to "Explain"
self.submit_button.setText(self.submit_label)

def onExplainLineClicked(self) -> None:
"""
Handles the event when the 'Explain Line' button is clicked.
Toggles the button between 'Explain' and 'Stop'.
"""
datatype = self.datatype.split(':')[1]
il_type = self.il_type.name
self.text_box.clear()
self.request = self.LlmApi.explain(self.bv, self.offset_addr, datatype, il_type, self.get_line_text, self.display_response)

self.submit_button = self.sender()

if self.submit_button.text() == "Explain Line":
self.submit_label = self.submit_button.text()
# Start explanation
datatype = self.datatype.split(':')[1]
il_type = self.il_type.name
self.text_box.clear()

# Trigger LLM query and store request
self.request = self.LlmApi.explain(self.bv, self.offset_addr, datatype, il_type, self.get_line_text, self.display_response)

# Change the button text to "Stop"
self.submit_button.setText("Stop")
else:
# Stop the running query
self.LlmApi.stop_threads()

# Revert the button back to "Explain"
self.submit_button.setText(self.submit_label)

def onClearTextClicked(self) -> None:
"""
Expand All @@ -364,31 +400,64 @@ def onClearTextClicked(self) -> None:

def onSubmitQueryClicked(self) -> None:
"""
Submits the custom query entered by the user when the 'Submit' button is clicked.
Submits the custom query or stops a running query based on the button state.
"""
query = self.query_edit.toPlainText()
query = self._process_custom_query(query)
self.session_log.append({"user": query, "assistant": "Awaiting response..."})
# Toggle functionality between Submit and Stop
self.submit_button = self.sender()

# Prepend the session log to the query for context
full_query = "\n".join([f"User: {entry['user']}\nAssistant: {entry['assistant']}" for entry in self.session_log]) + f"\nUser: {query}"
if self.submit_button.text() == "Submit":
self.submit_label = self.submit_button.text()
# Start a new query
query = self.query_edit.toPlainText()
query = self._process_custom_query(query)
self.session_log.append({"user": query, "assistant": "Awaiting response..."})

# Prepend the session log to the query for context
full_query = "\n".join([f"User: {entry['user']}\nAssistant: {entry['assistant']}" for entry in self.session_log]) + f"\nUser: {query}"

# Store the running request
self.request = self.LlmApi.query(full_query, self.display_custom_response)

# Update button to Stop
self.submit_button.setText("Stop")
else:
# Stop the running query
self.LlmApi.stop_threads()
# Revert button back to Submit
self.submit_button.setText(self.submit_label)

self.request = self.LlmApi.query(full_query, self.display_custom_response)

def onAnalyzeFunctionClicked(self) -> None:
"""
Event for the 'Analyze Function' button.
"""
datatype = self.datatype.split(':')[1]
il_type = self.il_type.name
func = self.get_func_text()

for fn_name, checkbox in self.filter_checkboxes.items():
if checkbox.isChecked():
action = fn_name.split(':')[0].replace(' ', '_')
self.request = self.LlmApi.analyze_function(
action, self.bv, self.offset_addr, datatype, il_type, func, self.display_analyze_response
)
Toggles the button between 'Analyze Function' and 'Stop'.
"""
self.submit_button = self.sender()

if self.submit_button.text() == "Analyze Function":
self.submit_label = self.submit_button.text()
# Start analysis
datatype = self.datatype.split(':')[1]
il_type = self.il_type.name
func = self.get_func_text()

for fn_name, checkbox in self.filter_checkboxes.items():
if checkbox.isChecked():
action = fn_name.split(':')[0].replace(' ', '_')

# Trigger LLM query and store request
self.request = self.LlmApi.analyze_function(
action, self.bv, self.offset_addr, datatype, il_type, func, self.display_analyze_response
)

# Change the button text to "Stop"
self.submit_button.setText("Stop")
else:
# Stop the running query
self.LlmApi.stop_threads()

# Revert the button back to "Analyze Function"
self.submit_button.setText(self.submit_label)

def onAnalyzeClearClicked(self) -> None:
"""
Expand Down Expand Up @@ -436,6 +505,10 @@ def display_response(self, response) -> None:
html_resp += self._generate_feedback_buttons()
self.response = response["response"]
self.text_box.setHtml(html_resp)
if(not self.LlmApi.isRunning()):
# Revert the button back to "Explain"
self.submit_button.setText(self.submit_label)


def display_custom_response(self, response) -> None:
"""
Expand All @@ -453,6 +526,9 @@ def display_custom_response(self, response) -> None:
html_resp += self._generate_feedback_buttons()
self.response = response["response"]
self.query_response_browser.setHtml(html_resp)
if(not self.LlmApi.isRunning()):
# Revert the button back to "Explain"
self.submit_button.setText(self.submit_label)

def display_analyze_response(self, response) -> None:
"""
Expand Down Expand Up @@ -497,6 +573,10 @@ def display_analyze_response(self, response) -> None:

# Resize columns to fit the content
self.actions_table.resizeColumnsToContents()
if(not self.LlmApi.isRunning()):
# Revert the button back to "Explain"
self.submit_button.setText(self.submit_label)



def _format_action(self, action: dict) -> str:
Expand Down
16 changes: 16 additions & 0 deletions src/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, client: OpenAI, model: str, max_tokens: int, query: str, syst
self.query = query
self.system = system
self.tools = tools or None
self.running = True

def run(self) -> None:
"""
Expand All @@ -53,6 +54,10 @@ def run(self) -> None:
max_tokens=self.max_tokens,
tools=self.tools,
)

if not self.running: # Check before processing response
return

if self.tools:
#print(f"finish_reason: {response.choices[0].finish_reason}")
#print(f"{response.choices[0].message.content}")
Expand Down Expand Up @@ -81,10 +86,21 @@ def run(self) -> None:
else: # Not self.tools
response_buffer = ""
for chunk in response:
if not self.running: # Stop consuming stream if interrupted
return
message_chunk = chunk.choices[0].delta.content or ""
response_buffer += message_chunk
self.update_response.emit({"response":response_buffer})

def stop(self):
"""
Stops the thread by setting the running flag to False and calling the built-in quit and terminate methods.
"""
self.running = False # Signal to stop processing the API response
self.quit() # Graceful stop (it allows the thread to clean up)
self.terminate() # Forcefully kill the thread if it doesn't stop immediately
self.wait() # Ensure the thread has completely stopped


def _generate_random_string(self, length=8):
"""
Expand Down

0 comments on commit 3ddf484

Please sign in to comment.