diff --git a/src/handler.py b/src/handler.py index 59137a8..c574551 100644 --- a/src/handler.py +++ b/src/handler.py @@ -1,6 +1,6 @@ import os import runpod -from typing import List, AsyncGenerator +from typing import List, AsyncGenerator, Dict, Union from tensorrt_llm import LLM, SamplingParams from huggingface_hub import login from tensorrt_llm.hlapi import BuildConfig, KvCacheConfig @@ -47,22 +47,7 @@ async def generate_async(self, prompts: List[str], max_tokens: int = 100) -> Asy worker = TRTLLMWorker(model_path) - -# def handler(job): -# """Handler function that will be used to process jobs.""" -# job_input = job['input'] -# prompts = job_input.get('prompts', ["Hello, how are you?"]) -# max_tokens = job_input.get('max_tokens', 100) -# streaming = job_input.get('streaming', False) - -# try: -# results = worker.generate(prompts, max_tokens) -# return {"status": "success", "output": results} -# except Exception as e: -# return {"status": "error", "message": str(e)} - - -async def handler(job): +async def handler(job: Dict) -> AsyncGenerator[Dict[str, Union[str, List[str]]], None]: """Handler function that will be used to process jobs.""" job_input = job['input'] prompts = job_input.get('prompts', ["Hello, how are you?"]) @@ -78,8 +63,8 @@ async def handler(job): yield {"status": "success", "output": results} else: results = worker.generate(prompts, max_tokens) - return {"status": "success", "output": results} + yield {"status": "success", "output": results} except Exception as e: - return {"status": "error", "message": str(e)} + yield {"status": "error", "message": str(e)} runpod.serverless.start({"handler": handler, "return_aggregate_stream": True}) \ No newline at end of file