diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index de710a94fd..761ee49f20 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -440,14 +440,12 @@ def create_agent_turn( @webmethod(route="/agents/turn/get") async def get_agents_turn( - self, - agent_id: str, - turn_id: str, + self, agent_id: str, session_id: str, turn_id: str ) -> Turn: ... @webmethod(route="/agents/step/get") async def get_agents_step( - self, agent_id: str, turn_id: str, step_id: str + self, agent_id: str, session_id: str, turn_id: str, step_id: str ) -> AgentStepResponse: ... @webmethod(route="/agents/session/create") diff --git a/llama_stack/providers/impls/meta_reference/agents/agents.py b/llama_stack/providers/impls/meta_reference/agents/agents.py index 5a209d0b76..c8c9c7f3b7 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agents.py +++ b/llama_stack/providers/impls/meta_reference/agents/agents.py @@ -138,13 +138,29 @@ async def _create_agent_turn_streaming( async for event in agent.create_and_execute_turn(request): yield event - async def get_agents_turn(self, agent_id: str, turn_id: str) -> Turn: - raise NotImplementedError() + async def get_agents_turn( + self, agent_id: str, session_id: str, turn_id: str + ) -> Turn: + turn = await self.persistence_store.get( + f"session:{agent_id}:{session_id}:{turn_id}" + ) + turn = json.loads(turn) + turn = Turn(**turn) + return turn async def get_agents_step( - self, agent_id: str, turn_id: str, step_id: str + self, agent_id: str, session_id: str, turn_id: str, step_id: str ) -> AgentStepResponse: - raise NotImplementedError() + turn = await self.persistence_store.get( + f"session:{agent_id}:{session_id}:{turn_id}" + ) + turn = json.loads(turn) + turn = Turn(**turn) + steps = turn.steps + for step in steps: + if step.step_id == step_id: + return AgentStepResponse(step=step) + raise ValueError(f"Provided step_id {step_id} could not be found") async def get_agents_session( self, @@ -152,10 +168,26 @@ async def get_agents_session( session_id: str, turn_ids: Optional[List[str]] = None, ) -> Session: - raise NotImplementedError() + session = await self.persistence_store.get(f"session:{agent_id}:{session_id}") + session = Session(**json.loads(session)) + turns = [] + if turn_ids: + for turn_id in turn_ids: + turn = await self.persistence_store.get( + f"session:{agent_id}:{session_id}:{turn_id}" + ) + turn = json.loads(turn) + turn = Turn(**turn) + turns.append(turn) + return Session( + session_name=session.session_name, + session_id=session_id, + turns=turns if turns else [], + started_at=session.started_at, + ) async def delete_agents_session(self, agent_id: str, session_id: str) -> None: - raise NotImplementedError() + await self.persistence_store.delete(f"session:{agent_id}:{session_id}") async def delete_agents(self, agent_id: str) -> None: - raise NotImplementedError() + await self.persistence_store.delete(f"agent:{agent_id}")