from dataclasses import dataclass from pathlib import Path import logging import base64 import random import gc import os import numpy as np import torch from typing import Dict, Any, Optional, List, Union, Tuple import json from omegaconf import OmegaConf from PIL import Image import io from pipeline import CausalInferencePipeline from demo_utils.constant import ZERO_VAE_CACHE from demo_utils.vae_block3 import VAEDecoderWrapper from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Get token from environment hf_token = os.getenv("HF_API_TOKEN") # Constraints MAX_LARGE_SIDE = 1280 MAX_SMALL_SIDE = 768 MAX_FRAMES = 169 # Based on Wan model capabilities @dataclass class GenerationConfig: """Configuration for video generation using Wan model""" # general content settings prompt: str = "" negative_prompt: str = "worst quality, lowres, blurry, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles" # video model settings width: int = 960 # Wan model default width height: int = 576 # Wan model default height # number of frames (based on Wan model block structure) num_frames: int = 105 # 7 blocks * 15 frames per block # guidance and sampling settings guidance_scale: float = 7.5 num_inference_steps: int = 4 # Distilled model uses fewer steps # reproducible generation settings seed: int = -1 # -1 means random seed # output settings fps: int = 15 # FPS of the final video quality: int = 18 # Video quality (CRF) # advanced settings mixed_precision: bool = True use_taehv: bool = False # Whether to use TAEHV decoder use_trt: bool = False # Whether to use TensorRT optimized decoder def validate_and_adjust(self) -> 'GenerationConfig': """Validate and adjust parameters to meet constraints""" # Ensure dimensions are multiples of 32 and within limits self.width = max(128, min(MAX_LARGE_SIDE, round(self.width / 32) * 32)) self.height = max(128, min(MAX_LARGE_SIDE, round(self.height / 32) * 32)) # Ensure frame count is reasonable self.num_frames = min(self.num_frames, MAX_FRAMES) # Set random seed if not specified if self.seed == -1: self.seed = random.randint(0, 2**32 - 1) return self def load_image_to_tensor_with_resize_and_crop( image_input: Union[str, bytes], target_height: int = 576, target_width: int = 960, quality: int = 100 ) -> torch.Tensor: """Load and process an image into a tensor for Wan model. Args: image_input: Either a file path (str) or image data (bytes) target_height: Desired height of output tensor target_width: Desired width of output tensor quality: JPEG quality to use when re-encoding """ # Handle base64 data URI if isinstance(image_input, str) and image_input.startswith('data:'): header, encoded = image_input.split(",", 1) image_data = base64.b64decode(encoded) image = Image.open(io.BytesIO(image_data)).convert("RGB") # Handle raw bytes elif isinstance(image_input, bytes): image = Image.open(io.BytesIO(image_input)).convert("RGB") # Handle file path elif isinstance(image_input, str): image = Image.open(image_input).convert("RGB") else: raise ValueError("image_input must be either a file path, bytes, or base64 data URI") # Apply JPEG compression if quality < 100 if quality < 100: buffer = io.BytesIO() image.save(buffer, format="JPEG", quality=quality) buffer.seek(0) image = Image.open(buffer).convert("RGB") # Resize and crop to target dimensions input_width, input_height = image.size aspect_ratio_target = target_width / target_height aspect_ratio_frame = input_width / input_height if aspect_ratio_frame > aspect_ratio_target: new_width = int(input_height * aspect_ratio_target) new_height = input_height x_start = (input_width - new_width) // 2 y_start = 0 else: new_width = input_width new_height = int(input_width / aspect_ratio_target) x_start = 0 y_start = (input_height - new_height) // 2 image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height)) image = image.resize((target_width, target_height)) # Convert to tensor format expected by Wan model frame_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float() frame_tensor = (frame_tensor / 127.5) - 1.0 return frame_tensor.unsqueeze(0) def initialize_vae_decoder(use_taehv=False, use_trt=False, device="cuda"): """Initialize VAE decoder based on configuration""" if use_trt: from demo_utils.vae import VAETRTWrapper print("Initializing TensorRT VAE Decoder...") vae_decoder = VAETRTWrapper() elif use_taehv: print("Initializing TAEHV VAE Decoder...") from demo_utils.taehv import TAEHV taehv_checkpoint_path = "/repository/taehv/taew2_1.pth" if not os.path.exists(taehv_checkpoint_path): print(f"Downloading TAEHV checkpoint to {taehv_checkpoint_path}...") os.makedirs("checkpoints", exist_ok=True) import urllib.request download_url = "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth" try: urllib.request.urlretrieve(download_url, taehv_checkpoint_path) except Exception as e: raise RuntimeError(f"Failed to download taew2_1.pth: {e}") class DotDict(dict): __getattr__ = dict.get class TAEHVDiffusersWrapper(torch.nn.Module): def __init__(self): super().__init__() self.dtype = torch.float16 self.taehv = TAEHV(checkpoint_path=taehv_checkpoint_path).to(self.dtype) self.config = DotDict(scaling_factor=1.0) def decode(self, latents, return_dict=None): return self.taehv.decode_video(latents, parallel=True).mul_(2).sub_(1) vae_decoder = TAEHVDiffusersWrapper() else: print("Initializing Default VAE Decoder...") vae_decoder = VAEDecoderWrapper() try: # I should have called the folder "Wan2.1-T2V-1.3B" instead of "wan2.1" #vae_state_dict = torch.load('/repository/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu") vae_state_dict = torch.load('/repository/wan2.1/Wan2.1_VAE.pth', map_location="cpu") decoder_state_dict = {k: v for k, v in vae_state_dict.items() if 'decoder.' in k or 'conv2' in k} vae_decoder.load_state_dict(decoder_state_dict) except FileNotFoundError: print("Warning: Default VAE weights not found.") vae_decoder.eval().to(dtype=torch.float16).requires_grad_(False).to(device) return vae_decoder def create_wan_pipeline( config: GenerationConfig, device: str = "cuda" ) -> CausalInferencePipeline: """Create and configure the Wan video pipeline""" # Load configuration try: wan_config = OmegaConf.load("/repository/configs/self_forcing_dmd.yaml") default_config = OmegaConf.load("/repository/configs/default_config.yaml") wan_config = OmegaConf.merge(default_config, wan_config) except FileNotFoundError as e: logger.error(f"Error loading config file: {e}") raise RuntimeError(f"Config files not found: {e}") # Initialize model components text_encoder = WanTextEncoder() transformer = WanDiffusionWrapper(is_causal=True) # Load checkpoint checkpoint_path = "/repository/self-forcing/checkpoints/self_forcing_dmd.pt" try: state_dict = torch.load(checkpoint_path, map_location="cpu") transformer.load_state_dict(state_dict.get('generator_ema', state_dict.get('generator'))) except FileNotFoundError as e: logger.error(f"Error loading checkpoint: {e}") raise RuntimeError(f"Checkpoint not found: {checkpoint_path}") # Move to device and set precision text_encoder.eval().to(dtype=torch.float16).requires_grad_(False).to(device) transformer.eval().to(dtype=torch.float16).requires_grad_(False).to(device) # Initialize VAE decoder vae_decoder = initialize_vae_decoder( use_taehv=config.use_taehv, use_trt=config.use_trt, device=device ) # Create pipeline pipeline = CausalInferencePipeline( wan_config, device=device, generator=transformer, text_encoder=text_encoder, vae=vae_decoder ) pipeline.to(dtype=torch.float16).to(device) return pipeline def frames_to_video_bytes(frames: List[np.ndarray], fps: int = 15, quality: int = 18) -> bytes: """Convert frames to MP4 video bytes""" import tempfile import subprocess with tempfile.TemporaryDirectory() as temp_dir: # Save frames as images frame_paths = [] for i, frame in enumerate(frames): frame_path = os.path.join(temp_dir, f"frame_{i:06d}.png") Image.fromarray(frame).save(frame_path) frame_paths.append(frame_path) # Create video using ffmpeg output_path = os.path.join(temp_dir, "output.mp4") cmd = [ "ffmpeg", "-y", "-framerate", str(fps), "-i", os.path.join(temp_dir, "frame_%06d.png"), "-c:v", "libx264", "-crf", str(quality), "-pix_fmt", "yuv420p", "-movflags", "faststart", output_path ] try: subprocess.run(cmd, check=True, capture_output=True) with open(output_path, "rb") as f: return f.read() except subprocess.CalledProcessError as e: logger.error(f"FFmpeg error: {e}") raise RuntimeError(f"Video encoding failed: {e}") class EndpointHandler: """Handler for the Wan Video endpoint""" def __init__(self, model_path: str = "./"): """Initialize the endpoint handler Args: model_path: Path to model weights """ # Enable TF32 for potential speedup on Ampere GPUs torch.backends.cuda.matmul.allow_tf32 = True # The pipeline will be loaded during inference to save memory self.pipeline = None self.device = "cuda" if torch.cuda.is_available() else "cpu" # Perform warm-up inference if GPU is available if self.device == "cuda": logger.info("Performing warm-up inference...") self._warmup() logger.info("Warm-up completed!") else: logger.info("CPU device detected, skipping warm-up") def _warmup(self): """Perform a warm-up inference to prepare the model for future requests""" try: # Create a simple test configuration test_config = GenerationConfig( prompt="a cat walking", negative_prompt="worst quality, lowres", width=480, # Smaller resolution for faster warm-up height=320, num_frames=33, # Fewer frames for faster warm-up guidance_scale=7.5, num_inference_steps=2, # Fewer steps for faster warm-up seed=42, # Fixed seed for consistent warm-up fps=15, mixed_precision=True, ).validate_and_adjust() # Create the pipeline if it doesn't exist if self.pipeline is None: self.pipeline = create_wan_pipeline(test_config, self.device) # Run a quick inference with torch.no_grad(): # Set seeds for reproducibility random.seed(test_config.seed) np.random.seed(test_config.seed) torch.manual_seed(test_config.seed) # Generate video frames (simplified version) conditional_dict = self.pipeline.text_encoder(text_prompts=[test_config.prompt]) for key, value in conditional_dict.items(): conditional_dict[key] = value.to(dtype=torch.float16) rnd = torch.Generator(self.device).manual_seed(int(test_config.seed)) self.pipeline._initialize_kv_cache(1, torch.float16, device=self.device) self.pipeline._initialize_crossattn_cache(1, torch.float16, device=self.device) # Generate a small noise tensor for testing noise = torch.randn([1, 3, 8, 20, 32], device=self.device, dtype=torch.float16, generator=rnd) # Clean up del noise, conditional_dict torch.cuda.empty_cache() gc.collect() logger.info("Warm-up successful!") except Exception as e: # Log the error but don't fail initialization import traceback error_message = f"Warm-up failed (but this is non-critical): {str(e)}\n{traceback.format_exc()}" logger.warning(error_message) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """Process inference requests Args: data: Request data containing inputs and parameters Returns: Dictionary with generated video and metadata """ # Extract inputs and parameters inputs = data.get("inputs", {}) # Support both formats: # 1. {"inputs": {"prompt": "...", "image": "..."}} # 2. {"inputs": "..."} (prompt only) if isinstance(inputs, str): input_prompt = inputs input_image = None else: input_prompt = inputs.get("prompt", "") input_image = inputs.get("image") params = data.get("parameters", {}) if not input_prompt: raise ValueError("Prompt must be provided") # Create and validate configuration config = GenerationConfig( # general content settings prompt=input_prompt, negative_prompt=params.get("negative_prompt", GenerationConfig.negative_prompt), # video model settings width=params.get("width", GenerationConfig.width), height=params.get("height", GenerationConfig.height), num_frames=params.get("num_frames", GenerationConfig.num_frames), guidance_scale=params.get("guidance_scale", GenerationConfig.guidance_scale), num_inference_steps=params.get("num_inference_steps", GenerationConfig.num_inference_steps), # reproducible generation settings seed=params.get("seed", GenerationConfig.seed), # output settings fps=params.get("fps", GenerationConfig.fps), quality=params.get("quality", GenerationConfig.quality), # advanced settings mixed_precision=params.get("mixed_precision", GenerationConfig.mixed_precision), use_taehv=params.get("use_taehv", GenerationConfig.use_taehv), use_trt=params.get("use_trt", GenerationConfig.use_trt), ).validate_and_adjust() try: with torch.no_grad(): # Set random seeds for reproducibility random.seed(config.seed) np.random.seed(config.seed) torch.manual_seed(config.seed) # Create pipeline if not already created if self.pipeline is None: self.pipeline = create_wan_pipeline(config, self.device) # Prepare text conditioning conditional_dict = self.pipeline.text_encoder(text_prompts=[config.prompt]) for key, value in conditional_dict.items(): conditional_dict[key] = value.to(dtype=torch.float16) # Initialize caches rnd = torch.Generator(self.device).manual_seed(int(config.seed)) self.pipeline._initialize_kv_cache(1, torch.float16, device=self.device) self.pipeline._initialize_crossattn_cache(1, torch.float16, device=self.device) # Generate noise tensor noise = torch.randn( [1, 21, 16, config.height // 16, config.width // 16], device=self.device, dtype=torch.float16, generator=rnd ) # Initialize VAE cache vae_cache = None latents_cache = None if not config.use_taehv and not config.use_trt: vae_cache = [c.to(device=self.device, dtype=torch.float16) for c in ZERO_VAE_CACHE] # Generation parameters num_blocks = 7 current_start_frame = 0 all_num_frames = [self.pipeline.num_frame_per_block] * num_blocks all_frames = [] # Generate video blocks for idx, current_num_frames in enumerate(all_num_frames): logger.info(f"Processing block {idx+1}/{num_blocks}") noisy_input = noise[:, current_start_frame : current_start_frame + current_num_frames] # Denoising steps for step_idx, current_timestep in enumerate(self.pipeline.denoising_step_list): timestep = torch.ones([1, current_num_frames], device=noise.device, dtype=torch.int64) * current_timestep _, denoised_pred = self.pipeline.generator( noisy_image_or_video=noisy_input, conditional_dict=conditional_dict, timestep=timestep, kv_cache=self.pipeline.kv_cache1, crossattn_cache=self.pipeline.crossattn_cache, current_start=current_start_frame * self.pipeline.frame_seq_length ) if step_idx < len(self.pipeline.denoising_step_list) - 1: next_timestep = self.pipeline.denoising_step_list[step_idx + 1] noisy_input = self.pipeline.scheduler.add_noise( denoised_pred.flatten(0, 1), torch.randn_like(denoised_pred.flatten(0, 1)), next_timestep * torch.ones([1 * current_num_frames], device=noise.device, dtype=torch.long) ).unflatten(0, denoised_pred.shape[:2]) # Update cache for next block if idx < len(all_num_frames) - 1: self.pipeline.generator( noisy_image_or_video=denoised_pred, conditional_dict=conditional_dict, timestep=torch.zeros_like(timestep), kv_cache=self.pipeline.kv_cache1, crossattn_cache=self.pipeline.crossattn_cache, current_start=current_start_frame * self.pipeline.frame_seq_length, ) # Decode to pixels if config.use_trt: pixels, vae_cache = self.pipeline.vae.forward(denoised_pred.half(), *vae_cache) elif config.use_taehv: if latents_cache is None: latents_cache = denoised_pred else: denoised_pred = torch.cat([latents_cache, denoised_pred], dim=1) latents_cache = denoised_pred[:, -3:] pixels = self.pipeline.vae.decode(denoised_pred) else: pixels, vae_cache = self.pipeline.vae(denoised_pred.half(), *vae_cache) # Handle frame skipping if idx == 0 and not config.use_trt: pixels = pixels[:, 3:] elif config.use_taehv and idx > 0: pixels = pixels[:, 12:] # Convert frames to numpy for frame_idx in range(pixels.shape[1]): frame_tensor = pixels[0, frame_idx] frame_np = torch.clamp(frame_tensor.float(), -1., 1.) * 127.5 + 127.5 frame_np = frame_np.to(torch.uint8).cpu().numpy() frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC all_frames.append(frame_np) current_start_frame += current_num_frames # Convert frames to video video_bytes = frames_to_video_bytes(all_frames, fps=config.fps, quality=config.quality) # Convert to base64 data URI video_b64 = base64.b64encode(video_bytes).decode('utf-8') video_uri = f"data:video/mp4;base64,{video_b64}" # Prepare metadata metadata = { "width": config.width, "height": config.height, "num_frames": len(all_frames), "fps": config.fps, "duration": len(all_frames) / config.fps, "seed": config.seed, "prompt": config.prompt, } # Clean up to prevent CUDA OOM errors del noise, conditional_dict, pixels if self.device == "cuda": torch.cuda.empty_cache() gc.collect() return { "video": video_uri, "content-type": "video/mp4", "metadata": metadata } except Exception as e: # Log the error and reraise import traceback error_message = f"Error generating video: {str(e)}\n{traceback.format_exc()}" logger.error(error_message) raise RuntimeError(error_message)