From a0094984ed59d73b90f85300da2ab75a5723482b Mon Sep 17 00:00:00 2001 From: Vikramjeet Date: Fri, 17 Jan 2025 14:08:53 +0530 Subject: [PATCH] commit --- api/combined_serve.py | 301 +++++++++++++++++++++++++++++++++++ api/serve.py | 283 -------------------------------- configs/combined_settings.py | 44 ++--- 3 files changed, 325 insertions(+), 303 deletions(-) create mode 100644 api/combined_serve.py delete mode 100644 api/serve.py diff --git a/api/combined_serve.py b/api/combined_serve.py new file mode 100644 index 0000000..22a87b0 --- /dev/null +++ b/api/combined_serve.py @@ -0,0 +1,301 @@ +""" +combined_serve.py + +A unified API for Mochi, LTX, Hunyuan, and Allegro video generation using batch-based parallel processing. +All requests (single or multi) are sent as a batch, and we process them concurrently. + +Usage: + 1. Place this file in your repo (e.g., in `api/combined_serve.py`). + 2. Run: python combined_serve.py + 3. POST requests to http://localhost:8000/api/v1/inference + +Expected request format: +{ + "batch": [ + { + "model_name": "mochi", + "prompt": "A calm ocean scene at sunset", + "negative_prompt": "blurry, worst quality", + "num_inference_steps": 50, + "guidance_scale": 4.5, + "height": 480, + "width": 848 + ... + }, + { + "model_name": "hunyuan", + "prompt": "A beautiful mountain landscape", + "num_frames": 16, + "num_inference_steps": 50, + "fps": 8 + } + ... + ] +} +""" + +import sys +import os +import torch +import litserve as ls +import json +from concurrent.futures import ProcessPoolExecutor, as_completed +from typing import Any, Dict, List +from loguru import logger +from prometheus_client import CollectorRegistry, Histogram, make_asgi_app, multiprocess +from api.ltx_serve import LTXVideoAPI +from api.mochi_serve import MochiVideoAPI +from api.hunyuan_serve import HunyuanVideoAPI +from api.allegro_serve import AllegroVideoAPI +from configs.combined_settings import CombinedBatchRequest, CombinedItemRequest +import time +from typing import Union + +os.environ["PROMETHEUS_MULTIPROC_DIR"] = "/tmp/prometheus_multiproc" +if not os.path.exists("/tmp/prometheus_multiproc"): + os.makedirs("/tmp/prometheus_multiproc") + +registry = CollectorRegistry() +multiprocess.MultiProcessCollector(registry) + + +class PrometheusLogger(ls.Logger): + """ + Enterprise-grade Prometheus metrics collector for combined service monitoring. + + Implements detailed performance tracking with multi-process support for production + deployments. Provides high-resolution timing metrics for all service operations. + + Metrics: + request_processing_seconds: + - Type: Histogram + - Labels: model_name, function_name + - Description: Processing time per operation per model + """ + + def __init__(self): + super().__init__() + self.function_duration = Histogram( + "request_processing_seconds", + "Time spent processing request", + ["model_name", "function_name"], + registry=registry + ) + + def process(self, key: str, value: float) -> None: + """ + Record a metric observation with operation-specific labeling. + + Args: + key: Operation identifier in format "model_name:function_name" + value: Duration measurement in seconds + """ + if ":" in key: + model_name, func_name = key.split(":", 1) + else: + model_name, func_name = "unknown", key + + self.function_duration.labels( + model_name=model_name, + function_name=func_name + ).observe(value) + + +class CombinedVideoAPI(ls.LitAPI): + """ + Combined Video Generation API for Mochi, LTX, Hunyuan, and Allegro models. + This API handles requests in batch form, even for single items. + + Steps: + 1) setup(device): Initialize all sub-APIs on the specified device (CPU, GPU). + 2) decode_request(request): Parse the request body using Pydantic `CombinedBatchRequest`. + 3) predict(inputs): Parallel process each item in the batch. + 4) encode_response(outputs): Format the final JSON response. + """ + + def setup(self, device: str) -> None: + """ + Called once at server startup. + Initializes all model APIs on the same device. + """ + logger.info(f"Initializing CombinedVideoAPI on device={device}") + self.ltx_api = LTXVideoAPI() + self.mochi_api = MochiVideoAPI() + self.hunyuan_api = HunyuanVideoAPI() + self.allegro_api = AllegroVideoAPI() + + self.ltx_api.setup(device=device) + self.mochi_api.setup(device=device) + self.hunyuan_api.setup(device=device) + self.allegro_api.setup(device=device) + + self.model_apis = { + "ltx": self.ltx_api, + "mochi": self.mochi_api, + "hunyuan": self.hunyuan_api, + "allegro": self.allegro_api + } + + logger.info("All sub-APIs (LTX, Mochi, Hunyuan, Allegro) successfully set up") + + def decode_request(self, request: Any) -> Dict[str, List[Dict[str, Any]]]: + """ + Interprets the raw request body as a batch, then validates it. + We unify single vs. multiple requests by requiring a `batch` array. + + Args: + request: The raw request data (usually a dict from the body). + + Returns: + A dictionary with key 'items' containing a list of validated dicts. + + Raises: + ValidationError if the request doesn't match CombinedBatchRequest schema. + """ + # If user directly posted an array, wrap it to match the expected schema + if isinstance(request, list): + request = {"batch": request} + + # Validate using CombinedBatchRequest + validated_batch = CombinedBatchRequest(**request) + + # Convert each CombinedItemRequest into a dict for usage in predict + items = [item.dict() for item in validated_batch.batch] + return {"items": items} + + def predict(self, inputs: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]: + """ + Execute parallel inference for all items in the 'items' list. + + Args: + inputs: Dictionary with key 'items' -> list of items. + Each item is a dict with fields like 'model_name', 'prompt', etc. + + Returns: + List of generation results or error details + """ + items = inputs["items"] + logger.info(f"Processing batch of {len(items)} request(s) in parallel") + results = [] + + for item in items: + try: + start_time = time.time() + model_name = item.get("model_name", "").lower() + + if model_name not in self.model_apis: + raise ValueError(f"Invalid model_name: {model_name}") + + sub_api = self.model_apis[model_name] + sub_decoded = sub_api.decode_request(item) + sub_pred = sub_api.predict(sub_decoded) + + if not sub_pred: + raise RuntimeError("No result returned from sub-API") + + end_time = time.time() + result = { + "status": "success", + "model_name": model_name, + "generation_result": sub_pred[0], + "generation_params": item, + "time_taken": end_time - start_time + } + results.append(result) + logger.info(f"Generation completed for model {model_name}, prompt: {item.get('prompt', '')}") + + except Exception as e: + logger.error(f"Error in generation for model {model_name}: {e}") + error_result = { + "status": "error", + "model_name": model_name, + "error": str(e) + } + results.append(error_result) + + return results if results else [{"status": "error", "error": "No results generated"}] + + def encode_response(self, output: Union[Dict[str, Any], List[Any]]) -> Dict[str, Any]: + """ + Format generation results for API response. + + Args: + output: Raw generation results or error information + + Returns: + Formatted API response with standardized structure + + Note: + Handles both success and error cases with consistent formatting + """ + if isinstance(output, list): + output = output[0] if output else {"status": "error", "error": "No output generated"} + + if output.get("status") == "success": + return { + "status": "success", + "model_name": output.get("model_name"), + "video_path": output.get("generation_result", {}).get("video_path"), + "generation_params": output.get("generation_params", {}), + "time_taken": output.get("time_taken"), + "metrics": { + "total_time": output.get("time_taken") + } + } + else: + return { + "status": "error", + "model_name": output.get("model_name", "unknown"), + "error": output.get("error", "Unknown error occurred") + } + +def main(): + """ + Initialize and launch the combined video generation service. + + Sets up the complete service infrastructure including: + - Prometheus metrics collection + - Structured logging + - API server configuration + - Error handling + """ + prometheus_logger = PrometheusLogger() + prometheus_logger.mount( + path="/api/v1/metrics", + app=make_asgi_app(registry=registry) + ) + + logger.remove() + logger.add( + sys.stdout, + format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function} - {message}", + level="INFO" + ) + logger.add( + "logs/combined_api.log", + rotation="100 MB", + retention="1 week", + level="DEBUG" + ) + + try: + api = CombinedVideoAPI() + server = litserve.LitServer( + api, + api_path='/api/v1/video/combined', + accelerator="auto", + devices="auto", + max_batch_size=4, + track_requests=True, + loggers=[prometheus_logger], + generate_client_file=False + ) + logger.info("Starting combined video generation server on port 8000") + server.run(port=8000) + except Exception as e: + logger.error(f"Server failed to start: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/api/serve.py b/api/serve.py deleted file mode 100644 index 1e7d5d2..0000000 --- a/api/serve.py +++ /dev/null @@ -1,283 +0,0 @@ -""" -combined_serve.py - -A unified API for both LTX and Mochi video generation using batch-based parallel processing. -All requests (single or multi) are sent as a batch, and we process them concurrently. - -Usage: - 1. Place this file in your repo (e.g., in `api/combined_serve.py`). - 2. Run: python combined_serve.py - 3. POST requests to http://localhost:8000/api/v1/inference - -Expected request format: -{ - "batch": [ - { - "model_name": "mochi", - "prompt": "A calm ocean scene at sunset", - "negative_prompt": "blurry, worst quality", - "num_inference_steps": 50, - "guidance_scale": 4.5, - "height": 480, - "width": 848 - ... - }, - ... - ] -} -""" - -import sys -import os -import torch -import litserve as ls -import json -from concurrent.futures import ProcessPoolExecutor, as_completed -from typing import Any, Dict, List -from loguru import logger -from prometheus_client import CollectorRegistry, Histogram, make_asgi_app, multiprocess -from api.ltx_serve import LTXVideoAPI -from api.mochi_serve import MochiVideoAPI -from configs.combined_settings import CombinedBatchRequest, CombinedItemRequest -os.environ["PROMETHEUS_MULTIPROC_DIR"] = "/tmp/prometheus_multiproc" -if not os.path.exists("/tmp/prometheus_multiproc"): - os.makedirs("/tmp/prometheus_multiproc") - -registry = CollectorRegistry() -multiprocess.MultiProcessCollector(registry) - - -class PrometheusLogger(ls.Logger): - """ - Custom logger for Prometheus metrics. - Tracks request duration for each (model_name, function_name) pair. - """ - - def __init__(self): - super().__init__() - self.function_duration = Histogram( - "combined_request_duration_seconds", - "Time spent processing video generation requests", - ["model_name", "function_name"], - registry=registry - ) - - def process(self, key: str, value: float) -> None: - """ - Record metric observations with labels for both model_name and function_name. - `key` is expected to have the format "model_name:function_name". - """ - if ":" in key: - model_name, func_name = key.split(":", 1) - else: - model_name, func_name = "unknown", key - - self.function_duration.labels( - model_name=model_name, - function_name=func_name - ).observe(value) - - -class CombinedVideoAPI(ls.LitAPI): - """ - Combined Video Generation API for both LTX and Mochi models. - This API handles requests in batch form, even for single items. - - Steps: - 1) setup(device): Initialize LTX and Mochi sub-APIs on the specified device (CPU, GPU). - 2) decode_request(request): Parse the request body using Pydantic `CombinedBatchRequest`. - 3) predict(inputs): Parallel process each item in the batch. - 4) encode_response(outputs): Format the final JSON response. - """ - - def setup(self, device: str) -> None: - """ - Called once at server startup. - Initializes both the LTX and Mochi APIs on the same device. - """ - logger.info(f"Initializing CombinedVideoAPI on device={device}") - self.ltx_api = LTXVideoAPI() - self.mochi_api = MochiVideoAPI() - - self.ltx_api.setup(device=device) - self.mochi_api.setup(device=device) - - self.model_apis = { - "ltx": self.ltx_api, - "mochi": self.mochi_api - } - - logger.info("All sub-APIs (LTX, Mochi) successfully set up") - - def decode_request(self, request: Any) -> Dict[str, List[Dict[str, Any]]]: - """ - Interprets the raw request body as a batch, then validates it. - We unify single vs. multiple requests by requiring a `batch` array. - - Args: - request: The raw request data (usually a dict from the body). - - Returns: - A dictionary with key 'items' containing a list of validated dicts. - - Raises: - ValidationError if the request doesn't match CombinedBatchRequest schema. - """ - # If user directly posted an array, wrap it to match the expected schema - if isinstance(request, list): - request = {"batch": request} - - # Validate using CombinedBatchRequest - validated_batch = CombinedBatchRequest(**request) - - # Convert each CombinedItemRequest into a dict for usage in predict - items = [item.dict() for item in validated_batch.batch] - return {"items": items} - - def predict(self, inputs: Dict[str, List[Dict[str, Any]]]) -> Dict[str, Any]: - """ - Execute parallel inference for all items in the 'items' list. - - Args: - inputs: Dictionary with key 'items' -> list of items. - Each item is a dict with fields like 'model_name', 'prompt', etc. - - Returns: - Dictionary with 'batch_results': a list of output dicts, - each containing status, video_id, video_url, etc. - """ - items = inputs["items"] - logger.info(f"Processing batch of {len(items)} request(s) in parallel") - - # We'll define a helper function for one item - def process_single(item: Dict[str, Any]) -> Dict[str, Any]: - """ - Takes a single request dict, delegates to the correct sub-API (LTX or Mochi). - Returns the predicted result (video URL, etc.). - """ - model_name = item.get("model_name", "").lower() - if model_name not in self.model_apis: - return { - "status": "error", - "error": f"Invalid model_name: {model_name}" - } - - sub_api = self.model_apis[model_name] - - # Sub-API workflow: decode -> predict -> single result - # Note: sub_api.decode_request() often returns a list. We'll handle that carefully. - try: - # Prepare sub-request in their expected format - sub_decoded = sub_api.decode_request(item) - sub_pred = sub_api.predict(sub_decoded) - return sub_pred[0] if sub_pred else { - "status": "error", - "error": "No result returned from sub-API." - } - except Exception as e: - logger.error(f"[{model_name}] sub-api error: {e}") - return {"status": "error", "error": str(e), "model_name": model_name} - - # Use a ProcessPoolExecutor to handle CPU-heavy tasks concurrently - results = [] - with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor: - future_to_idx = {} - for idx, item in enumerate(items): - future = executor.submit(process_single, item) - future_to_idx[future] = idx - - for f in as_completed(future_to_idx): - idx = future_to_idx[f] - try: - out = f.result() - out["item_index"] = idx - if "model_name" not in out: - out["model_name"] = items[idx].get("model_name", "unknown") - results.append(out) - except Exception as e: - # If something catastrophic happened in process_single - results.append({"status": "error", "error": str(e), "item_index": idx}) - - # Sort results by item_index so response order matches input order - results.sort(key=lambda x: x["item_index"]) - return {"batch_results": results} - - def encode_response(self, outputs: Dict[str, Any]) -> Dict[str, Any]: - """ - Convert the raw dictionary from `predict` into a final response. - We unify single vs. multiple items: - - The client always receives "batch_results" - (with 1 result if originally a single item). - - Sub-APIs often have their own encode_response() method to standardize the final JSON. - We'll call that to keep consistent format. - - Returns: - The final JSON-serializable dict. - """ - if "batch_results" not in outputs: - return { - "status": "error", - "error": "No batch_results field found in predict output" - } - - for item in outputs["batch_results"]: - if item.get("status") == "error": - continue - - model_name = item.get("model_name", "").lower() - if model_name in self.model_apis: - sub_encoded = self.model_apis[model_name].encode_response(item) - item.update(sub_encoded) - - return outputs - - -def main(): - """ - Main entry point for the combined server, exposing /predict on port 8000. - This version logs metrics to Prometheus and logs to console + file. - """ - from litserve import LitServer - - # PROMETHEUS LOGGER - prometheus_logger = PrometheusLogger() - prometheus_logger.mount( - path="/metrics", - app=make_asgi_app(registry=registry) - ) - - # LOGGING - logger.remove() # Remove default handler - logger.add( - sys.stdout, - format="{time:YYYY-MM-DD HH:mm:ss} " - "| {level: <8} " - "| {name}:{function} - " - "{message}", - level="INFO" - ) - logger.add( - "logs/combined_api.log", - rotation="100 MB", - retention="1 week", - level="DEBUG" - ) - - api = CombinedVideoAPI() - server = LitServer( - api, - api_path="/api/v1/inference", - accelerator="auto", - devices="auto", - max_batch_size=4, - track_requests=True, - loggers=[prometheus_logger] - ) - - logger.info("Starting combined video generation server on port 8000") - server.run(port=8000) - - -if __name__ == "__main__": - main() diff --git a/configs/combined_settings.py b/configs/combined_settings.py index bd9d7d9..e5583c5 100644 --- a/configs/combined_settings.py +++ b/configs/combined_settings.py @@ -1,10 +1,10 @@ """ combined_settings.py -Central Pydantic models and optional unified config for combining Mochi and LTX requests/settings. +Central Pydantic models and optional unified config for combining Mochi, LTX, Hunyuan and Allegro requests/settings. 1) CombinedItemRequest: - - Defines the schema for a single text-to-video request, including model_name (mochi/ltx) + - Defines the schema for a single text-to-video request, including model_name (mochi/ltx/hunyuan/allegro) and common fields like prompt, negative_prompt, resolution, etc. 2) CombinedBatchRequest: @@ -28,12 +28,11 @@ "width": 848 }, { - "model_name": "ltx", - "prompt": "Golden autumn leaves swirling", - "num_inference_steps": 40, - "guidance_scale": 3.0, - "height": 480, - "width": 704 + "model_name": "hunyuan", + "prompt": "A beautiful mountain landscape", + "num_frames": 16, + "num_inference_steps": 50, + "fps": 8 } ] } @@ -45,31 +44,35 @@ # If you want to embed or reference them here: from .mochi_settings import MochiSettings from .ltx_settings import LTXVideoSettings +from .hunyuan_config import HunyuanConfig +from .allegro_config import AllegroConfig from pydantic_settings import BaseSettings class CombinedItemRequest(BaseModel): """ - A single request object for either Mochi or LTX. + A single request object for Mochi, LTX, Hunyuan, or Allegro. Fields: - model_name (str): Which model to use: 'mochi' or 'ltx'. + model_name (str): Which model to use: 'mochi', 'ltx', 'hunyuan', or 'allegro'. prompt (str): Main prompt describing the video content. negative_prompt (Optional[str]): Text describing what to avoid. num_inference_steps (Optional[int]): Override for inference steps. guidance_scale (Optional[float]): Classifier-free guidance scale. height (Optional[int]): Video height in pixels. width (Optional[int]): Video width in pixels. - (Add additional fields as needed for your models.) + num_frames (Optional[int]): Number of frames (Hunyuan/Allegro). + fps (Optional[int]): Frames per second (Hunyuan/Allegro). """ - model_name: str = Field(..., description="Model to use: 'ltx' or 'mochi'.") + model_name: str = Field(..., description="Model to use: 'ltx', 'mochi', 'hunyuan', or 'allegro'.") prompt: str = Field(..., description="Prompt describing the video content.") negative_prompt: Optional[str] = Field(None, description="Things to avoid in generation.") num_inference_steps: Optional[int] = Field(40, description="Number of denoising steps.") guidance_scale: Optional[float] = Field(3.0, description="Guidance scale for generation.") height: Optional[int] = Field(480, description="Video height in pixels.") width: Optional[int] = Field(704, description="Video width in pixels.") - # Add any more fields your sub-models need, e.g. fps, frames, etc. + num_frames: Optional[int] = Field(16, description="Number of frames (Hunyuan/Allegro).") + fps: Optional[int] = Field(8, description="Frames per second (Hunyuan/Allegro).") class CombinedBatchRequest(BaseModel): @@ -80,7 +83,7 @@ class CombinedBatchRequest(BaseModel): { "batch": [ { "model_name": "mochi", "prompt": "...", ... }, - { "model_name": "ltx", "prompt": "...", ... } + { "model_name": "hunyuan", "prompt": "...", ... } ] } """ @@ -91,18 +94,19 @@ class CombinedBatchRequest(BaseModel): class CombinedConfig(BaseSettings): """ - Optional: A unified config that embeds or references your Mochi/LTX model settings. + Optional: A unified config that embeds or references your model settings. - This can be used if you want to store and manipulate both sets of settings in one place. - For example, you might define environment variables to override mochi or ltx defaults. + This can be used if you want to store and manipulate all model settings in one place. + For example, you might define environment variables to override model defaults. Usage: from configs.combined_settings import CombinedConfig combined_config = CombinedConfig() - # Access mochi or ltx settings: combined_config.mochi_config, combined_config.ltx_config """ - mochi_config: MochiSettings = MochiSettings() - ltx_config: LTXVideoSettings = LTXVideoSettings() + mochi: MochiSettings = MochiSettings() + ltx: LTXVideoSettings = LTXVideoSettings() + hunyuan: HunyuanConfig = HunyuanConfig() + allegro: AllegroConfig = AllegroConfig() class Config: env_prefix = "COMBINED_"