|
|
|
|
|
|
|
|
|
""" |
|
Enhanced Video Background Replacement (SAM2 + MatAnyone + AI Backgrounds) |
|
- Strict tensor shapes for MatAnyone (image: 3xHxW, first-frame prob mask: 1xHxW) |
|
- First frame uses PROB path (no idx_mask / objects) to avoid assertion |
|
- Memory management & cleanup |
|
- SDXL / Playground / OpenAI backgrounds |
|
- Gradio UI with "CHAPTER" dividers |
|
- FIXED: Enhanced positioning with debug logging and coordinate precision |
|
""" |
|
|
|
|
|
|
|
|
|
import os |
|
import sys |
|
import gc |
|
import cv2 |
|
import psutil |
|
import time |
|
import json |
|
import base64 |
|
import random |
|
import shutil |
|
import logging |
|
import traceback |
|
import subprocess |
|
import tempfile |
|
import threading |
|
from dataclasses import dataclass |
|
from contextlib import contextmanager |
|
from pathlib import Path |
|
from typing import Optional, Tuple, List |
|
|
|
import numpy as np |
|
from PIL import Image |
|
import gradio as gr |
|
from moviepy.editor import VideoFileClip |
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
|
logger = logging.getLogger("bgx") |
|
|
|
|
|
os.environ.setdefault("CUDA_MODULE_LOADING", "LAZY") |
|
os.environ.setdefault("TORCH_CUDNN_V8_API_ENABLED", "1") |
|
os.environ.setdefault("PYTHONUNBUFFERED", "1") |
|
os.environ.setdefault("MKL_NUM_THREADS", "4") |
|
os.environ.setdefault("BFX_QUALITY", "max") |
|
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:128,roundup_power2_divisions:16") |
|
os.environ.setdefault("HYDRA_FULL_ERROR", "1") |
|
os.environ["OMP_NUM_THREADS"] = "2" |
|
|
|
|
|
BASE_DIR = Path(__file__).resolve().parent |
|
CHECKPOINTS = BASE_DIR / "checkpoints" |
|
TEMP_DIR = BASE_DIR / "temp" |
|
OUT_DIR = BASE_DIR / "outputs" |
|
BACKGROUND_DIR = OUT_DIR / "backgrounds" |
|
for p in (CHECKPOINTS, TEMP_DIR, OUT_DIR, BACKGROUND_DIR): |
|
p.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
try: |
|
import torch |
|
TORCH_AVAILABLE = True |
|
CUDA_AVAILABLE = torch.cuda.is_available() |
|
DEVICE = "cuda" if CUDA_AVAILABLE else "cpu" |
|
try: |
|
if torch.backends.cuda.is_built(): |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
if hasattr(torch.backends, "cudnn"): |
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cudnn.deterministic = False |
|
if CUDA_AVAILABLE: |
|
torch.cuda.set_per_process_memory_fraction(0.8) |
|
except Exception: |
|
pass |
|
except Exception: |
|
TORCH_AVAILABLE = False |
|
CUDA_AVAILABLE = False |
|
DEVICE = "cpu" |
|
|
|
|
|
|
|
|
|
GRADIENT_PRESETS = { |
|
"Blue Fade": ((128, 64, 0), (255, 128, 0)), |
|
"Sunset": ((255, 128, 0), (255, 0, 128)), |
|
"Green Field": ((64, 128, 64), (160, 255, 160)), |
|
"Slate": ((40, 40, 48), (96, 96, 112)), |
|
"Ocean": ((255, 140, 0), (255, 215, 0)), |
|
"Forest": ((34, 139, 34), (144, 238, 144)), |
|
"Sunset Pink": ((255, 182, 193), (255, 105, 180)), |
|
"Cool Blue": ((173, 216, 230), (0, 191, 255)), |
|
} |
|
|
|
AI_PROMPT_SUGGESTIONS = [ |
|
"Custom (write your own)", |
|
"modern minimalist office with soft lighting, clean desk, blurred background", |
|
"elegant conference room with large windows and city view", |
|
"contemporary workspace with plants and natural light", |
|
"luxury hotel lobby with marble floors and warm ambient lighting", |
|
"professional studio with clean white background and soft lighting", |
|
"modern corporate meeting room with glass walls and city skyline", |
|
"sophisticated home office with bookshelf and warm wood tones", |
|
"sleek coworking space with industrial design elements", |
|
"abstract geometric patterns in blue and gold, modern art style", |
|
"soft watercolor texture with pastel colors, dreamy atmosphere", |
|
] |
|
|
|
def _make_vertical_gradient(width: int, height: int, c1, c2) -> np.ndarray: |
|
width = max(1, int(width)) |
|
height = max(1, int(height)) |
|
top = np.array(c1, dtype=np.float32) |
|
bot = np.array(c2, dtype=np.float32) |
|
rows = np.linspace(top, bot, num=height, dtype=np.float32) |
|
grad = np.repeat(rows[:, None, :], repeats=width, axis=1) |
|
return np.clip(grad, 0, 255).astype(np.uint8) |
|
|
|
def run_ffmpeg(args: list, fail_ok=False) -> bool: |
|
cmd = ["ffmpeg", "-y", "-hide_banner", "-loglevel", "error"] + args |
|
try: |
|
subprocess.run(cmd, check=True, capture_output=True) |
|
return True |
|
except Exception as e: |
|
if not fail_ok: |
|
logger.error(f"ffmpeg failed: {e}") |
|
return False |
|
|
|
def write_video_h264(clip, path: str, fps: Optional[int] = None, crf: int = 18, preset: str = "medium"): |
|
fps = fps or max(1, int(round(getattr(clip, "fps", None) or 24))) |
|
clip.write_videofile( |
|
path, |
|
audio=False, |
|
fps=fps, |
|
codec="libx264", |
|
preset=preset, |
|
ffmpeg_params=["-crf", str(crf), "-pix_fmt", "yuv420p", "-profile:v", "high", "-movflags", "+faststart"], |
|
logger=None, |
|
verbose=False, |
|
) |
|
|
|
def download_file(url: str, dest: Path, name: str) -> bool: |
|
if dest.exists(): |
|
logger.info(f"{name} already exists") |
|
return True |
|
try: |
|
import requests |
|
logger.info(f"Downloading {name} ...") |
|
with requests.get(url, stream=True, timeout=300) as r: |
|
r.raise_for_status() |
|
with open(dest, "wb") as f: |
|
for chunk in r.iter_content(chunk_size=8192): |
|
if chunk: |
|
f.write(chunk) |
|
return True |
|
except Exception as e: |
|
logger.error(f"Failed to download {name}: {e}") |
|
if dest.exists(): |
|
try: dest.unlink() |
|
except Exception: pass |
|
return False |
|
|
|
def ensure_repo(repo_name: str, git_url: str) -> Optional[Path]: |
|
repo_path = CHECKPOINTS / f"{repo_name}_repo" |
|
if not repo_path.exists(): |
|
try: |
|
subprocess.run(["git", "clone", "--depth", "1", git_url, str(repo_path)], |
|
check=True, timeout=300, capture_output=True) |
|
logger.info(f"{repo_name} cloned") |
|
except Exception as e: |
|
logger.error(f"Failed to clone {repo_name}: {e}") |
|
return None |
|
repo_str = str(repo_path) |
|
if repo_str not in sys.path: |
|
sys.path.insert(0, repo_str) |
|
return repo_path |
|
|
|
def _reset_hydra(): |
|
try: |
|
from hydra.core.global_hydra import GlobalHydra |
|
if GlobalHydra().is_initialized(): |
|
GlobalHydra.instance().clear() |
|
except Exception: |
|
pass |
|
|
|
|
|
|
|
|
|
@dataclass |
|
class MemoryStats: |
|
cpu_percent: float |
|
cpu_memory_mb: float |
|
gpu_memory_mb: float = 0.0 |
|
gpu_memory_reserved_mb: float = 0.0 |
|
temp_files_count: int = 0 |
|
temp_files_size_mb: float = 0.0 |
|
|
|
class MemoryManager: |
|
def __init__(self): |
|
self.temp_files: List[str] = [] |
|
self.cleanup_lock = threading.Lock() |
|
self.torch_available = TORCH_AVAILABLE |
|
self.cuda_available = CUDA_AVAILABLE |
|
|
|
def get_memory_stats(self) -> MemoryStats: |
|
process = psutil.Process() |
|
cpu_percent = psutil.cpu_percent(interval=0.1) |
|
cpu_memory_mb = process.memory_info().rss / (1024 * 1024) |
|
gpu_memory_mb = 0.0 |
|
gpu_memory_reserved_mb = 0.0 |
|
if self.torch_available and self.cuda_available: |
|
try: |
|
import torch |
|
gpu_memory_mb = torch.cuda.memory_allocated() / (1024 * 1024) |
|
gpu_memory_reserved_mb = torch.cuda.memory_reserved() / (1024 * 1024) |
|
except Exception: |
|
pass |
|
|
|
temp_count, temp_size_mb = 0, 0.0 |
|
for tf in self.temp_files: |
|
if os.path.exists(tf): |
|
temp_count += 1 |
|
try: |
|
temp_size_mb += os.path.getsize(tf) / (1024 * 1024) |
|
except Exception: |
|
pass |
|
return MemoryStats(cpu_percent, cpu_memory_mb, gpu_memory_mb, gpu_memory_reserved_mb, temp_count, temp_size_mb) |
|
|
|
def register_temp_file(self, path: str): |
|
with self.cleanup_lock: |
|
if path not in self.temp_files: |
|
self.temp_files.append(path) |
|
|
|
def cleanup_temp_files(self): |
|
with self.cleanup_lock: |
|
cleaned = 0 |
|
for tf in self.temp_files[:]: |
|
try: |
|
if os.path.isdir(tf): |
|
shutil.rmtree(tf, ignore_errors=True) |
|
elif os.path.exists(tf): |
|
os.unlink(tf) |
|
cleaned += 1 |
|
except Exception as e: |
|
logger.warning(f"Failed to cleanup {tf}: {e}") |
|
finally: |
|
try: self.temp_files.remove(tf) |
|
except Exception: pass |
|
if cleaned: |
|
logger.info(f"Cleaned {cleaned} temp paths") |
|
|
|
def aggressive_cleanup(self): |
|
logger.info("Aggressive cleanup...") |
|
gc.collect() |
|
if self.torch_available and self.cuda_available: |
|
try: |
|
import torch |
|
torch.cuda.empty_cache() |
|
torch.cuda.synchronize() |
|
except Exception: |
|
pass |
|
self.cleanup_temp_files() |
|
gc.collect() |
|
|
|
@contextmanager |
|
def mem_context(self, name="op"): |
|
stats = self.get_memory_stats() |
|
logger.info(f"Start {name} | CPU {stats.cpu_memory_mb:.1f}MB, GPU {stats.gpu_memory_mb:.1f}MB") |
|
try: |
|
yield self |
|
finally: |
|
self.aggressive_cleanup() |
|
stats = self.get_memory_stats() |
|
logger.info(f"End {name} | CPU {stats.cpu_memory_mb:.1f}MB, GPU {stats.gpu_memory_mb:.1f}MB") |
|
|
|
memory_manager = MemoryManager() |
|
|
|
|
|
|
|
|
|
class SystemState: |
|
def __init__(self): |
|
self.torch_available = TORCH_AVAILABLE |
|
self.cuda_available = CUDA_AVAILABLE |
|
self.device = DEVICE |
|
self.sam2_ready = False |
|
self.matanyone_ready = False |
|
self.sam2_error = None |
|
self.matanyone_error = None |
|
|
|
def status_text(self) -> str: |
|
stats = memory_manager.get_memory_stats() |
|
return ( |
|
"=== SYSTEM STATUS ===\n" |
|
f"PyTorch: {'✅' if self.torch_available else '❌'}\n" |
|
f"CUDA: {'✅' if self.cuda_available else '❌'}\n" |
|
f"Device: {self.device}\n" |
|
f"SAM2: {'✅' if self.sam2_ready else ('❌' if self.sam2_error else '⏳')}\n" |
|
f"MatAnyone: {'✅' if self.matanyone_ready else ('❌' if self.matanyone_error else '⏳')}\n\n" |
|
"=== MEMORY ===\n" |
|
f"CPU: {stats.cpu_percent:.1f}% ({stats.cpu_memory_mb:.1f} MB)\n" |
|
f"GPU: {stats.gpu_memory_mb:.1f} MB (Reserved {stats.gpu_memory_reserved_mb:.1f} MB)\n" |
|
f"Temp: {stats.temp_files_count} files ({stats.temp_files_size_mb:.1f} MB)\n" |
|
) |
|
|
|
state = SystemState() |
|
|
|
|
|
|
|
|
|
class SAM2Handler: |
|
def __init__(self): |
|
self.predictor = None |
|
self.initialized = False |
|
|
|
def initialize(self) -> bool: |
|
if not (TORCH_AVAILABLE and CUDA_AVAILABLE): |
|
state.sam2_error = "SAM2 requires CUDA" |
|
return False |
|
|
|
with memory_manager.mem_context("SAM2 init"): |
|
try: |
|
_reset_hydra() |
|
repo_path = ensure_repo("sam2", "https://github.com/facebookresearch/segment-anything-2.git") |
|
if not repo_path: |
|
state.sam2_error = "Clone failed" |
|
return False |
|
|
|
ckpt = CHECKPOINTS / "sam2.1_hiera_large.pt" |
|
url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt" |
|
if not download_file(url, ckpt, "SAM2 Large"): |
|
state.sam2_error = "SAM2 ckpt download failed" |
|
return False |
|
|
|
from hydra.core.global_hydra import GlobalHydra |
|
from hydra import initialize_config_dir |
|
from sam2.build_sam import build_sam2 |
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
|
config_dir = (repo_path / "sam2" / "configs").as_posix() |
|
if GlobalHydra().is_initialized(): |
|
GlobalHydra.instance().clear() |
|
initialize_config_dir(config_dir=config_dir, version_base=None) |
|
|
|
model = build_sam2("sam2.1/sam2.1_hiera_l.yaml", str(ckpt), device="cuda") |
|
self.predictor = SAM2ImagePredictor(model) |
|
|
|
|
|
test = np.zeros((64, 64, 3), dtype=np.uint8) |
|
self.predictor.set_image(test) |
|
masks, scores, _ = self.predictor.predict( |
|
point_coords=np.array([[32, 32]]), |
|
point_labels=np.ones(1, dtype=np.int64), |
|
multimask_output=True, |
|
) |
|
ok = masks is not None and len(masks) > 0 |
|
self.initialized = ok |
|
state.sam2_ready = ok |
|
if not ok: |
|
state.sam2_error = "SAM2 verify failed" |
|
return ok |
|
|
|
except Exception as e: |
|
state.sam2_error = f"SAM2 init error: {e}" |
|
return False |
|
|
|
def create_mask(self, image_rgb: np.ndarray) -> Optional[np.ndarray]: |
|
if not self.initialized: |
|
return None |
|
with memory_manager.mem_context("SAM2 mask"): |
|
try: |
|
self.predictor.set_image(image_rgb) |
|
h, w = image_rgb.shape[:2] |
|
strategies = [ |
|
np.array([[w // 2, h // 2]]), |
|
np.array([[w // 2, h // 3]]), |
|
np.array([[w // 2, h // 3], [w // 2, (2 * h) // 3]]), |
|
] |
|
best, best_score = None, -1.0 |
|
for pc in strategies: |
|
masks, scores, _ = self.predictor.predict( |
|
point_coords=pc, |
|
point_labels=np.ones(len(pc), dtype=np.int64), |
|
multimask_output=True, |
|
) |
|
if masks is not None and len(masks) > 0: |
|
i = int(np.argmax(scores)) |
|
sc = float(scores[i]) |
|
if sc > best_score: |
|
best_score, best = sc, masks[i] |
|
|
|
if best is None: |
|
return None |
|
|
|
mask_u8 = (best * 255).astype(np.uint8) |
|
k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
|
mask_clean = cv2.morphologyEx(mask_u8, cv2.MORPH_CLOSE, k) |
|
mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_OPEN, k) |
|
mask_clean = cv2.GaussianBlur(mask_clean, (3, 3), 1.0) |
|
return mask_clean |
|
except Exception as e: |
|
logger.error(f"SAM2 mask error: {e}") |
|
return None |
|
|
|
|
|
|
|
|
|
class MatAnyoneHandler: |
|
""" |
|
FIXED MatAnyone handler using existing matanyone_fixed files |
|
""" |
|
def __init__(self): |
|
self.core = None |
|
self.initialized = False |
|
|
|
|
|
def _to_chw_float(self, img01: np.ndarray) -> "torch.Tensor": |
|
"""img01: HxWx3 in [0,1] -> torch float (3,H,W) on DEVICE (no batch).""" |
|
assert img01.ndim == 3 and img01.shape[2] == 3, f"Expected HxWx3, got {img01.shape}" |
|
t = torch.from_numpy(img01.transpose(2, 0, 1)).contiguous().float() |
|
return t.to(DEVICE, non_blocking=CUDA_AVAILABLE) |
|
|
|
def _prob_hw_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor": |
|
"""mask_u8: HxW -> torch float (H,W) in [0,1] on DEVICE (no batch, no channel).""" |
|
if mask_u8.shape[0] != h or mask_u8.shape[1] != w: |
|
mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST) |
|
prob = (mask_u8.astype(np.float32) / 255.0) |
|
t = torch.from_numpy(prob).contiguous().float() |
|
return t.to(DEVICE, non_blocking=CUDA_AVAILABLE) |
|
|
|
def _prob_1hw_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor": |
|
"""Optional: 1xHxW (channel-first, still unbatched).""" |
|
if mask_u8.shape[0] != h or mask_u8.shape[1] != w: |
|
mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST) |
|
prob = (mask_u8.astype(np.float32) / 255.0)[None, ...] |
|
t = torch.from_numpy(prob).contiguous().float() |
|
return t.to(DEVICE, non_blocking=CUDA_AVAILABLE) |
|
|
|
def _alpha_to_u8_hw(self, alpha_like) -> np.ndarray: |
|
""" |
|
Accepts torch / numpy / tuple(list) outputs. |
|
Returns uint8 HxW (0..255). Squeezes common shapes down to HxW. |
|
""" |
|
if isinstance(alpha_like, (list, tuple)) and len(alpha_like) > 1: |
|
alpha_like = alpha_like[1] |
|
|
|
if isinstance(alpha_like, torch.Tensor): |
|
t = alpha_like.detach() |
|
if t.is_cuda: |
|
t = t.cpu() |
|
a = t.float().clamp(0, 1).numpy() |
|
else: |
|
a = np.asarray(alpha_like, dtype=np.float32) |
|
a = np.clip(a, 0, 1) |
|
|
|
a = np.squeeze(a) |
|
if a.ndim == 3 and a.shape[0] >= 1: |
|
a = a[0] |
|
if a.ndim != 2: |
|
raise ValueError(f"Alpha must be HxW; got {a.shape}") |
|
|
|
return np.clip(a * 255.0, 0, 255).astype(np.uint8) |
|
|
|
def initialize(self) -> bool: |
|
""" |
|
FIXED MatAnyone initialization using existing matanyone_fixed files |
|
""" |
|
if not TORCH_AVAILABLE: |
|
state.matanyone_error = "PyTorch required" |
|
return False |
|
|
|
with memory_manager.mem_context("MatAnyone init"): |
|
try: |
|
|
|
local_matanyone = BASE_DIR / "matanyone_fixed" |
|
|
|
if not local_matanyone.exists(): |
|
state.matanyone_error = "matanyone_fixed directory not found" |
|
return False |
|
|
|
|
|
matanyone_str = str(local_matanyone) |
|
if matanyone_str not in sys.path: |
|
sys.path.insert(0, matanyone_str) |
|
|
|
|
|
try: |
|
from inference.inference_core import InferenceCore |
|
from utils.get_default_model import get_matanyone_model |
|
except Exception as e: |
|
state.matanyone_error = f"Import error: {e}" |
|
return False |
|
|
|
|
|
ckpt = CHECKPOINTS / "matanyone.pth" |
|
if not ckpt.exists(): |
|
url = "https://github.com/pq-yang/MatAnyone/releases/download/v1.0.0/matanyone.pth" |
|
if not download_file(url, ckpt, "MatAnyone"): |
|
logger.warning("MatAnyone checkpoint download failed, using random weights") |
|
|
|
|
|
net = get_matanyone_model(str(ckpt), device=DEVICE) |
|
|
|
if net is None: |
|
state.matanyone_error = "Model creation failed" |
|
return False |
|
|
|
|
|
self.core = InferenceCore(net) |
|
self.initialized = True |
|
state.matanyone_ready = True |
|
|
|
logger.info("Fixed MatAnyone initialized successfully") |
|
return True |
|
|
|
except Exception as e: |
|
state.matanyone_error = f"MatAnyone init error: {e}" |
|
logger.error(f"MatAnyone initialization failed: {e}") |
|
return False |
|
|
|
def _try_step_variants_seed(self, |
|
img_chw_t: "torch.Tensor", |
|
prob_hw_t: "torch.Tensor", |
|
prob_1hw_t: "torch.Tensor"): |
|
""" |
|
Simplified step variants using fixed MatAnyone |
|
""" |
|
|
|
try: |
|
return self.core.step(img_chw_t, prob_hw_t) |
|
except Exception as e: |
|
try: |
|
return self.core.step(img_chw_t, prob_1hw_t) |
|
except Exception as e2: |
|
|
|
return self.core.step(img_chw_t) |
|
|
|
def _try_step_variants_noseed(self, img_chw_t: "torch.Tensor"): |
|
""" |
|
Simplified noseed variants using fixed MatAnyone |
|
""" |
|
return self.core.step(img_chw_t) |
|
|
|
|
|
def process_video(self, input_path: str, mask_path: str, output_path: str) -> str: |
|
""" |
|
Produce a single-channel alpha mp4 matching input fps & size. |
|
|
|
First frame: pass a soft seed prob (~HW) alongside the image. |
|
Remaining frames: call step(image) only. |
|
""" |
|
if not self.initialized or self.core is None: |
|
raise RuntimeError("MatAnyone not initialized") |
|
|
|
out_dir = Path(output_path) |
|
out_dir.mkdir(parents=True, exist_ok=True) |
|
alpha_path = out_dir / "alpha.mp4" |
|
|
|
cap = cv2.VideoCapture(input_path) |
|
if not cap.isOpened(): |
|
raise RuntimeError("Could not open input video") |
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS) or 24.0 |
|
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
|
|
|
seed_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) |
|
if seed_mask is None: |
|
cap.release() |
|
raise RuntimeError("Seed mask read failed") |
|
|
|
prob_hw_t = self._prob_hw_from_mask_u8(seed_mask, w, h) |
|
prob_1hw_t = self._prob_1hw_from_mask_u8(seed_mask, w, h) |
|
|
|
|
|
tmp_dir = TEMP_DIR / f"ma_{int(time.time())}_{random.randint(1000,9999)}" |
|
tmp_dir.mkdir(parents=True, exist_ok=True) |
|
memory_manager.register_temp_file(str(tmp_dir)) |
|
|
|
frame_idx = 0 |
|
|
|
|
|
ok, frame_bgr = cap.read() |
|
if not ok or frame_bgr is None: |
|
cap.release() |
|
raise RuntimeError("Empty first frame") |
|
frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 |
|
|
|
img_chw_t = self._to_chw_float(frame_rgb01) |
|
|
|
with torch.no_grad(): |
|
out_prob = self._try_step_variants_seed( |
|
img_chw_t, prob_hw_t, prob_1hw_t |
|
) |
|
|
|
alpha_u8 = self._alpha_to_u8_hw(out_prob) |
|
cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8) |
|
frame_idx += 1 |
|
|
|
|
|
while True: |
|
ok, frame_bgr = cap.read() |
|
if not ok or frame_bgr is None: |
|
break |
|
|
|
frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 |
|
img_chw_t = self._to_chw_float(frame_rgb01) |
|
|
|
with torch.no_grad(): |
|
out_prob = self._try_step_variants_noseed(img_chw_t) |
|
|
|
alpha_u8 = self._alpha_to_u8_hw(out_prob) |
|
cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8) |
|
frame_idx += 1 |
|
|
|
cap.release() |
|
|
|
|
|
list_file = tmp_dir / "list.txt" |
|
with open(list_file, "w") as f: |
|
for i in range(frame_idx): |
|
f.write(f"file '{(tmp_dir / f'{i:06d}.png').as_posix()}'\n") |
|
|
|
cmd = [ |
|
"ffmpeg", "-y", "-hide_banner", "-loglevel", "error", |
|
"-f", "concat", "-safe", "0", |
|
"-r", f"{fps:.6f}", |
|
"-i", str(list_file), |
|
"-vf", f"format=gray,scale={w}:{h}:flags=area", |
|
"-pix_fmt", "yuv420p", |
|
"-c:v", "libx264", "-preset", "medium", "-crf", "18", |
|
str(alpha_path) |
|
] |
|
subprocess.run(cmd, check=True) |
|
return str(alpha_path) |
|
|
|
|
|
|
|
|
|
def _maybe_enable_xformers(pipe): |
|
try: |
|
pipe.enable_xformers_memory_efficient_attention() |
|
except Exception: |
|
pass |
|
|
|
def _setup_memory_efficient_pipeline(pipe, require_gpu: bool): |
|
_maybe_enable_xformers(pipe) |
|
if not require_gpu: |
|
try: |
|
if hasattr(pipe, "enable_attention_slicing"): |
|
pipe.enable_attention_slicing("auto") |
|
if hasattr(pipe, "enable_model_cpu_offload"): |
|
pipe.enable_model_cpu_offload() |
|
if hasattr(pipe, "enable_sequential_cpu_offload"): |
|
pipe.enable_sequential_cpu_offload() |
|
except Exception: |
|
pass |
|
|
|
def generate_sdxl_background(width:int, height:int, prompt:str, steps:int=30, guidance:float=7.0, |
|
seed:Optional[int]=None, require_gpu:bool=False) -> str: |
|
if not TORCH_AVAILABLE: |
|
raise RuntimeError("PyTorch required for SDXL") |
|
with memory_manager.mem_context("SDXL background"): |
|
try: |
|
from diffusers import StableDiffusionXLPipeline |
|
except ImportError as e: |
|
raise RuntimeError("Install diffusers/transformers/accelerate") from e |
|
|
|
if require_gpu and not CUDA_AVAILABLE: |
|
raise RuntimeError("Force GPU enabled but CUDA not available") |
|
|
|
device = "cuda" if CUDA_AVAILABLE else "cpu" |
|
torch_dtype = torch.float16 if CUDA_AVAILABLE else torch.float32 |
|
|
|
generator = torch.Generator(device=device) |
|
if seed is None: |
|
seed = random.randint(0, 2**31 - 1) |
|
generator.manual_seed(int(seed)) |
|
|
|
pipe = StableDiffusionXLPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
torch_dtype=torch_dtype, |
|
add_watermarker=False, |
|
).to(device) |
|
|
|
_setup_memory_efficient_pipeline(pipe, require_gpu) |
|
|
|
enhanced = f"{prompt}, professional studio lighting, high detail, clean composition" |
|
img = pipe( |
|
prompt=enhanced, |
|
height=int(height), |
|
width=int(width), |
|
num_inference_steps=int(steps), |
|
guidance_scale=float(guidance), |
|
generator=generator |
|
).images[0] |
|
|
|
out = TEMP_DIR / f"sdxl_bg_{int(time.time())}_{seed or 0:08d}.jpg" |
|
img.save(out, quality=95, optimize=True) |
|
memory_manager.register_temp_file(str(out)) |
|
del pipe, img |
|
return str(out) |
|
|
|
def generate_playground_v25_background(width:int, height:int, prompt:str, steps:int=30, guidance:float=7.0, |
|
seed:Optional[int]=None, require_gpu:bool=False) -> str: |
|
if not TORCH_AVAILABLE: |
|
raise RuntimeError("PyTorch required for Playground v2.5") |
|
with memory_manager.mem_context("Playground v2.5 background"): |
|
try: |
|
from diffusers import DiffusionPipeline |
|
except ImportError as e: |
|
raise RuntimeError("Install diffusers/transformers/accelerate") from e |
|
|
|
if require_gpu and not CUDA_AVAILABLE: |
|
raise RuntimeError("Force GPU enabled but CUDA not available") |
|
|
|
device = "cuda" if CUDA_AVAILABLE else "cpu" |
|
torch_dtype = torch.float16 if CUDA_AVAILABLE else torch.float32 |
|
|
|
generator = torch.Generator(device=device) |
|
if seed is None: |
|
seed = random.randint(0, 2**31 - 1) |
|
generator.manual_seed(int(seed)) |
|
|
|
repo_id = "playgroundai/playground-v2.5-1024px-aesthetic" |
|
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch_dtype).to(device) |
|
_setup_memory_efficient_pipeline(pipe, require_gpu) |
|
|
|
enhanced = f"{prompt}, professional quality, soft light, minimal distractions" |
|
img = pipe( |
|
prompt=enhanced, |
|
height=int(height), |
|
width=int(width), |
|
num_inference_steps=int(steps), |
|
guidance_scale=float(guidance), |
|
generator=generator |
|
).images[0] |
|
|
|
out = TEMP_DIR / f"pg25_bg_{int(time.time())}_{seed or 0:08d}.jpg" |
|
img.save(out, quality=95, optimize=True) |
|
memory_manager.register_temp_file(str(out)) |
|
del pipe, img |
|
return str(out) |
|
|
|
def generate_sd15_background(width:int, height:int, prompt:str, steps:int=25, guidance:float=7.5, |
|
seed:Optional[int]=None, require_gpu:bool=False) -> str: |
|
if not TORCH_AVAILABLE: |
|
raise RuntimeError("PyTorch required for SD 1.5") |
|
with memory_manager.mem_context("SD1.5 background"): |
|
try: |
|
from diffusers import StableDiffusionPipeline |
|
except ImportError as e: |
|
raise RuntimeError("Install diffusers/transformers/accelerate") from e |
|
|
|
if require_gpu and not CUDA_AVAILABLE: |
|
raise RuntimeError("Force GPU enabled but CUDA not available") |
|
|
|
device = "cuda" if CUDA_AVAILABLE else "cpu" |
|
torch_dtype = torch.float16 if CUDA_AVAILABLE else torch.float32 |
|
|
|
generator = torch.Generator(device=device) |
|
if seed is None: |
|
seed = random.randint(0, 2**31 - 1) |
|
generator.manual_seed(int(seed)) |
|
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
torch_dtype=torch_dtype, |
|
safety_checker=None, |
|
requires_safety_checker=False |
|
).to(device) |
|
|
|
_setup_memory_efficient_pipeline(pipe, require_gpu) |
|
|
|
enhanced = f"{prompt}, professional background, clean composition" |
|
img = pipe( |
|
prompt=enhanced, |
|
height=int(height), |
|
width=int(width), |
|
num_inference_steps=int(steps), |
|
guidance_scale=float(guidance), |
|
generator=generator |
|
).images[0] |
|
|
|
out = TEMP_DIR / f"sd15_bg_{int(time.time())}_{seed or 0:08d}.jpg" |
|
img.save(out, quality=95, optimize=True) |
|
memory_manager.register_temp_file(str(out)) |
|
del pipe, img |
|
return str(out) |
|
|
|
def generate_openai_background(width:int, height:int, prompt:str, api_key:str, model:str="gpt-image-1") -> str: |
|
if not api_key or not isinstance(api_key, str) or len(api_key) < 10: |
|
raise RuntimeError("Missing or invalid OpenAI API key") |
|
with memory_manager.mem_context("OpenAI background"): |
|
target = "1024x1024" |
|
url = "https://api.openai.com/v1/images/generations" |
|
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} |
|
body = {"model": model, "prompt": f"{prompt}, professional background, studio lighting, minimal distractions, high detail", |
|
"size": target, "n": 1, "quality": "high"} |
|
import requests |
|
r = requests.post(url, headers=headers, data=json.dumps(body), timeout=120) |
|
if r.status_code != 200: |
|
raise RuntimeError(f"OpenAI API error: {r.status_code} {r.text}") |
|
data = r.json() |
|
b64 = data["data"][0]["b64_json"] |
|
raw = base64.b64decode(b64) |
|
tmp_png = TEMP_DIR / f"openai_raw_{int(time.time())}_{random.randint(1000,9999)}.png" |
|
with open(tmp_png, "wb") as f: |
|
f.write(raw) |
|
img = Image.open(tmp_png).convert("RGB").resize((int(width), int(height)), Image.LANCZOS) |
|
out = TEMP_DIR / f"openai_bg_{int(time.time())}_{random.randint(1000,9999)}.jpg" |
|
img.save(out, quality=95, optimize=True) |
|
try: os.unlink(tmp_png) |
|
except Exception: pass |
|
memory_manager.register_temp_file(str(out)) |
|
return str(out) |
|
|
|
def generate_ai_background_router(width:int, height:int, prompt:str, model:str="SDXL", |
|
steps:int=30, guidance:float=7.0, seed:Optional[int]=None, |
|
openai_key:Optional[str]=None, require_gpu:bool=False) -> str: |
|
try: |
|
if model == "OpenAI (gpt-image-1)": |
|
if not openai_key: |
|
raise RuntimeError("OpenAI API key not provided") |
|
return generate_openai_background(width, height, prompt, openai_key, model="gpt-image-1") |
|
elif model == "Playground v2.5": |
|
return generate_playground_v25_background(width, height, prompt, steps, guidance, seed, require_gpu) |
|
elif model == "SDXL": |
|
return generate_sdxl_background(width, height, prompt, steps, guidance, seed, require_gpu) |
|
else: |
|
return generate_sd15_background(width, height, prompt, steps, guidance, seed, require_gpu) |
|
except Exception as e: |
|
logger.warning(f"{model} generation failed: {e}; falling back to SD1.5/gradient") |
|
try: |
|
return generate_sd15_background(width, height, prompt, steps, guidance, seed, require_gpu=False) |
|
except Exception: |
|
grad = _make_vertical_gradient(width, height, (235, 240, 245), (210, 220, 230)) |
|
out = TEMP_DIR / f"bg_fallback_{int(time.time())}.jpg" |
|
cv2.imwrite(str(out), grad) |
|
memory_manager.register_temp_file(str(out)) |
|
return str(out) |
|
|
|
|
|
|
|
|
|
class ChunkedVideoProcessor: |
|
def __init__(self, chunk_size_frames: int = 60): |
|
self.chunk_size = int(chunk_size_frames) |
|
|
|
def _extract_chunk(self, video_path: str, start_frame: int, end_frame: int, fps: float) -> str: |
|
chunk_path = str(TEMP_DIR / f"chunk_{start_frame}_{end_frame}_{random.randint(1000,9999)}.mp4") |
|
start_time = start_frame / fps |
|
duration = max(0.001, (end_frame - start_frame) / fps) |
|
cmd = [ |
|
"ffmpeg", "-y", "-hide_banner", "-loglevel", "error", |
|
"-ss", f"{start_time:.6f}", "-i", video_path, |
|
"-t", f"{duration:.6f}", |
|
"-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", |
|
"-c:v", "libx264", "-preset", "veryfast", "-crf", "20", |
|
"-an", chunk_path |
|
] |
|
subprocess.run(cmd, check=True) |
|
return chunk_path |
|
|
|
def _merge_chunks(self, chunk_paths: List[str], fps: float, width: int, height: int) -> str: |
|
if not chunk_paths: |
|
raise ValueError("No chunks to merge") |
|
if len(chunk_paths) == 1: |
|
return chunk_paths[0] |
|
concat_file = TEMP_DIR / f"concat_{random.randint(1000,9999)}.txt" |
|
with open(concat_file, "w") as f: |
|
for c in chunk_paths: |
|
f.write(f"file '{c}'\n") |
|
out = TEMP_DIR / f"merged_{random.randint(1000,9999)}.mp4" |
|
cmd = ["ffmpeg", "-y", "-hide_banner", "-loglevel", "error", |
|
"-f", "concat", "-safe", "0", "-i", str(concat_file), |
|
"-c", "copy", str(out)] |
|
subprocess.run(cmd, check=True) |
|
return str(out) |
|
|
|
def process_video_chunks(self, video_path: str, processor_func, **kwargs) -> str: |
|
cap = cv2.VideoCapture(video_path) |
|
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
fps = cap.get(cv2.CAP_PROP_FPS) or 24.0 |
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
cap.release() |
|
|
|
processed: List[str] = [] |
|
for start in range(0, total, self.chunk_size): |
|
end = min(start + self.chunk_size, total) |
|
with memory_manager.mem_context(f"chunk {start}-{end}"): |
|
ch = self._extract_chunk(video_path, start, end, fps) |
|
memory_manager.register_temp_file(ch) |
|
out = processor_func(ch, **kwargs) |
|
memory_manager.register_temp_file(out) |
|
processed.append(out) |
|
return self._merge_chunks(processed, fps, width, height) |
|
|
|
|
|
|
|
|
|
def process_video_main( |
|
video_path: str, |
|
background_path: Optional[str] = None, |
|
trim_duration: Optional[float] = None, |
|
crf: int = 18, |
|
preserve_audio_flag: bool = True, |
|
placement: Optional[dict] = None, |
|
use_chunked_processing: bool = False, |
|
progress: gr.Progress = gr.Progress(track_tqdm=True), |
|
) -> Tuple[Optional[str], str]: |
|
|
|
messages: List[str] = [] |
|
with memory_manager.mem_context("Pipeline"): |
|
try: |
|
progress(0, desc="Initializing models") |
|
sam2 = SAM2Handler() |
|
matanyone = MatAnyoneHandler() |
|
|
|
if not sam2.initialize(): |
|
return None, f"SAM2 init failed: {state.sam2_error}" |
|
if not matanyone.initialize(): |
|
return None, f"MatAnyone init failed: {state.matanyone_error}" |
|
messages.append("✅ SAM2 & MatAnyone initialized") |
|
|
|
progress(0.1, desc="Preparing video") |
|
input_video = video_path |
|
|
|
|
|
if trim_duration and float(trim_duration) > 0: |
|
trimmed = TEMP_DIR / f"trimmed_{int(time.time())}_{random.randint(1000,9999)}.mp4" |
|
memory_manager.register_temp_file(str(trimmed)) |
|
with VideoFileClip(video_path) as clip: |
|
d = min(float(trim_duration), float(clip.duration or trim_duration)) |
|
sub = clip.subclip(0, d) |
|
write_video_h264(sub, str(trimmed), crf=int(crf)) |
|
sub.close() |
|
input_video = str(trimmed) |
|
messages.append(f"✂️ Trimmed to {d:.1f}s") |
|
else: |
|
with VideoFileClip(video_path) as clip: |
|
messages.append(f"🎞️ Full video: {clip.duration:.1f}s") |
|
|
|
progress(0.2, desc="Creating SAM2 mask") |
|
cap = cv2.VideoCapture(input_video) |
|
ret, first_frame = cap.read() |
|
cap.release() |
|
if not ret or first_frame is None: |
|
return None, "Could not read video" |
|
h, w = first_frame.shape[:2] |
|
rgb0 = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) |
|
mask = sam2.create_mask(rgb0) |
|
if mask is None: |
|
return None, "SAM2 mask failed" |
|
|
|
mask_path = TEMP_DIR / f"mask_{int(time.time())}_{random.randint(1000,9999)}.png" |
|
memory_manager.register_temp_file(str(mask_path)) |
|
cv2.imwrite(str(mask_path), mask) |
|
messages.append("✅ Person mask created") |
|
|
|
progress(0.35, desc="Matting video") |
|
if use_chunked_processing: |
|
chunker = ChunkedVideoProcessor(chunk_size_frames=60) |
|
alpha_video = chunker.process_video_chunks( |
|
input_video, |
|
lambda chunk_path, **_k: matanyone.process_video( |
|
input_path=chunk_path, |
|
mask_path=str(mask_path), |
|
output_path=str(TEMP_DIR / f"matanyone_chunk_{int(time.time())}_{random.randint(1000,9999)}") |
|
) |
|
) |
|
memory_manager.register_temp_file(alpha_video) |
|
else: |
|
out_dir = TEMP_DIR / f"matanyone_out_{int(time.time())}_{random.randint(1000,9999)}" |
|
out_dir.mkdir(parents=True, exist_ok=True) |
|
memory_manager.register_temp_file(str(out_dir)) |
|
alpha_video = matanyone.process_video( |
|
input_path=input_video, |
|
mask_path=str(mask_path), |
|
output_path=str(out_dir) |
|
) |
|
|
|
if not alpha_video or not os.path.exists(alpha_video): |
|
return None, "MatAnyone did not produce alpha video" |
|
messages.append("✅ Alpha video generated") |
|
|
|
progress(0.55, desc="Preparing background") |
|
original_clip = VideoFileClip(input_video) |
|
alpha_clip = VideoFileClip(alpha_video) |
|
|
|
if background_path and os.path.exists(background_path): |
|
messages.append("🖼️ Using background file") |
|
bg_bgr = cv2.imread(background_path) |
|
bg_bgr = cv2.resize(bg_bgr, (w, h)) |
|
bg_rgb = cv2.cvtColor(bg_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 |
|
else: |
|
messages.append("🖼️ Using gradient background") |
|
grad = _make_vertical_gradient(w, h, (200, 205, 215), (160, 170, 190)) |
|
bg_rgb = cv2.cvtColor(grad, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 |
|
|
|
|
|
placement = placement or {} |
|
px = max(0.0, min(1.0, float(placement.get("x", 0.5)))) |
|
py = max(0.0, min(1.0, float(placement.get("y", 0.75)))) |
|
ps = max(0.3, min(2.0, float(placement.get("scale", 1.0)))) |
|
feather_px = max(0, min(50, int(placement.get("feather", 3)))) |
|
|
|
|
|
logger.info(f"POSITIONING DEBUG: px={px:.3f}, py={py:.3f}, ps={ps:.3f}, feather={feather_px}") |
|
logger.info(f"VIDEO DIMENSIONS: {w}x{h}") |
|
logger.info(f"TARGET CENTER: ({int(px * w)}, {int(py * h)})") |
|
|
|
frame_count = 0 |
|
def composite_frame(get_frame, t): |
|
nonlocal frame_count |
|
frame_count += 1 |
|
|
|
|
|
frame = get_frame(t).astype(np.float32) / 255.0 |
|
hh, ww = frame.shape[:2] |
|
|
|
|
|
alpha_duration = getattr(alpha_clip, 'duration', None) |
|
if alpha_duration and alpha_duration > 0: |
|
|
|
alpha_t = min(t, alpha_duration - 0.01) |
|
alpha_t = max(0.0, alpha_t) |
|
else: |
|
alpha_t = 0.0 |
|
|
|
try: |
|
a = alpha_clip.get_frame(alpha_t) |
|
|
|
if a.ndim == 3: |
|
a = a[:, :, 0] |
|
a = a.astype(np.float32) / 255.0 |
|
|
|
|
|
if a.shape != (hh, ww): |
|
logger.warning(f"Alpha size mismatch: {a.shape} vs {(hh, ww)}, resizing...") |
|
a = cv2.resize(a, (ww, hh), interpolation=cv2.INTER_LINEAR) |
|
|
|
except Exception as e: |
|
logger.error(f"Alpha frame error at t={t:.3f}: {e}") |
|
return (bg_rgb * 255).astype(np.uint8) |
|
|
|
|
|
sw = max(1, round(ww * ps)) |
|
sh = max(1, round(hh * ps)) |
|
|
|
|
|
try: |
|
fg_scaled = cv2.resize(frame, (sw, sh), interpolation=cv2.INTER_AREA if ps < 1.0 else cv2.INTER_LINEAR) |
|
a_scaled = cv2.resize(a, (sw, sh), interpolation=cv2.INTER_AREA if ps < 1.0 else cv2.INTER_LINEAR) |
|
except Exception as e: |
|
logger.error(f"Scaling error: {e}") |
|
return (bg_rgb * 255).astype(np.uint8) |
|
|
|
|
|
fg_canvas = np.zeros_like(frame, dtype=np.float32) |
|
a_canvas = np.zeros((hh, ww), dtype=np.float32) |
|
|
|
|
|
cx = round(px * ww) |
|
cy = round(py * hh) |
|
|
|
|
|
x0 = cx - sw // 2 |
|
y0 = cy - sh // 2 |
|
|
|
|
|
if frame_count <= 3: |
|
logger.info(f"FRAME {frame_count}: scaled_size=({sw}, {sh}), center=({cx}, {cy}), top_left=({x0}, {y0})") |
|
|
|
|
|
xs0 = max(0, x0) |
|
ys0 = max(0, y0) |
|
xs1 = min(ww, x0 + sw) |
|
ys1 = min(hh, y0 + sh) |
|
|
|
|
|
if xs1 <= xs0 or ys1 <= ys0: |
|
if frame_count <= 3: |
|
logger.warning(f"Subject outside bounds: dest=({xs0},{ys0})-({xs1},{ys1})") |
|
return (bg_rgb * 255).astype(np.uint8) |
|
|
|
|
|
src_x0 = xs0 - x0 |
|
src_y0 = ys0 - y0 |
|
src_x1 = src_x0 + (xs1 - xs0) |
|
src_y1 = src_y0 + (ys1 - ys0) |
|
|
|
|
|
if (src_x1 > sw or src_y1 > sh or src_x0 < 0 or src_y0 < 0 or |
|
src_x1 <= src_x0 or src_y1 <= src_y0): |
|
if frame_count <= 3: |
|
logger.error(f"Invalid source region: ({src_x0},{src_y0})-({src_x1},{src_y1}) for {sw}x{sh} scaled") |
|
return (bg_rgb * 255).astype(np.uint8) |
|
|
|
|
|
try: |
|
fg_canvas[ys0:ys1, xs0:xs1, :] = fg_scaled[src_y0:src_y1, src_x0:src_x1, :] |
|
a_canvas[ys0:ys1, xs0:xs1] = a_scaled[src_y0:src_y1, src_x0:src_x1] |
|
except Exception as e: |
|
logger.error(f"Canvas placement failed: {e}") |
|
logger.error(f"Dest: [{ys0}:{ys1}, {xs0}:{xs1}], Src: [{src_y0}:{src_y1}, {src_x0}:{src_x1}]") |
|
return (bg_rgb * 255).astype(np.uint8) |
|
|
|
|
|
if feather_px > 0: |
|
kernel_size = max(3, feather_px * 2 + 1) |
|
if kernel_size % 2 == 0: |
|
kernel_size += 1 |
|
try: |
|
a_canvas = cv2.GaussianBlur(a_canvas, (kernel_size, kernel_size), feather_px / 3.0) |
|
except Exception as e: |
|
logger.warning(f"Feathering failed: {e}") |
|
|
|
|
|
a3 = np.expand_dims(a_canvas, axis=2) |
|
comp = a3 * fg_canvas + (1.0 - a3) * bg_rgb |
|
result = np.clip(comp * 255, 0, 255).astype(np.uint8) |
|
|
|
return result |
|
|
|
progress(0.7, desc="Compositing") |
|
final_clip = original_clip.fl(composite_frame) |
|
|
|
output_path = OUT_DIR / f"processed_{int(time.time())}_{random.randint(1000,9999)}.mp4" |
|
temp_video_path = TEMP_DIR / f"temp_video_{int(time.time())}_{random.randint(1000,9999)}.mp4" |
|
memory_manager.register_temp_file(str(temp_video_path)) |
|
|
|
write_video_h264(final_clip, str(temp_video_path), crf=int(crf)) |
|
original_clip.close(); alpha_clip.close(); final_clip.close() |
|
|
|
progress(0.85, desc="Merging audio") |
|
if preserve_audio_flag: |
|
success = run_ffmpeg([ |
|
"-i", str(temp_video_path), |
|
"-i", video_path, |
|
"-map", "0:v:0", |
|
"-map", "1:a:0?", |
|
"-c:v", "copy", |
|
"-c:a", "aac", |
|
"-b:a", "192k", |
|
"-shortest", |
|
str(output_path) |
|
], fail_ok=True) |
|
if success: |
|
messages.append("🔊 Original audio preserved") |
|
else: |
|
shutil.copy2(str(temp_video_path), str(output_path)) |
|
messages.append("⚠️ Audio merge failed, saved w/o audio") |
|
else: |
|
shutil.copy2(str(temp_video_path), str(output_path)) |
|
messages.append("🔇 Saved without audio") |
|
|
|
messages.append("✅ Done") |
|
stats = memory_manager.get_memory_stats() |
|
messages.append(f"📊 CPU {stats.cpu_memory_mb:.1f}MB, GPU {stats.gpu_memory_mb:.1f}MB") |
|
messages.append(f"🎯 Processed {frame_count} frames with placement ({px:.2f}, {py:.2f}) @ {ps:.2f}x scale") |
|
progress(1.0, desc="Done") |
|
return str(output_path), "\n".join(messages) |
|
|
|
except Exception as e: |
|
err = f"Processing failed: {str(e)}\n\n{traceback.format_exc()}" |
|
return None, err |
|
|
|
|
|
|
|
|
|
def create_interface(): |
|
def diag(): |
|
return state.status_text() |
|
|
|
def cleanup(): |
|
memory_manager.aggressive_cleanup() |
|
s = memory_manager.get_memory_stats() |
|
return f"🧹 Cleanup\nCPU: {s.cpu_memory_mb:.1f}MB\nGPU: {s.gpu_memory_mb:.1f}MB\nTemp: {s.temp_files_count} files" |
|
|
|
def preload(ai_model, openai_key, force_gpu, progress=gr.Progress()): |
|
try: |
|
progress(0, desc="Preloading...") |
|
msg = "" |
|
if ai_model in ("SDXL", "Playground v2.5", "SD 1.5 (fallback)"): |
|
try: |
|
if ai_model == "SDXL": |
|
_ = generate_sdxl_background(64, 64, "plain", steps=2, guidance=3.5, seed=42, require_gpu=bool(force_gpu)) |
|
elif ai_model == "Playground v2.5": |
|
_ = generate_playground_v25_background(64, 64, "plain", steps=2, guidance=3.5, seed=42, require_gpu=bool(force_gpu)) |
|
else: |
|
_ = generate_sd15_background(64, 64, "plain", steps=2, guidance=3.5, seed=42, require_gpu=bool(force_gpu)) |
|
msg += f"{ai_model} preloaded.\n" |
|
except Exception as e: |
|
msg += f"{ai_model} preload failed: {e}\n" |
|
|
|
_reset_hydra() |
|
s, m = SAM2Handler(), MatAnyoneHandler() |
|
ok_s = s.initialize() |
|
_reset_hydra() |
|
ok_m = m.initialize() |
|
progress(1.0, desc="Preload complete") |
|
return f"✅ Preload\n{msg}SAM2: {'ready' if ok_s else 'failed'}\nMatAnyone: {'ready' if ok_m else 'failed'}" |
|
except Exception as e: |
|
return f"❌ Preload error: {e}" |
|
|
|
def generate_background_safe(video_file, ai_prompt, ai_steps, ai_guidance, ai_seed, |
|
ai_model, openai_key, force_gpu, progress=gr.Progress()): |
|
if not video_file: |
|
return None, "Upload a video first", gr.update(visible=False), None |
|
with memory_manager.mem_context("Background generation"): |
|
try: |
|
video_path = video_file.name if hasattr(video_file, 'name') else str(video_file) |
|
if not os.path.exists(video_path): |
|
return None, "Video not found", gr.update(visible=False), None |
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): |
|
return None, "Could not open video", gr.update(visible=False), None |
|
ret, frame = cap.read() |
|
cap.release() |
|
if not ret or frame is None: |
|
return None, "Could not read frame", gr.update(visible=False), None |
|
h, w = int(frame.shape[0]), int(frame.shape[1]) |
|
|
|
steps = max(1, min(50, int(ai_steps or 30))) |
|
guidance = max(1.0, min(15.0, float(ai_guidance or 7.0))) |
|
try: |
|
seed_val = int(ai_seed) if ai_seed and str(ai_seed).strip() else None |
|
except Exception: |
|
seed_val = None |
|
|
|
progress(0.1, desc=f"Generating {ai_model}") |
|
bg_path = generate_ai_background_router( |
|
width=w, height=h, prompt=str(ai_prompt or "professional office background").strip(), |
|
model=str(ai_model or "SDXL"), steps=steps, guidance=guidance, |
|
seed=seed_val, openai_key=openai_key, require_gpu=bool(force_gpu) |
|
) |
|
progress(1.0, desc="Background ready") |
|
if bg_path and os.path.exists(bg_path): |
|
return bg_path, f"AI background generated with {ai_model}", gr.update(visible=True), bg_path |
|
else: |
|
return None, "No output file", gr.update(visible=False), None |
|
except Exception as e: |
|
logger.error(f"Background generation error: {e}") |
|
return None, f"Background generation failed: {str(e)}", gr.update(visible=False), None |
|
|
|
def approve_background(bg_path): |
|
try: |
|
if not bg_path or not (isinstance(bg_path, str) and os.path.exists(bg_path)): |
|
return None, "Generate a background first", gr.update(visible=False) |
|
ext = os.path.splitext(bg_path)[1].lower() or ".jpg" |
|
safe_name = f"approved_{int(time.time())}_{random.randint(1000,9999)}{ext}" |
|
dest = BACKGROUND_DIR / safe_name |
|
shutil.copy2(bg_path, dest) |
|
return str(dest), f"✅ Background approved → {dest.name}", gr.update(visible=False) |
|
except Exception as e: |
|
return None, f"⚠️ Approve failed: {e}", gr.update(visible=False) |
|
|
|
css = """ |
|
.gradio-container { font-size: 16px !important; } |
|
label { font-size: 18px !important; font-weight: 600 !important; color: #2d3748 !important; } |
|
.process-button { font-size: 20px !important; font-weight: 700 !important; padding: 16px 28px !important; } |
|
.memory-info { background: #f8fafc; border: 1px solid #e2e8f0; border-radius: 8px; padding: 12px; } |
|
""" |
|
|
|
with gr.Blocks(title="Enhanced Video Background Replacement", theme=gr.themes.Soft(), css=css) as interface: |
|
gr.Markdown("# 🎬 Enhanced Video Background Replacement") |
|
gr.Markdown("_SAM2 + MatAnyone + AI Backgrounds — with strict tensor shapes & memory management_") |
|
|
|
gr.HTML(f""" |
|
<div class='memory-info'> |
|
<strong>Device:</strong> {DEVICE} |
|
<strong>PyTorch:</strong> {'✅' if TORCH_AVAILABLE else '❌'} |
|
<strong>CUDA:</strong> {'✅' if CUDA_AVAILABLE else '❌'} |
|
</div> |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
video_input = gr.Video(label="Input Video") |
|
|
|
gr.Markdown("### Background") |
|
bg_method = gr.Radio(choices=["Upload Image", "Gradients", "AI Generated"], |
|
value="AI Generated", label="Background Method") |
|
|
|
|
|
with gr.Group(visible=False) as upload_group: |
|
upload_img = gr.Image(label="Background Image", type="filepath") |
|
|
|
|
|
with gr.Group(visible=False) as gradient_group: |
|
gradient_choice = gr.Dropdown(label="Gradient Style", |
|
choices=list(GRADIENT_PRESETS.keys()), |
|
value="Slate") |
|
|
|
|
|
with gr.Group(visible=True) as ai_group: |
|
prompt_suggestions = gr.Dropdown(label="💡 Prompt Inspiration", |
|
choices=AI_PROMPT_SUGGESTIONS, |
|
value="Custom (write your own)") |
|
ai_prompt = gr.Textbox(label="Background Description", |
|
value="professional office background", lines=3) |
|
ai_model = gr.Radio(["SDXL", "Playground v2.5", "SD 1.5 (fallback)", "OpenAI (gpt-image-1)"], |
|
value="SDXL", label="AI Model") |
|
with gr.Accordion("Connect services (optional)", open=False): |
|
openai_api_key = gr.Textbox(label="OpenAI API Key", type="password", |
|
placeholder="sk-... (kept only in this session)") |
|
with gr.Row(): |
|
ai_steps = gr.Slider(10, 50, value=30, step=1, label="Quality (steps)") |
|
ai_guidance = gr.Slider(1.0, 15.0, value=7.0, step=0.1, label="Guidance") |
|
ai_seed = gr.Number(label="Seed (optional)", precision=0) |
|
force_gpu_ai = gr.Checkbox(value=True, label="Force GPU for AI background") |
|
preload_btn = gr.Button("📦 Preload Models") |
|
preload_status = gr.Textbox(label="Preload Status", lines=4) |
|
generate_bg_btn = gr.Button("Generate AI Background", variant="primary") |
|
ai_generated_bg = gr.Image(label="Generated Background", type="filepath") |
|
approve_bg_btn = gr.Button("✅ Approve Background", visible=False) |
|
approved_background_path = gr.State(value=None) |
|
last_generated_bg = gr.State(value=None) |
|
ai_status = gr.Textbox(label="Generation Status", lines=2) |
|
|
|
gr.Markdown("### Processing") |
|
with gr.Row(): |
|
trim_enabled = gr.Checkbox(label="Trim Video", value=False) |
|
trim_seconds = gr.Number(label="Trim Duration (seconds)", value=5, precision=1) |
|
with gr.Row(): |
|
crf_value = gr.Slider(0, 30, value=18, step=1, label="Quality (CRF - lower=better)") |
|
audio_enabled = gr.Checkbox(label="Preserve Audio", value=True) |
|
with gr.Row(): |
|
use_chunked = gr.Checkbox(label="Use Chunked Processing", value=False) |
|
|
|
gr.Markdown("### Subject Placement") |
|
with gr.Row(): |
|
place_x = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Horizontal") |
|
place_y = gr.Slider(0.0, 1.0, value=0.75, step=0.01, label="Vertical") |
|
with gr.Row(): |
|
place_scale = gr.Slider(0.3, 2.0, value=1.0, step=0.01, label="Scale") |
|
place_feather = gr.Slider(0, 15, value=3, step=1, label="Edge feather (px)") |
|
|
|
process_btn = gr.Button("🚀 Process Video", variant="primary", elem_classes=["process-button"]) |
|
|
|
gr.Markdown("### System") |
|
with gr.Row(): |
|
diagnostics_btn = gr.Button("📊 System Diagnostics") |
|
cleanup_btn = gr.Button("🧹 Memory Cleanup") |
|
diagnostics_output = gr.Textbox(label="System Status", lines=10) |
|
|
|
with gr.Column(scale=1): |
|
output_video = gr.Video(label="Processed Video") |
|
download_file = gr.File(label="Download Processed Video") |
|
status_output = gr.Textbox(label="Processing Status", lines=20) |
|
|
|
|
|
def update_background_visibility(method): |
|
return ( |
|
gr.update(visible=(method == "Upload Image")), |
|
gr.update(visible=(method == "Gradients")), |
|
gr.update(visible=(method == "AI Generated")), |
|
) |
|
|
|
def update_prompt_from_suggestion(suggestion): |
|
if suggestion == "Custom (write your own)": |
|
return gr.update(value="", placeholder="Describe the background you want...") |
|
return gr.update(value=suggestion) |
|
|
|
bg_method.change( |
|
update_background_visibility, |
|
inputs=[bg_method], |
|
outputs=[upload_group, gradient_group, ai_group] |
|
) |
|
prompt_suggestions.change(update_prompt_from_suggestion, inputs=[prompt_suggestions], outputs=[ai_prompt]) |
|
|
|
preload_btn.click(preload, |
|
inputs=[ai_model, openai_api_key, force_gpu_ai], |
|
outputs=[preload_status], |
|
show_progress=True |
|
) |
|
|
|
generate_bg_btn.click( |
|
generate_background_safe, |
|
inputs=[video_input, ai_prompt, ai_steps, ai_guidance, ai_seed, ai_model, openai_api_key, force_gpu_ai], |
|
outputs=[ai_generated_bg, ai_status, approve_bg_btn, last_generated_bg], |
|
show_progress=True |
|
) |
|
approve_bg_btn.click( |
|
approve_background, |
|
inputs=[ai_generated_bg], |
|
outputs=[approved_background_path, ai_status, approve_bg_btn] |
|
) |
|
|
|
diagnostics_btn.click(diag, outputs=[diagnostics_output]) |
|
cleanup_btn.click(cleanup, outputs=[diagnostics_output]) |
|
|
|
def process_video( |
|
video_file, |
|
bg_method, |
|
upload_img, |
|
gradient_choice, |
|
approved_background_path, |
|
last_generated_bg, |
|
trim_enabled, trim_seconds, crf_value, audio_enabled, |
|
use_chunked, |
|
place_x, place_y, place_scale, place_feather, |
|
progress=gr.Progress(track_tqdm=True), |
|
): |
|
try: |
|
if not video_file: |
|
return None, None, "Please upload a video file" |
|
video_path = video_file.name if hasattr(video_file, 'name') else str(video_file) |
|
|
|
|
|
bg_path = None |
|
try: |
|
if bg_method == "Upload Image" and upload_img: |
|
bg_path = upload_img if isinstance(upload_img, str) else getattr(upload_img, "name", None) |
|
elif bg_method == "Gradients": |
|
cap = cv2.VideoCapture(video_path) |
|
ret, frame = cap.read(); cap.release() |
|
if ret and frame is not None: |
|
h, w = frame.shape[:2] |
|
if gradient_choice in GRADIENT_PRESETS: |
|
grad = _make_vertical_gradient(w, h, *GRADIENT_PRESETS[gradient_choice]) |
|
tmp_bg = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False, dir=TEMP_DIR).name |
|
cv2.imwrite(tmp_bg, grad) |
|
memory_manager.register_temp_file(tmp_bg) |
|
bg_path = tmp_bg |
|
else: |
|
if approved_background_path: |
|
bg_path = approved_background_path |
|
elif last_generated_bg and isinstance(last_generated_bg, str) and os.path.exists(last_generated_bg): |
|
bg_path = last_generated_bg |
|
except Exception as e: |
|
logger.error(f"Background setup error: {e}") |
|
return None, None, f"Background setup failed: {str(e)}" |
|
|
|
result_path, status = process_video_main( |
|
video_path=video_path, |
|
background_path=bg_path, |
|
trim_duration=float(trim_seconds) if (trim_enabled and float(trim_seconds) > 0) else None, |
|
crf=int(crf_value), |
|
preserve_audio_flag=bool(audio_enabled), |
|
placement=dict(x=float(place_x), y=float(place_y), scale=float(place_scale), feather=int(place_feather)), |
|
use_chunked_processing=bool(use_chunked), |
|
progress=progress, |
|
) |
|
|
|
if result_path and os.path.exists(result_path): |
|
return result_path, result_path, f"✅ Success\n\n{status}" |
|
else: |
|
return None, None, f"❌ Failed\n\n{status or 'Unknown error'}" |
|
except Exception as e: |
|
tb = traceback.format_exc() |
|
return None, None, f"❌ Crash: {e}\n\n{tb}" |
|
|
|
process_btn.click( |
|
process_video, |
|
inputs=[ |
|
video_input, |
|
bg_method, |
|
upload_img, |
|
gradient_choice, |
|
approved_background_path, last_generated_bg, |
|
trim_enabled, trim_seconds, crf_value, audio_enabled, |
|
use_chunked, |
|
place_x, place_y, place_scale, place_feather, |
|
], |
|
outputs=[output_video, download_file, status_output], |
|
show_progress=True |
|
) |
|
|
|
return interface |
|
|
|
|
|
|
|
|
|
def main(): |
|
logger.info("Starting Enhanced Background Replacement") |
|
stats = memory_manager.get_memory_stats() |
|
logger.info(f"Initial memory: CPU {stats.cpu_memory_mb:.1f}MB, GPU {stats.gpu_memory_mb:.1f}MB") |
|
interface = create_interface() |
|
interface.queue(max_size=3) |
|
try: |
|
interface.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False, |
|
inbrowser=False, |
|
show_error=True |
|
) |
|
finally: |
|
logger.info("Shutting down - cleanup") |
|
memory_manager.cleanup_temp_files() |
|
memory_manager.aggressive_cleanup() |
|
|
|
if __name__ == "__main__": |
|
main() |