From 85815b596b722ca53b56869f09de6ae5b63e4447 Mon Sep 17 00:00:00 2001 From: pandyamarut Date: Tue, 17 Sep 2024 15:21:49 -0700 Subject: [PATCH] update handler Signed-off-by: pandyamarut --- src/handler.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/handler.py b/src/handler.py index 8c55fa1..cc78440 100644 --- a/src/handler.py +++ b/src/handler.py @@ -2,11 +2,20 @@ import runpod from typing import List from tensorrt_llm import LLM, SamplingParams +from huggingface_hub import login # Enable build caching os.environ["TLLM_HLAPI_BUILD_CACHE"] = "1" # Optionally, set a custom cache directory # os.environ["TLLM_HLAPI_BUILD_CACHE_ROOT"] = "/path/to/custom/cache" +#HF_TOKEN for downloading models + + + +hf_token = os.environ["HF_TOKEN"] +login(token=hf_token) + + class TRTLLMWorker: def __init__(self, model_path: str): @@ -19,12 +28,15 @@ def generate(self, prompts: List[str], max_tokens: int = 100) -> List[str]: results = [] for output in outputs: results.append(output.outputs[0].text) - return results # Initialize the worker outside the handler # This ensures the model is loaded only once when the serverless function starts -worker = TRTLLMWorker("TinyLlama/TinyLlama-1.1B-Chat-v1.0") +# this path is hf model "/model_name" egs: meta-llama/Meta-Llama-3.1-8B-Instruct +model_path = os.environ["MODEL_PATH"] +worker = TRTLLMWorker(model_path) + + def handler(job): """Handler function that will be used to process jobs."""