Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
Signed-off-by: pandyamarut <[email protected]>
  • Loading branch information
pandyamarut committed Sep 17, 2024
1 parent 7837407 commit f1d493f
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 34 deletions.
43 changes: 23 additions & 20 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,29 +1,32 @@
# Base image -> https://github.com/runpod/containers/blob/main/official-templates/base/Dockerfile
# DockerHub -> https://hub.docker.com/r/runpod/base/tags
FROM runpod/base:0.4.0-cuda11.8.0
# Start with NVIDIA CUDA base image
FROM nvidia/cuda:12.1.0-base-ubuntu22.04

# The base image comes with many system dependencies pre-installed to help you get started quickly.
# Please refer to the base image's Dockerfile for more information before adding additional dependencies.
# IMPORTANT: The base image overrides the default huggingface cache location.
# Avoid prompts from apt
ENV DEBIAN_FRONTEND=noninteractive

# Install system dependencies and Python
RUN apt-get update -y && \
apt-get install -y python3-pip python3-dev git libopenmpi-dev && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*

# --- Optional: System dependencies ---
# COPY builder/setup.sh /setup.sh
# RUN /bin/bash /setup.sh && \
# rm /setup.sh
# Clone TensorRT-LLM repository
RUN git clone https://github.com/NVIDIA/TensorRT-LLM.git /app/TensorRT-LLM

# Set working directory
WORKDIR /app/TensorRT-LLM/examples/llm-api

# Python dependencies
COPY builder/requirements.txt /requirements.txt
RUN python3.11 -m pip install --upgrade pip && \
python3.11 -m pip install --upgrade -r /requirements.txt --no-cache-dir && \
rm /requirements.txt
# Install Python dependencies
RUN pip3 install -r requirements.txt

# NOTE: The base image comes with multiple Python versions pre-installed.
# It is reccommended to specify the version of Python when running your code.
# Install additional dependencies for the serverless worker
RUN pip3 install runpod

# Set the working directory to /app
WORKDIR /app

# Add src files (Worker Template)
ADD src .
# Copy the src directory containing handler.py
COPY src /app/src

CMD python3.11 -u /handler.py
# Command to run the serverless worker
CMD ["python3", "/app/src/handler.py"]
52 changes: 38 additions & 14 deletions src/handler.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,42 @@
""" Example handler file. """

import os
import runpod

# If your handler runs inference on a model, load the model here.
# You will want models to be loaded into memory before starting serverless.

from typing import List
from tensorrt_llm import LLM, SamplingParams

# Enable build caching
os.environ["TLLM_HLAPI_BUILD_CACHE"] = "1"
# Optionally, set a custom cache directory
# os.environ["TLLM_HLAPI_BUILD_CACHE_ROOT"] = "/path/to/custom/cache"

class TRTLLMWorker:
def __init__(self, model_path: str):
self.llm = LLM(model=model_path, enable_build_cache=True)

def generate(self, prompts: List[str], max_tokens: int = 100) -> List[str]:
sampling_params = SamplingParams(max_new_tokens=max_tokens)
outputs = self.llm.generate(prompts, sampling_params)

results = []
for output in outputs:
results.append(output.outputs[0].text)

return results

# Initialize the worker outside the handler
# This ensures the model is loaded only once when the serverless function starts
worker = TRTLLMWorker("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

def handler(job):
""" Handler function that will be used to process jobs. """
"""Handler function that will be used to process jobs."""
job_input = job['input']

name = job_input.get('name', 'World')

return f"Hello, {name}!"


runpod.serverless.start({"handler": handler})
prompts = job_input.get('prompts', ["Hello, how are you?"])
max_tokens = job_input.get('max_tokens', 100)

try:
results = worker.generate(prompts, max_tokens)
return {"status": "success", "output": results}
except Exception as e:
return {"status": "error", "message": str(e)}

if __name__ == "__main__":
runpod.serverless.start({"handler": handler})

0 comments on commit f1d493f

Please sign in to comment.