Skip to content

Commit

Permalink
Update abstra-lib
Browse files Browse the repository at this point in the history
  • Loading branch information
abstra-bot committed Feb 5, 2025
1 parent da689ab commit f1ff4ad
Show file tree
Hide file tree
Showing 283 changed files with 1,998 additions and 871 deletions.
18 changes: 18 additions & 0 deletions abstra/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import List, Literal

from abstra_internals.controllers.sdk_context import SDKContextStore
from abstra_internals.entities.agents import ConnectionModel


def get_connections(role: Literal["client", "agent"]) -> List[ConnectionModel]:
if role == "client":
return (
SDKContextStore.get_by_thread().repositories.role_clients.get_connections()
)
else:
return (
SDKContextStore.get_by_thread().repositories.role_agents.get_connections()
)


__all__ = ["get_connections"]
5 changes: 5 additions & 0 deletions abstra/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import fire

from abstra_internals.interface.cli.agents import add_agent
from abstra_internals.interface.cli.deploy import deploy
from abstra_internals.interface.cli.dir import select_dir
from abstra_internals.interface.cli.editor import editor
Expand Down Expand Up @@ -42,6 +43,10 @@ def restore(self, root_dir: str = "."):
SettingsController.set_root_path(root_dir)
restore()

def add_agent(self, agent_id: str, agent_name: str, root_dir: Optional[str] = None):
SettingsController.set_root_path(root_dir or select_dir())
add_agent(agent_project_id=agent_id, agent_title=agent_name)

def start(self, root_dir: Optional[str] = None, token: Optional[str] = None):
SettingsController.set_root_path(root_dir or select_dir())
start(token)
Expand Down
4 changes: 4 additions & 0 deletions abstra_internals/cloud/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abstra_internals.cloud.server_application import CustomApplication
from abstra_internals.cloud.server_hooks import GunicornOptionsBuilder
from abstra_internals.controllers.main import MainController
from abstra_internals.controllers.service.roles.client import RoleClientController
from abstra_internals.environment import DEFAULT_PORT
from abstra_internals.logger import AbstraLogger
from abstra_internals.repositories.factory import get_prodution_app_repositories
Expand All @@ -17,6 +18,9 @@ def run():
controller = MainController(repositories=get_prodution_app_repositories())
StdioPatcher.apply(controller)

role_client_controller = RoleClientController(controller.repositories)
role_client_controller.sync_connection_pool()

options = GunicornOptionsBuilder(controller).build()
app = get_cloud_app(controller)

Expand Down
45 changes: 36 additions & 9 deletions abstra_internals/cloud_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
CloudApiCliBuildCreateResponse,
)
from abstra_internals.credentials import resolve_headers
from abstra_internals.environment import CLOUD_API_CLI_URL, CLOUD_API_ENDPOINT, HOST
from abstra_internals.environment import (
CLOUD_API_CLI_URL,
CLOUD_API_ENDPOINT,
HOST,
IS_PRODUCTION,
PROJECT_ID,
)
from abstra_internals.logger import AbstraLogger
from abstra_internals.settings import Settings

Expand Down Expand Up @@ -82,18 +88,32 @@ def cancel_all(headers: dict, thread_id: str):
r.raise_for_status()


def get_project_info(headers: dict):
url = f"{CLOUD_API_CLI_URL}/project"
def get_project_info(headers: dict, project_id: Optional[str] = None):
if project_id:
url = f"{CLOUD_API_CLI_URL}/project/{project_id}"
else:
url = f"{CLOUD_API_CLI_URL}/project"
r = requests.get(url, headers=headers)
r.raise_for_status()
return r.json()


def get_project_id():
if IS_PRODUCTION:
return PROJECT_ID

else:
headers = resolve_headers()
if headers is None:
return None
return get_project_info(headers)["id"]


class TunnelRequest(BaseModel):
method: str
path: str
headers: dict
body: Optional[dict]
body: Optional[str]
query: dict
sessionPath: str
requestId: str
Expand Down Expand Up @@ -139,13 +159,15 @@ def loop():
kwargs: Any = dict(
headers=request.headers,
params=request.query,
**dict(data=json.dumps(request.body) if request.body else {}),
**dict(data=request.body if request.body else {}),
)
if not request.path.startswith("/_hooks/"):
if not request.path.startswith(
"/_hooks/"
) and not request.path.startswith("/_tasks"):
response = TunnelResponse(
status=403,
headers={},
text="Forbidden",
text=f"Forbidden path: {request.path}",
sessionPath=request.sessionPath,
requestId=request.requestId,
)
Expand All @@ -166,9 +188,14 @@ def loop():
ws.send(response_json)

else:
global session
session = SessionPathMessage.model_validate_json(message)
public_url = (
f"{CLOUD_API_ENDPOINT}/tunnel/forward/{session.sessionPath}"
)
Settings.set_public_url(public_url)
print(
f"Hooks can also be fired from {Fore.GREEN} {CLOUD_API_ENDPOINT}/tunnel/forward/{session.sessionPath}/_hooks/:hook-path{Fore.RESET}"
f"Hooks can also be fired from {Fore.GREEN} {public_url}/_hooks/:hook-path{Fore.RESET}"
)
except simple_websocket.ConnectionClosed as e:
print(f"Connection closed: {e}")
Expand All @@ -184,4 +211,4 @@ def loop():
ws.close()
ws = None

Thread(target=loop).start()
Thread(target=loop, daemon=True).start()
Empty file.
119 changes: 119 additions & 0 deletions abstra_internals/controllers/common/task_executors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from typing import Optional

import requests

from abstra_internals.email_templates import task_waiting_template
from abstra_internals.entities.execution import Execution, PreExecution
from abstra_internals.entities.execution_context import ScriptContext
from abstra_internals.environment import IS_PRODUCTION
from abstra_internals.repositories.factory import Repositories
from abstra_internals.repositories.project.project import (
AgentStage,
ClientStage,
FormStage,
ProjectRepository,
ScriptStage,
Stage,
)
from abstra_internals.repositories.tasks import TaskDTO, TaskPayload


class TaskExecutor:
def __init__(self, repos: Repositories) -> None:
self.project = ProjectRepository.load()
self.repos = repos

def send_task(
self,
type: str,
current_stage: Stage,
payload: TaskPayload,
execution: Optional[Execution] = None,
) -> None:
project = ProjectRepository.load()
next_stages = [
project.get_stage_raises(t.target_id)
for t in current_stage.workflow_transitions
if t.matches(type)
]

for stage in next_stages:
task = self.repos.tasks.send_task(
type=type,
payload=payload,
source_stage_id=current_stage.id,
target_stage_id=stage.id,
execution_id=execution.id if execution else None,
)
self._send_waiting_thread_notification(task)
if execution:
execution.context.sent_tasks.append(task.id)
if isinstance(stage, ScriptStage):
self.repos.producer.submit(
PreExecution(
context=ScriptContext(task_id=task.id),
stage_id=stage.id,
)
)
elif (
isinstance(stage, AgentStage)
and stage.project_id is not None
and stage.client_stage_id is not None
):
agent = self.repos.role_clients.get_agent(stage.project_id)
conn = next(
c
for c in self.repos.role_clients.get_connections()
if c.agent_project_id == agent.project_id
and c.client_stage_id == stage.id
)

assert conn is not None, "Connection for agent not found"

requests.post(
agent.tasks_url + "/agent",
json={
"task_data": {
"type": type,
"payload": {
**payload,
"connection_token": conn.token,
},
},
"target_stage_id": (stage.client_stage_id),
"execution_id": (
execution.id if IS_PRODUCTION and execution else None
),
},
headers={"authorization": conn.token},
).raise_for_status()

elif isinstance(stage, ClientStage):
assert isinstance(payload["connection_token"], str)
conn = self.repos.role_agents.get_connection_by_token(
payload["connection_token"]
)

requests.post(
conn.client_task_url,
json=task.model_dump(),
headers={"Authorization": conn.token},
).raise_for_status()

def _send_waiting_thread_notification(self, task: TaskDTO):
stage = self.project.get_stage(task.target_stage_id)
if not stage:
raise Exception(f"Stage {task.target_stage_id} not found")

if not (isinstance(stage, FormStage) and stage.notification_trigger.enabled):
return

recipient_emails = stage.notification_trigger.get_recipients(task.payload)
if not recipient_emails:
return

self.repos.email.send(
task_waiting_template.generate_email(
recipient_emails=recipient_emails, form=stage
)
)
4 changes: 2 additions & 2 deletions abstra_internals/controllers/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
ScriptContext,
)
from abstra_internals.repositories.factory import Repositories
from abstra_internals.repositories.project.project import Stage
from abstra_internals.repositories.project.project import Stage, StageWithFile


class NotStartedException(Exception):
Expand Down Expand Up @@ -42,7 +42,7 @@ def submit(
def run(
self,
*,
stage: Stage,
stage: StageWithFile,
client: Optional[ExecutionClient] = None,
context: Optional[ClientContext] = None,
):
Expand Down
4 changes: 2 additions & 2 deletions abstra_internals/controllers/execution_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from abstra_internals.entities.execution import ClientContext
from abstra_internals.environment import set_SERVER_UUID, set_WORKER_UUID
from abstra_internals.logger import AbstraLogger, Environment
from abstra_internals.repositories.project.project import Stage
from abstra_internals.repositories.project.project import StageWithFile
from abstra_internals.settings import Settings
from abstra_internals.stdio_patcher import StdioPatcher

Expand All @@ -17,7 +17,7 @@ def ExecutionProcess(
server_port: int,
worker_uuid: str,
arbiter_uuid: str,
stage: Stage,
stage: StageWithFile,
controller: MainController,
environment: Optional[Environment],
request: Optional[ClientContext] = None,
Expand Down
4 changes: 2 additions & 2 deletions abstra_internals/controllers/execution_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from abstra_internals.logger import AbstraLogger
from abstra_internals.modules import import_as_new
from abstra_internals.repositories.factory import Repositories
from abstra_internals.repositories.project.project import Stage
from abstra_internals.repositories.project.project import StageWithFile
from abstra_internals.usage import execution_usage
from abstra_internals.utils.datetime import now_str

Expand All @@ -19,7 +19,7 @@

def ExecutionTarget(
*,
stage: Stage,
stage: StageWithFile,
execution: Execution,
client: ExecutionClient,
repositories: Repositories,
Expand Down
10 changes: 5 additions & 5 deletions abstra_internals/controllers/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@
LogEntry,
)
from abstra_internals.repositories.factory import Repositories
from abstra_internals.repositories.jwt_signer import (
EditorJWTRepository,
JWTRepository,
)
from abstra_internals.repositories.jwt_signer import EditorJWTRepository, JWTRepository
from abstra_internals.repositories.keyvalue import KVRepository
from abstra_internals.repositories.producer import ProducerRepository
from abstra_internals.repositories.project.project import (
Expand All @@ -35,6 +32,7 @@
ProjectRepository,
ScriptStage,
Stage,
StageWithFile,
StyleSettingsWithSidebar,
)
from abstra_internals.repositories.roles import RolesRepository
Expand Down Expand Up @@ -386,7 +384,9 @@ def update_stage(self, id: str, changes: Dict[str, Any]) -> Stage:
if not stage:
raise Exception(f"Stage with id {id} not found")

if code_content := changes.pop("code_content", None):
if isinstance(stage, StageWithFile) and (
code_content := changes.pop("code_content", None)
):
Settings.root_path.joinpath(stage.file_path).write_text(
code_content, encoding="utf-8"
)
Expand Down
Loading

0 comments on commit f1ff4ad

Please sign in to comment.