generated from runpod-workers/worker-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from runpod-workers/up-init
Up-init
- Loading branch information
Showing
2 changed files
with
85 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,32 @@ | ||
# Base image -> https://github.com/runpod/containers/blob/main/official-templates/base/Dockerfile | ||
# DockerHub -> https://hub.docker.com/r/runpod/base/tags | ||
FROM runpod/base:0.4.0-cuda11.8.0 | ||
# Start with NVIDIA CUDA base image | ||
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 | ||
|
||
# The base image comes with many system dependencies pre-installed to help you get started quickly. | ||
# Please refer to the base image's Dockerfile for more information before adding additional dependencies. | ||
# IMPORTANT: The base image overrides the default huggingface cache location. | ||
# Avoid prompts from apt | ||
ENV DEBIAN_FRONTEND=noninteractive | ||
|
||
# Install system dependencies and Python | ||
RUN apt-get update -y && \ | ||
apt-get install -y python3-pip python3-dev git libopenmpi-dev && \ | ||
apt-get clean && \ | ||
rm -rf /var/lib/apt/lists/* | ||
|
||
# --- Optional: System dependencies --- | ||
# COPY builder/setup.sh /setup.sh | ||
# RUN /bin/bash /setup.sh && \ | ||
# rm /setup.sh | ||
# Clone TensorRT-LLM repository | ||
RUN git clone https://github.com/NVIDIA/TensorRT-LLM.git /app/TensorRT-LLM | ||
|
||
# Set working directory | ||
WORKDIR /app/TensorRT-LLM/examples/llm-api | ||
|
||
# Python dependencies | ||
COPY builder/requirements.txt /requirements.txt | ||
RUN python3.11 -m pip install --upgrade pip && \ | ||
python3.11 -m pip install --upgrade -r /requirements.txt --no-cache-dir && \ | ||
rm /requirements.txt | ||
# Install Python dependencies | ||
RUN pip3 install -r requirements.txt | ||
|
||
# NOTE: The base image comes with multiple Python versions pre-installed. | ||
# It is reccommended to specify the version of Python when running your code. | ||
# Install additional dependencies for the serverless worker | ||
RUN pip3 install --upgrade runpod transformers | ||
|
||
# Set the working directory to /app | ||
WORKDIR /app | ||
|
||
# Add src files (Worker Template) | ||
ADD src . | ||
# Copy the src directory containing handler.py | ||
COPY src /app/src | ||
|
||
CMD python3.11 -u /handler.py | ||
# Command to run the serverless worker | ||
CMD ["python3", "/app/src/handler.py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,70 @@ | ||
""" Example handler file. """ | ||
|
||
import os | ||
import runpod | ||
from typing import List, AsyncGenerator, Dict, Union | ||
from tensorrt_llm import LLM, SamplingParams | ||
from huggingface_hub import login | ||
from tensorrt_llm.hlapi import BuildConfig, KvCacheConfig | ||
|
||
# If your handler runs inference on a model, load the model here. | ||
# You will want models to be loaded into memory before starting serverless. | ||
# Enable build caching | ||
os.environ["TLLM_HLAPI_BUILD_CACHE"] = "1" | ||
# Optionally, set a custom cache directory | ||
# os.environ["TLLM_HLAPI_BUILD_CACHE_ROOT"] = "/path/to/custom/cache" | ||
#HF_TOKEN for downloading models | ||
|
||
|
||
def handler(job): | ||
""" Handler function that will be used to process jobs. """ | ||
job_input = job['input'] | ||
|
||
name = job_input.get('name', 'World') | ||
hf_token = os.environ["HF_TOKEN"] | ||
login(token=hf_token) | ||
|
||
|
||
|
||
return f"Hello, {name}!" | ||
class TRTLLMWorker: | ||
def __init__(self, model_path: str): | ||
self.llm = LLM(model=model_path, enable_build_cache=True, kv_cache_config=KvCacheConfig(), build_config=BuildConfig()) | ||
|
||
def generate(self, prompts: List[str], max_tokens: int = 100) -> List[str]: | ||
sampling_params = SamplingParams(max_new_tokens=max_tokens) | ||
outputs = self.llm.generate(prompts, sampling_params) | ||
|
||
results = [] | ||
for output in outputs: | ||
results.append(output.outputs[0].text) | ||
return results | ||
|
||
|
||
async def generate_async(self, prompts: List[str], max_tokens: int = 100) -> AsyncGenerator[str, None]: | ||
sampling_params = SamplingParams(max_new_tokens=max_tokens) | ||
|
||
async for output in self.llm.generate_async(prompts, sampling_params): | ||
for request_output in output.outputs: | ||
if request_output.text: | ||
yield request_output.text | ||
|
||
# Initialize the worker outside the handler | ||
# This ensures the model is loaded only once when the serverless function starts | ||
# this path is hf model "<org_name>/model_name" egs: meta-llama/Meta-Llama-3.1-8B-Instruct | ||
model_path = os.environ["MODEL_PATH"] | ||
worker = TRTLLMWorker(model_path) | ||
|
||
|
||
async def handler(job: Dict) -> AsyncGenerator[Dict[str, Union[str, List[str]]], None]: | ||
"""Handler function that will be used to process jobs.""" | ||
job_input = job['input'] | ||
prompts = job_input.get('prompts', ["Hello, how are you?"]) | ||
max_tokens = job_input.get('max_tokens', 100) | ||
streaming = job_input.get('streaming', False) | ||
|
||
try: | ||
if streaming: | ||
results = [] | ||
async for chunk in worker.generate_async(prompts, max_tokens): | ||
results.append(chunk) | ||
yield {"status": "streaming", "chunk": chunk} | ||
yield {"status": "success", "output": results} | ||
else: | ||
results = worker.generate(prompts, max_tokens) | ||
yield {"status": "success", "output": results} | ||
except Exception as e: | ||
yield {"status": "error", "message": str(e)} | ||
|
||
runpod.serverless.start({"handler": handler}) | ||
runpod.serverless.start({"handler": handler, "return_aggregate_stream": True}) |