generated from runpod-workers/worker-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: pandyamarut <[email protected]>
- Loading branch information
1 parent
abd769a
commit 1a014c2
Showing
5 changed files
with
1,210 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,70 +1,151 @@ | ||
import os | ||
import asyncio | ||
from typing import Optional, Dict, Any | ||
from dataclasses import dataclass | ||
import runpod | ||
from typing import List, AsyncGenerator, Dict, Union | ||
from tensorrt_llm import LLM, SamplingParams | ||
from transformers import AutoTokenizer | ||
from tensorrt_llm import LLM, BuildConfig | ||
from tensorrt_llm import LlmArgs | ||
from serve import OpenAIServer | ||
from dotenv import load_dotenv | ||
from huggingface_hub import login | ||
from tensorrt_llm.hlapi import BuildConfig, KvCacheConfig | ||
import requests | ||
|
||
# Enable build caching | ||
os.environ["TLLM_HLAPI_BUILD_CACHE"] = "1" | ||
# Optionally, set a custom cache directory | ||
# os.environ["TLLM_HLAPI_BUILD_CACHE_ROOT"] = "/path/to/custom/cache" | ||
#HF_TOKEN for downloading models | ||
@dataclass | ||
class ServerConfig: | ||
model: str | ||
tokenizer: Optional[str] = None | ||
max_beam_width: Optional[int] = BuildConfig.max_beam_width | ||
max_batch_size: Optional[int] = BuildConfig.max_batch_size | ||
max_num_tokens: Optional[int] = BuildConfig.max_num_tokens | ||
max_seq_len: Optional[int] = BuildConfig.max_seq_len | ||
tp_size: Optional[int] = 1 | ||
pp_size: Optional[int] = 1 | ||
kv_cache_free_gpu_memory_fraction: Optional[float] = 0.9 | ||
trust_remote_code: bool = False | ||
|
||
@classmethod | ||
def from_env(cls) -> 'ServerConfig': | ||
model = os.getenv('TRTLLM_MODEL') | ||
if not model: | ||
raise ValueError("TRTLLM_MODEL environment variable must be set") | ||
|
||
return cls( | ||
model=model, | ||
tokenizer=os.getenv('TRTLLM_TOKENIZER'), | ||
max_beam_width=int(os.getenv('TRTLLM_MAX_BEAM_WIDTH', str(BuildConfig.max_beam_width))) if os.getenv('TRTLLM_MAX_BEAM_WIDTH') else None, | ||
max_batch_size=int(os.getenv('TRTLLM_MAX_BATCH_SIZE', str(BuildConfig.max_batch_size))) if os.getenv('TRTLLM_MAX_BATCH_SIZE') else None, | ||
max_num_tokens=int(os.getenv('TRTLLM_MAX_NUM_TOKENS', str(BuildConfig.max_num_tokens))) if os.getenv('TRTLLM_MAX_NUM_TOKENS') else None, | ||
max_seq_len=int(os.getenv('TRTLLM_MAX_SEQ_LEN', str(BuildConfig.max_seq_len))) if os.getenv('TRTLLM_MAX_SEQ_LEN') else None, | ||
tp_size=int(os.getenv('TRTLLM_TP_SIZE', '1')) if os.getenv('TRTLLM_TP_SIZE') else None, | ||
pp_size=int(os.getenv('TRTLLM_PP_SIZE', '1')) if os.getenv('TRTLLM_PP_SIZE') else None, | ||
kv_cache_free_gpu_memory_fraction=float(os.getenv('TRTLLM_KV_CACHE_FREE_GPU_MEMORY_FRACTION', '0.9')) if os.getenv('TRTLLM_KV_CACHE_FREE_GPU_MEMORY_FRACTION') else None, | ||
trust_remote_code=os.getenv('TRTLLM_TRUST_REMOTE_CODE', '').lower() in ('true', '1', 'yes') | ||
) | ||
|
||
hf_token = os.environ["HF_TOKEN"] | ||
login(token=hf_token) | ||
def validate(self) -> None: | ||
if not self.model: | ||
raise ValueError("Model path or name must be provided") | ||
|
||
class TensorRTLLMServer: | ||
""" | ||
Singleton class to manage TensorRT-LLM server instance and handle requests | ||
""" | ||
# _instance = None | ||
# _initialized = False | ||
|
||
# def __new__(cls): | ||
# if cls._instance is None: | ||
# cls._instance = super(TensorRTLLMServer, cls).__new__(cls) | ||
# return cls._instance | ||
|
||
class TRTLLMWorker: | ||
def __init__(self, model_path: str): | ||
self.llm = LLM(model=model_path, enable_build_cache=True, kv_cache_config=KvCacheConfig(), build_config=BuildConfig()) | ||
def __init__(self): | ||
self._initialize_server() | ||
self.host = '0.0.0.0' | ||
self.port = 8000 | ||
|
||
def _initialize_server(self): | ||
"""Initialize the TensorRT-LLM server and load model""" | ||
# Load environment variables | ||
load_dotenv() | ||
|
||
def generate(self, prompts: List[str], max_tokens: int = 100) -> List[str]: | ||
sampling_params = SamplingParams(max_new_tokens=max_tokens) | ||
outputs = self.llm.generate(prompts, sampling_params) | ||
# Handle HuggingFace login | ||
huggingface_token = os.getenv("HF_TOKEN") | ||
if huggingface_token: | ||
print("Logging in to Hugging Face...") | ||
login(huggingface_token) | ||
|
||
# Initialize configuration | ||
self.config = ServerConfig.from_env() | ||
self.config.validate() | ||
|
||
# Create build configuration | ||
build_config = BuildConfig( | ||
max_batch_size=self.config.max_batch_size, | ||
max_num_tokens=self.config.max_num_tokens, | ||
max_beam_width=self.config.max_beam_width, | ||
max_seq_len=self.config.max_seq_len | ||
) | ||
|
||
# Initialize LLM | ||
self.llm = LLM( | ||
model=self.config.model, | ||
tokenizer=self.config.tokenizer, | ||
tensor_parallel_size=self.config.tp_size, | ||
pipeline_parallel_size=self.config.pp_size, | ||
trust_remote_code=self.config.trust_remote_code, | ||
build_config=build_config | ||
) | ||
|
||
# Initialize tokenizer | ||
self.tokenizer = AutoTokenizer.from_pretrained( | ||
self.config.tokenizer or self.config.model, | ||
trust_remote_code=self.config.trust_remote_code | ||
) | ||
|
||
# Initialize OpenAI compatible server | ||
self.server = OpenAIServer( | ||
llm=self.llm, | ||
model=self.config.model, | ||
hf_tokenizer=self.tokenizer | ||
) | ||
|
||
results = [] | ||
for output in outputs: | ||
results.append(output.outputs[0].text) | ||
return results | ||
asyncio.run(self.server(self.host, self.port)) | ||
|
||
# Initialize the server at module load time | ||
server = TensorRTLLMServer() | ||
|
||
async def async_handler(job): | ||
"""Handle the requests asynchronously.""" | ||
job_input = job["input"] | ||
print(f"JOB_INPUT: {job_input}") | ||
|
||
async def generate_async(self, prompts: List[str], max_tokens: int = 100) -> AsyncGenerator[str, None]: | ||
sampling_params = SamplingParams(max_new_tokens=max_tokens) | ||
|
||
async for output in self.llm.generate_async(prompts, sampling_params): | ||
for request_output in output.outputs: | ||
if request_output.text: | ||
yield request_output.text | ||
|
||
# Initialize the worker outside the handler | ||
# This ensures the model is loaded only once when the serverless function starts | ||
# this path is hf model "<org_name>/model_name" egs: meta-llama/Meta-Llama-3.1-8B-Instruct | ||
model_path = os.environ["MODEL_PATH"] | ||
worker = TRTLLMWorker(model_path) | ||
|
||
|
||
async def handler(job: Dict) -> AsyncGenerator[Dict[str, Union[str, List[str]]], None]: | ||
"""Handler function that will be used to process jobs.""" | ||
job_input = job['input'] | ||
prompts = job_input.get('prompts', ["Hello, how are you?"]) | ||
max_tokens = job_input.get('max_tokens', 100) | ||
streaming = job_input.get('streaming', False) | ||
base_url = "http://0.0.0.0:8000" | ||
|
||
try: | ||
if streaming: | ||
results = [] | ||
async for chunk in worker.generate_async(prompts, max_tokens): | ||
results.append(chunk) | ||
yield {"status": "streaming", "chunk": chunk} | ||
yield {"status": "success", "output": results} | ||
if job_input.get("openai_route"): | ||
openai_route, openai_input = job_input.get("openai_route"), job_input.get("openai_input") | ||
|
||
openai_url = f"{base_url}" + openai_route | ||
headers = {"Content-Type": "application/json"} | ||
|
||
response = requests.post(openai_url, headers=headers, json=openai_input) | ||
# Process the streamed response | ||
if openai_input.get("stream", False): | ||
for formated_chunk in response: | ||
yield formated_chunk | ||
else: | ||
for chunk in response.iter_lines(): | ||
if chunk: | ||
decoded_chunk = chunk.decode('utf-8') | ||
yield decoded_chunk | ||
else: | ||
generate_url = f"{base_url}/generate" | ||
headers = {"Content-Type": "application/json"} | ||
# Directly pass `job_input` to `json`. Can we tell users the possible fields of `job_input`? | ||
response = requests.post(generate_url, json=job_input, headers=headers) | ||
if response.status_code == 200: | ||
yield response.json() | ||
else: | ||
results = worker.generate(prompts, max_tokens) | ||
yield {"status": "success", "output": results} | ||
except Exception as e: | ||
yield {"status": "error", "message": str(e)} | ||
yield {"error": f"Generate request failed with status code {response.status_code}", "details": response.text} | ||
|
||
runpod.serverless.start({"handler": handler, "return_aggregate_stream": True}) | ||
runpod.serverless.start({"handler": async_handler, "return_aggregate_stream": True}) |
Oops, something went wrong.