generated from runpod-workers/worker-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: pandyamarut <[email protected]>
- Loading branch information
1 parent
7837407
commit f1d493f
Showing
2 changed files
with
61 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,32 @@ | ||
# Base image -> https://github.com/runpod/containers/blob/main/official-templates/base/Dockerfile | ||
# DockerHub -> https://hub.docker.com/r/runpod/base/tags | ||
FROM runpod/base:0.4.0-cuda11.8.0 | ||
# Start with NVIDIA CUDA base image | ||
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 | ||
|
||
# The base image comes with many system dependencies pre-installed to help you get started quickly. | ||
# Please refer to the base image's Dockerfile for more information before adding additional dependencies. | ||
# IMPORTANT: The base image overrides the default huggingface cache location. | ||
# Avoid prompts from apt | ||
ENV DEBIAN_FRONTEND=noninteractive | ||
|
||
# Install system dependencies and Python | ||
RUN apt-get update -y && \ | ||
apt-get install -y python3-pip python3-dev git libopenmpi-dev && \ | ||
apt-get clean && \ | ||
rm -rf /var/lib/apt/lists/* | ||
|
||
# --- Optional: System dependencies --- | ||
# COPY builder/setup.sh /setup.sh | ||
# RUN /bin/bash /setup.sh && \ | ||
# rm /setup.sh | ||
# Clone TensorRT-LLM repository | ||
RUN git clone https://github.com/NVIDIA/TensorRT-LLM.git /app/TensorRT-LLM | ||
|
||
# Set working directory | ||
WORKDIR /app/TensorRT-LLM/examples/llm-api | ||
|
||
# Python dependencies | ||
COPY builder/requirements.txt /requirements.txt | ||
RUN python3.11 -m pip install --upgrade pip && \ | ||
python3.11 -m pip install --upgrade -r /requirements.txt --no-cache-dir && \ | ||
rm /requirements.txt | ||
# Install Python dependencies | ||
RUN pip3 install -r requirements.txt | ||
|
||
# NOTE: The base image comes with multiple Python versions pre-installed. | ||
# It is reccommended to specify the version of Python when running your code. | ||
# Install additional dependencies for the serverless worker | ||
RUN pip3 install runpod | ||
|
||
# Set the working directory to /app | ||
WORKDIR /app | ||
|
||
# Add src files (Worker Template) | ||
ADD src . | ||
# Copy the src directory containing handler.py | ||
COPY src /app/src | ||
|
||
CMD python3.11 -u /handler.py | ||
# Command to run the serverless worker | ||
CMD ["python3", "/app/src/handler.py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,42 @@ | ||
""" Example handler file. """ | ||
|
||
import os | ||
import runpod | ||
|
||
# If your handler runs inference on a model, load the model here. | ||
# You will want models to be loaded into memory before starting serverless. | ||
|
||
from typing import List | ||
from tensorrt_llm import LLM, SamplingParams | ||
|
||
# Enable build caching | ||
os.environ["TLLM_HLAPI_BUILD_CACHE"] = "1" | ||
# Optionally, set a custom cache directory | ||
# os.environ["TLLM_HLAPI_BUILD_CACHE_ROOT"] = "/path/to/custom/cache" | ||
|
||
class TRTLLMWorker: | ||
def __init__(self, model_path: str): | ||
self.llm = LLM(model=model_path, enable_build_cache=True) | ||
|
||
def generate(self, prompts: List[str], max_tokens: int = 100) -> List[str]: | ||
sampling_params = SamplingParams(max_new_tokens=max_tokens) | ||
outputs = self.llm.generate(prompts, sampling_params) | ||
|
||
results = [] | ||
for output in outputs: | ||
results.append(output.outputs[0].text) | ||
|
||
return results | ||
|
||
# Initialize the worker outside the handler | ||
# This ensures the model is loaded only once when the serverless function starts | ||
worker = TRTLLMWorker("TinyLlama/TinyLlama-1.1B-Chat-v1.0") | ||
|
||
def handler(job): | ||
""" Handler function that will be used to process jobs. """ | ||
"""Handler function that will be used to process jobs.""" | ||
job_input = job['input'] | ||
|
||
name = job_input.get('name', 'World') | ||
|
||
return f"Hello, {name}!" | ||
|
||
|
||
runpod.serverless.start({"handler": handler}) | ||
prompts = job_input.get('prompts', ["Hello, how are you?"]) | ||
max_tokens = job_input.get('max_tokens', 100) | ||
|
||
try: | ||
results = worker.generate(prompts, max_tokens) | ||
return {"status": "success", "output": results} | ||
except Exception as e: | ||
return {"status": "error", "message": str(e)} | ||
|
||
if __name__ == "__main__": | ||
runpod.serverless.start({"handler": handler}) |