""" ╔══════════════════════════════════════════════════════════╗ ║ AI VIDEO ENHANCER — ULTIMATE EDITION ║ ║ RealESRGAN 4K + GFPGAN Face Restoration ║ ║ Production-Grade | Memory Optimized | ZeroGPU ║ ╚══════════════════════════════════════════════════════════╝ """ import os import sys import gc import json import shutil import subprocess import tempfile import time import logging from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path from typing import Optional, Tuple # ────────────────────────────────────────────── # 🔧 COMPATIBILITY PATCHES (must run before imports) # ────────────────────────────────────────────── def _apply_compat_patches() -> None: """Apply forward-compatibility shims for older third-party libraries.""" # Patch 1: torchvision ≥ 0.17 removed functional_tensor module try: import torchvision.transforms.functional as F sys.modules.setdefault("torchvision.transforms.functional_tensor", F) except ImportError: pass # Patch 2: NumPy 2.x removed legacy type aliases (np.int, np.float, …) import numpy as _np for _alias, _builtin in (("int", int), ("float", float), ("bool", bool), ("complex", complex)): if not hasattr(_np, _alias): setattr(_np, _alias, _builtin) _apply_compat_patches() # ────────────────────────────────────────────── # 📦 STANDARD IMPORTS # ────────────────────────────────────────────── import numpy as np import cv2 import torch import gradio as gr import spaces # ────────────────────────────────────────────── # 🤖 AI LIBRARY IMPORTS (optional — graceful degradation) # ────────────────────────────────────────────── try: from basicsr.archs.rrdbnet_arch import RRDBNet from gfpgan import GFPGANer HAVE_ENHANCERS = True except ImportError as exc: logging.warning("AI libraries unavailable (%s). Install via requirements.txt.", exc) HAVE_ENHANCERS = False RRDBNet = object # type: ignore[assignment,misc] GFPGANer = object # type: ignore[assignment,misc] # ────────────────────────────────────────────── # ⚙️ CONFIGURATION # ────────────────────────────────────────────── logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") log = logging.getLogger(__name__) MAX_DURATION_SECONDS: int = 120 # Maximum accepted video length FRAME_DOWNSCALE_WIDTH: int = 1280 # Pre-processing cap to protect VRAM BATCH_SIZE: int = 1 # Frames processed per GPU pass (safe for ZeroGPU) FFMPEG_CRF: int = 20 # Output quality (lower = better, larger file) FFMPEG_PRESET: str = "medium" # Encoding speed/quality trade-off AUDIO_BITRATE: str = "192k" TEMP_DIR: Path = Path(tempfile.gettempdir()) / "ai_video_enhancer" WEIGHTS_DIR: Path = Path("weights") TEMP_DIR.mkdir(parents=True, exist_ok=True) WEIGHTS_DIR.mkdir(parents=True, exist_ok=True) MODEL_URLS: dict[str, str] = { "RealESRGAN_x4plus.pth": ( "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth" ), "GFPGANv1.4.pth": ( "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" ), } # ────────────────────────────────────────────── # 🗂️ DATA CLASSES # ────────────────────────────────────────────── @dataclass(frozen=True) class VideoMeta: duration: float width: int height: int fps: float @dataclass class ProcessingResult: status: str video_path: Optional[str] = None comparison_path: Optional[str] = None success: bool = False # ────────────────────────────────────────────── # 🛠️ UTILITIES # ────────────────────────────────────────────── def download_weights() -> None: """Download model weights if they are not already present.""" for filename, url in MODEL_URLS.items(): dest = WEIGHTS_DIR / filename if dest.exists(): log.info("✔ Weights already cached: %s", filename) continue log.info("📥 Downloading %s …", filename) try: torch.hub.download_url_to_file(url, str(dest)) log.info("✔ Downloaded: %s", filename) except Exception as exc: raise RuntimeError(f"Failed to download {filename}: {exc}") from exc def _require_ffmpeg() -> None: """Raise an informative error when FFmpeg is missing from PATH.""" try: subprocess.run( ["ffmpeg", "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True, ) except (FileNotFoundError, subprocess.CalledProcessError) as exc: raise RuntimeError( "FFmpeg is not installed or not found in PATH. " "Please install it: https://ffmpeg.org/download.html" ) from exc def run_ffmpeg(*args: str) -> None: """Execute an FFmpeg command, raising on non-zero exit.""" _require_ffmpeg() cmd = ["ffmpeg", "-y", *args] result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) if result.returncode != 0: raise RuntimeError( f"FFmpeg failed (exit {result.returncode}):\n{result.stderr.decode(errors='replace')}" ) def probe_video(video_path: str) -> VideoMeta: """Return essential metadata for a video file using ffprobe.""" cmd = [ "ffprobe", "-v", "error", "-select_streams", "v:0", "-print_format", "json", "-show_entries", "stream=width,height,duration,r_frame_rate", video_path, ] try: raw = subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode() except subprocess.CalledProcessError as exc: raise RuntimeError(f"ffprobe failed on '{video_path}'") from exc stream = json.loads(raw)["streams"][0] num, den = map(int, stream["r_frame_rate"].split("/")) return VideoMeta( duration=float(stream.get("duration", 0)), width=int(stream["width"]), height=int(stream["height"]), fps=num / den if den else 30.0, ) def has_audio_stream(video_path: str) -> bool: """Return True when the file contains at least one audio stream.""" result = subprocess.run( [ "ffprobe", "-v", "error", "-select_streams", "a", "-show_entries", "stream=codec_type", "-of", "default=noprint_wrappers=1", video_path, ], stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) return bool(result.stdout.decode().strip()) # ────────────────────────────────────────────── # 🎬 VIDEO I/O # ────────────────────────────────────────────── def extract_frames(video_path: str, frames_dir: Path) -> None: """Extract every frame from a video into JPEG files (high quality).""" frames_dir.mkdir(parents=True, exist_ok=True) run_ffmpeg( "-i", video_path, "-vsync", "0", "-q:v", "2", # JPEG quality: 1–31, lower = better str(frames_dir / "%08d.jpg"), ) log.info("✔ Extracted frames → %s", frames_dir) def reassemble_video(frames_dir: Path, audio_src: str, out_path: str, fps: float) -> None: """ Reassemble enhanced frames into an MP4, optionally muxing the original audio. Uses a two-pass approach: encode video first, then mux audio if present. """ tmp_video = str(frames_dir.parent / "tmp_video.mp4") # Pass 1 — encode video (no audio) run_ffmpeg( "-framerate", str(fps), "-i", str(frames_dir / "%08d.jpg"), "-c:v", "libx264", "-preset", FFMPEG_PRESET, "-crf", str(FFMPEG_CRF), "-pix_fmt", "yuv420p", tmp_video, ) # Pass 2 — mux audio (if present) if has_audio_stream(audio_src): run_ffmpeg( "-i", tmp_video, "-i", audio_src, "-c:v", "copy", "-c:a", "aac", "-b:a", AUDIO_BITRATE, "-map", "0:v:0", "-map", "1:a:0", "-shortest", out_path, ) os.remove(tmp_video) else: shutil.move(tmp_video, out_path) log.info("✔ Final video → %s", out_path) # ────────────────────────────────────────────── # 🖼️ COMPARISON IMAGE # ────────────────────────────────────────────── def create_comparison_card( before_path: str, after_img: np.ndarray, save_path: str, ) -> Optional[str]: """ Render a side-by-side BEFORE/AFTER comparison card. Args: before_path: Path to the original (un-enhanced) frame. after_img: Already-enhanced frame as a NumPy BGR array. save_path: Where to write the comparison JPEG. Returns: save_path on success, None on failure. """ before_img = cv2.imread(before_path) if before_img is None or after_img is None: log.warning("Comparison skipped — could not read source images.") return None h, w = after_img.shape[:2] before_resized = cv2.resize(before_img, (w, h), interpolation=cv2.INTER_CUBIC) canvas = np.hstack((before_resized, after_img)) font = cv2.FONT_HERSHEY_SIMPLEX scale, thickness_fg, thickness_shadow = 2.5, 4, 15 shadow_color, before_color, after_color = (0, 0, 0), (255, 255, 255), (0, 255, 80) y_pos = 100 for text, x_offset, fg_color in ( ("ORIGINAL", 50, before_color), ("ENHANCED (AI)", w + 50, after_color), ): # Shadow pass cv2.putText(canvas, text, (x_offset, y_pos), font, scale, shadow_color, thickness_shadow, cv2.LINE_AA) # Foreground pass cv2.putText(canvas, text, (x_offset, y_pos), font, scale, fg_color, thickness_fg, cv2.LINE_AA) cv2.imwrite(save_path, canvas, [cv2.IMWRITE_JPEG_QUALITY, 95]) return save_path # ────────────────────────────────────────────── # 🧠 GPU PROCESSING # ────────────────────────────────────────────── @contextmanager def _managed_models(device: torch.device): """ Context manager that loads AI models, yields them, then releases all VRAM. Ensures cleanup even when an exception is raised mid-processing. """ download_weights() log.info("🔄 Loading AI models on %s …", device) model_bg = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4, ) checkpoint = torch.load( str(WEIGHTS_DIR / "RealESRGAN_x4plus.pth"), map_location=device, ) state_dict = checkpoint.get("params_ema") or checkpoint["params"] model_bg.load_state_dict(state_dict, strict=True) model_bg.eval().half().to(device) class _RealESRGANWrapper: """Minimal wrapper to make the raw RRDBNet compatible with GFPGANer.""" def __init__(self, model: torch.nn.Module, dev: torch.device) -> None: self._model = model self._dev = dev def enhance( self, img: np.ndarray, outscale: Optional[float] = None, **_kwargs, ) -> Tuple[np.ndarray, None]: # HWC BGR uint8 → NCHW RGB float16 tensor = ( torch.from_numpy(img[:, :, ::-1].copy()) # BGR→RGB .permute(2, 0, 1) .unsqueeze(0) .half() .to(self._dev) .div(255.0) ) with torch.no_grad(): out = self._model(tensor) # NCHW → HWC, clamp, scale, BGR out_np = ( out.squeeze() .float() .cpu() .clamp_(0, 1) .permute(1, 2, 0) .numpy() ) return (out_np[:, :, ::-1] * 255.0).round().astype(np.uint8), None upsampler = _RealESRGANWrapper(model_bg, device) face_enhancer = GFPGANer( model_path=str(WEIGHTS_DIR / "GFPGANv1.4.pth"), upscale=4, arch="clean", channel_multiplier=2, bg_upsampler=upsampler, ) try: yield face_enhancer finally: log.info("🧹 Releasing model memory …") del face_enhancer, upsampler, model_bg torch.cuda.empty_cache() gc.collect() @spaces.GPU(duration=500) def process_frames_on_gpu( frames_dir: Path, progress: gr.Progress = gr.Progress(), ) -> Optional[str]: """ Enhance all JPEG frames in *frames_dir* using RealESRGAN + GFPGAN. Writes enhanced frames back in-place and returns a comparison image path. """ if not HAVE_ENHANCERS: raise RuntimeError( "AI enhancement libraries are unavailable. " "Please install them via requirements.txt." ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") all_frames = sorted(frames_dir.glob("*.jpg")) total = len(all_frames) if total == 0: raise RuntimeError("No frames were extracted. The video may be corrupt or empty.") comparison_source_path = str(all_frames[total // 2]) comparison_enhanced: Optional[np.ndarray] = None with _managed_models(device) as enhancer: for idx in range(0, total, BATCH_SIZE): frame_path = str(all_frames[idx]) progress(idx / total, desc=f"Enhancing frame {idx + 1}/{total}") try: img = cv2.imread(frame_path, cv2.IMREAD_COLOR) if img is None: log.warning("Skipped unreadable frame: %s", frame_path) continue # Downscale oversized frames to protect VRAM h, w = img.shape[:2] if w > FRAME_DOWNSCALE_WIDTH: scale = FRAME_DOWNSCALE_WIDTH / w img = cv2.resize( img, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_AREA, ) _, _, enhanced = enhancer.enhance( img, has_aligned=False, only_center_face=False, paste_back=True, ) cv2.imwrite(frame_path, enhanced, [cv2.IMWRITE_JPEG_QUALITY, 95]) if frame_path == comparison_source_path: comparison_enhanced = enhanced.copy() except Exception as exc: log.error("Error processing frame %s: %s", frame_path, exc, exc_info=True) # Non-fatal: continue with remaining frames finally: # Per-frame VRAM sweep torch.cuda.empty_cache() gc.collect() # Build comparison card if comparison_enhanced is not None: comp_path = str(frames_dir.parent / "comparison.jpg") return create_comparison_card(comparison_source_path, comparison_enhanced, comp_path) return None # ────────────────────────────────────────────── # 🚀 MAIN PIPELINE # ────────────────────────────────────────────── def process_video( video_file, progress: gr.Progress = gr.Progress(), ) -> Tuple[str, Optional[str], Optional[str]]: """ Full enhancement pipeline: 1. Validate & probe input video 2. Extract frames 3. AI enhancement (GPU) 4. Reassemble + mux audio 5. Return status, output path, comparison image Returns: (status_message, output_video_path, comparison_image_path) """ if video_file is None: return "⚠️ Please upload a video file.", None, None input_path = video_file.name if hasattr(video_file, "name") else str(video_file) job_dir = TEMP_DIR / f"job_{int(time.time() * 1000)}" job_dir.mkdir(parents=True, exist_ok=True) frames_dir = job_dir / "frames" try: # ── Step 1: Probe ────────────────────────────── progress(0.00, desc="📊 Analysing video …") meta = probe_video(input_path) log.info("Video: %.1fs | %dx%d | %.2f fps", meta.duration, meta.width, meta.height, meta.fps) if meta.duration > MAX_DURATION_SECONDS: return ( f"❌ Video is too long ({meta.duration:.1f}s). " f"Maximum allowed: {MAX_DURATION_SECONDS}s.", None, None, ) # ── Step 2: Extract frames ───────────────────── progress(0.05, desc="🎞️ Extracting frames …") extract_frames(input_path, frames_dir) # ── Step 3: AI Enhancement ───────────────────── progress(0.10, desc="🤖 Starting AI enhancement …") comparison_path = process_frames_on_gpu(frames_dir, progress=progress) # ── Step 4: Reassemble ───────────────────────── progress(0.90, desc="🎬 Encoding final video …") out_path = str(job_dir / "enhanced_output.mp4") reassemble_video(frames_dir, input_path, out_path, meta.fps) progress(1.00, desc="✅ Done!") return "✅ Enhancement complete! Your 4K video is ready.", out_path, comparison_path except Exception as exc: log.exception("Pipeline failed: %s", exc) return f"❌ Error: {exc}", None, None finally: # Always remove the (potentially large) frames directory shutil.rmtree(frames_dir, ignore_errors=True) log.info("🧹 Cleaned up frames directory.") # ────────────────────────────────────────────── # 🎨 GRADIO UI # ────────────────────────────────────────────── _CSS = """ @import url('https://fonts.googleapis.com/css2?family=Rajdhani:wght@400;600;700&family=Space+Mono:ital@0;1&display=swap'); :root { --bg-deep: #070a0f; --bg-card: #0d1117; --bg-panel: #111820; --accent-cyan: #00e5ff; --accent-lime: #b2ff59; --accent-mag: #ff4081; --text-main: #e8f4f8; --text-muted: #5a7080; --border: rgba(0,229,255,0.12); --radius: 12px; --glow-cyan: 0 0 30px rgba(0,229,255,0.25); --glow-lime: 0 0 30px rgba(178,255,89,0.2); } /* ── Base ── */ body, .gradio-container { background: var(--bg-deep) !important; font-family: 'Rajdhani', sans-serif !important; color: var(--text-main) !important; min-height: 100vh; } /* ── Animated scanline grid background ── */ .gradio-container::before { content: ""; position: fixed; inset: 0; background-image: linear-gradient(rgba(0,229,255,0.03) 1px, transparent 1px), linear-gradient(90deg, rgba(0,229,255,0.03) 1px, transparent 1px); background-size: 40px 40px; pointer-events: none; z-index: 0; } /* ── Hero header ── */ .hero-wrap { position: relative; text-align: center; padding: 3rem 1rem 2rem; z-index: 1; } .hero-title { font-family: 'Rajdhani', sans-serif; font-weight: 700; font-size: clamp(2rem, 6vw, 4.5rem); letter-spacing: 0.12em; text-transform: uppercase; background: linear-gradient(90deg, var(--accent-cyan) 0%, var(--accent-lime) 55%, var(--accent-mag) 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text; line-height: 1.1; margin: 0 0 0.5rem; text-shadow: none; filter: drop-shadow(0 0 20px rgba(0,229,255,0.4)); animation: pulse-glow 3s ease-in-out infinite alternate; } @keyframes pulse-glow { from { filter: drop-shadow(0 0 15px rgba(0,229,255,0.3)); } to { filter: drop-shadow(0 0 35px rgba(178,255,89,0.5)); } } .hero-subtitle { font-family: 'Space Mono', monospace; font-size: 0.78rem; letter-spacing: 0.3em; color: var(--text-muted); text-transform: uppercase; } /* ── Cards ── */ .card { background: var(--bg-card); border: 1px solid var(--border); border-radius: var(--radius); padding: 1.5rem; position: relative; overflow: hidden; z-index: 1; } .card::before { content: ""; position: absolute; inset: 0; background: linear-gradient(135deg, rgba(0,229,255,0.04) 0%, transparent 60%); pointer-events: none; } .card-label { font-family: 'Space Mono', monospace; font-size: 0.65rem; letter-spacing: 0.25em; color: var(--accent-cyan); text-transform: uppercase; margin-bottom: 0.8rem; display: flex; align-items: center; gap: 8px; } .card-label::before { content: ""; display: inline-block; width: 20px; height: 2px; background: var(--accent-cyan); box-shadow: 0 0 8px var(--accent-cyan); } /* ── Upload area ── */ .upload-zone { border: 2px dashed rgba(0,229,255,0.25) !important; border-radius: var(--radius) !important; background: rgba(0,229,255,0.02) !important; transition: all 0.3s ease; min-height: 200px; } .upload-zone:hover { border-color: var(--accent-cyan) !important; background: rgba(0,229,255,0.06) !important; box-shadow: var(--glow-cyan); } /* ── Enhance button ── */ .enhance-btn { background: linear-gradient(135deg, #003d4d 0%, #001a26 100%) !important; border: 1px solid var(--accent-cyan) !important; color: var(--accent-cyan) !important; font-family: 'Rajdhani', sans-serif !important; font-weight: 700 !important; font-size: 1.1rem !important; letter-spacing: 0.2em !important; text-transform: uppercase !important; padding: 1rem 2rem !important; border-radius: var(--radius) !important; cursor: pointer; box-shadow: var(--glow-cyan); transition: all 0.3s ease; width: 100%; position: relative; overflow: hidden; } .enhance-btn::after { content: ""; position: absolute; top: -50%; left: -50%; width: 200%; height: 200%; background: linear-gradient( 45deg, transparent 30%, rgba(0,229,255,0.15) 50%, transparent 70% ); transform: translateX(-100%); transition: transform 0.6s ease; } .enhance-btn:hover::after { transform: translateX(100%); } .enhance-btn:hover { background: linear-gradient(135deg, #005566 0%, #002233 100%) !important; box-shadow: 0 0 40px rgba(0,229,255,0.45) !important; transform: translateY(-2px); } /* ── Status log ── */ .status-log textarea { background: #050810 !important; border: 1px solid var(--border) !important; color: var(--accent-lime) !important; font-family: 'Space Mono', monospace !important; font-size: 0.78rem !important; border-radius: 8px !important; min-height: 60px !important; } /* ── Spec badges ── */ .specs { display: flex; flex-wrap: wrap; gap: 8px; justify-content: center; margin: 1.2rem 0 0; z-index: 1; position: relative; } .spec-badge { font-family: 'Space Mono', monospace; font-size: 0.65rem; letter-spacing: 0.15em; color: var(--text-muted); border: 1px solid rgba(255,255,255,0.08); padding: 4px 12px; border-radius: 999px; background: rgba(255,255,255,0.02); text-transform: uppercase; } /* ── Gradio component overrides ── */ .gr-box, .gr-form { background: transparent !important; border: none !important; } label span { color: var(--text-muted) !important; font-family: 'Space Mono', monospace !important; font-size: 0.72rem !important; letter-spacing: 0.1em !important; } """ _HEADER_HTML = """
AI Video Enhancer
RealESRGAN 4K Upscale  ·  GFPGAN Face Restoration  ·  ZeroGPU Accelerated
4× Upscale Face Restoration Audio Preserved Max 2 Minutes CUDA FP16
""" with gr.Blocks(title="AI Video Enhancer", css=_CSS, theme=gr.themes.Base()) as demo: gr.HTML(_HEADER_HTML) with gr.Row(equal_height=False): # ── Left column: Input ───────────────────── with gr.Column(scale=1, elem_classes=["card"]): gr.HTML('
Input Video
') video_input = gr.Video( label="", sources=["upload"], elem_classes=["upload-zone"], ) gr.HTML('
') enhance_btn = gr.Button( "⚡ Enhance Video", variant="primary", elem_classes=["enhance-btn"], ) gr.HTML('
') status_log = gr.Textbox( label="STATUS", interactive=False, placeholder="Waiting for input …", elem_classes=["status-log"], ) # ── Right column: Output ─────────────────── with gr.Column(scale=2, elem_classes=["card"]): gr.HTML('
Before / After Preview
') comparison_img = gr.Image(label="", type="filepath", show_label=False) gr.HTML('
Enhanced Output
') video_output = gr.Video(label="", show_label=False) enhance_btn.click( fn=process_video, inputs=[video_input], outputs=[status_log, video_output, comparison_img], ) # ────────────────────────────────────────────── if __name__ == "__main__": demo.launch()