From 8a01b9e40c5d00ef2a9c8ede92a4657c0e6f097d Mon Sep 17 00:00:00 2001 From: Sarthak Deshpande <60317842+cheesecake100201@users.noreply.github.com> Date: Wed, 23 Oct 2024 02:20:43 +0530 Subject: [PATCH] Added implementations for get_agents_session, delete_agents_session and delete_agents (#267) --- llama_stack/apis/agents/agents.py | 6 +-- .../impls/meta_reference/agents/agents.py | 46 ++++++++++++++++--- 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index db0b1a2691..e0eaacf517 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -438,14 +438,12 @@ async def create_agent_turn( @webmethod(route="/agents/turn/get") async def get_agents_turn( - self, - agent_id: str, - turn_id: str, + self, agent_id: str, session_id: str, turn_id: str ) -> Turn: ... @webmethod(route="/agents/step/get") async def get_agents_step( - self, agent_id: str, turn_id: str, step_id: str + self, agent_id: str, session_id: str, turn_id: str, step_id: str ) -> AgentStepResponse: ... @webmethod(route="/agents/session/create") diff --git a/llama_stack/providers/impls/meta_reference/agents/agents.py b/llama_stack/providers/impls/meta_reference/agents/agents.py index 8b3ece978f..ca5a003599 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agents.py +++ b/llama_stack/providers/impls/meta_reference/agents/agents.py @@ -138,13 +138,29 @@ async def _create_agent_turn_streaming( async for event in agent.create_and_execute_turn(request): yield event - async def get_agents_turn(self, agent_id: str, turn_id: str) -> Turn: - raise NotImplementedError() + async def get_agents_turn( + self, agent_id: str, session_id: str, turn_id: str + ) -> Turn: + turn = await self.persistence_store.get( + f"session:{agent_id}:{session_id}:{turn_id}" + ) + turn = json.loads(turn) + turn = Turn(**turn) + return turn async def get_agents_step( - self, agent_id: str, turn_id: str, step_id: str + self, agent_id: str, session_id: str, turn_id: str, step_id: str ) -> AgentStepResponse: - raise NotImplementedError() + turn = await self.persistence_store.get( + f"session:{agent_id}:{session_id}:{turn_id}" + ) + turn = json.loads(turn) + turn = Turn(**turn) + steps = turn.steps + for step in steps: + if step.step_id == step_id: + return AgentStepResponse(step=step) + raise ValueError(f"Provided step_id {step_id} could not be found") async def get_agents_session( self, @@ -152,10 +168,26 @@ async def get_agents_session( session_id: str, turn_ids: Optional[List[str]] = None, ) -> Session: - raise NotImplementedError() + session = await self.persistence_store.get(f"session:{agent_id}:{session_id}") + session = Session(**json.loads(session)) + turns = [] + if turn_ids: + for turn_id in turn_ids: + turn = await self.persistence_store.get( + f"session:{agent_id}:{session_id}:{turn_id}" + ) + turn = json.loads(turn) + turn = Turn(**turn) + turns.append(turn) + return Session( + session_name=session.session_name, + session_id=session_id, + turns=turns if turns else [], + started_at=session.started_at, + ) async def delete_agents_session(self, agent_id: str, session_id: str) -> None: - raise NotImplementedError() + await self.persistence_store.delete(f"session:{agent_id}:{session_id}") async def delete_agents(self, agent_id: str) -> None: - raise NotImplementedError() + await self.persistence_store.delete(f"agent:{agent_id}")