Skip to content

Commit

Permalink
Improved API provider settings to allow storing multiple providers.
Browse files Browse the repository at this point in the history
  • Loading branch information
jtang613 committed Sep 12, 2024
1 parent af74bd2 commit ac10c3e
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 40 deletions.
32 changes: 24 additions & 8 deletions src/llm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ def __init__(self):
self.threads = [] # Keep a list of active threads
self.initialize_database()
self.rag = RAG(self.settings.get_string('binassist.rag_db_path'))
self.api_provider = self.get_active_provider()

def get_active_provider(self):
"""
Returns the currently active API provider.
"""
active_name = self.settings.get_string('binassist.active_provider')
providers = json.loads(self.settings.get_json('binassist.api_providers'))
return next((p for p in providers if p['api___name'] == active_name), None)

def rag_init(self, markdown_files: List[str]) -> None:
"""
Expand Down Expand Up @@ -86,8 +95,9 @@ def _create_client(self) -> OpenAI:
Returns:
OpenAI: Configured API client instance.
"""
base_url = self.settings.get_string('binassist.remote_host')
api_key = self.settings.get_string('binassist.api_key')
self.api_provider = self.get_active_provider()
base_url = self.api_provider['api__host']
api_key = self.api_provider['api_key']
return OpenAI(base_url=base_url, api_key=api_key, http_client=http_client)

def initialize_database(self) -> None:
Expand Down Expand Up @@ -145,6 +155,8 @@ def explain(self, bv, addr, bin_type, il_type, addr_to_text_func, signal) -> str
str: The query sent to the LLM.
"""
client = self._create_client()
model = self.api_provider['api__model']
max_tokens = self.api_provider['api__max_tokens']
query = f"Describe the functionality of the decompiled {bv.platform.name} {bin_type} code " +\
f"below (represented as {il_type}). Provide a summary paragraph section followed by " +\
f"an analysis section that lists the functionality of each line of code. The analysis " +\
Expand All @@ -153,7 +165,7 @@ def explain(self, bv, addr, bin_type, il_type, addr_to_text_func, signal) -> str
f"present. But only fallback to strings or log messages that are clearly function " +\
f"names for this function.\n```\n" +\
f"{addr_to_text_func(bv, addr)}\n```"
self._start_thread(client, query, self.SYSTEM_PROMPT, signal)
self._start_thread(client, model, max_tokens, query, self.SYSTEM_PROMPT, signal)
return query

def query(self, query, signal) -> str:
Expand All @@ -168,13 +180,15 @@ def query(self, query, signal) -> str:
str: The query sent to the LLM.
"""
client = self._create_client()
model = self.api_provider['api__model']
max_tokens = self.api_provider['api__max_tokens']
if self.use_rag():
context = self._get_rag_context(query)
augmented_query = f"Context:\n{context}\n\nQuery: {query}"
self._start_thread(client, augmented_query, self.SYSTEM_PROMPT, signal)
self._start_thread(client, model, max_tokens, augmented_query, self.SYSTEM_PROMPT, signal)
return augmented_query
else:
self._start_thread(client, query, self.SYSTEM_PROMPT, signal)
self._start_thread(client, model, max_tokens, query, self.SYSTEM_PROMPT, signal)
return query

def analyze_function(self, action: str, bv, addr, bin_type, il_type, addr_to_text_func, signal) -> str:
Expand All @@ -195,14 +209,16 @@ def analyze_function(self, action: str, bv, addr, bin_type, il_type, addr_to_tex
str: The query sent to the LLM.
"""
client = self._create_client()
model = self.api_provider['api__model']
max_tokens = self.api_provider['api__max_tokens']
code = addr_to_text_func(bv, addr)
prompt = ToolCalling.ACTION_PROMPTS.get(action, "").format(code=code)

if not prompt:
raise ValueError(f"Unknown action type: {action}")

query = f"{prompt}\n{self.FUNCTION_PROMPT}{self.FORMAT_PROMPT}"
self._start_thread(client, query, f"{self.SYSTEM_PROMPT}{self.FUNCTION_PROMPT}{self.FORMAT_PROMPT}", signal, ToolCalling.FN_TEMPLATES)
self._start_thread(client, model, max_tokens, query, f"{self.SYSTEM_PROMPT}{self.FUNCTION_PROMPT}{self.FORMAT_PROMPT}", signal, ToolCalling.FN_TEMPLATES)

return query

Expand Down Expand Up @@ -238,7 +254,7 @@ def get_rag_document_list(self) -> List[str]:
"""
return self.rag.get_document_list()

def _start_thread(self, client, query, system, signal, tools=None) -> None:
def _start_thread(self, client, model, max_tokens, query, system, signal, tools=None) -> None:
"""
Starts a new thread to handle streaming responses from the LLM.
Expand All @@ -249,7 +265,7 @@ def _start_thread(self, client, query, system, signal, tools=None) -> None:
signal (Signal): Qt signal to update with the response.
tools (dict): A dictionary of available toold for the LLM to consdider.
"""
thread = StreamingThread(client, query, system, tools)
thread = StreamingThread(client, model, max_tokens, query, system, tools)
thread.update_response.connect(signal)
self.threads.append(thread) # Keep track of the thread
thread.start()
Expand Down
58 changes: 29 additions & 29 deletions src/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

class BinAssistSettings(Settings):
"""
Manages the configuration settings for the BinAssist plugin, including API keys, model settings,
Manages the configuration settings for the BinAssist plugin, including API providers, RAG settings,
and other preferences that need to be stored and retrieved across sessions.
"""

Expand All @@ -28,25 +28,36 @@ def _register_settings(self) -> None:
self.register_group('binassist', 'BinAssist')

settings_definitions = {
'binassist.remote_host': {
'title': 'Remote API Host',
'description': 'The API host endpoint used to make requests.',
'type': 'string',
'default': 'https://api.openai.com/v1'
},
'binassist.api_key': {
'title': 'API Key',
'description': 'The API key used to make requests.',
'type': 'string',
'default': None,
'ignore': ["SettingsProjectScope", "SettingsResourceScope"],
'hidden': True
# API Provider fields have odd underscores so the sort sanely in the Settings view.
'binassist.api_providers': {
'title': 'API Providers',
'description': 'List of API providers for BinAssist',
'type': 'array',
'elementType': 'object',
'default': [
{
'api___name': 'GPT-4o-Mini',
'api__host': 'https://api.openai.com/v1',
'api_key': '',
'api__model': 'gpt-4o-mini',
'api__max_tokens': 16384
}
],
'properties': {
'api___name': {'type': 'string', 'title': 'Provider Name'},
'api__host': {'type': 'string', 'title': 'Remote API Host'},
'api_key': {'type': 'string', 'title': 'API Key', 'hidden': True, "ignore" : ["SettingsProjectScope", "SettingsResourceScope"]},
'api__model': {'type': 'string', 'title': 'LLM Model'},
'api__max_tokens': {'type': 'number', 'title': 'Max Completion Tokens', 'minValue': 1, 'maxValue': 128*1024}
}
},
'binassist.model': {
'title': 'LLM Model',
'description': 'The LLM model used to generate the response.',
'binassist.active_provider': {
'title': 'Active API Provider',
'description': 'The currently selected API provider',
'type': 'string',
'default': 'gpt-4o-mini'
'default': 'GPT-4o-Mini',
# 'enum': ['GPT-4o-Mini'], # This will be dynamically updated
# 'uiSelectionAction': 'binassist_refresh_providers'
},
'binassist.rlhf_db': {
'title': 'RLHF Database Path',
Expand All @@ -55,14 +66,6 @@ def _register_settings(self) -> None:
'default': 'rlhf_feedback.db',
'uiSelectionAction': 'file'
},
'binassist.max_tokens': {
'title': 'Max Completion Tokens',
'description': 'The maximum number of tokens used for completion.',
'type': 'number',
'default': 8192,
'minValue': 1,
'maxValue': 128*1024
},
'binassist.rag_db_path': {
'title': 'RAG Database Path',
'description': 'Path to store the RAG vector database.',
Expand All @@ -79,7 +82,4 @@ def _register_settings(self) -> None:
}

for key, properties in settings_definitions.items():
if 'minValue' in properties and 'maxValue' in properties:
properties['message'] = f"Min: {properties['minValue']}, Max: {properties['maxValue']}"
self.register_setting(key, json.dumps(properties))

10 changes: 7 additions & 3 deletions src/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,23 @@ class StreamingThread(QtCore.QThread):

update_response = QtCore.Signal(dict)

def __init__(self, client: OpenAI, query: str, system: str, tools=None) -> None:
def __init__(self, client: OpenAI, model: str, max_tokens: int, query: str, system: str, tools=None) -> None:
"""
Initializes the thread with the necessary parameters for making a streaming API call.
Parameters:
client (OpenAI): The OpenAI client used for making API calls.
model (str): The model to use.
max_tokens (int): The max number of context tokens.
query (str): The user's query to be processed.
system (str): System-level instructions or context for the API call.
tools (list): A list of tools that the LLM can call during the response.
"""
super().__init__()
self.settings = Settings()
self.client = client
self.model = model
self.max_tokens = max_tokens
self.query = query
self.system = system
self.tools = tools or None
Expand All @@ -40,13 +44,13 @@ def run(self) -> None:
signaling the main thread upon updates or when an error occurs.
"""
response = self.client.chat.completions.create(
model=self.settings.get_string('binassist.model'),
model=self.model,
messages=[
{"role": "system", "content": self.system},
{"role": "user", "content": self.query}
],
stream=False if self.tools else True,
max_tokens=self.settings.get_integer('binassist.max_tokens'),
max_tokens=self.max_tokens,
tools=self.tools,
)
if self.tools:
Expand Down

0 comments on commit ac10c3e

Please sign in to comment.