diff --git a/agent/src/og_agent/agent_api_server.py b/agent/src/og_agent/agent_api_server.py index d9f6d0d..e62e62a 100644 --- a/agent/src/og_agent/agent_api_server.py +++ b/agent/src/og_agent/agent_api_server.py @@ -178,11 +178,10 @@ class TaskRequest(BaseModel): async def run_task(task: TaskRequest, key): - async for respond in agent_sdk.prompt(task.prompt, key, files=task.input_files): + async for respond in agent_sdk.prompt(task.prompt, key, files=task.input_files, context_id=task.context_id): response = StepResponse.new_from(respond).model_dump(exclude_none=True) yield "data: %s\n" % json.dumps(response) - @app.post("/process") async def process_task( task: TaskRequest, diff --git a/sdk/src/og_sdk/agent_sdk.py b/sdk/src/og_sdk/agent_sdk.py index 14ac5ad..17595cc 100644 --- a/sdk/src/og_sdk/agent_sdk.py +++ b/sdk/src/og_sdk/agent_sdk.py @@ -182,11 +182,11 @@ async def add_kernel(self, key, endpoint, api_key): response = await self.stub.add_kernel(request, metadata=metadata) return response - async def prompt(self, prompt, api_key, files=[]): + async def prompt(self, prompt, api_key, files=[], context_id=None): metadata = aio.Metadata( ("api_key", api_key), ) - request = agent_server_pb2.ProcessTaskRequest(task=prompt, input_files=files) + request = agent_server_pb2.ProcessTaskRequest(task=prompt, input_files=files, context_id=context_id) async for respond in self.stub.process_task(request, metadata=metadata): yield respond