|
import gradio as gr |
|
import tempfile |
|
import random |
|
import json |
|
import os |
|
import shutil |
|
import hashlib |
|
import uuid |
|
from pathlib import Path |
|
import time |
|
import logging |
|
import torch |
|
import numpy as np |
|
from typing import Dict, Any, List, Optional, Tuple, Union |
|
from diffusers import AutoencoderKLWan, WanPipeline |
|
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler |
|
from diffusers.utils import export_to_video |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
STORAGE_PATH = Path(os.getenv('STORAGE_PATH', './data')) |
|
LORA_PATH = STORAGE_PATH / "loras" |
|
OUTPUT_PATH = STORAGE_PATH / "output" |
|
|
|
MODEL_VERSION = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" |
|
DEFAULT_PROMPT_PREFIX = "" |
|
|
|
|
|
STORAGE_PATH.mkdir(parents=True, exist_ok=True) |
|
LORA_PATH.mkdir(parents=True, exist_ok=True) |
|
OUTPUT_PATH.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
pipe = None |
|
current_lora_id = None |
|
|
|
def format_time(seconds: float) -> str: |
|
"""Format time duration in seconds to human readable string""" |
|
hours = int(seconds // 3600) |
|
minutes = int((seconds % 3600) // 60) |
|
secs = int(seconds % 60) |
|
|
|
parts = [] |
|
if hours > 0: |
|
parts.append(f"{hours}h") |
|
if minutes > 0: |
|
parts.append(f"{minutes}m") |
|
if secs > 0 or not parts: |
|
parts.append(f"{secs}s") |
|
|
|
return " ".join(parts) |
|
|
|
def upload_lora_file(file: tempfile._TemporaryFileWrapper) -> Tuple[str, str]: |
|
"""Upload a LoRA file and return a hash-based ID for future reference |
|
|
|
Args: |
|
file: Uploaded file object from Gradio |
|
|
|
Returns: |
|
Tuple[str, str]: Hash-based ID for the stored file (returned twice for both outputs) |
|
""" |
|
if file is None: |
|
return "", "" |
|
|
|
try: |
|
|
|
sha256_hash = hashlib.sha256() |
|
with open(file.name, "rb") as f: |
|
for chunk in iter(lambda: f.read(4096), b""): |
|
sha256_hash.update(chunk) |
|
file_hash = sha256_hash.hexdigest() |
|
|
|
|
|
dest_path = LORA_PATH / f"{file_hash}.safetensors" |
|
|
|
|
|
if dest_path.exists(): |
|
logger.info("LoRA file already exists") |
|
return file_hash, file_hash |
|
|
|
|
|
shutil.copy(file.name, dest_path) |
|
|
|
logger.info(f"a new LoRA file has been uploaded") |
|
return file_hash, file_hash |
|
except Exception as e: |
|
logger.error(f"Error uploading LoRA file: {e}") |
|
raise gr.Error(f"Failed to upload LoRA file: {str(e)}") |
|
|
|
def get_lora_file_path(lora_id: Optional[str]) -> Optional[Path]: |
|
"""Get the path to a LoRA file from its hash-based ID |
|
|
|
Args: |
|
lora_id: Hash-based ID of the stored LoRA file |
|
|
|
Returns: |
|
Path: Path to the LoRA file if found, None otherwise |
|
""" |
|
if not lora_id: |
|
return None |
|
|
|
|
|
lora_path = LORA_PATH / f"{lora_id}.safetensors" |
|
if lora_path.exists(): |
|
return lora_path |
|
|
|
return None |
|
|
|
def get_or_create_pipeline( |
|
enable_cpu_offload: bool = True, |
|
flow_shift: float = 3.0 |
|
) -> WanPipeline: |
|
"""Get existing pipeline or create a new one if necessary |
|
|
|
Args: |
|
enable_cpu_offload: Whether to enable CPU offload |
|
flow_shift: Flow shift parameter for scheduler |
|
|
|
Returns: |
|
WanPipeline: The pipeline for generation |
|
""" |
|
global pipe |
|
|
|
if pipe is None: |
|
|
|
logger.info("Creating new pipeline") |
|
|
|
|
|
vae = AutoencoderKLWan.from_pretrained(MODEL_VERSION, subfolder="vae", torch_dtype=torch.float32) |
|
|
|
|
|
pipe = WanPipeline.from_pretrained(MODEL_VERSION, vae=vae, torch_dtype=torch.bfloat16) |
|
|
|
|
|
pipe.scheduler = UniPCMultistepScheduler.from_config( |
|
pipe.scheduler.config, |
|
flow_shift=flow_shift |
|
) |
|
|
|
|
|
pipe.to("cuda") |
|
|
|
|
|
if enable_cpu_offload: |
|
logger.info("Enabling CPU offload") |
|
pipe.enable_model_cpu_offload() |
|
else: |
|
|
|
if pipe.scheduler.config.flow_shift != flow_shift: |
|
logger.info(f"Updating scheduler flow_shift from {pipe.scheduler.config.flow_shift} to {flow_shift}") |
|
pipe.scheduler = UniPCMultistepScheduler.from_config( |
|
pipe.scheduler.config, |
|
flow_shift=flow_shift |
|
) |
|
|
|
return pipe |
|
|
|
def manage_lora_weights(pipe: WanPipeline, lora_id: Optional[str], lora_weight: float) -> Tuple[bool, Optional[Path]]: |
|
"""Manage LoRA weights, loading/unloading only when necessary |
|
|
|
Args: |
|
pipe: The pipeline to manage LoRA weights for |
|
lora_id: UUID of LoRA file to use |
|
lora_weight: Weight of LoRA contribution |
|
|
|
Returns: |
|
Tuple[bool, Optional[Path]]: (Is using LoRA, Path to LoRA file) |
|
""" |
|
global current_lora_id |
|
|
|
|
|
using_lora = lora_id is not None and lora_id.strip() != "" and lora_weight > 0 |
|
|
|
|
|
if not using_lora and current_lora_id is not None: |
|
logger.info(f"Unloading current LoRA with ID") |
|
try: |
|
|
|
pipe.unload_lora_weights() |
|
current_lora_id = None |
|
except Exception as e: |
|
logger.error(f"Error unloading LoRA weights: {e}") |
|
return False, None |
|
|
|
|
|
if using_lora: |
|
lora_path = get_lora_file_path(lora_id) |
|
|
|
if not lora_path: |
|
|
|
logger.warning(f"LoRA file with ID {lora_id} not found. Using base model instead.") |
|
|
|
|
|
if current_lora_id is not None: |
|
logger.info(f"Unloading current LoRA") |
|
try: |
|
pipe.unload_lora_weights() |
|
except Exception as e: |
|
logger.error(f"Error unloading LoRA weights: {e}") |
|
current_lora_id = None |
|
|
|
return False, None |
|
|
|
|
|
if lora_id != current_lora_id: |
|
|
|
if current_lora_id is not None: |
|
logger.info(f"Unloading current LoRA") |
|
try: |
|
pipe.unload_lora_weights() |
|
except Exception as e: |
|
logger.error(f"Error unloading LoRA weights: {e}") |
|
|
|
|
|
logger.info("Using a LoRA") |
|
try: |
|
pipe.load_lora_weights(lora_path, weight_name=str(lora_path), adapter_name="default") |
|
current_lora_id = lora_id |
|
except Exception as e: |
|
logger.error(f"Error loading LoRA weights: {e}") |
|
return False, None |
|
else: |
|
logger.info(f"Using currently loaded LoRA with ID") |
|
|
|
return True, lora_path |
|
|
|
return False, None |
|
|
|
def generate_video( |
|
prompt: str, |
|
negative_prompt: str, |
|
prompt_prefix: str, |
|
width: int, |
|
height: int, |
|
num_frames: int, |
|
guidance_scale: float, |
|
flow_shift: float, |
|
lora_id: Optional[str], |
|
lora_weight: float, |
|
inference_steps: int, |
|
fps: int = 16, |
|
seed: int = -1, |
|
enable_cpu_offload: bool = True, |
|
conditioning_image: Optional[str] = None, |
|
progress=gr.Progress() |
|
) -> str: |
|
"""Generate a video using the Wan model with optional LoRA weights |
|
|
|
Args: |
|
prompt: Text prompt for generation |
|
negative_prompt: Negative text prompt |
|
prompt_prefix: Prefix to add to all prompts |
|
width: Output video width |
|
height: Output video height |
|
num_frames: Number of frames to generate |
|
guidance_scale: Classifier-free guidance scale |
|
flow_shift: Flow shift parameter for scheduler |
|
lora_id: UUID of LoRA file to use |
|
lora_weight: Weight of LoRA contribution |
|
inference_steps: Number of inference steps |
|
fps: Frames per second for output video |
|
seed: Random seed (-1 for random) |
|
enable_cpu_offload: Whether to enable CPU offload for VRAM optimization |
|
conditioning_image: Path to conditioning image for image-to-video (not used in this app) |
|
progress: Gradio progress callback |
|
|
|
Returns: |
|
str: Video path |
|
""" |
|
global pipe, current_lora_id |
|
|
|
try: |
|
|
|
progress(0.00, desc="Initializing generation") |
|
|
|
|
|
progress(0.02, desc="Processing prompt") |
|
if prompt_prefix and not prompt.startswith(prompt_prefix): |
|
full_prompt = f"{prompt_prefix}{prompt}" |
|
else: |
|
full_prompt = prompt |
|
|
|
|
|
adjusted_num_frames = ((num_frames - 1) // 8) * 8 + 1 |
|
if adjusted_num_frames != num_frames: |
|
logger.info(f"Adjusted number of frames from {num_frames} to {adjusted_num_frames} to match model requirements") |
|
num_frames = adjusted_num_frames |
|
|
|
|
|
progress(0.03, desc="Setting up random seed") |
|
if seed == -1: |
|
seed = random.randint(0, 2**32 - 1) |
|
logger.info(f"Using randomly generated seed: {seed}") |
|
|
|
|
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
generator = torch.Generator(device="cuda") |
|
generator = generator.manual_seed(seed) |
|
|
|
|
|
progress(0.05, desc="Preparing model") |
|
pipe = get_or_create_pipeline(enable_cpu_offload, flow_shift) |
|
|
|
|
|
progress(0.25, desc="Managing LoRA weights") |
|
using_lora, lora_path = manage_lora_weights(pipe, lora_id, lora_weight) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_file: |
|
output_path = temp_file.name |
|
|
|
|
|
progress(0.40, desc="Starting video generation") |
|
|
|
|
|
start_time = torch.cuda.Event(enable_timing=True) |
|
end_time = torch.cuda.Event(enable_timing=True) |
|
|
|
start_time.record() |
|
|
|
progress(0.45, desc="Running diffusion process") |
|
|
|
|
|
output = pipe( |
|
prompt=full_prompt, |
|
negative_prompt=negative_prompt, |
|
height=height, |
|
width=width, |
|
num_frames=num_frames, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=inference_steps, |
|
generator=generator, |
|
|
|
|
|
|
|
).frames[0] |
|
|
|
|
|
progress(0.90, desc="Generation complete") |
|
|
|
end_time.record() |
|
torch.cuda.synchronize() |
|
generation_time = start_time.elapsed_time(end_time) / 1000 |
|
|
|
logger.info(f"Video generation completed in {format_time(generation_time)}") |
|
|
|
|
|
progress(0.90, desc="Exporting video") |
|
export_to_video(output, output_path, fps=fps) |
|
|
|
|
|
progress(0.95, desc="Saving video") |
|
|
|
|
|
output_id = str(uuid.uuid4()) |
|
saved_output_path = OUTPUT_PATH / f"{output_id}.mp4" |
|
shutil.copy(output_path, saved_output_path) |
|
logger.info(f"Saved video with ID: {output_id}") |
|
|
|
|
|
|
|
progress(0.98, desc="Cleaning up resources") |
|
|
|
progress(1.0, desc="Generation complete") |
|
|
|
return output_path |
|
|
|
except Exception as e: |
|
import traceback |
|
error_msg = f"Error generating video: {str(e)}\n{traceback.format_exc()}" |
|
logger.error(error_msg) |
|
|
|
|
|
if pipe is not None: |
|
|
|
if current_lora_id is not None: |
|
try: |
|
pipe.unload_lora_weights() |
|
current_lora_id = None |
|
except: |
|
pass |
|
|
|
|
|
try: |
|
pipe = None |
|
torch.cuda.empty_cache() |
|
except: |
|
pass |
|
|
|
|
|
raise gr.Error(f"Error generating video: {str(e)}") |
|
|
|
|
|
with gr.Blocks(title="Video Generation API") as app: |
|
|
|
with gr.Tabs(): |
|
|
|
with gr.TabItem("1️⃣ Upload LoRA"): |
|
gr.Markdown("## Upload LoRA Weights") |
|
gr.Markdown("Upload your custom LoRA weights file to use for generation. The file will be automatically stored and you'll receive a unique hash-based ID.") |
|
|
|
with gr.Row(): |
|
lora_file = gr.File(label="LoRA File (safetensors format)") |
|
|
|
with gr.Row(): |
|
lora_id_output = gr.Textbox(label="LoRA Hash ID (use this in the generation tab)", interactive=False) |
|
|
|
|
|
|
|
|
|
with gr.TabItem("2️⃣ Generate Video"): |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
|
|
prompt = gr.Textbox( |
|
label="Prompt", |
|
placeholder="Enter your prompt here...", |
|
lines=3 |
|
) |
|
|
|
negative_prompt = gr.Textbox( |
|
label="Negative Prompt", |
|
placeholder="Enter negative prompt here...", |
|
lines=3, |
|
value="worst quality, low quality, blurry, jittery, distorted, ugly, deformed, disfigured, messy background" |
|
) |
|
|
|
prompt_prefix = gr.Textbox( |
|
label="Prompt Prefix", |
|
placeholder="Prefix to add to all prompts", |
|
value=DEFAULT_PROMPT_PREFIX |
|
) |
|
|
|
with gr.Row(): |
|
width = gr.Slider( |
|
label="Width", |
|
minimum=256, |
|
maximum=1280, |
|
step=8, |
|
value=1280 |
|
) |
|
|
|
height = gr.Slider( |
|
label="Height", |
|
minimum=256, |
|
maximum=720, |
|
step=8, |
|
value=720 |
|
) |
|
|
|
with gr.Row(): |
|
num_frames = gr.Slider( |
|
label="Number of Frames", |
|
minimum=9, |
|
maximum=257, |
|
step=8, |
|
value=49 |
|
) |
|
|
|
fps = gr.Slider( |
|
label="FPS", |
|
minimum=1, |
|
maximum=60, |
|
step=1, |
|
value=16 |
|
) |
|
|
|
with gr.Row(): |
|
guidance_scale = gr.Slider( |
|
label="Guidance Scale", |
|
minimum=1.0, |
|
maximum=10.0, |
|
step=0.1, |
|
value=5.0 |
|
) |
|
|
|
flow_shift = gr.Slider( |
|
label="Flow Shift", |
|
minimum=0.0, |
|
maximum=10.0, |
|
step=0.1, |
|
value=3.0 |
|
) |
|
|
|
lora_id = gr.Textbox( |
|
label="LoRA ID (from upload tab)", |
|
placeholder="Enter your LoRA ID here...", |
|
) |
|
|
|
with gr.Row(): |
|
lora_weight = gr.Slider( |
|
label="LoRA Weight", |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.01, |
|
value=0.7 |
|
) |
|
|
|
inference_steps = gr.Slider( |
|
label="Inference Steps", |
|
minimum=1, |
|
maximum=100, |
|
step=1, |
|
value=30 |
|
) |
|
|
|
seed = gr.Slider( |
|
label="Generation Seed (-1 for random)", |
|
minimum=-1, |
|
maximum=2147483647, |
|
step=1, |
|
value=-1 |
|
) |
|
|
|
enable_cpu_offload = gr.Checkbox( |
|
label="Enable Model CPU Offload (for low-VRAM GPUs)", |
|
value=False |
|
) |
|
|
|
generate_btn = gr.Button( |
|
"Generate Video", |
|
variant="primary" |
|
) |
|
|
|
with gr.Column(scale=1): |
|
|
|
preview_video = gr.Video( |
|
label="Generated Video", |
|
interactive=False |
|
) |
|
|
|
|
|
generate_btn.click( |
|
fn=generate_video, |
|
inputs=[ |
|
prompt, |
|
negative_prompt, |
|
prompt_prefix, |
|
width, |
|
height, |
|
num_frames, |
|
guidance_scale, |
|
flow_shift, |
|
lora_id, |
|
lora_weight, |
|
inference_steps, |
|
fps, |
|
seed, |
|
enable_cpu_offload |
|
], |
|
outputs=[ |
|
preview_video |
|
] |
|
) |
|
|
|
|
|
lora_file.change( |
|
fn=upload_lora_file, |
|
inputs=[lora_file], |
|
outputs=[lora_id_output, lora_id] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
app.launch() |
|
|