Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add exclude_directories parameter in tools; fix file_extensions parameter #31

Merged
merged 3 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions backend/custom_skills/build_directory_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from backend.custom_skills.utils import check_directory_traversal

MAX_LENGTH = 3000
MAX_LENGTH = 10000


class DirectoryNode:
Expand All @@ -27,18 +27,20 @@ def __init__(self, path: Path, level: int):


class BuildDirectoryTree(BaseTool):
"""Print the structure of directories and files.
Directory traversal is not allowed (you cannot read /* or ../*).
"""
"""Print the structure of directories and files. Directory traversal is not allowed (you cannot read /* or ../*)."""

start_directory: Path = Field(
default_factory=Path.cwd,
description="The starting directory for the tree, defaults to the current working directory.",
)
file_extensions: set[str] = Field(
default_factory=set,
description="Set of file extensions to include in the tree. If empty, all files will be included. "
"Examples are {'.py', '.txt', '.md'}.",
file_extensions: list[str] = Field(
default_factory=list,
description="List of file extensions to include in the tree. If empty, all files will be included. "
"Examples are ['.py', '.txt', '.md'].",
)
exclude_directories: list[str] = Field(
default_factory=list,
description="List of directories to exclude from the tree. Examples are ['__pycache__', '.git'].",
)

_validate_start_directory = field_validator("start_directory", mode="after")(check_directory_traversal)
Expand All @@ -53,7 +55,7 @@ def build_tree(self) -> DirectoryNode:
children = [p for p in current_node.path.iterdir() if not p.name.startswith(".")]

for child in children:
if child.is_dir():
if child.is_dir() and child.name not in self.exclude_directories:
dir_node = DirectoryNode(child, current_node.level + 1)
current_node.children.append(dir_node)
queue.append(dir_node)
Expand Down Expand Up @@ -100,6 +102,7 @@ def run(self) -> str:
if __name__ == "__main__":
print(
BuildDirectoryTree(
file_extensions={".py", ".md"},
file_extensions=[],
exclude_directories=["__pycache__", ".git", ".idea", "venv", ".vscode", "node_modules", "build", "dist"],
).run()
)
37 changes: 28 additions & 9 deletions backend/custom_skills/print_all_files_in_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class PrintAllFilesInPath(BaseTool):
"""Print the contents of all files in a start_path recursively.
The parameters are: start_path, file_extensions.
The parameters are: start_path, file_extensions, exclude_directories.
Directory traversal is not allowed (you cannot read /* or ../*).
"""

Expand All @@ -17,10 +17,14 @@ class PrintAllFilesInPath(BaseTool):
description="The starting path to search for files, defaults to the current working directory. "
"Can be a filename or a directory.",
)
file_extensions: set[str] = Field(
default_factory=set,
description="Set of file extensions to include in the tree. If empty, all files will be included. "
"Examples are {'.py', '.txt', '.md'}.",
file_extensions: list[str] = Field(
default_factory=list,
description="List of file extensions to include in the tree. If empty, all files will be included. "
"Examples are ['.py', '.txt', '.md'].",
)
exclude_directories: list[str] = Field(
default_factory=list,
description="List of directories to exclude from the search. Examples are ['__pycache__', '.git'].",
)
truncate_to: int = Field(
default=None,
Expand All @@ -41,26 +45,41 @@ def run(self) -> str:
return f"{str(start_path)}:\n```\n{read_file(start_path)}\n```\n"

for path in start_path.rglob("*"):
# ignore files in hidden directories
if any(part.startswith(".") for part in path.parts):
# ignore files in hidden directories or excluded directories
if any(part.startswith(".") for part in path.parts) or any(
part in self.exclude_directories for part in path.parts
):
continue

if path.is_file() and (not self.file_extensions or path.suffix in self.file_extensions):
output.append(f"{str(path)}:\n```\n{read_file(path)}\n```\n")

output_str = "\n".join(output)

if self.truncate_to and len(output_str) > self.truncate_to:
output_str = (
output_str[: self.truncate_to]
+ "\n\n... (truncated output, please use a smaller directory or apply a filter)"
)

return output_str


if __name__ == "__main__":
# list of extensions: ".py", ".json", ".yaml", ".yml", ".md", ".txt", ".tsx", ".ts", ".js", ".jsx", ".html"
print(
PrintAllFilesInPath(
start_path=".",
file_extensions={".py", ".json", ".yaml", ".yml", ".md", ".txt", ".tsx", ".ts", ".js", ".jsx", ".html"},
file_extensions=[],
exclude_directories=[
"frontend",
"__pycache__",
".git",
".idea",
"venv",
".vscode",
"node_modules",
"build",
"dist",
],
).run()
)
33 changes: 21 additions & 12 deletions backend/custom_skills/summarize_all_code_in_path.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from pathlib import Path

from agency_swarm import BaseTool
Expand All @@ -10,8 +11,7 @@
SYSTEM_MESSAGE = """\
Your main job is to handle programming code from SEVERAL FILES. \
Each file's content is shown within triple backticks and has a FILE PATH as a title. \
It's vital to KEEP the FILE PATHS.
Here's what to do:
It's vital to KEEP the FILE PATHS. Here's what to do:
1. ALWAYS KEEP the FILE PATHS for each file.
2. Start each file with a short SUMMARY of its content. Mention important points but don't repeat details found later.
3. KEEP important elements like non-trivial imports, function details, type hints, and key constants. \
Expand All @@ -29,7 +29,7 @@

class SummarizeAllCodeInPath(BaseTool):
"""Summarize code using GPT-3. The skill uses the `PrintAllFilesInPath` skill to get the code to summarize.
The parameters are: start_path, file_extensions.
The parameters are: start_path, file_extensions, exclude_directories.
Directory traversal is not allowed (you cannot read /* or ../*).
"""

Expand All @@ -38,23 +38,27 @@ class SummarizeAllCodeInPath(BaseTool):
description="The starting path to search for files, defaults to the current working directory. "
"Can be a filename or a directory.",
)
file_extensions: set[str] = Field(
default_factory=set,
description="Set of file extensions to include in the tree. If empty, all files will be included. "
"Examples are {'.py', '.txt', '.md'}.",
file_extensions: list[str] = Field(
default_factory=list,
description="List of file extensions to include in the tree. If empty, all files will be included. "
"Examples are ['.py', '.txt', '.md'].",
)
exclude_directories: list[str] | None = Field(
default_factory=list,
description="List of directories to exclude from the search. Examples are ['__pycache__', '.git'].",
)
truncate_to: int = Field(
default=None,
description="Truncate the output to this many characters. If None or skipped, the output is not truncated.",
)

def run(self) -> str:
def run(self, api_key: str | None = None) -> str:
"""Run the skill and return the output."""
delimiter = "\n\n```\n"

full_code = PrintAllFilesInPath(
start_path=self.start_path,
file_extensions=self.file_extensions,
exclude_directories=self.exclude_directories,
).run()

# Chunk the input based on token limit
Expand All @@ -63,7 +67,11 @@ def run(self) -> str:
outputs = []
for chunk in chunks:
output = get_chat_completion(
system_message=SYSTEM_MESSAGE, user_prompt=chunk, temperature=0.0, model=settings.gpt_cheap_model
system_message=SYSTEM_MESSAGE,
user_prompt=chunk,
temperature=0.0,
model=settings.gpt_small_model,
api_key=api_key,
)
outputs.append(output)

Expand All @@ -82,6 +90,7 @@ def run(self) -> str:
print(
SummarizeAllCodeInPath(
start_path=".",
file_extensions={".py"},
).run()
file_extensions=[".py"],
exclude_directories=["__pycache__", ".git", ".idea", "venv", ".vscode", "node_modules", "build", "dist"],
).run(api_key=os.getenv("API_KEY"))
)
11 changes: 8 additions & 3 deletions backend/custom_skills/summarize_code.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from pathlib import Path

from agency_swarm import BaseTool
Expand Down Expand Up @@ -32,10 +33,14 @@ class SummarizeCode(BaseTool):
..., description="The name of the file to be summarized. It can be a relative or absolute path."
)

def run(self) -> str:
def run(self, api_key: str | None = None) -> str:
code = PrintFileContents(file_name=self.file_name).run()
output = get_chat_completion(
system_message=SYSTEM_MESSAGE, user_prompt=code, temperature=0.0, model=settings.gpt_cheap_model
system_message=SYSTEM_MESSAGE,
user_prompt=code,
temperature=0.0,
model=settings.gpt_small_model,
api_key=api_key,
)
return output

Expand All @@ -44,5 +49,5 @@ def run(self) -> str:
print(
SummarizeCode(
file_name="example.py",
).run()
).run(api_key=os.getenv("API_KEY"))
)
19 changes: 13 additions & 6 deletions backend/dependencies/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from backend.services.agency_manager import AgencyManager
from backend.services.agent_manager import AgentManager
from backend.services.caching.redis_cache_manager import RedisCacheManager
from backend.services.env_config_manager import EnvConfigManager
from backend.services.session_manager import SessionManager
from backend.settings import settings

Expand All @@ -22,25 +23,31 @@ def get_redis_cache_manager(redis: aioredis.Redis = Depends(get_redis)) -> Redis
return RedisCacheManager(redis)


def get_env_config_manager(
env_config_storage: EnvConfigFirestoreStorage = Depends(EnvConfigFirestoreStorage),
) -> EnvConfigManager:
return EnvConfigManager(env_config_storage)


def get_agent_manager(
storage: AgentConfigFirestoreStorage = Depends(AgentConfigFirestoreStorage),
env_config_storage: EnvConfigFirestoreStorage = Depends(EnvConfigFirestoreStorage),
env_config_manager: EnvConfigManager = Depends(get_env_config_manager),
) -> AgentManager:
return AgentManager(storage, env_config_storage)
return AgentManager(storage, env_config_manager)


def get_agency_manager(
cache_manager: RedisCacheManager = Depends(get_redis_cache_manager),
agent_manager: AgentManager = Depends(get_agent_manager),
agency_config_storage: AgencyConfigFirestoreStorage = Depends(AgencyConfigFirestoreStorage),
env_config_storage: EnvConfigFirestoreStorage = Depends(EnvConfigFirestoreStorage),
env_config_manager: EnvConfigManager = Depends(get_env_config_manager),
) -> AgencyManager:
return AgencyManager(cache_manager, agent_manager, agency_config_storage, env_config_storage)
return AgencyManager(cache_manager, agent_manager, agency_config_storage, env_config_manager)


def get_session_manager(
session_storage: SessionConfigFirestoreStorage = Depends(SessionConfigFirestoreStorage),
env_config_storage: EnvConfigFirestoreStorage = Depends(EnvConfigFirestoreStorage),
env_config_manager: EnvConfigManager = Depends(get_env_config_manager),
) -> SessionManager:
"""Returns a SessionManager object"""
return SessionManager(session_storage=session_storage, env_config_storage=env_config_storage)
return SessionManager(session_storage=session_storage, env_config_manager=env_config_manager)
11 changes: 2 additions & 9 deletions backend/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import json

import firebase_admin
import openai
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from firebase_admin import credentials
from pydantic import ValidationError
from starlette.staticfiles import StaticFiles

Expand All @@ -16,7 +12,7 @@
from backend.exception_handlers import bad_request_exception_handler, unhandled_exception_handler # noqa # isort:skip
from backend.routers.v1 import v1_router # noqa # isort:skip
from backend.settings import settings # noqa # isort:skip
from backend.utils import init_webserver_folders # noqa # isort:skip
from backend.utils import init_webserver_folders, init_firebase_app # noqa # isort:skip

# just a placeholder for compatibility with agency-swarm
openai.api_key = "sk-1234567890"
Expand Down Expand Up @@ -51,10 +47,7 @@
app.mount("/", StaticFiles(directory=folders["static_folder_root"], html=True), name="ui")

# Initialize FireStore
if settings.google_credentials:
cred_json = json.loads(settings.google_credentials)
cred = credentials.Certificate(cred_json)
firebase_admin.initialize_app(cred)
init_firebase_app()


if __name__ == "__main__":
Expand Down
8 changes: 4 additions & 4 deletions backend/routers/v1/api/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from fastapi import APIRouter, Depends, HTTPException

from backend.dependencies.auth import get_current_user
from backend.dependencies.dependencies import get_agency_manager
from backend.dependencies.dependencies import get_agency_manager, get_env_config_manager
from backend.models.auth import User
from backend.models.request_models import SessionMessagePostRequest
from backend.repositories.agency_config_firestore_storage import AgencyConfigFirestoreStorage
from backend.repositories.env_config_firestore_storage import EnvConfigFirestoreStorage
from backend.repositories.session_firestore_storage import SessionConfigFirestoreStorage
from backend.services.agency_manager import AgencyManager
from backend.services.env_config_manager import EnvConfigManager
from backend.services.env_vars_manager import ContextEnvVarsManager
from backend.services.oai_client import get_openai_client

Expand All @@ -29,7 +29,7 @@ async def get_message_list(
session_id: str,
before: str | None = None,
session_storage: SessionConfigFirestoreStorage = Depends(SessionConfigFirestoreStorage),
env_config_storage: EnvConfigFirestoreStorage = Depends(EnvConfigFirestoreStorage),
env_config_manager: EnvConfigManager = Depends(get_env_config_manager),
):
"""Return a list of last 20 messages for the given session."""
# check if the current_user has permissions to send a message to the agency
Expand All @@ -45,7 +45,7 @@ async def get_message_list(
ContextEnvVarsManager.set("owner_id", current_user.id)

# use OpenAI's Assistants API to get the messages by thread_id=session_id
client = get_openai_client(env_config_storage)
client = get_openai_client(env_config_manager)
messages = client.beta.threads.messages.list(thread_id=session_id, limit=20, before=before)
return messages

Expand Down
8 changes: 4 additions & 4 deletions backend/services/agency_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

from backend.models.agency_config import AgencyConfig
from backend.repositories.agency_config_firestore_storage import AgencyConfigFirestoreStorage
from backend.repositories.env_config_firestore_storage import EnvConfigFirestoreStorage
from backend.services.agent_manager import AgentManager
from backend.services.caching.redis_cache_manager import RedisCacheManager
from backend.services.env_config_manager import EnvConfigManager
from backend.services.oai_client import get_openai_client

logger = logging.getLogger(__name__)
Expand All @@ -20,12 +20,12 @@ def __init__(
cache_manager: RedisCacheManager,
agent_manager: AgentManager,
agency_config_storage: AgencyConfigFirestoreStorage,
env_config_storage: EnvConfigFirestoreStorage,
env_config_manager: EnvConfigManager,
) -> None:
self.agency_config_storage = agency_config_storage
self.agent_manager = agent_manager
self.cache_manager = cache_manager
self.env_config_storage = env_config_storage
self.env_config_manager = env_config_manager

async def get_agency(self, agency_id: str, session_id: str | None = None) -> Agency | None:
cache_key = self.get_cache_key(agency_id, session_id)
Expand Down Expand Up @@ -136,7 +136,7 @@ def _remove_client_objects(agency: Agency) -> Agency:

def _set_client_objects(self, agency: Agency) -> Agency:
"""Restore all client objects within the agency object"""
client = get_openai_client(env_config_storage=self.env_config_storage)
client = get_openai_client(env_config_manager=self.env_config_manager)
# Restore client for each agent in the agency
for agent in agency.agents:
agent.client = client
Expand Down
Loading
Loading