diff --git a/src/handler.py b/src/handler.py index cc78440..59137a8 100644 --- a/src/handler.py +++ b/src/handler.py @@ -1,8 +1,9 @@ import os import runpod -from typing import List +from typing import List, AsyncGenerator from tensorrt_llm import LLM, SamplingParams from huggingface_hub import login +from tensorrt_llm.hlapi import BuildConfig, KvCacheConfig # Enable build caching os.environ["TLLM_HLAPI_BUILD_CACHE"] = "1" @@ -19,7 +20,7 @@ class TRTLLMWorker: def __init__(self, model_path: str): - self.llm = LLM(model=model_path, enable_build_cache=True) + self.llm = LLM(model=model_path, enable_build_cache=True, kv_cache_config=KvCacheConfig(), build_config=BuildConfig()) def generate(self, prompts: List[str], max_tokens: int = 100) -> List[str]: sampling_params = SamplingParams(max_new_tokens=max_tokens) @@ -30,25 +31,55 @@ def generate(self, prompts: List[str], max_tokens: int = 100) -> List[str]: results.append(output.outputs[0].text) return results + + async def generate_async(self, prompts: List[str], max_tokens: int = 100) -> AsyncGenerator[str, None]: + sampling_params = SamplingParams(max_new_tokens=max_tokens) + + async for output in self.llm.generate_async(prompts, sampling_params): + for request_output in output.outputs: + if request_output.text: + yield request_output.text + # Initialize the worker outside the handler # This ensures the model is loaded only once when the serverless function starts -# this path is hf model "/model_name" egs: meta-llama/Meta-Llama-3.1-8B-Instruct +# this path is hf model "/model_name" egs: meta-llama/Meta-Llama-3.1-8B-Instruct model_path = os.environ["MODEL_PATH"] worker = TRTLLMWorker(model_path) -def handler(job): +# def handler(job): +# """Handler function that will be used to process jobs.""" +# job_input = job['input'] +# prompts = job_input.get('prompts', ["Hello, how are you?"]) +# max_tokens = job_input.get('max_tokens', 100) +# streaming = job_input.get('streaming', False) + +# try: +# results = worker.generate(prompts, max_tokens) +# return {"status": "success", "output": results} +# except Exception as e: +# return {"status": "error", "message": str(e)} + + +async def handler(job): """Handler function that will be used to process jobs.""" job_input = job['input'] prompts = job_input.get('prompts', ["Hello, how are you?"]) max_tokens = job_input.get('max_tokens', 100) + streaming = job_input.get('streaming', False) try: - results = worker.generate(prompts, max_tokens) - return {"status": "success", "output": results} + if streaming: + results = [] + async for chunk in worker.generate_async(prompts, max_tokens): + results.append(chunk) + yield {"status": "streaming", "chunk": chunk} + yield {"status": "success", "output": results} + else: + results = worker.generate(prompts, max_tokens) + return {"status": "success", "output": results} except Exception as e: return {"status": "error", "message": str(e)} - -runpod.serverless.start({"handler": handler}) \ No newline at end of file +runpod.serverless.start({"handler": handler, "return_aggregate_stream": True}) \ No newline at end of file