From d9e611d0347fb801f3ad88f6ba743fa6769e54ed Mon Sep 17 00:00:00 2001 From: Nikita Bobrovskiy <39348559+bonk1t@users.noreply.github.com> Date: Tue, 19 Dec 2023 17:07:36 +0000 Subject: [PATCH] Minor AI refactoring (#8) Testing out self-improvement capabilities. Most changes are made by the default agency itself Core functionality (see Postman collection) is tested manually * AI refactoring * Update README.md * Apply suggestions from code review Co-authored-by: Guilherme Parpinelli <59970362+guiparpinelli@users.noreply.github.com> --- README.md | 80 ++++++------------- src/nalgonda/agency_config_lock_manager.py | 17 +++- src/nalgonda/agency_manager.py | 21 ++--- src/nalgonda/config.py | 68 ---------------- src/nalgonda/constants.py | 9 ++- src/nalgonda/custom_tools/__init__.py | 3 +- .../custom_tools/build_directory_tree.py | 5 +- .../custom_tools/generate_proposal.py | 16 +--- .../print_all_files_in_directory.py | 5 +- src/nalgonda/custom_tools/search_web.py | 5 +- src/nalgonda/custom_tools/utils.py | 45 +++++------ .../custom_tools/write_and_save_program.py | 16 ++-- src/nalgonda/data/default_config.json | 3 +- src/nalgonda/database/__init__.py | 2 + src/nalgonda/exceptions.py | 4 +- src/nalgonda/main.py | 17 ++-- src/nalgonda/models/__init__.py | 2 + src/nalgonda/models/agency_config.py | 45 +++++++++++ src/nalgonda/models/agent_config.py | 12 +++ src/nalgonda/models/request_models.py | 8 +- src/nalgonda/routers/__init__.py | 2 + src/nalgonda/routers/v1/__init__.py | 2 - src/nalgonda/routers/v1/api/__init__.py | 4 +- src/nalgonda/routers/v1/api/agency.py | 20 +++-- src/nalgonda/routers/v1/websocket.py | 75 ++++++++--------- src/nalgonda/settings.py | 14 ++++ ...ger.py => websocket_connection_manager.py} | 10 +-- tests/custom_tools/test_utils.py | 9 ++- 28 files changed, 258 insertions(+), 261 deletions(-) delete mode 100644 src/nalgonda/config.py create mode 100644 src/nalgonda/models/agency_config.py create mode 100644 src/nalgonda/models/agent_config.py create mode 100644 src/nalgonda/settings.py rename src/nalgonda/{connection_manager.py => websocket_connection_manager.py} (71%) diff --git a/README.md b/README.md index e3a0a114..12e8d931 100644 --- a/README.md +++ b/README.md @@ -2,73 +2,45 @@ ## Overview -Project Nalgonda is a tool for managing and executing AI agents. -It is built on top of the [OpenAI Assitants API](https://platform.openai.com/docs/assistants/overview) -and provides a simple interface for configuring agents and executing them. +Project Nalgonda is an innovative platform for managing and executing AI-driven swarm agencies. +It is built upon the foundational [OpenAI Assistants API](https://platform.openai.com/docs/assistants/overview) +and extends its functionality through a suite of specialized tools and a sophisticated management system for AI agencies. + +## Key Components + +- **Agency Configuration Manager**: Manages configurations for AI agencies, ensuring they're loaded and saved properly. +- **WebSocket Connection Manager**: Handles real-time WebSocket connections for interactive agency-client communication. +- **Custom Tools**: A collection of tools including `SearchWeb`, `GenerateProposal`, `BuildDirectoryTree`, and more, +providing specialized functionalities tailored to the needs of different agency roles. +- **Data Persistence**: Utilizes JSON-based configurations to maintain agency states and preferences across sessions. ## Features - **Agency Configuration**: Configure agencies with agents - **Tool Configuration**: Configure tools with custom parameters -- **Tool Execution**: Execute tools and returns results +- **Tool Execution**: Execute tools and return results - **Agent Configuration**: Configure agents with their knowledge and tools - **User Management**: Manage users and their access to different agencies [TODO] -## Getting Started - -### Prerequisites - -- Python 3.11 or higher -- FastAPI -- Uvicorn (for running the server) -- Additional Python packages as listed in `pyproject.toml` - -### Installation - -1. **Clone the Repository** - - ```sh - git clone https://github.com/bonk1t/nalgonda.git - cd nalgonda - ``` - -2. **Install Dependencies** - - Using poetry: - - ```sh - poetry install - ``` +## Installation - Or using pip: +Ensure you have Python 3.11 or higher and follow these steps to get started: - ```sh - pip install -r requirements.txt - ``` +1. Install the required dependencies (from `requirements.txt` or using Poetry). +2. Set up the necessary environment variables, including `OPENAI_API_KEY`. +3. Use the provided JSON configuration files as templates to configure your own AI agencies. +4. Start the FastAPI server (`uvicorn nalgonda.main:app --reload`) to interact with the system. -3. **Set up Environment Variables** - - Ensure to set up the necessary environment variables such as `OPENAI_API_KEY`. - -### Running the Application - -1. **Start the FastAPI Server** - - ```sh - uvicorn nalgonda.main:app --reload - ``` - - The API will be available at `http://localhost:8000`. - -2. **Accessing the Endpoints** - - Use a tool like Postman or Swagger UI to interact with the API endpoints. +Note: Refer to individual class and method docstrings for detailed instructions and usage. ## Usage ### API Endpoints -Send a POST request to the /create_agency endpoint to create an agency. The response will contain the following: -- agency_id: The ID of the agency -### WebSocket Endpoints -After creating an agency, you can connect to the WebSocket endpoint at /ws/{agency_id} to communicate with the agency. +Send POST requests to endpoints such as `POST /v1/api/agency` and `POST /v1/api/agency/message` to perform operations +like creating new agencies and sending messages to them. + +### WebSocket Communication + +Connect to WebSocket endpoints (e.g., `/v1/ws/{agency_id}`, `/v1/ws/{agency_id}/{thread_id}`) +to engage in real-time communication with configured AI agencies. diff --git a/src/nalgonda/agency_config_lock_manager.py b/src/nalgonda/agency_config_lock_manager.py index 7efa4487..727215d2 100644 --- a/src/nalgonda/agency_config_lock_manager.py +++ b/src/nalgonda/agency_config_lock_manager.py @@ -3,10 +3,23 @@ class AgencyConfigLockManager: - """Lock manager for agency config files""" + """Manages locking for agency configuration files. + This manager guarantees that each agency configuration has a unique lock, + preventing simultaneous access and modification by multiple processes. + """ + + # Mapping from agency ID to its corresponding Lock. _locks: dict[str, threading.Lock] = defaultdict(threading.Lock) @classmethod - def get_lock(cls, agency_id): + def get_lock(cls, agency_id: str) -> threading.Lock: + """Retrieves the lock for a given agency ID, creating it if not present. + + Args: + agency_id (str): The unique identifier for the agency. + + Returns: + threading.Lock: The lock associated with the agency ID. + """ return cls._locks[agency_id] diff --git a/src/nalgonda/agency_manager.py b/src/nalgonda/agency_manager.py index ce10c4ab..20222c0d 100644 --- a/src/nalgonda/agency_manager.py +++ b/src/nalgonda/agency_manager.py @@ -5,19 +5,19 @@ from agency_swarm import Agency, Agent -from nalgonda.config import AgencyConfig from nalgonda.custom_tools import TOOL_MAPPING +from nalgonda.models.agency_config import AgencyConfig logger = logging.getLogger(__name__) class AgencyManager: - def __init__(self): - self.cache = {} # agency_id+thread_id: agency + def __init__(self) -> None: + self.cache: dict[str, Agency] = {} # Mapping from agency_id+thread_id to Agency class instance self.lock = asyncio.Lock() async def create_agency(self, agency_id: str | None = None) -> tuple[Agency, str]: - """Create the agency for the given agency ID.""" + """Create an agency and return the agency and the agency_id.""" agency_id = agency_id or uuid.uuid4().hex async with self.lock: @@ -27,9 +27,9 @@ async def create_agency(self, agency_id: str | None = None) -> tuple[Agency, str return agency, agency_id async def get_agency(self, agency_id: str, thread_id: str | None) -> Agency | None: - """Get the agency for the given agency ID and thread ID.""" + """Get the agency from the cache.""" async with self.lock: - return self.cache.get(self.get_cache_key(agency_id, thread_id), None) + return self.cache.get(self.get_cache_key(agency_id, thread_id)) async def cache_agency(self, agency: Agency, agency_id: str, thread_id: str | None) -> None: """Cache the agency for the given agency ID and thread ID.""" @@ -39,9 +39,10 @@ async def cache_agency(self, agency: Agency, agency_id: str, thread_id: str | No async def delete_agency_from_cache(self, agency_id: str, thread_id: str | None) -> None: async with self.lock: - self.cache.pop(self.get_cache_key(agency_id, thread_id), None) + cache_key = self.get_cache_key(agency_id, thread_id) + self.cache.pop(cache_key, None) - async def refresh_thread_id(self, agency, agency_id, thread_id) -> str | None: + async def refresh_thread_id(self, agency: Agency, agency_id: str, thread_id: str | None) -> str | None: new_thread_id = agency.main_thread.id if thread_id != new_thread_id: await self.cache_agency(agency, agency_id, new_thread_id) @@ -87,8 +88,8 @@ def load_agency_from_config(agency_id: str) -> Agency: # It saves all the settings in the settings.json file (in the root folder, not thread safe) agency = Agency(agency_chart, shared_instructions=config.agency_manifesto) - config.update_agent_ids_in_config(agency_id, agents=agency.agents) - config.save(agency_id) + config.update_agent_ids_in_config(agency.agents) + config.save() logger.info(f"Agency creation took {time.time() - start} seconds. Session ID: {agency_id}") return agency diff --git a/src/nalgonda/config.py b/src/nalgonda/config.py deleted file mode 100644 index 66dcfce2..00000000 --- a/src/nalgonda/config.py +++ /dev/null @@ -1,68 +0,0 @@ -from pathlib import Path - -from agency_config_lock_manager import AgencyConfigLockManager -from agency_swarm import Agent -from pydantic import BaseModel, Field -from pydantic_settings import BaseSettings, SettingsConfigDict - -from nalgonda.constants import CONFIG_FILE, DEFAULT_CONFIG_FILE - -LATEST_GPT_MODEL = "gpt-4-1106-preview" - - -class Settings(BaseSettings): - openai_api_key: str = Field(validation_alias="OPENAI_API_KEY") - gpt_model: str = Field(default=LATEST_GPT_MODEL, validation_alias="GPT_MODEL") - - model_config = SettingsConfigDict() - - -settings = Settings() - - -class AgentConfig(BaseModel): - """Config for an agent""" - - id: str | None = None - role: str - description: str - instructions: str - files_folder: str | None = None - tools: list[str] = Field(default_factory=list) - - -class AgencyConfig(BaseModel): - """Config for the agency""" - - agency_manifesto: str = Field(default="Agency Manifesto") - agents: list[AgentConfig] - agency_chart: list[str | list[str]] # contains agent roles - - def update_agent_ids_in_config(self, agency_id: str, agents: list[Agent]) -> None: - """Update the agent IDs in the config file""" - for agent in agents: - for agent_conf in self.agents: - if agent.name == f"{agent_conf.role}_{agency_id}": - agent_conf.id = agent.id - break - - @classmethod - def load(cls, agency_id: str) -> "AgencyConfig": - """Load the config from a file""" - config_file_name = cls.get_config_name(agency_id) - config_file_name = config_file_name if config_file_name.exists() else DEFAULT_CONFIG_FILE - - lock = AgencyConfigLockManager.get_lock(agency_id) - with lock as _, open(config_file_name) as f: - return cls.model_validate_json(f.read()) - - def save(self, agency_id: str) -> None: - """Save the config to a file""" - lock = AgencyConfigLockManager.get_lock(agency_id) - with lock as _, open(self.get_config_name(agency_id), "w") as f: - f.write(self.model_dump_json(indent=2)) - - @staticmethod - def get_config_name(agency_id: str) -> Path: - """Get the name of the config file""" - return Path(f"{CONFIG_FILE}_{agency_id}.json") diff --git a/src/nalgonda/constants.py b/src/nalgonda/constants.py index b7f39ad8..6e8c4d88 100644 --- a/src/nalgonda/constants.py +++ b/src/nalgonda/constants.py @@ -1,6 +1,9 @@ from pathlib import Path -# File and Directory Constants -DATA_DIR = Path(__file__).resolve().parent / "data" +# Constants representing base and data directories +BASE_DIR = Path(__file__).resolve(strict=True).parent +DATA_DIR = BASE_DIR / "data" + +# Constants for default configuration files DEFAULT_CONFIG_FILE = DATA_DIR / "default_config.json" -CONFIG_FILE = DATA_DIR / "config" +CONFIG_FILE_BASE = DATA_DIR / "config.json" diff --git a/src/nalgonda/custom_tools/__init__.py b/src/nalgonda/custom_tools/__init__.py index b106b00a..15132955 100644 --- a/src/nalgonda/custom_tools/__init__.py +++ b/src/nalgonda/custom_tools/__init__.py @@ -5,6 +5,7 @@ from nalgonda.custom_tools.generate_proposal import GenerateProposal from nalgonda.custom_tools.print_all_files_in_directory import PrintAllFilesInDirectory from nalgonda.custom_tools.search_web import SearchWeb +from nalgonda.custom_tools.write_and_save_program import WriteAndSaveProgram TOOL_MAPPING = { "CodeInterpreter": CodeInterpreter, @@ -13,5 +14,5 @@ "GenerateProposal": GenerateProposal, "PrintAllFilesInDirectory": PrintAllFilesInDirectory, "SearchWeb": SearchWeb, - # "WriteAndSaveProgram": WriteAndSaveProgram, + "WriteAndSaveProgram": WriteAndSaveProgram, } diff --git a/src/nalgonda/custom_tools/build_directory_tree.py b/src/nalgonda/custom_tools/build_directory_tree.py index 7d0127b2..83f81249 100644 --- a/src/nalgonda/custom_tools/build_directory_tree.py +++ b/src/nalgonda/custom_tools/build_directory_tree.py @@ -17,10 +17,11 @@ class BuildDirectoryTree(BaseTool): ) file_extensions: set[str] = Field( default_factory=set, - description="Set of file extensions to include in the tree. If empty, all files will be included.", + description="Set of file extensions to include in the tree. If empty, all files will be included. " + "Examples are {'.py', '.txt', '.md'}.", ) - _validate_start_directory = field_validator("start_directory", mode="before")(check_directory_traversal) + _validate_start_directory = field_validator("start_directory", mode="after")(check_directory_traversal) def run(self) -> str: """Recursively print the tree of directories and files using pathlib.""" diff --git a/src/nalgonda/custom_tools/generate_proposal.py b/src/nalgonda/custom_tools/generate_proposal.py index bb5195bc..e51dc2e9 100644 --- a/src/nalgonda/custom_tools/generate_proposal.py +++ b/src/nalgonda/custom_tools/generate_proposal.py @@ -3,7 +3,7 @@ from nalgonda.custom_tools.utils import get_chat_completion -USER_PROMPT_PREFIX = "Please draft a proposal for the following project brief: " +USER_PROMPT_PREFIX = "Please draft a proposal for the following project brief: \n" SYSTEM_MESSAGE = """\ You are a professional proposal drafting assistant. \ Do not include any actual technologies or technical details into proposal unless \ @@ -19,14 +19,6 @@ class GenerateProposal(BaseTool): project_brief: str = Field(..., description="The project brief to generate a proposal for.") def run(self) -> str: - user_prompt = self.get_user_prompt() - message = get_chat_completion( - user_prompt=user_prompt, - system_message=SYSTEM_MESSAGE, - temperature=0.6, - ) - - return message - - def get_user_prompt(self): - return f"{USER_PROMPT_PREFIX}\n{self.project_brief}" + user_prompt = f"{USER_PROMPT_PREFIX}{self.project_brief}" + response = get_chat_completion(user_prompt=user_prompt, system_message=SYSTEM_MESSAGE, temperature=0.6) + return response diff --git a/src/nalgonda/custom_tools/print_all_files_in_directory.py b/src/nalgonda/custom_tools/print_all_files_in_directory.py index fb20b5cc..d03b911b 100644 --- a/src/nalgonda/custom_tools/print_all_files_in_directory.py +++ b/src/nalgonda/custom_tools/print_all_files_in_directory.py @@ -17,10 +17,11 @@ class PrintAllFilesInDirectory(BaseTool): ) file_extensions: set[str] = Field( default_factory=set, - description="Set of file extensions to include in the output. If empty, all files will be included.", + description="Set of file extensions to include in the tree. If empty, all files will be included. " + "Examples are {'.py', '.txt', '.md'}.", ) - _validate_start_directory = field_validator("start_directory", mode="before")(check_directory_traversal) + _validate_start_directory = field_validator("start_directory", mode="after")(check_directory_traversal) def run(self) -> str: """ diff --git a/src/nalgonda/custom_tools/search_web.py b/src/nalgonda/custom_tools/search_web.py index 22b74561..b7cd814b 100644 --- a/src/nalgonda/custom_tools/search_web.py +++ b/src/nalgonda/custom_tools/search_web.py @@ -10,7 +10,8 @@ class SearchWeb(BaseTool): ..., description="The search phrase you want to use. " "Optimize the search phrase for an internet search engine.", ) + max_results: int = Field(default=10, description="The maximum number of search results to return, default is 10.") - def run(self): + def run(self) -> str: with DDGS() as ddgs: - return str("\n".join(str(r) for r in ddgs.text(self.phrase, max_results=10))) + return "\n".join(str(result) for result in ddgs.text(self.phrase, max_results=self.max_results)) diff --git a/src/nalgonda/custom_tools/utils.py b/src/nalgonda/custom_tools/utils.py index 65336547..000e2634 100644 --- a/src/nalgonda/custom_tools/utils.py +++ b/src/nalgonda/custom_tools/utils.py @@ -1,44 +1,37 @@ -import tempfile from pathlib import Path from agency_swarm.util import get_openai_client +from nalgonda.settings import settings -def get_chat_completion(user_prompt, system_message, **kwargs) -> str: - """ - Generate a chat completion based on a prompt and a system message. - This function is a wrapper around the OpenAI API. - """ - from config import settings +def get_chat_completion(user_prompt: str, system_message: str, **kwargs) -> str: + """Generate a chat completion based on a prompt and a system message. + This function is a wrapper around the OpenAI API.""" client = get_openai_client() completion = client.chat.completions.create( model=settings.gpt_model, messages=[ - { - "role": "system", - "content": system_message, - }, - { - "role": "user", - "content": user_prompt, - }, + {"role": "system", "content": system_message}, + {"role": "user", "content": user_prompt}, ], **kwargs, ) + return completion.choices[0].message.content - return str(completion.choices[0].message.content) - -def check_directory_traversal(dir_path: str) -> Path: - """ - Ensures that the given directory path is within allowed paths. - """ - path = Path(dir_path) - if ".." in path.parts: +def check_directory_traversal(directory: Path) -> Path: + """Ensures that the given directory path is within allowed paths.""" + if ".." in directory.parts: raise ValueError("Directory traversal is not allowed.") - allowed_bases = [Path(tempfile.gettempdir()).resolve(), Path.home().resolve()] - if not any(str(path.resolve()).startswith(str(base)) for base in allowed_bases): + base_directory = Path.cwd() + + # Resolve the directory path against the base directory + resolved_directory = (base_directory / directory).resolve() + + # Check if the resolved directory is a subpath of the base directory + if not resolved_directory.is_relative_to(base_directory): raise ValueError("Directory traversal is not allowed.") - return path + + return resolved_directory diff --git a/src/nalgonda/custom_tools/write_and_save_program.py b/src/nalgonda/custom_tools/write_and_save_program.py index 4b71d558..b03b54f9 100644 --- a/src/nalgonda/custom_tools/write_and_save_program.py +++ b/src/nalgonda/custom_tools/write_and_save_program.py @@ -17,6 +17,9 @@ class File(BaseTool): description="The name of the file including the extension and the file path from your current directory " "if needed.", ) + chain_of_thought: str = Field( + ..., description="Think step by step to determine the correct plan that is needed to write the file." + ) body: str = Field(..., description="Correct contents of a file") def run(self): @@ -24,12 +27,15 @@ def run(self): return "Invalid file path. Directory traversal is not allowed." # Extract the directory path from the file name - directory = DATA_DIR / self._agency_id / os.path.dirname(self.file_name) # TODO: pass agency_id to all tools - full_path = directory / self.file_name + directory_path = Path(self.file_name).parent + agency_id = "test_agency_id" # agency_id = self.context["agency_id"] # TODO: pass agency_id to all tools + directory = DATA_DIR / agency_id / directory_path + + # Ensure the directory exists + directory.mkdir(parents=True, exist_ok=True) - # If the directory is not empty, check if it exists and create it if not - if directory and not os.path.exists(directory): - os.makedirs(directory) + # Construct the full path using the directory and file name + full_path = directory / Path(self.file_name).name # Write the file with open(full_path, "w") as f: diff --git a/src/nalgonda/data/default_config.json b/src/nalgonda/data/default_config.json index dac5b540..2e61317c 100644 --- a/src/nalgonda/data/default_config.json +++ b/src/nalgonda/data/default_config.json @@ -10,7 +10,8 @@ "tools": [ "CodeInterpreter", "BuildDirectoryTree", - "PrintAllFilesInDirectory" + "PrintAllFilesInDirectory", + "WriteAndSaveProgram" ] } ], diff --git a/src/nalgonda/database/__init__.py b/src/nalgonda/database/__init__.py index e69de29b..432d56de 100644 --- a/src/nalgonda/database/__init__.py +++ b/src/nalgonda/database/__init__.py @@ -0,0 +1,2 @@ +# __init__.py for the database module +# This file can be used to initialize database connections and define related utilities. diff --git a/src/nalgonda/exceptions.py b/src/nalgonda/exceptions.py index ddd37f4a..5908bc5e 100644 --- a/src/nalgonda/exceptions.py +++ b/src/nalgonda/exceptions.py @@ -1,2 +1,2 @@ -class AgencyNotFound(Exception): - pass +# exceptions.py - Custom exceptions for the Nalgonda project +# This file is reserved for custom exception classes. diff --git a/src/nalgonda/main.py b/src/nalgonda/main.py index f583db98..6db8a6e7 100644 --- a/src/nalgonda/main.py +++ b/src/nalgonda/main.py @@ -1,28 +1,25 @@ import logging from fastapi import FastAPI -from routers.v1 import v1_router from nalgonda.constants import DATA_DIR +from nalgonda.routers.v1 import v1_router -# Ensure directories exist -DATA_DIR.mkdir(exist_ok=True) - -app = FastAPI() +# Ensure data directory exists +DATA_DIR.mkdir(parents=True, exist_ok=True) +# Logging configuration logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - handlers=[ - # logging.FileHandler(DATA_DIR / "logs.log"), - logging.StreamHandler(), - ], + handlers=[logging.StreamHandler()], ) logger = logging.getLogger(__name__) +# FastAPI app initialization +app = FastAPI() app.include_router(v1_router) - if __name__ == "__main__": import uvicorn diff --git a/src/nalgonda/models/__init__.py b/src/nalgonda/models/__init__.py index e69de29b..ba35b251 100644 --- a/src/nalgonda/models/__init__.py +++ b/src/nalgonda/models/__init__.py @@ -0,0 +1,2 @@ +# __init__.py for the models module +# This file can be used to import model classes to make them easily accessible. diff --git a/src/nalgonda/models/agency_config.py b/src/nalgonda/models/agency_config.py new file mode 100644 index 00000000..ae1250f8 --- /dev/null +++ b/src/nalgonda/models/agency_config.py @@ -0,0 +1,45 @@ +import json +from pathlib import Path + +from agency_swarm import Agent +from pydantic import BaseModel, Field + +from nalgonda.agency_config_lock_manager import AgencyConfigLockManager +from nalgonda.constants import CONFIG_FILE_BASE, DEFAULT_CONFIG_FILE +from nalgonda.models.agent_config import AgentConfig + + +class AgencyConfig(BaseModel): + """Agency configuration model""" + + agency_id: str = Field(...) + agency_manifesto: str = Field(default="Agency Manifesto") + agents: list[AgentConfig] = Field(...) + agency_chart: list[str | list[str]] = Field(...) # contains agent roles + + def update_agent_ids_in_config(self, agents: list[Agent]) -> None: + """Update agent ids in config with the ids of the agents in the swarm""" + for agent in agents: + for agent_config in self.agents: + if agent.name == f"{agent_config.role}_{self.agency_id}": + agent_config.id = agent.id + + @classmethod + def load(cls, agency_id: str) -> "AgencyConfig": + """Load agency config from file""" + config_file_path = cls.get_config_path(agency_id) + if not config_file_path.is_file(): + config_file_path = DEFAULT_CONFIG_FILE + + with AgencyConfigLockManager.get_lock(agency_id), config_file_path.open() as file: + config = json.load(file) + config["agency_id"] = agency_id + return cls.model_validate(config) + + def save(self) -> None: + with AgencyConfigLockManager.get_lock(self.agency_id), self.get_config_path(self.agency_id).open("w") as file: + file.write(self.model_dump_json(indent=2)) + + @staticmethod + def get_config_path(agency_id: str) -> Path: + return CONFIG_FILE_BASE.with_name(f"config_{agency_id}.json") diff --git a/src/nalgonda/models/agent_config.py b/src/nalgonda/models/agent_config.py new file mode 100644 index 00000000..40fa7b07 --- /dev/null +++ b/src/nalgonda/models/agent_config.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel, Field + + +class AgentConfig(BaseModel): + """Config for an agent""" + + id: str | None = None + role: str + description: str + instructions: str + files_folder: str | None = None + tools: list[str] = Field(default_factory=list) diff --git a/src/nalgonda/models/request_models.py b/src/nalgonda/models/request_models.py index 544fdb4a..2a1d15ff 100644 --- a/src/nalgonda/models/request_models.py +++ b/src/nalgonda/models/request_models.py @@ -1,7 +1,7 @@ -from pydantic import BaseModel +from pydantic import BaseModel, Field class AgencyMessagePostRequest(BaseModel): - agency_id: str - message: str - thread_id: str | None = None + agency_id: str = Field(..., description="The unique identifier for the agency.") + message: str = Field(..., description="The message to be sent to the agency.") + thread_id: str | None = Field(None, description="The identifier for the conversational thread, if applicable.") diff --git a/src/nalgonda/routers/__init__.py b/src/nalgonda/routers/__init__.py index e69de29b..a87acd0e 100644 --- a/src/nalgonda/routers/__init__.py +++ b/src/nalgonda/routers/__init__.py @@ -0,0 +1,2 @@ +# __init__.py for the routers module +# This file is the entry point to the API routers, handling the routing of requests. diff --git a/src/nalgonda/routers/v1/__init__.py b/src/nalgonda/routers/v1/__init__.py index 573ec3a5..1d10a550 100644 --- a/src/nalgonda/routers/v1/__init__.py +++ b/src/nalgonda/routers/v1/__init__.py @@ -1,5 +1,3 @@ -# nalgonda/routers/v1/__init__.py - from fastapi import APIRouter from .api import api_router diff --git a/src/nalgonda/routers/v1/api/__init__.py b/src/nalgonda/routers/v1/api/__init__.py index 20873f98..8e1b889e 100644 --- a/src/nalgonda/routers/v1/api/__init__.py +++ b/src/nalgonda/routers/v1/api/__init__.py @@ -2,11 +2,11 @@ from fastapi import APIRouter -from .agency import agency_api_router +from .agency import agency_router api_router = APIRouter( prefix="/api", responses={404: {"description": "Not found"}}, ) -api_router.include_router(agency_api_router) +api_router.include_router(agency_router) diff --git a/src/nalgonda/routers/v1/api/agency.py b/src/nalgonda/routers/v1/api/agency.py index ad9685e6..1dfe7799 100644 --- a/src/nalgonda/routers/v1/api/agency.py +++ b/src/nalgonda/routers/v1/api/agency.py @@ -7,16 +7,14 @@ from nalgonda.models.request_models import AgencyMessagePostRequest logger = logging.getLogger(__name__) -agency_manager = AgencyManager() - -agency_api_router = APIRouter( - prefix="/agency", +agency_router = APIRouter( responses={404: {"description": "Not found"}}, ) +agency_manager = AgencyManager() -@agency_api_router.post("/") -async def create_agency(): +@agency_router.post("/agency") +async def create_agency() -> dict: """Create a new agency and return its id.""" # TODO: Add authentication: check if user is logged in and has permission to create an agency @@ -24,14 +22,14 @@ async def create_agency(): return {"agency_id": agency_id} -@agency_api_router.post("/message") -async def send_message(payload: AgencyMessagePostRequest) -> dict: +@agency_router.post("/agency/message") +async def post_agency_message(request: AgencyMessagePostRequest) -> dict: """Send a message to the CEO of the given agency.""" # TODO: Add authentication: check if agency_id is valid for the given user - user_message = payload.message - agency_id = payload.agency_id - thread_id = payload.thread_id + user_message = request.message + agency_id = request.agency_id + thread_id = request.thread_id logger.info(f"Received message: {user_message}, agency_id: {agency_id}, thread_id: {thread_id}") diff --git a/src/nalgonda/routers/v1/websocket.py b/src/nalgonda/routers/v1/websocket.py index cb17a04a..e3c955c7 100644 --- a/src/nalgonda/routers/v1/websocket.py +++ b/src/nalgonda/routers/v1/websocket.py @@ -5,31 +5,30 @@ from agency_swarm import Agency from agency_swarm.messages import MessageOutput from fastapi import APIRouter, WebSocket, WebSocketDisconnect -from websockets import ConnectionClosedOK +from websockets.exceptions import ConnectionClosedOK from nalgonda.agency_manager import AgencyManager -from nalgonda.connection_manager import ConnectionManager +from nalgonda.websocket_connection_manager import WebSocketConnectionManager logger = logging.getLogger(__name__) -ws_manager = ConnectionManager() +connection_manager = WebSocketConnectionManager() agency_manager = AgencyManager() ws_router = APIRouter( - prefix="/ws", tags=["websocket"], responses={404: {"description": "Not found"}}, ) -@ws_router.websocket("/{agency_id}") +@ws_router.websocket("/ws/{agency_id}") async def websocket_initial_endpoint(websocket: WebSocket, agency_id: str): """WebSocket endpoint for initial connection.""" await base_websocket_endpoint(websocket, agency_id) -@ws_router.websocket("/{agency_id}/{thread_id}") +@ws_router.websocket("/ws/{agency_id}/{thread_id}") async def websocket_thread_endpoint(websocket: WebSocket, agency_id: str, thread_id: str): """WebSocket endpoint for maintaining conversation with a specific thread.""" - await base_websocket_endpoint(websocket, agency_id, thread_id) + await base_websocket_endpoint(websocket, agency_id, thread_id=thread_id) async def base_websocket_endpoint(websocket: WebSocket, agency_id: str, thread_id: str | None = None): @@ -38,7 +37,7 @@ async def base_websocket_endpoint(websocket: WebSocket, agency_id: str, thread_i # TODO: Add authentication: check if agency_id is valid for the given user - await ws_manager.connect(websocket) + await connection_manager.connect(websocket) logger.info(f"WebSocket connected for agency_id: {agency_id}, thread_id: {thread_id}") agency = await agency_manager.get_agency(agency_id, thread_id) @@ -46,39 +45,43 @@ async def base_websocket_endpoint(websocket: WebSocket, agency_id: str, thread_i # TODO: remove this once Redis is used for storing agencies: # the problem now is that cache is empty in the websocket thread agency, _ = await agency_manager.create_agency(agency_id) - # await ws_manager.send_message("Agency not found", websocket) - # await ws_manager.disconnect(websocket) + # await connection_manager.send_message("Agency not found", websocket) + # await connection_manager.disconnect(websocket) # await websocket.close() # return try: - while True: - try: - user_message = await websocket.receive_text() - - if not user_message.strip(): - await ws_manager.send_message("message not provided", websocket) - continue - - await process_ws_message(user_message, agency, websocket) - - new_thread_id = await agency_manager.refresh_thread_id(agency, agency_id, thread_id) - if new_thread_id is not None: - await ws_manager.send_message(json.dumps({"thread_id": new_thread_id}), websocket) - thread_id = new_thread_id - - except (WebSocketDisconnect, ConnectionClosedOK) as e: - raise e - except Exception as e: - logger.exception(e) - await ws_manager.send_message(f"Error: {e}\nPlease try again.", websocket) + await websocket_receive_and_process_messages(websocket, agency_id, agency, thread_id) + except (WebSocketDisconnect, ConnectionClosedOK): + await connection_manager.disconnect(websocket) + logger.info(f"WebSocket disconnected for agency_id: {agency_id}") + + +async def websocket_receive_and_process_messages( + websocket: WebSocket, agency_id: str, agency: Agency, thread_id: str | None +) -> None: + """Receive messages from the websocket and process them.""" + while True: + try: + user_message = await websocket.receive_text() + + if not user_message.strip(): + await connection_manager.send_message("message not provided", websocket) continue - except WebSocketDisconnect: - await ws_manager.disconnect(websocket) - logger.info(f"WebSocket disconnected for agency_id: {agency_id}") - except ConnectionClosedOK: - logger.info(f"WebSocket disconnected for agency_id: {agency_id}") + await process_ws_message(user_message, agency, websocket) + + new_thread_id = await agency_manager.refresh_thread_id(agency, agency_id, thread_id) + if new_thread_id is not None: + await connection_manager.send_message(json.dumps({"thread_id": new_thread_id}), websocket) + thread_id = new_thread_id + + except (WebSocketDisconnect, ConnectionClosedOK) as e: + raise e + except Exception: + logger.exception(f"Exception while processing message: agency_id: {agency_id}, thread_id: {thread_id}") + await connection_manager.send_message("Something went wrong. Please try again.", websocket) + continue async def process_ws_message(user_message: str, agency: Agency, websocket: WebSocket): @@ -99,4 +102,4 @@ def get_next() -> MessageOutput | None: break response_text = response.get_formatted_content() - await ws_manager.send_message(response_text, websocket) + await connection_manager.send_message(response_text, websocket) diff --git a/src/nalgonda/settings.py b/src/nalgonda/settings.py new file mode 100644 index 00000000..89f3e6b0 --- /dev/null +++ b/src/nalgonda/settings.py @@ -0,0 +1,14 @@ +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict + +LATEST_GPT_MODEL = "gpt-4-1106-preview" + + +class Settings(BaseSettings): + # openai_api_key: str = Field(validation_alias="OPENAI_API_KEY") + gpt_model: str = Field(default=LATEST_GPT_MODEL, validation_alias="GPT_MODEL") + + model_config = SettingsConfigDict(env_file=".env", env_prefix="AINHAND_", case_sensitive=True) + + +settings = Settings() diff --git a/src/nalgonda/connection_manager.py b/src/nalgonda/websocket_connection_manager.py similarity index 71% rename from src/nalgonda/connection_manager.py rename to src/nalgonda/websocket_connection_manager.py index e41a0018..0c31b353 100644 --- a/src/nalgonda/connection_manager.py +++ b/src/nalgonda/websocket_connection_manager.py @@ -3,18 +3,18 @@ from starlette.websockets import WebSocket -class ConnectionManager: +class WebSocketConnectionManager: def __init__(self): self.active_connections: list[WebSocket] = [] - self.connections_lock = asyncio.Lock() + self._connections_lock = asyncio.Lock() async def connect(self, websocket: WebSocket): - await websocket.accept() - async with self.connections_lock: + async with self._connections_lock: + await websocket.accept() self.active_connections.append(websocket) async def disconnect(self, websocket: WebSocket): - async with self.connections_lock: + async with self._connections_lock: if websocket in self.active_connections: self.active_connections.remove(websocket) diff --git a/tests/custom_tools/test_utils.py b/tests/custom_tools/test_utils.py index fbb6ee56..0d3f39cb 100644 --- a/tests/custom_tools/test_utils.py +++ b/tests/custom_tools/test_utils.py @@ -1,11 +1,18 @@ +from pathlib import Path + import pytest from nalgonda.custom_tools.utils import check_directory_traversal +@pytest.mark.parametrize("path", [".", "tests", "tests/custom_tools"]) +def test_check_directory_traversal_does_not_raise_for_valid_paths(path): + check_directory_traversal(Path(path)) + + @pytest.mark.parametrize("path", ["..", "/", "/sbin"]) def test_check_directory_traversal_raises_for_attempts(path): with pytest.raises(ValueError) as e: - check_directory_traversal(path) + check_directory_traversal(Path(path)) assert e.errisinstance(ValueError) assert "Directory traversal is not allowed." in str(e.value)