Skip to content

Commit

Permalink
Add exclude_directories parameter in tools; fix file_extensions param…
Browse files Browse the repository at this point in the history
…eter (#31)

* Fix file_extensions parameter in tools; add exclude parameter
* Refactor
* Update the GPT models
  • Loading branch information
bonk1t authored Mar 26, 2024
1 parent 2083d95 commit eeb53c3
Show file tree
Hide file tree
Showing 24 changed files with 206 additions and 141 deletions.
23 changes: 13 additions & 10 deletions backend/custom_skills/build_directory_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from backend.custom_skills.utils import check_directory_traversal

MAX_LENGTH = 3000
MAX_LENGTH = 10000


class DirectoryNode:
Expand All @@ -27,18 +27,20 @@ def __init__(self, path: Path, level: int):


class BuildDirectoryTree(BaseTool):
"""Print the structure of directories and files.
Directory traversal is not allowed (you cannot read /* or ../*).
"""
"""Print the structure of directories and files. Directory traversal is not allowed (you cannot read /* or ../*)."""

start_directory: Path = Field(
default_factory=Path.cwd,
description="The starting directory for the tree, defaults to the current working directory.",
)
file_extensions: set[str] = Field(
default_factory=set,
description="Set of file extensions to include in the tree. If empty, all files will be included. "
"Examples are {'.py', '.txt', '.md'}.",
file_extensions: list[str] = Field(
default_factory=list,
description="List of file extensions to include in the tree. If empty, all files will be included. "
"Examples are ['.py', '.txt', '.md'].",
)
exclude_directories: list[str] = Field(
default_factory=list,
description="List of directories to exclude from the tree. Examples are ['__pycache__', '.git'].",
)

_validate_start_directory = field_validator("start_directory", mode="after")(check_directory_traversal)
Expand All @@ -53,7 +55,7 @@ def build_tree(self) -> DirectoryNode:
children = [p for p in current_node.path.iterdir() if not p.name.startswith(".")]

for child in children:
if child.is_dir():
if child.is_dir() and child.name not in self.exclude_directories:
dir_node = DirectoryNode(child, current_node.level + 1)
current_node.children.append(dir_node)
queue.append(dir_node)
Expand Down Expand Up @@ -100,6 +102,7 @@ def run(self) -> str:
if __name__ == "__main__":
print(
BuildDirectoryTree(
file_extensions={".py", ".md"},
file_extensions=[],
exclude_directories=["__pycache__", ".git", ".idea", "venv", ".vscode", "node_modules", "build", "dist"],
).run()
)
37 changes: 28 additions & 9 deletions backend/custom_skills/print_all_files_in_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class PrintAllFilesInPath(BaseTool):
"""Print the contents of all files in a start_path recursively.
The parameters are: start_path, file_extensions.
The parameters are: start_path, file_extensions, exclude_directories.
Directory traversal is not allowed (you cannot read /* or ../*).
"""

Expand All @@ -17,10 +17,14 @@ class PrintAllFilesInPath(BaseTool):
description="The starting path to search for files, defaults to the current working directory. "
"Can be a filename or a directory.",
)
file_extensions: set[str] = Field(
default_factory=set,
description="Set of file extensions to include in the tree. If empty, all files will be included. "
"Examples are {'.py', '.txt', '.md'}.",
file_extensions: list[str] = Field(
default_factory=list,
description="List of file extensions to include in the tree. If empty, all files will be included. "
"Examples are ['.py', '.txt', '.md'].",
)
exclude_directories: list[str] = Field(
default_factory=list,
description="List of directories to exclude from the search. Examples are ['__pycache__', '.git'].",
)
truncate_to: int = Field(
default=None,
Expand All @@ -41,26 +45,41 @@ def run(self) -> str:
return f"{str(start_path)}:\n```\n{read_file(start_path)}\n```\n"

for path in start_path.rglob("*"):
# ignore files in hidden directories
if any(part.startswith(".") for part in path.parts):
# ignore files in hidden directories or excluded directories
if any(part.startswith(".") for part in path.parts) or any(
part in self.exclude_directories for part in path.parts
):
continue

if path.is_file() and (not self.file_extensions or path.suffix in self.file_extensions):
output.append(f"{str(path)}:\n```\n{read_file(path)}\n```\n")

output_str = "\n".join(output)

if self.truncate_to and len(output_str) > self.truncate_to:
output_str = (
output_str[: self.truncate_to]
+ "\n\n... (truncated output, please use a smaller directory or apply a filter)"
)

return output_str


if __name__ == "__main__":
# list of extensions: ".py", ".json", ".yaml", ".yml", ".md", ".txt", ".tsx", ".ts", ".js", ".jsx", ".html"
print(
PrintAllFilesInPath(
start_path=".",
file_extensions={".py", ".json", ".yaml", ".yml", ".md", ".txt", ".tsx", ".ts", ".js", ".jsx", ".html"},
file_extensions=[],
exclude_directories=[
"frontend",
"__pycache__",
".git",
".idea",
"venv",
".vscode",
"node_modules",
"build",
"dist",
],
).run()
)
33 changes: 21 additions & 12 deletions backend/custom_skills/summarize_all_code_in_path.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from pathlib import Path

from agency_swarm import BaseTool
Expand All @@ -10,8 +11,7 @@
SYSTEM_MESSAGE = """\
Your main job is to handle programming code from SEVERAL FILES. \
Each file's content is shown within triple backticks and has a FILE PATH as a title. \
It's vital to KEEP the FILE PATHS.
Here's what to do:
It's vital to KEEP the FILE PATHS. Here's what to do:
1. ALWAYS KEEP the FILE PATHS for each file.
2. Start each file with a short SUMMARY of its content. Mention important points but don't repeat details found later.
3. KEEP important elements like non-trivial imports, function details, type hints, and key constants. \
Expand All @@ -29,7 +29,7 @@

class SummarizeAllCodeInPath(BaseTool):
"""Summarize code using GPT-3. The skill uses the `PrintAllFilesInPath` skill to get the code to summarize.
The parameters are: start_path, file_extensions.
The parameters are: start_path, file_extensions, exclude_directories.
Directory traversal is not allowed (you cannot read /* or ../*).
"""

Expand All @@ -38,23 +38,27 @@ class SummarizeAllCodeInPath(BaseTool):
description="The starting path to search for files, defaults to the current working directory. "
"Can be a filename or a directory.",
)
file_extensions: set[str] = Field(
default_factory=set,
description="Set of file extensions to include in the tree. If empty, all files will be included. "
"Examples are {'.py', '.txt', '.md'}.",
file_extensions: list[str] = Field(
default_factory=list,
description="List of file extensions to include in the tree. If empty, all files will be included. "
"Examples are ['.py', '.txt', '.md'].",
)
exclude_directories: list[str] | None = Field(
default_factory=list,
description="List of directories to exclude from the search. Examples are ['__pycache__', '.git'].",
)
truncate_to: int = Field(
default=None,
description="Truncate the output to this many characters. If None or skipped, the output is not truncated.",
)

def run(self) -> str:
def run(self, api_key: str | None = None) -> str:
"""Run the skill and return the output."""
delimiter = "\n\n```\n"

full_code = PrintAllFilesInPath(
start_path=self.start_path,
file_extensions=self.file_extensions,
exclude_directories=self.exclude_directories,
).run()

# Chunk the input based on token limit
Expand All @@ -63,7 +67,11 @@ def run(self) -> str:
outputs = []
for chunk in chunks:
output = get_chat_completion(
system_message=SYSTEM_MESSAGE, user_prompt=chunk, temperature=0.0, model=settings.gpt_cheap_model
system_message=SYSTEM_MESSAGE,
user_prompt=chunk,
temperature=0.0,
model=settings.gpt_small_model,
api_key=api_key,
)
outputs.append(output)

Expand All @@ -82,6 +90,7 @@ def run(self) -> str:
print(
SummarizeAllCodeInPath(
start_path=".",
file_extensions={".py"},
).run()
file_extensions=[".py"],
exclude_directories=["__pycache__", ".git", ".idea", "venv", ".vscode", "node_modules", "build", "dist"],
).run(api_key=os.getenv("API_KEY"))
)
11 changes: 8 additions & 3 deletions backend/custom_skills/summarize_code.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from pathlib import Path

from agency_swarm import BaseTool
Expand Down Expand Up @@ -32,10 +33,14 @@ class SummarizeCode(BaseTool):
..., description="The name of the file to be summarized. It can be a relative or absolute path."
)

def run(self) -> str:
def run(self, api_key: str | None = None) -> str:
code = PrintFileContents(file_name=self.file_name).run()
output = get_chat_completion(
system_message=SYSTEM_MESSAGE, user_prompt=code, temperature=0.0, model=settings.gpt_cheap_model
system_message=SYSTEM_MESSAGE,
user_prompt=code,
temperature=0.0,
model=settings.gpt_small_model,
api_key=api_key,
)
return output

Expand All @@ -44,5 +49,5 @@ def run(self) -> str:
print(
SummarizeCode(
file_name="example.py",
).run()
).run(api_key=os.getenv("API_KEY"))
)
19 changes: 13 additions & 6 deletions backend/dependencies/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from backend.services.agency_manager import AgencyManager
from backend.services.agent_manager import AgentManager
from backend.services.caching.redis_cache_manager import RedisCacheManager
from backend.services.env_config_manager import EnvConfigManager
from backend.services.session_manager import SessionManager
from backend.settings import settings

Expand All @@ -22,25 +23,31 @@ def get_redis_cache_manager(redis: aioredis.Redis = Depends(get_redis)) -> Redis
return RedisCacheManager(redis)


def get_env_config_manager(
env_config_storage: EnvConfigFirestoreStorage = Depends(EnvConfigFirestoreStorage),
) -> EnvConfigManager:
return EnvConfigManager(env_config_storage)


def get_agent_manager(
storage: AgentConfigFirestoreStorage = Depends(AgentConfigFirestoreStorage),
env_config_storage: EnvConfigFirestoreStorage = Depends(EnvConfigFirestoreStorage),
env_config_manager: EnvConfigManager = Depends(get_env_config_manager),
) -> AgentManager:
return AgentManager(storage, env_config_storage)
return AgentManager(storage, env_config_manager)


def get_agency_manager(
cache_manager: RedisCacheManager = Depends(get_redis_cache_manager),
agent_manager: AgentManager = Depends(get_agent_manager),
agency_config_storage: AgencyConfigFirestoreStorage = Depends(AgencyConfigFirestoreStorage),
env_config_storage: EnvConfigFirestoreStorage = Depends(EnvConfigFirestoreStorage),
env_config_manager: EnvConfigManager = Depends(get_env_config_manager),
) -> AgencyManager:
return AgencyManager(cache_manager, agent_manager, agency_config_storage, env_config_storage)
return AgencyManager(cache_manager, agent_manager, agency_config_storage, env_config_manager)


def get_session_manager(
session_storage: SessionConfigFirestoreStorage = Depends(SessionConfigFirestoreStorage),
env_config_storage: EnvConfigFirestoreStorage = Depends(EnvConfigFirestoreStorage),
env_config_manager: EnvConfigManager = Depends(get_env_config_manager),
) -> SessionManager:
"""Returns a SessionManager object"""
return SessionManager(session_storage=session_storage, env_config_storage=env_config_storage)
return SessionManager(session_storage=session_storage, env_config_manager=env_config_manager)
11 changes: 2 additions & 9 deletions backend/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import json

import firebase_admin
import openai
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from firebase_admin import credentials
from pydantic import ValidationError
from starlette.staticfiles import StaticFiles

Expand All @@ -16,7 +12,7 @@
from backend.exception_handlers import bad_request_exception_handler, unhandled_exception_handler # noqa # isort:skip
from backend.routers.v1 import v1_router # noqa # isort:skip
from backend.settings import settings # noqa # isort:skip
from backend.utils import init_webserver_folders # noqa # isort:skip
from backend.utils import init_webserver_folders, init_firebase_app # noqa # isort:skip

# just a placeholder for compatibility with agency-swarm
openai.api_key = "sk-1234567890"
Expand Down Expand Up @@ -51,10 +47,7 @@
app.mount("/", StaticFiles(directory=folders["static_folder_root"], html=True), name="ui")

# Initialize FireStore
if settings.google_credentials:
cred_json = json.loads(settings.google_credentials)
cred = credentials.Certificate(cred_json)
firebase_admin.initialize_app(cred)
init_firebase_app()


if __name__ == "__main__":
Expand Down
8 changes: 4 additions & 4 deletions backend/routers/v1/api/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from fastapi import APIRouter, Depends, HTTPException

from backend.dependencies.auth import get_current_user
from backend.dependencies.dependencies import get_agency_manager
from backend.dependencies.dependencies import get_agency_manager, get_env_config_manager
from backend.models.auth import User
from backend.models.request_models import SessionMessagePostRequest
from backend.repositories.agency_config_firestore_storage import AgencyConfigFirestoreStorage
from backend.repositories.env_config_firestore_storage import EnvConfigFirestoreStorage
from backend.repositories.session_firestore_storage import SessionConfigFirestoreStorage
from backend.services.agency_manager import AgencyManager
from backend.services.env_config_manager import EnvConfigManager
from backend.services.env_vars_manager import ContextEnvVarsManager
from backend.services.oai_client import get_openai_client

Expand All @@ -29,7 +29,7 @@ async def get_message_list(
session_id: str,
before: str | None = None,
session_storage: SessionConfigFirestoreStorage = Depends(SessionConfigFirestoreStorage),
env_config_storage: EnvConfigFirestoreStorage = Depends(EnvConfigFirestoreStorage),
env_config_manager: EnvConfigManager = Depends(get_env_config_manager),
):
"""Return a list of last 20 messages for the given session."""
# check if the current_user has permissions to send a message to the agency
Expand All @@ -45,7 +45,7 @@ async def get_message_list(
ContextEnvVarsManager.set("owner_id", current_user.id)

# use OpenAI's Assistants API to get the messages by thread_id=session_id
client = get_openai_client(env_config_storage)
client = get_openai_client(env_config_manager)
messages = client.beta.threads.messages.list(thread_id=session_id, limit=20, before=before)
return messages

Expand Down
8 changes: 4 additions & 4 deletions backend/services/agency_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

from backend.models.agency_config import AgencyConfig
from backend.repositories.agency_config_firestore_storage import AgencyConfigFirestoreStorage
from backend.repositories.env_config_firestore_storage import EnvConfigFirestoreStorage
from backend.services.agent_manager import AgentManager
from backend.services.caching.redis_cache_manager import RedisCacheManager
from backend.services.env_config_manager import EnvConfigManager
from backend.services.oai_client import get_openai_client

logger = logging.getLogger(__name__)
Expand All @@ -20,12 +20,12 @@ def __init__(
cache_manager: RedisCacheManager,
agent_manager: AgentManager,
agency_config_storage: AgencyConfigFirestoreStorage,
env_config_storage: EnvConfigFirestoreStorage,
env_config_manager: EnvConfigManager,
) -> None:
self.agency_config_storage = agency_config_storage
self.agent_manager = agent_manager
self.cache_manager = cache_manager
self.env_config_storage = env_config_storage
self.env_config_manager = env_config_manager

async def get_agency(self, agency_id: str, session_id: str | None = None) -> Agency | None:
cache_key = self.get_cache_key(agency_id, session_id)
Expand Down Expand Up @@ -136,7 +136,7 @@ def _remove_client_objects(agency: Agency) -> Agency:

def _set_client_objects(self, agency: Agency) -> Agency:
"""Restore all client objects within the agency object"""
client = get_openai_client(env_config_storage=self.env_config_storage)
client = get_openai_client(env_config_manager=self.env_config_manager)
# Restore client for each agent in the agency
for agent in agency.agents:
agent.client = client
Expand Down
Loading

0 comments on commit eeb53c3

Please sign in to comment.