import gradio as gr from fastapi import FastAPI, UploadFile, File import uvicorn, uuid, os import subprocess import sys # Initialize on startup def initialize(): """Initialize the application by downloading models and cloning repositories""" print("Initializing Wan2.1 VACE environment...") # Clone repositories if they don't exist if not os.path.exists("Wan2.1"): print("Cloning Wan2.1 repository...") try: subprocess.run(["git", "clone", "https://github.com/Wan-Video/Wan2.1.git"], check=True) except Exception as e: print(f"Warning: Failed to clone Wan2.1: {e}") if not os.path.exists("VACE"): print("Cloning VACE repository...") try: subprocess.run(["git", "clone", "https://github.com/ali-vilab/VACE.git"], check=True) except Exception as e: print(f"Warning: Failed to clone VACE: {e}") # Patch Wan2.1 attention.py to disable flash_attn requirement attention_file = "Wan2.1/wan/modules/attention.py" if os.path.exists(attention_file): print("Patching attention.py to disable flash_attn requirement...") try: with open(attention_file, 'r') as f: content = f.read() # Replace the assert statement with a fallback if "assert FLASH_ATTN_2_AVAILABLE" in content: # First, ensure F is imported if "import torch.nn.functional as F" not in content: content = "import torch.nn.functional as F\n" + content # Replace the assert with a conditional return # Find the line with assert and get its indentation lines = content.split('\n') for i, line in enumerate(lines): if "assert FLASH_ATTN_2_AVAILABLE" in line: # Get the indentation of the assert line indent = len(line) - len(line.lstrip()) indent_str = ' ' * indent # Replace with properly indented if statement lines[i] = f"{indent_str}if not FLASH_ATTN_2_AVAILABLE:" lines.insert(i + 1, f"{indent_str} return F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)") break content = '\n'.join(lines) with open(attention_file, 'w') as f: f.write(content) print("Successfully patched attention.py") except Exception as e: print(f"Warning: Failed to patch attention.py: {e}") # Download models using huggingface-cli models = [ ("Wan-AI/Wan2.1-VACE-1.3B", "Wan2.1-VACE-1.3B"), ("Wan-AI/Wan2.1-FLF2V-14B-720P", "Wan2.1-FLF2V-14B-720P") ] for repo_id, local_dir in models: if not os.path.exists(local_dir): print(f"Downloading {repo_id}...") try: # Use huggingface-cli to download all files subprocess.run([ "huggingface-cli", "download", repo_id, "--local-dir", local_dir, "--local-dir-use-symlinks", "False" ], check=True) print(f"Successfully downloaded {repo_id}") except subprocess.CalledProcessError as e: print(f"ERROR: Failed to download {repo_id}: {e}") print("Please ensure you have sufficient disk space and network connectivity.") else: print(f"Model {local_dir} already exists, skipping download") # Check for critical model files critical_files = [ "Wan2.1-VACE-1.3B/models_t5_umt5-xxl-enc-bf16.pth", "Wan2.1-VACE-1.3B/diffusion_pytorch_model.safetensors", "Wan2.1-FLF2V-14B-720P/models_t5_umt5-xxl-enc-bf16.pth", "Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model.safetensors.index.json" ] for file_path in critical_files: if not os.path.exists(file_path): print(f"WARNING: Critical model file missing: {file_path}") print("The application may not work properly without this file.") # Run initialization initialize() # Import after initialization from wan_runner import generate_video, generate_image api = FastAPI() @api.post("/generate_video") async def api_generate_video(ref: UploadFile = File(...), first: UploadFile = File(...), last: UploadFile = File(...)): uid = uuid.uuid4().hex os.makedirs(uid, exist_ok=True) paths = [f"{uid}/{name}" for name in ["ref.png", "first.png", "last.png"]] for upload, path in zip([ref, first, last], paths): with open(path, "wb") as f: f.write(await upload.read()) output = f"{uid}/output.mp4" generate_video(*paths, output) return {"video_path": output} @api.post("/generate_image") async def api_generate_image(ref: UploadFile = File(...), prompt: str = ""): uid = uuid.uuid4().hex os.makedirs(uid, exist_ok=True) ref_path = f"{uid}/ref.png" with open(ref_path, "wb") as f: f.write(await ref.read()) output = f"{uid}/output.png" generate_image(ref_path, prompt, output) return {"image_path": output} with gr.Blocks() as demo: with gr.Tab("動画生成"): gr.Markdown("### FLF2V-14B 動画生成\n⚠️ このモデルは**1280×720 (16:9)のみ**サポートしています。アップロード画像が他のサイズでも**1280×720**に自動リサイズされます。\n⏱️ 生成される動画は**5秒間**です。") ref_img = gr.Image(label="参照画像", type="pil") first_img = gr.Image(label="開始画像", type="pil") last_img = gr.Image(label="終了画像", type="pil") btn_video = gr.Button("動画を生成") output_video = gr.Video() def video_ui(ref, first, last): import tempfile from PIL import Image try: with tempfile.TemporaryDirectory() as tmpdir: # FLF2V-14B only supports 1280x720 resolution = (1280, 720) # Resize all images to required resolution # Note: This may change aspect ratio, but it's required by the model ref_resized = ref.resize(resolution, Image.Resampling.LANCZOS) first_resized = first.resize(resolution, Image.Resampling.LANCZOS) last_resized = last.resize(resolution, Image.Resampling.LANCZOS) # Save resized images ref_path = f"{tmpdir}/ref.png" first_path = f"{tmpdir}/first.png" last_path = f"{tmpdir}/last.png" ref_resized.save(ref_path) first_resized.save(first_path) last_resized.save(last_path) output = f"{uuid.uuid4().hex}.mp4" # FLF2V only supports 1280x720 generate_video(ref_path, first_path, last_path, output, size="1280*720") return output except FileNotFoundError as e: raise gr.Error(str(e)) except Exception as e: raise gr.Error(f"動画生成エラー: {str(e)}") btn_video.click(video_ui, [ref_img, first_img, last_img], output_video) with gr.Tab("画像生成"): gr.Markdown("### VACE-1.3B 画像生成\n⚠️ このモデルは**832×480(横長)または480×832(縦長)のみ**サポートしています。アップロード画像が他のサイズでも対応解像度に自動リサイズされます。") ref_img2 = gr.Image(label="参照画像", type="pil") prompt = gr.Textbox(label="画像プロンプト") btn_image = gr.Button("画像を生成") output_image = gr.Image() def image_ui(ref, prompt): import tempfile from PIL import Image try: with tempfile.TemporaryDirectory() as tmpdir: # Get original aspect ratio orig_width, orig_height = ref.size aspect_ratio = orig_width / orig_height # Supported resolutions for VACE model based on generate.py # エラーメッセージから実際にサポートされているのは以下の2つのみ supported_resolutions = [ (832, 480), # 16:9 (approx) landscape (480, 832), # 9:16 (approx) portrait ] # Find best matching resolution based on aspect ratio best_resolution = min(supported_resolutions, key=lambda res: abs((res[0]/res[1]) - aspect_ratio)) # Resize to best matching resolution ref_resized = ref.resize(best_resolution, Image.Resampling.LANCZOS) # Save resized image ref_path = f"{tmpdir}/ref.png" ref_resized.save(ref_path) # Update size parameter for model size_param = f"{best_resolution[0]}*{best_resolution[1]}" output = f"{uuid.uuid4().hex}.png" # Pass size parameter to generate_image generate_image(ref_path, prompt, output, size=size_param) return output except FileNotFoundError as e: raise gr.Error(str(e)) except Exception as e: raise gr.Error(f"画像生成エラー: {str(e)}") btn_image.click(image_ui, [ref_img2, prompt], output_image) app = gr.mount_gradio_app(api, demo, path="/") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)