diff --git a/backend/custom_skills/build_directory_tree.py b/backend/custom_skills/build_directory_tree.py index 6524b232..861a3fd9 100644 --- a/backend/custom_skills/build_directory_tree.py +++ b/backend/custom_skills/build_directory_tree.py @@ -6,7 +6,7 @@ from backend.custom_skills.utils import check_directory_traversal -MAX_LENGTH = 3000 +MAX_LENGTH = 10000 class DirectoryNode: @@ -27,18 +27,20 @@ def __init__(self, path: Path, level: int): class BuildDirectoryTree(BaseTool): - """Print the structure of directories and files. - Directory traversal is not allowed (you cannot read /* or ../*). - """ + """Print the structure of directories and files. Directory traversal is not allowed (you cannot read /* or ../*).""" start_directory: Path = Field( default_factory=Path.cwd, description="The starting directory for the tree, defaults to the current working directory.", ) - file_extensions: set[str] = Field( - default_factory=set, - description="Set of file extensions to include in the tree. If empty, all files will be included. " - "Examples are {'.py', '.txt', '.md'}.", + file_extensions: list[str] = Field( + default_factory=list, + description="List of file extensions to include in the tree. If empty, all files will be included. " + "Examples are ['.py', '.txt', '.md'].", + ) + exclude_directories: list[str] = Field( + default_factory=list, + description="List of directories to exclude from the tree. Examples are ['__pycache__', '.git'].", ) _validate_start_directory = field_validator("start_directory", mode="after")(check_directory_traversal) @@ -53,7 +55,7 @@ def build_tree(self) -> DirectoryNode: children = [p for p in current_node.path.iterdir() if not p.name.startswith(".")] for child in children: - if child.is_dir(): + if child.is_dir() and child.name not in self.exclude_directories: dir_node = DirectoryNode(child, current_node.level + 1) current_node.children.append(dir_node) queue.append(dir_node) @@ -100,6 +102,7 @@ def run(self) -> str: if __name__ == "__main__": print( BuildDirectoryTree( - file_extensions={".py", ".md"}, + file_extensions=[], + exclude_directories=["__pycache__", ".git", ".idea", "venv", ".vscode", "node_modules", "build", "dist"], ).run() ) diff --git a/backend/custom_skills/print_all_files_in_path.py b/backend/custom_skills/print_all_files_in_path.py index 9f3d2358..3782b2bd 100644 --- a/backend/custom_skills/print_all_files_in_path.py +++ b/backend/custom_skills/print_all_files_in_path.py @@ -8,7 +8,7 @@ class PrintAllFilesInPath(BaseTool): """Print the contents of all files in a start_path recursively. - The parameters are: start_path, file_extensions. + The parameters are: start_path, file_extensions, exclude_directories. Directory traversal is not allowed (you cannot read /* or ../*). """ @@ -17,10 +17,14 @@ class PrintAllFilesInPath(BaseTool): description="The starting path to search for files, defaults to the current working directory. " "Can be a filename or a directory.", ) - file_extensions: set[str] = Field( - default_factory=set, - description="Set of file extensions to include in the tree. If empty, all files will be included. " - "Examples are {'.py', '.txt', '.md'}.", + file_extensions: list[str] = Field( + default_factory=list, + description="List of file extensions to include in the tree. If empty, all files will be included. " + "Examples are ['.py', '.txt', '.md'].", + ) + exclude_directories: list[str] = Field( + default_factory=list, + description="List of directories to exclude from the search. Examples are ['__pycache__', '.git'].", ) truncate_to: int = Field( default=None, @@ -41,26 +45,41 @@ def run(self) -> str: return f"{str(start_path)}:\n```\n{read_file(start_path)}\n```\n" for path in start_path.rglob("*"): - # ignore files in hidden directories - if any(part.startswith(".") for part in path.parts): + # ignore files in hidden directories or excluded directories + if any(part.startswith(".") for part in path.parts) or any( + part in self.exclude_directories for part in path.parts + ): continue + if path.is_file() and (not self.file_extensions or path.suffix in self.file_extensions): output.append(f"{str(path)}:\n```\n{read_file(path)}\n```\n") output_str = "\n".join(output) - if self.truncate_to and len(output_str) > self.truncate_to: output_str = ( output_str[: self.truncate_to] + "\n\n... (truncated output, please use a smaller directory or apply a filter)" ) + return output_str if __name__ == "__main__": + # list of extensions: ".py", ".json", ".yaml", ".yml", ".md", ".txt", ".tsx", ".ts", ".js", ".jsx", ".html" print( PrintAllFilesInPath( start_path=".", - file_extensions={".py", ".json", ".yaml", ".yml", ".md", ".txt", ".tsx", ".ts", ".js", ".jsx", ".html"}, + file_extensions=[], + exclude_directories=[ + "frontend", + "__pycache__", + ".git", + ".idea", + "venv", + ".vscode", + "node_modules", + "build", + "dist", + ], ).run() ) diff --git a/backend/custom_skills/summarize_all_code_in_path.py b/backend/custom_skills/summarize_all_code_in_path.py index c9c4386e..ede6106f 100644 --- a/backend/custom_skills/summarize_all_code_in_path.py +++ b/backend/custom_skills/summarize_all_code_in_path.py @@ -1,3 +1,4 @@ +import os from pathlib import Path from agency_swarm import BaseTool @@ -10,8 +11,7 @@ SYSTEM_MESSAGE = """\ Your main job is to handle programming code from SEVERAL FILES. \ Each file's content is shown within triple backticks and has a FILE PATH as a title. \ -It's vital to KEEP the FILE PATHS. -Here's what to do: +It's vital to KEEP the FILE PATHS. Here's what to do: 1. ALWAYS KEEP the FILE PATHS for each file. 2. Start each file with a short SUMMARY of its content. Mention important points but don't repeat details found later. 3. KEEP important elements like non-trivial imports, function details, type hints, and key constants. \ @@ -29,7 +29,7 @@ class SummarizeAllCodeInPath(BaseTool): """Summarize code using GPT-3. The skill uses the `PrintAllFilesInPath` skill to get the code to summarize. - The parameters are: start_path, file_extensions. + The parameters are: start_path, file_extensions, exclude_directories. Directory traversal is not allowed (you cannot read /* or ../*). """ @@ -38,23 +38,27 @@ class SummarizeAllCodeInPath(BaseTool): description="The starting path to search for files, defaults to the current working directory. " "Can be a filename or a directory.", ) - file_extensions: set[str] = Field( - default_factory=set, - description="Set of file extensions to include in the tree. If empty, all files will be included. " - "Examples are {'.py', '.txt', '.md'}.", + file_extensions: list[str] = Field( + default_factory=list, + description="List of file extensions to include in the tree. If empty, all files will be included. " + "Examples are ['.py', '.txt', '.md'].", + ) + exclude_directories: list[str] | None = Field( + default_factory=list, + description="List of directories to exclude from the search. Examples are ['__pycache__', '.git'].", ) truncate_to: int = Field( default=None, description="Truncate the output to this many characters. If None or skipped, the output is not truncated.", ) - def run(self) -> str: + def run(self, api_key: str | None = None) -> str: """Run the skill and return the output.""" delimiter = "\n\n```\n" - full_code = PrintAllFilesInPath( start_path=self.start_path, file_extensions=self.file_extensions, + exclude_directories=self.exclude_directories, ).run() # Chunk the input based on token limit @@ -63,7 +67,11 @@ def run(self) -> str: outputs = [] for chunk in chunks: output = get_chat_completion( - system_message=SYSTEM_MESSAGE, user_prompt=chunk, temperature=0.0, model=settings.gpt_cheap_model + system_message=SYSTEM_MESSAGE, + user_prompt=chunk, + temperature=0.0, + model=settings.gpt_small_model, + api_key=api_key, ) outputs.append(output) @@ -82,6 +90,7 @@ def run(self) -> str: print( SummarizeAllCodeInPath( start_path=".", - file_extensions={".py"}, - ).run() + file_extensions=[".py"], + exclude_directories=["__pycache__", ".git", ".idea", "venv", ".vscode", "node_modules", "build", "dist"], + ).run(api_key=os.getenv("API_KEY")) ) diff --git a/backend/custom_skills/summarize_code.py b/backend/custom_skills/summarize_code.py index 69963a67..ba0c3040 100644 --- a/backend/custom_skills/summarize_code.py +++ b/backend/custom_skills/summarize_code.py @@ -1,3 +1,4 @@ +import os from pathlib import Path from agency_swarm import BaseTool @@ -32,10 +33,14 @@ class SummarizeCode(BaseTool): ..., description="The name of the file to be summarized. It can be a relative or absolute path." ) - def run(self) -> str: + def run(self, api_key: str | None = None) -> str: code = PrintFileContents(file_name=self.file_name).run() output = get_chat_completion( - system_message=SYSTEM_MESSAGE, user_prompt=code, temperature=0.0, model=settings.gpt_cheap_model + system_message=SYSTEM_MESSAGE, + user_prompt=code, + temperature=0.0, + model=settings.gpt_small_model, + api_key=api_key, ) return output @@ -44,5 +49,5 @@ def run(self) -> str: print( SummarizeCode( file_name="example.py", - ).run() + ).run(api_key=os.getenv("API_KEY")) ) diff --git a/backend/dependencies/dependencies.py b/backend/dependencies/dependencies.py index e3e61e20..8acb78b0 100644 --- a/backend/dependencies/dependencies.py +++ b/backend/dependencies/dependencies.py @@ -8,6 +8,7 @@ from backend.services.agency_manager import AgencyManager from backend.services.agent_manager import AgentManager from backend.services.caching.redis_cache_manager import RedisCacheManager +from backend.services.env_config_manager import EnvConfigManager from backend.services.session_manager import SessionManager from backend.settings import settings @@ -22,25 +23,31 @@ def get_redis_cache_manager(redis: aioredis.Redis = Depends(get_redis)) -> Redis return RedisCacheManager(redis) +def get_env_config_manager( + env_config_storage: EnvConfigFirestoreStorage = Depends(EnvConfigFirestoreStorage), +) -> EnvConfigManager: + return EnvConfigManager(env_config_storage) + + def get_agent_manager( storage: AgentConfigFirestoreStorage = Depends(AgentConfigFirestoreStorage), - env_config_storage: EnvConfigFirestoreStorage = Depends(EnvConfigFirestoreStorage), + env_config_manager: EnvConfigManager = Depends(get_env_config_manager), ) -> AgentManager: - return AgentManager(storage, env_config_storage) + return AgentManager(storage, env_config_manager) def get_agency_manager( cache_manager: RedisCacheManager = Depends(get_redis_cache_manager), agent_manager: AgentManager = Depends(get_agent_manager), agency_config_storage: AgencyConfigFirestoreStorage = Depends(AgencyConfigFirestoreStorage), - env_config_storage: EnvConfigFirestoreStorage = Depends(EnvConfigFirestoreStorage), + env_config_manager: EnvConfigManager = Depends(get_env_config_manager), ) -> AgencyManager: - return AgencyManager(cache_manager, agent_manager, agency_config_storage, env_config_storage) + return AgencyManager(cache_manager, agent_manager, agency_config_storage, env_config_manager) def get_session_manager( session_storage: SessionConfigFirestoreStorage = Depends(SessionConfigFirestoreStorage), - env_config_storage: EnvConfigFirestoreStorage = Depends(EnvConfigFirestoreStorage), + env_config_manager: EnvConfigManager = Depends(get_env_config_manager), ) -> SessionManager: """Returns a SessionManager object""" - return SessionManager(session_storage=session_storage, env_config_storage=env_config_storage) + return SessionManager(session_storage=session_storage, env_config_manager=env_config_manager) diff --git a/backend/main.py b/backend/main.py index 86c8ab47..351aac4c 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,10 +1,6 @@ -import json - -import firebase_admin import openai from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from firebase_admin import credentials from pydantic import ValidationError from starlette.staticfiles import StaticFiles @@ -16,7 +12,7 @@ from backend.exception_handlers import bad_request_exception_handler, unhandled_exception_handler # noqa # isort:skip from backend.routers.v1 import v1_router # noqa # isort:skip from backend.settings import settings # noqa # isort:skip -from backend.utils import init_webserver_folders # noqa # isort:skip +from backend.utils import init_webserver_folders, init_firebase_app # noqa # isort:skip # just a placeholder for compatibility with agency-swarm openai.api_key = "sk-1234567890" @@ -51,10 +47,7 @@ app.mount("/", StaticFiles(directory=folders["static_folder_root"], html=True), name="ui") # Initialize FireStore -if settings.google_credentials: - cred_json = json.loads(settings.google_credentials) - cred = credentials.Certificate(cred_json) - firebase_admin.initialize_app(cred) +init_firebase_app() if __name__ == "__main__": diff --git a/backend/routers/v1/api/message.py b/backend/routers/v1/api/message.py index cb781514..eab9c5fa 100644 --- a/backend/routers/v1/api/message.py +++ b/backend/routers/v1/api/message.py @@ -6,13 +6,13 @@ from fastapi import APIRouter, Depends, HTTPException from backend.dependencies.auth import get_current_user -from backend.dependencies.dependencies import get_agency_manager +from backend.dependencies.dependencies import get_agency_manager, get_env_config_manager from backend.models.auth import User from backend.models.request_models import SessionMessagePostRequest from backend.repositories.agency_config_firestore_storage import AgencyConfigFirestoreStorage -from backend.repositories.env_config_firestore_storage import EnvConfigFirestoreStorage from backend.repositories.session_firestore_storage import SessionConfigFirestoreStorage from backend.services.agency_manager import AgencyManager +from backend.services.env_config_manager import EnvConfigManager from backend.services.env_vars_manager import ContextEnvVarsManager from backend.services.oai_client import get_openai_client @@ -29,7 +29,7 @@ async def get_message_list( session_id: str, before: str | None = None, session_storage: SessionConfigFirestoreStorage = Depends(SessionConfigFirestoreStorage), - env_config_storage: EnvConfigFirestoreStorage = Depends(EnvConfigFirestoreStorage), + env_config_manager: EnvConfigManager = Depends(get_env_config_manager), ): """Return a list of last 20 messages for the given session.""" # check if the current_user has permissions to send a message to the agency @@ -45,7 +45,7 @@ async def get_message_list( ContextEnvVarsManager.set("owner_id", current_user.id) # use OpenAI's Assistants API to get the messages by thread_id=session_id - client = get_openai_client(env_config_storage) + client = get_openai_client(env_config_manager) messages = client.beta.threads.messages.list(thread_id=session_id, limit=20, before=before) return messages diff --git a/backend/services/agency_manager.py b/backend/services/agency_manager.py index 36d440fb..712f7edf 100644 --- a/backend/services/agency_manager.py +++ b/backend/services/agency_manager.py @@ -6,9 +6,9 @@ from backend.models.agency_config import AgencyConfig from backend.repositories.agency_config_firestore_storage import AgencyConfigFirestoreStorage -from backend.repositories.env_config_firestore_storage import EnvConfigFirestoreStorage from backend.services.agent_manager import AgentManager from backend.services.caching.redis_cache_manager import RedisCacheManager +from backend.services.env_config_manager import EnvConfigManager from backend.services.oai_client import get_openai_client logger = logging.getLogger(__name__) @@ -20,12 +20,12 @@ def __init__( cache_manager: RedisCacheManager, agent_manager: AgentManager, agency_config_storage: AgencyConfigFirestoreStorage, - env_config_storage: EnvConfigFirestoreStorage, + env_config_manager: EnvConfigManager, ) -> None: self.agency_config_storage = agency_config_storage self.agent_manager = agent_manager self.cache_manager = cache_manager - self.env_config_storage = env_config_storage + self.env_config_manager = env_config_manager async def get_agency(self, agency_id: str, session_id: str | None = None) -> Agency | None: cache_key = self.get_cache_key(agency_id, session_id) @@ -136,7 +136,7 @@ def _remove_client_objects(agency: Agency) -> Agency: def _set_client_objects(self, agency: Agency) -> Agency: """Restore all client objects within the agency object""" - client = get_openai_client(env_config_storage=self.env_config_storage) + client = get_openai_client(env_config_manager=self.env_config_manager) # Restore client for each agent in the agency for agent in agency.agents: agent.client = client diff --git a/backend/services/agent_manager.py b/backend/services/agent_manager.py index 6930ceb1..f751f7fc 100644 --- a/backend/services/agent_manager.py +++ b/backend/services/agent_manager.py @@ -6,7 +6,7 @@ from backend.custom_skills import SKILL_MAPPING from backend.models.agent_config import AgentConfig from backend.repositories.agent_config_firestore_storage import AgentConfigFirestoreStorage -from backend.repositories.env_config_firestore_storage import EnvConfigFirestoreStorage +from backend.services.env_config_manager import EnvConfigManager from backend.services.oai_client import get_openai_client from backend.settings import settings @@ -14,8 +14,8 @@ class AgentManager: - def __init__(self, storage: AgentConfigFirestoreStorage, env_config_storage: EnvConfigFirestoreStorage) -> None: - self.env_config_storage = env_config_storage + def __init__(self, storage: AgentConfigFirestoreStorage, env_config_manager: EnvConfigManager) -> None: + self.env_config_manager = env_config_manager self.storage = storage async def create_or_update_agent(self, config: AgentConfig) -> str: @@ -60,5 +60,5 @@ def _construct_agent(self, agent_config: AgentConfig) -> Agent: model=settings.gpt_model, ) # a workaround: agent.client must be replaced with a proper implementation - agent.client = get_openai_client(env_config_storage=self.env_config_storage) + agent.client = get_openai_client(env_config_manager=self.env_config_manager) return agent diff --git a/backend/services/oai_client.py b/backend/services/oai_client.py index f7965b95..fe64f0ba 100644 --- a/backend/services/oai_client.py +++ b/backend/services/oai_client.py @@ -1,10 +1,12 @@ import instructor import openai -from backend.repositories.env_config_firestore_storage import EnvConfigFirestoreStorage from backend.services.env_config_manager import EnvConfigManager -def get_openai_client(env_config_storage: EnvConfigFirestoreStorage): - api_key = EnvConfigManager(env_config_storage).get_by_key("OPENAI_API_KEY") +def get_openai_client(env_config_manager: EnvConfigManager | None = None, api_key: str | None = None) -> openai.OpenAI: + if not api_key: + if not env_config_manager: + raise ValueError("Either env_config_manager or api_key must be provided") + api_key = env_config_manager.get_by_key("OPENAI_API_KEY") return instructor.patch(openai.OpenAI(api_key=api_key, max_retries=5)) diff --git a/backend/services/session_manager.py b/backend/services/session_manager.py index 1785c3d3..878418f0 100644 --- a/backend/services/session_manager.py +++ b/backend/services/session_manager.py @@ -4,14 +4,14 @@ from agency_swarm.threads import Thread from backend.models.session_config import SessionConfig -from backend.repositories.env_config_firestore_storage import EnvConfigFirestoreStorage from backend.repositories.session_firestore_storage import SessionConfigFirestoreStorage +from backend.services.env_config_manager import EnvConfigManager from backend.services.oai_client import get_openai_client class SessionManager: - def __init__(self, session_storage: SessionConfigFirestoreStorage, env_config_storage: EnvConfigFirestoreStorage): - self.env_config_storage = env_config_storage + def __init__(self, session_storage: SessionConfigFirestoreStorage, env_config_manager: EnvConfigManager): + self.env_config_manager = env_config_manager self.session_storage = session_storage def create_session(self, agency: Agency, agency_id: str, owner_id: str) -> str: @@ -28,7 +28,7 @@ def create_session(self, agency: Agency, agency_id: str, owner_id: str) -> str: def _create_threads(self, agency: Agency) -> str: """Create new threads for the given agency and return the thread ID of the main thread.""" - client = get_openai_client(self.env_config_storage) + client = get_openai_client(self.env_config_manager) self._init_threads(agency, client) return agency.main_thread.id diff --git a/backend/services/skill_service.py b/backend/services/skill_service.py index e802a98a..6be7c21a 100644 --- a/backend/services/skill_service.py +++ b/backend/services/skill_service.py @@ -21,7 +21,7 @@ def generate_skill_description(code: str): system_message=SKILL_SUMMARY_SYSTEM_MESSAGE, user_prompt=f"{USER_PROMPT}```\n{code}\n```", temperature=0.0, - model=settings.gpt_cheap_model, + model=settings.gpt_small_model, ) return summary diff --git a/backend/settings.py b/backend/settings.py index 75db83ff..8bcc18c2 100644 --- a/backend/settings.py +++ b/backend/settings.py @@ -1,8 +1,8 @@ from pydantic import Field, RedisDsn from pydantic_settings import BaseSettings, SettingsConfigDict -LATEST_GPT_MODEL = "gpt-4-turbo-preview" -CHEAP_GPT_MODEL = "gpt-3.5-turbo-1106" +LARGE_GPT_MODEL = "gpt-4-turbo-preview" +SMALL_GPT_MODEL = "gpt-3.5-turbo" class Settings(BaseSettings): @@ -10,8 +10,8 @@ class Settings(BaseSettings): access_token_expire_minutes: int = Field(default=30) google_credentials: str | None = Field(default=None) - gpt_model: str = Field(default=LATEST_GPT_MODEL) - gpt_cheap_model: str = Field(default=CHEAP_GPT_MODEL) + gpt_model: str = Field(default=LARGE_GPT_MODEL) + gpt_small_model: str = Field(default=SMALL_GPT_MODEL) redis_tls_url: RedisDsn | None = Field(default=None) redis_url: RedisDsn = Field(default="redis://localhost:6379/1") secret_key: str = Field(default="") diff --git a/backend/utils/__init__.py b/backend/utils/__init__.py index 0ded762f..ea7c1fc2 100644 --- a/backend/utils/__init__.py +++ b/backend/utils/__init__.py @@ -1,14 +1,27 @@ +import json import logging from pathlib import Path +import firebase_admin import tiktoken +from firebase_admin import credentials from backend.repositories.env_config_firestore_storage import EnvConfigFirestoreStorage +from backend.services.env_config_manager import EnvConfigManager from backend.services.oai_client import get_openai_client +from backend.settings import settings logger = logging.getLogger(__name__) +def init_firebase_app(): + """Initialize Firebase app.""" + if settings.google_credentials: + cred_json = json.loads(settings.google_credentials) + cred = credentials.Certificate(cred_json) + firebase_admin.initialize_app(cred) + + def init_webserver_folders(root_file_path: Path) -> dict[str, Path]: """ Initialize folders needed for a web server, such as static file directories @@ -25,10 +38,13 @@ def init_webserver_folders(root_file_path: Path) -> dict[str, Path]: return folders -def get_chat_completion(system_message: str, user_prompt: str, model: str, **kwargs) -> str: +def get_chat_completion(system_message: str, user_prompt: str, model: str, api_key: str | None = None, **kwargs) -> str: """Generate a chat completion based on a prompt and a system message. This function is a wrapper around the OpenAI API.""" - client = get_openai_client(env_config_storage=EnvConfigFirestoreStorage()) + if api_key: + client = get_openai_client(api_key=api_key) + else: + client = get_openai_client(env_config_manager=EnvConfigManager(EnvConfigFirestoreStorage())) completion = client.chat.completions.create( model=model, messages=[ diff --git a/frontend/src/components/utils.ts b/frontend/src/components/utils.ts index 1cb06a2f..248a86f6 100644 --- a/frontend/src/components/utils.ts +++ b/frontend/src/components/utils.ts @@ -249,7 +249,7 @@ export const formatDuration = (seconds: number) => { export const sampleWorkflowConfig = (type = "twoagents") => { const llm_model_config: IModelConfig[] = [ { - model: "gpt-4-1106-preview", + model: "gpt-4-turbo-preview", }, ]; @@ -337,10 +337,10 @@ export const sampleWorkflowConfig = (type = "twoagents") => { export const getModels = () => { const models = [ { - model: "gpt-4-1106-preview", + model: "gpt-4-turbo-preview", }, { - model: "gpt-3.5-turbo-16k", + model: "gpt-3.5-turbo", }, { model: "TheBloke/zephyr-7B-alpha-AWQ", diff --git a/frontend/src/components/views/builder/agents.tsx b/frontend/src/components/views/builder/agents.tsx index 7377856d..7d10c392 100644 --- a/frontend/src/components/views/builder/agents.tsx +++ b/frontend/src/components/views/builder/agents.tsx @@ -47,7 +47,7 @@ const AgentsView = ({}: any) => { llm_config: { config_list: [ { - model: "gpt-4-1106-preview", + model: "gpt-4-turbo-preview", }, ], temperature: 0.1, diff --git a/frontend/src/components/views/builder/models.tsx b/frontend/src/components/views/builder/models.tsx index d3e2f5b5..b5f0ebda 100644 --- a/frontend/src/components/views/builder/models.tsx +++ b/frontend/src/components/views/builder/models.tsx @@ -31,7 +31,7 @@ const ModelsView = ({}: any) => { }); const defaultModel: IModelConfig = { - model: "gpt-4-1106-preview", + model: "gpt-4-turbo-preview", description: "Sample OpenAI GPT-4 model", user_id: user?.email, }; diff --git a/tests/unit/custom_skills/test_build_directory_tree.py b/tests/unit/custom_skills/test_build_directory_tree.py index dbf17b67..1e359c41 100644 --- a/tests/unit/custom_skills/test_build_directory_tree.py +++ b/tests/unit/custom_skills/test_build_directory_tree.py @@ -8,16 +8,8 @@ def test_build_directory_tree_with_py_extension(temp_dir): Test if BuildDirectoryTree correctly lists only .py files in the directory tree. Sorted output without indentation is expected. """ - bdt = BuildDirectoryTree(start_directory=temp_dir, file_extensions={".py"}) - expected_output = "\n".join( - sorted( - [ - "/sub/test.py", - "/sub", - "", - ] - ) - ) + bdt = BuildDirectoryTree(start_directory=temp_dir, file_extensions=[".py"], exclude_directories=["__pycache__"]) + expected_output = "\n".join(sorted(["/sub/test.py", "/sub", ""])) assert bdt.run() == expected_output @@ -26,17 +18,10 @@ def test_build_directory_tree_with_multiple_extensions(temp_dir): Test if BuildDirectoryTree lists files with multiple specified extensions. Sorted output without indentation is expected. """ - bdt = BuildDirectoryTree(start_directory=temp_dir, file_extensions={".py", ".txt"}) - expected_output = "\n".join( - sorted( - [ - "/sub/test.py", - "/sub/test.txt", - "/sub", - "", - ] - ) + bdt = BuildDirectoryTree( + start_directory=temp_dir, file_extensions=[".py", ".txt"], exclude_directories=["__pycache__"] ) + expected_output = "\n".join(sorted(["/sub/test.py", "/sub/test.txt", "/sub", ""])) actual_output = bdt.run() assert actual_output == expected_output @@ -47,7 +32,8 @@ def test_build_directory_tree_default_settings(): """ bdt = BuildDirectoryTree() assert bdt.start_directory == Path.cwd() - assert bdt.file_extensions == set() + assert bdt.file_extensions == [] + assert bdt.exclude_directories == [] def test_build_directory_tree_output_length_limit(temp_dir): @@ -57,7 +43,20 @@ def test_build_directory_tree_output_length_limit(temp_dir): # Create a large number of files to exceed the limit for i in range(180): (temp_dir / f"file_{i}.txt").write_text("Dummy content") - - bdt = BuildDirectoryTree(start_directory=temp_dir) + bdt = BuildDirectoryTree(start_directory=temp_dir, exclude_directories=["__pycache__"]) output = bdt.run() assert len(output) <= 3000 # Adjusted to match the MAX_LENGTH constant + + +def test_build_directory_tree_exclude_directories(temp_dir): + """ + Test if BuildDirectoryTree correctly excludes specified directories. + """ + # Create a directory to be excluded + excluded_dir = temp_dir / "excluded_dir" + excluded_dir.mkdir() + (excluded_dir / "excluded_file.txt").write_text("Excluded content") + + bdt = BuildDirectoryTree(start_directory=temp_dir, exclude_directories=["excluded_dir"]) + output = bdt.run() + assert "excluded_dir" not in output diff --git a/tests/unit/custom_skills/test_print_all_files_in_path.py b/tests/unit/custom_skills/test_print_all_files_in_path.py index dfafb6a4..1a5d5e28 100644 --- a/tests/unit/custom_skills/test_print_all_files_in_path.py +++ b/tests/unit/custom_skills/test_print_all_files_in_path.py @@ -7,7 +7,7 @@ def test_print_all_files_no_extension_filter(temp_dir): """ Test if PrintAllFilesInPath correctly prints contents of all files when no file extension filter is applied. """ - pafid = PrintAllFilesInPath(start_path=temp_dir) + pafid = PrintAllFilesInPath(start_path=temp_dir, exclude_directories=["__pycache__"]) expected_output = { f"{temp_dir}/sub/test.py:\n```\nprint('hello')\n```", f"{temp_dir}/sub/test.txt:\n```\nhello world\n```", @@ -20,7 +20,7 @@ def test_print_all_files_with_py_extension(temp_dir): """ Test if PrintAllFilesInPath correctly prints contents of .py files only. """ - pafid = PrintAllFilesInPath(start_path=temp_dir, file_extensions={".py"}) + pafid = PrintAllFilesInPath(start_path=temp_dir, file_extensions=[".py"], exclude_directories=["__pycache__"]) expected_output = f"{temp_dir.joinpath('sub', 'test.py')}:\n```\nprint('hello')\n```\n" assert pafid.run() == expected_output @@ -29,7 +29,7 @@ def test_print_all_files_with_txt_extension(temp_dir): """ Test if PrintAllFilesInPath correctly prints contents of .txt files only. """ - pafid = PrintAllFilesInPath(start_path=temp_dir, file_extensions={".txt"}) + pafid = PrintAllFilesInPath(start_path=temp_dir, file_extensions=[".txt"], exclude_directories=["__pycache__"]) expected_output = f"{temp_dir.joinpath('sub', 'test.txt')}:\n```\nhello world\n```\n" assert pafid.run() == expected_output @@ -42,16 +42,14 @@ def test_print_all_files_error_reading_file(temp_dir): unreadable_file = temp_dir.joinpath("unreadable_file.txt") unreadable_file.write_text("content") unreadable_file.chmod(0o000) # make the file unreadable - - pafid = PrintAllFilesInPath(start_path=temp_dir, file_extensions={".txt"}) + pafid = PrintAllFilesInPath(start_path=temp_dir, file_extensions=[".txt"], exclude_directories=["__pycache__"]) assert "Error reading file" in pafid.run() - unreadable_file.chmod(0o644) # reset file permissions for cleanup @pytest.mark.parametrize("extension, expected_file", [(".py", "test.py"), (".txt", "test.txt")]) def test_print_all_files_with_extension_filter(temp_dir, extension, expected_file): - pafip = PrintAllFilesInPath(start_path=temp_dir, file_extensions={extension}) + pafip = PrintAllFilesInPath(start_path=temp_dir, file_extensions=[extension], exclude_directories=["__pycache__"]) expected_output = ( f"{temp_dir.joinpath('sub', expected_file)}:\n```\n" + temp_dir.joinpath("sub", expected_file).read_text() @@ -69,7 +67,23 @@ def create_file_in_path(tmp_path): def test_print_file_contents(create_file_in_path): - skill = PrintAllFilesInPath(start_path=create_file_in_path, file_extensions=[".txt"]) + skill = PrintAllFilesInPath( + start_path=create_file_in_path, file_extensions=[".txt"], exclude_directories=["__pycache__"] + ) result = skill.run() expected_result = f"{str(create_file_in_path)}:\n```\nFile content\n```\n" assert result == expected_result + + +def test_print_all_files_exclude_directories(temp_dir): + """ + Test if PrintAllFilesInPath correctly excludes specified directories. + """ + # Create a directory to be excluded + excluded_dir = temp_dir / "excluded_dir" + excluded_dir.mkdir() + (excluded_dir / "excluded_file.txt").write_text("Excluded content") + + pafid = PrintAllFilesInPath(start_path=temp_dir, exclude_directories=["excluded_dir"]) + output = pafid.run() + assert "excluded_dir" not in output diff --git a/tests/unit/custom_skills/test_summarize_all_code_in_path.py b/tests/unit/custom_skills/test_summarize_all_code_in_path.py index 6cc7a5ed..b6356720 100644 --- a/tests/unit/custom_skills/test_summarize_all_code_in_path.py +++ b/tests/unit/custom_skills/test_summarize_all_code_in_path.py @@ -22,8 +22,7 @@ def test_summarize_all_code_in_path_with_valid_codebase(mock_openai_client, mock # Create a simple Python file (tmp_path / "test.py").write_text('print("Hello, World!")') mock_openai_client.return_value.chat.completions.create.return_value = mock_openai_response - - summarize_skill = SummarizeAllCodeInPath(start_path=Path(tmp_path)) + summarize_skill = SummarizeAllCodeInPath(start_path=Path(tmp_path), exclude_directories=["__pycache__"]) results = summarize_skill.run() assert "Summary of the code" in results mock_openai_client.assert_called_once() @@ -33,8 +32,24 @@ def test_summarize_all_code_in_path_with_valid_codebase(mock_openai_client, mock def test_summarize_all_code_in_path_with_api_failure(mock_openai_client, tmp_path): # Create a simple Python file (tmp_path / "test.py").write_text('print("Hello, World!")') - summarize_skill = SummarizeAllCodeInPath(start_path=Path(tmp_path)) + summarize_skill = SummarizeAllCodeInPath(start_path=Path(tmp_path), exclude_directories=["__pycache__"]) with pytest.raises(Exception) as exc_info: summarize_skill.run() assert "API failed" in str(exc_info.value) mock_openai_client.assert_called_once() + + +def test_summarize_all_code_in_path_exclude_directories(tmp_path, mock_openai_response): + """ + Test if SummarizeAllCodeInPath correctly excludes specified directories. + """ + # Create a directory to be excluded + excluded_dir = tmp_path / "excluded_dir" + excluded_dir.mkdir() + (excluded_dir / "excluded_file.py").write_text('print("Excluded code")') + + summarize_skill = SummarizeAllCodeInPath(start_path=Path(tmp_path), exclude_directories=["excluded_dir"]) + with patch("backend.utils.get_openai_client") as mock_openai_client: + mock_openai_client.return_value.chat.completions.create.return_value = mock_openai_response + output = summarize_skill.run() + assert "excluded_dir" not in output diff --git a/tests/unit/services/test_agency_manager.py b/tests/unit/services/test_agency_manager.py index 864f8044..7c0aefab 100644 --- a/tests/unit/services/test_agency_manager.py +++ b/tests/unit/services/test_agency_manager.py @@ -3,6 +3,7 @@ import pytest from agency_swarm import Agency, Agent +from backend.dependencies.dependencies import get_env_config_manager from backend.models.agency_config import AgencyConfig from backend.repositories.agency_config_firestore_storage import AgencyConfigFirestoreStorage from backend.repositories.env_config_firestore_storage import EnvConfigFirestoreStorage @@ -18,7 +19,7 @@ def agency_manager(): cache_manager=MagicMock(), agent_manager=MagicMock(), agency_config_storage=AgencyConfigFirestoreStorage(), - env_config_storage=EnvConfigFirestoreStorage(), + env_config_manager=get_env_config_manager(env_config_storage=EnvConfigFirestoreStorage()), ) diff --git a/tests/unit/services/test_oai_client.py b/tests/unit/services/test_oai_client.py index 3caf0a2d..4b7da0ab 100644 --- a/tests/unit/services/test_oai_client.py +++ b/tests/unit/services/test_oai_client.py @@ -3,17 +3,6 @@ import pytest -@pytest.fixture -def mock_env_config_firestore_storage(): - return MagicMock(name="EnvConfigFirestoreStorage") - - -@pytest.fixture -def mock_env_config_manager(recover_oai_client): # noqa: ARG001 - with patch("backend.services.oai_client.EnvConfigManager") as mock: - yield mock - - @pytest.fixture def mock_openai_client(recover_oai_client): # noqa: ARG001 with patch("backend.services.oai_client.openai.OpenAI") as mock: @@ -27,21 +16,19 @@ def mock_instructor_patch(recover_oai_client): # noqa: ARG001 yield mock -def test_get_openai_client_uses_correct_api_key( - mock_env_config_manager, mock_openai_client, mock_env_config_firestore_storage, mock_instructor_patch -): +def test_get_openai_client_uses_correct_api_key(mock_openai_client, mock_instructor_patch): # Setup from backend.services.oai_client import get_openai_client expected_api_key = "test_api_key" - mock_env_config_manager.return_value.get_by_key.return_value = expected_api_key + mock_env_config_manager = MagicMock() + mock_env_config_manager.get_by_key.return_value = expected_api_key # Execute - client = get_openai_client(mock_env_config_firestore_storage) + client = get_openai_client(mock_env_config_manager) # Verify - mock_env_config_manager.assert_called_with(mock_env_config_firestore_storage) - mock_env_config_manager.return_value.get_by_key.assert_called_with("OPENAI_API_KEY") + mock_env_config_manager.get_by_key.assert_called_with("OPENAI_API_KEY") mock_openai_client.assert_called_with(api_key=expected_api_key, max_retries=5) mock_instructor_patch.assert_called_once() assert client == mock_instructor_patch.return_value, "The function should return a patched OpenAI client" diff --git a/tests/unit/services/test_session_manager.py b/tests/unit/services/test_session_manager.py index 2a00a761..5d25d126 100644 --- a/tests/unit/services/test_session_manager.py +++ b/tests/unit/services/test_session_manager.py @@ -15,19 +15,14 @@ def agency_mock(): return agency -@pytest.fixture -def env_config_storage_mock(): - return MagicMock() - - @pytest.fixture def session_storage_mock(): return MagicMock() @pytest.fixture -def session_manager(env_config_storage_mock, session_storage_mock): - return SessionManager(env_config_storage=env_config_storage_mock, session_storage=session_storage_mock) +def session_manager(session_storage_mock): + return SessionManager(env_config_manager=MagicMock(), session_storage=session_storage_mock) @pytest.fixture diff --git a/tests/unit/services/test_skill_service.py b/tests/unit/services/test_skill_service.py index 780bd908..652f2be0 100644 --- a/tests/unit/services/test_skill_service.py +++ b/tests/unit/services/test_skill_service.py @@ -20,7 +20,7 @@ def test_generate_skill_description(): + code + "\n```", temperature=0.0, - model="gpt-3.5-turbo-1106", + model="gpt-3.5-turbo", ) assert result == "Summary of the skill", "The function did not return the expected summary"