Skip to content

Commit

Permalink
Bug fixes: concurrency + Error handling of client disconnects + Tools
Browse files Browse the repository at this point in the history
  • Loading branch information
bonk1t committed Dec 11, 2023
1 parent 0cda6f4 commit 4a6f52f
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/nalgonda/custom_tools/build_directory_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def print_tree(self):
sub_indent = " " * 4 * (level + 1)

for f in files:
if self.file_extensions is None or f.endswith(tuple(self.file_extensions)):
if not self.file_extensions or f.endswith(tuple(self.file_extensions)):
tree_str += f"{sub_indent}{f}\n"

return tree_str
Expand Down
2 changes: 1 addition & 1 deletion src/nalgonda/custom_tools/print_all_files_in_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def run(self) -> str:
output = []
for root, _, files in os.walk(self.start_directory, topdown=True):
for file in files:
if self.file_extensions is None or file.endswith(tuple(self.file_extensions)):
if not self.file_extensions or file.endswith(tuple(self.file_extensions)):
file_path = os.path.join(root, file)
output.append(f"{file_path}:\n```\n{self.read_file(file_path)}\n```\n")
return "\n".join(output)
Expand Down
27 changes: 19 additions & 8 deletions src/nalgonda/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

from agency_manager import AgencyManager
from agency_swarm import Agency
from agency_swarm.messages import MessageOutput
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from nalgonda.constants import DATA_DIR
from websockets import ConnectionClosedOK

# Ensure directories exist
DATA_DIR.mkdir(exist_ok=True)
Expand Down Expand Up @@ -79,6 +81,8 @@ async def websocket_endpoint(websocket: WebSocket, agency_id: str):

await process_message(user_message, agency, websocket)

except (WebSocketDisconnect, ConnectionClosedOK) as e:
raise e
except Exception as e:
logger.exception(e)
await ws_manager.send_message(f"Error: {e}\nPlease try again.", websocket)
Expand All @@ -87,22 +91,29 @@ async def websocket_endpoint(websocket: WebSocket, agency_id: str):
except WebSocketDisconnect:
await ws_manager.disconnect(websocket)
logger.info(f"WebSocket disconnected for agency_id: {agency_id}")
except ConnectionClosedOK:
logger.info(f"WebSocket disconnected for agency_id: {agency_id}")


async def process_message(user_message: str, agency: Agency, websocket: WebSocket):
"""Process the user message and send the response to the websocket."""
loop = asyncio.get_running_loop()

gen = agency.get_completion(message=user_message, yield_messages=True)

async for response in async_gen(gen):
response_text = response.get_formatted_content()
await ws_manager.send_message(response_text, websocket)
def get_next() -> MessageOutput | None:
try:
return next(gen)
except StopIteration:
return None

while True:
response = await loop.run_in_executor(None, get_next)
if response is None:
break

async def async_gen(gen):
"""Asynchronous wrapper for a synchronous generator."""
for value in gen:
# Offload the blocking operation to a separate thread
yield await asyncio.to_thread(lambda v=value: v)
response_text = response.get_formatted_content()
await ws_manager.send_message(response_text, websocket)


if __name__ == "__main__":
Expand Down

0 comments on commit 4a6f52f

Please sign in to comment.