diff --git a/requirements/pt2.txt b/requirements/pt2.txt index 824473ab..0a77ad05 100644 --- a/requirements/pt2.txt +++ b/requirements/pt2.txt @@ -40,3 +40,4 @@ wheel>=0.41.0 xformers>=0.0.20 gradio streamlit-keyup==0.2.0 +imageio[ffmpeg]==2.26.1 diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index e79fc193..47fa2fc6 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -7,6 +7,7 @@ import cv2 import imageio import numpy as np +import shutil import streamlit as st import torch import torch.nn as nn @@ -40,6 +41,12 @@ from torchvision import transforms from torchvision.utils import make_grid, save_image +# Additional options for lower end setups +USE_CUDA = True # Set this to `False`` if you want to force CPU-only mode +lowvram_mode = False # Set to `True` to enable low VRAM mode +# (low VRAM mode = float32 => float16, tested to work great on RTX 3060 w/ 12GB VRAM) + +device = torch.device("cuda" if USE_CUDA and torch.cuda.is_available() else "cpu") @st.cache_resource() def init_st(version_dict, load_ckpt=True, load_filter=True): @@ -59,35 +66,29 @@ def init_st(version_dict, load_ckpt=True, load_filter=True): state["filter"] = DeepFloydDataFiltering(verbose=False) return state - def load_model(model): - model.cuda() - - -lowvram_mode = False - + global device # Use the global device variable + model.to(device) def set_lowvram_mode(mode): global lowvram_mode lowvram_mode = mode - def initial_model_load(model): + global device # Use the global device variable global lowvram_mode if lowvram_mode: - model.model.half() + model.model.half().to(device) else: - model.cuda() + model.to(device) return model - def unload_model(model): global lowvram_mode - if lowvram_mode: - model.cpu() + if lowvram_mode or not USE_CUDA: + model.cpu() # Move model to CPU to free GPU memory torch.cuda.empty_cache() - def load_model_from_config(config, ckpt=None, verbose=True): model = instantiate_from_config(config.model) @@ -497,13 +498,14 @@ def load_img( st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}") return img - def get_init_img(batch_size=1, key=None): - init_image = load_img(key=key).cuda() + device = torch.device("cuda" if USE_CUDA and torch.cuda.is_available() else "cpu") + + init_image = load_img(key=key).to(device) # Use `to(device)` to move to the correct device init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) + return init_image - def do_sample( model, sampler, @@ -529,9 +531,9 @@ def do_sample( st.text("Sampling") outputs = st.empty() - precision_scope = autocast + precision_scope = autocast if USE_CUDA else lambda device: device with torch.no_grad(): - with precision_scope("cuda"): + with precision_scope("cuda" if USE_CUDA else "cpu"): with model.ema_scope(): if T is not None: num_samples = [num_samples, T] @@ -754,7 +756,7 @@ def do_img2img( outputs = st.empty() precision_scope = autocast with torch.no_grad(): - with precision_scope("cuda"): + with precision_scope("cuda" if USE_CUDA else "cpu"): with model.ema_scope(): load_model(model.conditioner) batch, batch_uc = get_batch( @@ -783,20 +785,25 @@ def do_img2img( noise = torch.randn_like(z) - sigmas = sampler.discretization(sampler.num_steps).cuda() + # Move sigmas to the correct device (CUDA or CPU) + sigmas = sampler.discretization(sampler.num_steps).to(device) sigma = sigmas[0] st.info(f"all sigmas: {sigmas}") st.info(f"noising sigma: {sigma}") + + # Offset noise level handling if offset_noise_level > 0.0: noise = noise + offset_noise_level * append_dims( - torch.randn(z.shape[0], device=z.device), z.ndim + torch.randn(z.shape[0], device=device), z.ndim ) + + # Add noise handling if add_noise: - noised_z = z + noise * append_dims(sigma, z.ndim).cuda() + noised_z = z + noise * append_dims(sigma, z.ndim).to(device) noised_z = noised_z / torch.sqrt( 1.0 + sigmas[0] ** 2.0 - ) # Note: hardcoded to DDPM-like scaling. need to generalize later. + ) # Hardcoded to DDPM-like scaling; generalize if needed else: noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0) @@ -893,29 +900,39 @@ def load_img_for_prediction( st.image(pil_image) return image.to(device) * 2.0 - 1.0 - def save_video_as_grid_and_mp4( video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5 ): + # Check if FFmpeg is available + try: + import imageio_ffmpeg + except ImportError: + raise RuntimeError("FFmpeg support is not installed. Use 'pip install imageio[ffmpeg]' to install it.") + + if not shutil.which("ffmpeg"): + raise RuntimeError("System-level FFmpeg not found. Please install it and ensure it's in your PATH.") + os.makedirs(save_path, exist_ok=True) base_count = len(glob(os.path.join(save_path, "*.mp4"))) video_batch = rearrange(video_batch, "(b t) c h w -> b t c h w", t=T) video_batch = embed_watermark(video_batch) + for vid in video_batch: save_image(vid, fp=os.path.join(save_path, f"{base_count:06d}.png"), nrow=4) video_path = os.path.join(save_path, f"{base_count:06d}.mp4") - vid = ( - (rearrange(vid, "t c h w -> t h w c") * 255).cpu().numpy().astype(np.uint8) - ) - imageio.mimwrite(video_path, vid, fps=fps) - - video_path_h264 = video_path[:-4] + "_h264.mp4" - os.system(f"ffmpeg -i '{video_path}' -c:v libx264 '{video_path_h264}'") - with open(video_path_h264, "rb") as f: - video_bytes = f.read() - os.remove(video_path_h264) - st.video(video_bytes) - + vid = (rearrange(vid, "t c h w -> t h w c") * 255).cpu().numpy().astype(np.uint8) + + # Use the correct writer for MP4 format + writer = imageio.get_writer(video_path, fps=fps, format='ffmpeg', codec='libx264') + for frame in vid: + writer.append_data(frame) + writer.close() + + # Confirm that the file was created + if os.path.exists(video_path): + print(f"Video saved successfully at: {video_path}") + base_count += 1 + \ No newline at end of file diff --git a/scripts/demo/video_sampling.py b/scripts/demo/video_sampling.py index 1f4fcfc4..786baa11 100644 --- a/scripts/demo/video_sampling.py +++ b/scripts/demo/video_sampling.py @@ -180,6 +180,12 @@ if mode == "img2vid": img = load_img_for_prediction(W, H) + + # Check if the image is None and use a dummy image if necessary + if img is None: + st.warning("No image provided. Using a dummy tensor for initialization.") + img = torch.zeros([1, 3, H, W]).to(device) # Dummy tensor + if "sv3d" in version: cond_aug = 1e-5 else: