MogensR commited on
Commit
c0d6cda
·
verified ·
1 Parent(s): 2f1a576

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -1738
app.py CHANGED
@@ -1,1738 +0,0 @@
1
- #!/usr/bin/env python3
2
- # =============================================================================
3
- # CHAPTER 0: INTRO & OVERVIEW
4
- # =============================================================================
5
- """
6
- Enhanced Video Background Replacement (SAM2 + MatAnyone + AI Backgrounds)
7
- - Strict tensor shapes for MatAnyone (image: 3xHxW, first-frame prob mask: 1xHxW)
8
- - First frame uses PROB path (no idx_mask / objects) to avoid assertion
9
- - Memory management & cleanup
10
- - SDXL / Playground / OpenAI backgrounds
11
- - Gradio UI with "CHAPTER" dividers
12
- - FIXED: Enhanced positioning with debug logging and coordinate precision
13
- - REVERTED: Back to original MatAnyone repository (removed matanyone_fixed)
14
- """
15
-
16
- # =============================================================================
17
- # CHAPTER 1: IMPORTS & GLOBALS
18
- # =============================================================================
19
- import os
20
- import sys
21
- import gc
22
- import cv2
23
- import psutil
24
- import time
25
- import json
26
- import base64
27
- import random
28
- import shutil
29
- import logging
30
- import traceback
31
- import subprocess
32
- import tempfile
33
- import threading
34
- from dataclasses import dataclass
35
- from contextlib import contextmanager
36
- from pathlib import Path
37
- from typing import Optional, Tuple, List
38
- import numpy as np
39
- from PIL import Image
40
- import gradio as gr
41
- from moviepy.editor import VideoFileClip
42
-
43
- # EMERGENCY STOP - ADD THIS LINE RIGHT HERE:
44
- sys.exit("EMERGENCY STOP - Space paused for debugging")
45
-
46
- # ... rest of your code continues below
47
- # =============================================================================
48
- # CHAPTER 2: DYNAMIC CONFIGURATION
49
- # =============================================================================
50
-
51
- # Base configuration
52
- BASE_DIR = Path(__file__).resolve().parent
53
- LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO").upper()
54
- APP_TITLE = os.environ.get("APP_TITLE", "Enhanced Video Background Replacement")
55
-
56
- # Logging configuration
57
- logging.basicConfig(level=getattr(logging, LOG_LEVEL), format="%(asctime)s - %(levelname)s - %(message)s")
58
- logger = logging.getLogger("bgx")
59
-
60
- # Environment tuning (all configurable)
61
- os.environ.setdefault("CUDA_MODULE_LOADING", os.environ.get("CUDA_MODULE_LOADING", "LAZY"))
62
- os.environ.setdefault("TORCH_CUDNN_V8_API_ENABLED", os.environ.get("TORCH_CUDNN_V8_API_ENABLED", "1"))
63
- os.environ.setdefault("PYTHONUNBUFFERED", os.environ.get("PYTHONUNBUFFERED", "1"))
64
- os.environ.setdefault("MKL_NUM_THREADS", os.environ.get("MKL_NUM_THREADS", "4"))
65
- os.environ.setdefault("BFX_QUALITY", os.environ.get("BFX_QUALITY", "max"))
66
- os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:128,roundup_power2_divisions:16"))
67
- os.environ.setdefault("HYDRA_FULL_ERROR", os.environ.get("HYDRA_FULL_ERROR", "1"))
68
- os.environ["OMP_NUM_THREADS"] = os.environ.get("OMP_NUM_THREADS", "2")
69
-
70
- # Repository and model URLs (all configurable)
71
- SAM2_REPO_URL = os.environ.get("SAM2_REPO_URL", "https://github.com/facebookresearch/segment-anything-2.git")
72
- SAM2_CHECKPOINT_URL = os.environ.get("SAM2_CHECKPOINT_URL", "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt")
73
- SAM2_CHECKPOINT_NAME = os.environ.get("SAM2_CHECKPOINT_NAME", "sam2.1_hiera_large.pt")
74
- SAM2_CONFIG_NAME = os.environ.get("SAM2_CONFIG_NAME", "sam2.1/sam2.1_hiera_l.yaml")
75
-
76
- # REVERTED: Back to original MatAnyone repository
77
- MATANYONE_REPO_URL = os.environ.get("MATANYONE_REPO_URL", "https://github.com/pq-yang/MatAnyone.git")
78
- MATANYONE_CHECKPOINT_URL = os.environ.get("MATANYONE_CHECKPOINT_URL", "https://github.com/pq-yang/MatAnyone/releases/download/v1.0.0/matanyone.pth")
79
- MATANYONE_CHECKPOINT_NAME = os.environ.get("MATANYONE_CHECKPOINT_NAME", "matanyone.pth")
80
-
81
- # Diffusion model configurations
82
- DEFAULT_DIFFUSION_MODELS = {
83
- "SDXL": os.environ.get("SDXL_MODEL", "stabilityai/stable-diffusion-xl-base-1.0"),
84
- "Playground": os.environ.get("PLAYGROUND_MODEL", "playgroundai/playground-v2.5-1024px-aesthetic"),
85
- "SD15": os.environ.get("SD15_MODEL", "runwayml/stable-diffusion-v1-5")
86
- }
87
-
88
- # OpenAI configuration
89
- OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1/images/generations")
90
- OPENAI_DEFAULT_MODEL = os.environ.get("OPENAI_DEFAULT_MODEL", "gpt-image-1")
91
- OPENAI_DEFAULT_SIZE = os.environ.get("OPENAI_DEFAULT_SIZE", "1024x1024")
92
- OPENAI_DEFAULT_QUALITY = os.environ.get("OPENAI_DEFAULT_QUALITY", "high")
93
-
94
- # Server configuration
95
- SERVER_HOST = os.environ.get("SERVER_HOST", "0.0.0.0")
96
- SERVER_PORT = int(os.environ.get("SERVER_PORT", "7860"))
97
- SERVER_SHARE = os.environ.get("SERVER_SHARE", "False").lower() == "true"
98
- SERVER_INBROWSER = os.environ.get("SERVER_INBROWSER", "False").lower() == "true"
99
-
100
- # Memory and performance settings
101
- CUDA_MEMORY_FRACTION = float(os.environ.get("CUDA_MEMORY_FRACTION", "0.8"))
102
- GRADIO_QUEUE_MAX_SIZE = int(os.environ.get("GRADIO_QUEUE_MAX_SIZE", "3"))
103
- CHUNK_SIZE_FRAMES = int(os.environ.get("CHUNK_SIZE_FRAMES", "60"))
104
-
105
- # File handling
106
- DOWNLOAD_TIMEOUT = int(os.environ.get("DOWNLOAD_TIMEOUT", "300"))
107
- DOWNLOAD_CHUNK_SIZE = int(os.environ.get("DOWNLOAD_CHUNK_SIZE", "8192"))
108
-
109
- # FFmpeg settings
110
- FFMPEG_PRESET = os.environ.get("FFMPEG_PRESET", "medium")
111
- FFMPEG_DEFAULT_CRF = int(os.environ.get("FFMPEG_DEFAULT_CRF", "18"))
112
- FFMPEG_PIXEL_FORMAT = os.environ.get("FFMPEG_PIXEL_FORMAT", "yuv420p")
113
- FFMPEG_PROFILE = os.environ.get("FFMPEG_PROFILE", "high")
114
- FFMPEG_AUDIO_CODEC = os.environ.get("FFMPEG_AUDIO_CODEC", "aac")
115
- FFMPEG_AUDIO_BITRATE = os.environ.get("FFMPEG_AUDIO_BITRATE", "192k")
116
-
117
- # Paths (configurable base directories)
118
- CHECKPOINTS_DIR = os.environ.get("CHECKPOINTS_DIR", "checkpoints")
119
- TEMP_DIR_NAME = os.environ.get("TEMP_DIR_NAME", "temp")
120
- OUTPUT_DIR_NAME = os.environ.get("OUTPUT_DIR_NAME", "outputs")
121
- BACKGROUND_DIR_NAME = os.environ.get("BACKGROUND_DIR_NAME", "backgrounds")
122
-
123
- CHECKPOINTS = BASE_DIR / CHECKPOINTS_DIR
124
- TEMP_DIR = BASE_DIR / TEMP_DIR_NAME
125
- OUT_DIR = BASE_DIR / OUTPUT_DIR_NAME
126
- BACKGROUND_DIR = OUT_DIR / BACKGROUND_DIR_NAME
127
-
128
- for p in (CHECKPOINTS, TEMP_DIR, OUT_DIR, BACKGROUND_DIR):
129
- p.mkdir(parents=True, exist_ok=True)
130
-
131
- # Torch/device configuration
132
- try:
133
- import torch
134
- TORCH_AVAILABLE = True
135
- CUDA_AVAILABLE = torch.cuda.is_available()
136
- DEVICE = "cuda" if CUDA_AVAILABLE else "cpu"
137
- try:
138
- if torch.backends.cuda.is_built():
139
- torch.backends.cuda.matmul.allow_tf32 = True
140
- if hasattr(torch.backends, "cudnn"):
141
- torch.backends.cudnn.benchmark = True
142
- torch.backends.cudnn.deterministic = False
143
- if CUDA_AVAILABLE:
144
- torch.cuda.set_per_process_memory_fraction(CUDA_MEMORY_FRACTION)
145
- except Exception:
146
- pass
147
- except Exception:
148
- TORCH_AVAILABLE = False
149
- CUDA_AVAILABLE = False
150
- DEVICE = "cpu"
151
-
152
- # =============================================================================
153
- # CHAPTER 3: UI CONSTANTS & UTILS (Made configurable)
154
- # =============================================================================
155
- # Gradient presets (configurable via JSON env var)
156
- DEFAULT_GRADIENT_PRESETS = {
157
- "Blue Fade": ((128, 64, 0), (255, 128, 0)),
158
- "Sunset": ((255, 128, 0), (255, 0, 128)),
159
- "Green Field": ((64, 128, 64), (160, 255, 160)),
160
- "Slate": ((40, 40, 48), (96, 96, 112)),
161
- "Ocean": ((255, 140, 0), (255, 215, 0)),
162
- "Forest": ((34, 139, 34), (144, 238, 144)),
163
- "Sunset Pink": ((255, 182, 193), (255, 105, 180)),
164
- "Cool Blue": ((173, 216, 230), (0, 191, 255)),
165
- }
166
-
167
- try:
168
- GRADIENT_PRESETS = json.loads(os.environ.get("GRADIENT_PRESETS", "{}"))
169
- if not GRADIENT_PRESETS:
170
- GRADIENT_PRESETS = DEFAULT_GRADIENT_PRESETS
171
- except (json.JSONDecodeError, TypeError):
172
- GRADIENT_PRESETS = DEFAULT_GRADIENT_PRESETS
173
-
174
- # AI prompt suggestions (configurable)
175
- DEFAULT_AI_PROMPT_SUGGESTIONS = [
176
- "Custom (write your own)",
177
- "modern minimalist office with soft lighting, clean desk, blurred background",
178
- "elegant conference room with large windows and city view",
179
- "contemporary workspace with plants and natural light",
180
- "luxury hotel lobby with marble floors and warm ambient lighting",
181
- "professional studio with clean white background and soft lighting",
182
- "modern corporate meeting room with glass walls and city skyline",
183
- "sophisticated home office with bookshelf and warm wood tones",
184
- "sleek coworking space with industrial design elements",
185
- "abstract geometric patterns in blue and gold, modern art style",
186
- "soft watercolor texture with pastel colors, dreamy atmosphere",
187
- ]
188
-
189
- try:
190
- AI_PROMPT_SUGGESTIONS = json.loads(os.environ.get("AI_PROMPT_SUGGESTIONS", "[]"))
191
- if not AI_PROMPT_SUGGESTIONS:
192
- AI_PROMPT_SUGGESTIONS = DEFAULT_AI_PROMPT_SUGGESTIONS
193
- except (json.JSONDecodeError, TypeError):
194
- AI_PROMPT_SUGGESTIONS = DEFAULT_AI_PROMPT_SUGGESTIONS
195
-
196
- def _make_vertical_gradient(width: int, height: int, c1, c2) -> np.ndarray:
197
- width = max(1, int(width))
198
- height = max(1, int(height))
199
- top = np.array(c1, dtype=np.float32)
200
- bot = np.array(c2, dtype=np.float32)
201
- rows = np.linspace(top, bot, num=height, dtype=np.float32)
202
- grad = np.repeat(rows[:, None, :], repeats=width, axis=1)
203
- return np.clip(grad, 0, 255).astype(np.uint8)
204
-
205
- def run_ffmpeg(args: list, fail_ok=False) -> bool:
206
- cmd = ["ffmpeg", "-y", "-hide_banner", "-loglevel", "error"] + args
207
- try:
208
- subprocess.run(cmd, check=True, capture_output=True)
209
- return True
210
- except Exception as e:
211
- if not fail_ok:
212
- logger.error(f"ffmpeg failed: {e}")
213
- return False
214
-
215
- def write_video_h264(clip, path: str, fps: Optional[int] = None, crf: int = None, preset: str = None):
216
- fps = fps or max(1, int(round(getattr(clip, "fps", None) or 24)))
217
- crf = crf or FFMPEG_DEFAULT_CRF
218
- preset = preset or FFMPEG_PRESET
219
-
220
- clip.write_videofile(
221
- path,
222
- audio=False,
223
- fps=fps,
224
- codec="libx264",
225
- preset=preset,
226
- ffmpeg_params=["-crf", str(crf), "-pix_fmt", FFMPEG_PIXEL_FORMAT, "-profile:v", FFMPEG_PROFILE, "-movflags", "+faststart"],
227
- logger=None,
228
- verbose=False,
229
- )
230
-
231
- def download_file(url: str, dest: Path, name: str) -> bool:
232
- if dest.exists():
233
- logger.info(f"{name} already exists")
234
- return True
235
- try:
236
- import requests
237
- logger.info(f"Downloading {name} from {url}")
238
- with requests.get(url, stream=True, timeout=DOWNLOAD_TIMEOUT) as r:
239
- r.raise_for_status()
240
- with open(dest, "wb") as f:
241
- for chunk in r.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
242
- if chunk:
243
- f.write(chunk)
244
- logger.info(f"{name} downloaded successfully")
245
- return True
246
- except Exception as e:
247
- logger.error(f"Failed to download {name}: {e}")
248
- if dest.exists():
249
- try: dest.unlink()
250
- except Exception: pass
251
- return False
252
-
253
- def ensure_repo(repo_name: str, git_url: str) -> Optional[Path]:
254
- repo_path = CHECKPOINTS / f"{repo_name}_repo"
255
- if not repo_path.exists():
256
- try:
257
- logger.info(f"Cloning {repo_name} from {git_url}")
258
- subprocess.run(["git", "clone", "--depth", "1", git_url, str(repo_path)],
259
- check=True, timeout=DOWNLOAD_TIMEOUT, capture_output=True)
260
- logger.info(f"{repo_name} cloned successfully")
261
- except Exception as e:
262
- logger.error(f"Failed to clone {repo_name}: {e}")
263
- return None
264
- repo_str = str(repo_path)
265
- if repo_str not in sys.path:
266
- sys.path.insert(0, repo_str)
267
- return repo_path
268
-
269
- def _reset_hydra():
270
- try:
271
- from hydra.core.global_hydra import GlobalHydra
272
- if GlobalHydra().is_initialized():
273
- GlobalHydra.instance().clear()
274
- except Exception:
275
- pass
276
-
277
- # =============================================================================
278
- # CHAPTER 3A: STARTUP CLEANUP (safe, configurable)
279
- # =============================================================================
280
- # Controls:
281
- # CLEAN_ON_START = off | light | deep (default: light)
282
- # off -> do nothing
283
- # light -> clear ./temp, ./outputs (incl. ./outputs/backgrounds)
284
- # deep -> light + /tmp/gradio + HF caches (NOT model hub unless CLEAR_MODELS=1)
285
- # CLEAR_MODELS = 0|1 (deep-only) also clears ~/.cache/huggingface/hub and ./checkpoints
286
- # CLEAR_PLATFORM_TMP = 0|1 (additionally clear /tmp/gradio even in light)
287
- CLEAN_ON_START = os.environ.get("CLEAN_ON_START", "light").lower().strip()
288
- CLEAR_MODELS = os.environ.get("CLEAR_MODELS", "0") == "1"
289
- CLEAR_PLATFORM_TMP = os.environ.get("CLEAR_PLATFORM_TMP", "0") == "1"
290
-
291
- def _safe_rmtree(path: Path) -> bool:
292
- """Delete a directory only if it is our project dir or an allowed /tmp path."""
293
- try:
294
- path = path.resolve()
295
- if not path.exists():
296
- return False
297
- # allow: anything under project BASE_DIR
298
- if path.is_dir() and path.is_relative_to(BASE_DIR):
299
- shutil.rmtree(path, ignore_errors=True)
300
- return True
301
- # allow: /tmp/gradio* when platform uses temp uploads
302
- if str(path).startswith("/tmp/gradio"):
303
- shutil.rmtree(path, ignore_errors=True)
304
- return True
305
- except Exception:
306
- pass
307
- return False
308
-
309
- def startup_cleanup():
310
- """Clear caches/files so each run starts fresh."""
311
- if CLEAN_ON_START not in ("light", "deep"):
312
- logger.info(f"Startup cleanup skipped (CLEAN_ON_START={CLEAN_ON_START})")
313
- return
314
-
315
- to_clear: List[Path] = []
316
-
317
- # Always clear our temp/output trees
318
- to_clear += [TEMP_DIR, OUT_DIR, BACKGROUND_DIR]
319
-
320
- # /tmp/gradio holds uploads & intermediates on many hosts
321
- if CLEAN_ON_START == "deep" or CLEAR_PLATFORM_TMP:
322
- gradio_tmp = Path(os.environ.get("GRADIO_TEMP_DIR", "/tmp/gradio"))
323
- to_clear += [gradio_tmp]
324
-
325
- if CLEAN_ON_START == "deep":
326
- # Clear non-model HF caches
327
- hf_base = Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface"))
328
- to_clear += [hf_base / "diffusers", hf_base / "datasets"]
329
- # Optionally clear model hub + checkpoints (will redownload!)
330
- if CLEAR_MODELS:
331
- to_clear += [hf_base / "hub", CHECKPOINTS]
332
-
333
- # Perform deletions
334
- cleared = []
335
- for d in to_clear:
336
- if isinstance(d, Path) and d.exists():
337
- if _safe_rmtree(d):
338
- cleared.append(str(d))
339
-
340
- # Re-create our working directories
341
- for p in (CHECKPOINTS, TEMP_DIR, OUT_DIR, BACKGROUND_DIR):
342
- p.mkdir(parents=True, exist_ok=True)
343
-
344
- # Free CUDA cache too
345
- if TORCH_AVAILABLE and CUDA_AVAILABLE:
346
- try:
347
- torch.cuda.empty_cache()
348
- except Exception:
349
- pass
350
-
351
- logger.info(
352
- f"Startup cleanup done (mode={CLEAN_ON_START}, clear_models={int(CLEAR_MODELS)}). "
353
- f"Cleared: {', '.join(cleared) if cleared else 'nothing'}"
354
- )
355
-
356
-
357
-
358
-
359
- # =============================================================================
360
- # CHAPTER 4: MEMORY MANAGER
361
- # =============================================================================
362
- @dataclass
363
- class MemoryStats:
364
- cpu_percent: float
365
- cpu_memory_mb: float
366
- gpu_memory_mb: float = 0.0
367
- gpu_memory_reserved_mb: float = 0.0
368
- temp_files_count: int = 0
369
- temp_files_size_mb: float = 0.0
370
-
371
- class MemoryManager:
372
- def __init__(self):
373
- self.temp_files: List[str] = []
374
- self.cleanup_lock = threading.Lock()
375
- self.torch_available = TORCH_AVAILABLE
376
- self.cuda_available = CUDA_AVAILABLE
377
-
378
- def get_memory_stats(self) -> MemoryStats:
379
- process = psutil.Process()
380
- cpu_percent = psutil.cpu_percent(interval=0.1)
381
- cpu_memory_mb = process.memory_info().rss / (1024 * 1024)
382
- gpu_memory_mb = 0.0
383
- gpu_memory_reserved_mb = 0.0
384
- if self.torch_available and self.cuda_available:
385
- try:
386
- import torch
387
- gpu_memory_mb = torch.cuda.memory_allocated() / (1024 * 1024)
388
- gpu_memory_reserved_mb = torch.cuda.memory_reserved() / (1024 * 1024)
389
- except Exception:
390
- pass
391
-
392
- temp_count, temp_size_mb = 0, 0.0
393
- for tf in self.temp_files:
394
- if os.path.exists(tf):
395
- temp_count += 1
396
- try:
397
- temp_size_mb += os.path.getsize(tf) / (1024 * 1024)
398
- except Exception:
399
- pass
400
- return MemoryStats(cpu_percent, cpu_memory_mb, gpu_memory_mb, gpu_memory_reserved_mb, temp_count, temp_size_mb)
401
-
402
- def register_temp_file(self, path: str):
403
- with self.cleanup_lock:
404
- if path not in self.temp_files:
405
- self.temp_files.append(path)
406
-
407
- def cleanup_temp_files(self):
408
- with self.cleanup_lock:
409
- cleaned = 0
410
- for tf in self.temp_files[:]:
411
- try:
412
- if os.path.isdir(tf):
413
- shutil.rmtree(tf, ignore_errors=True)
414
- elif os.path.exists(tf):
415
- os.unlink(tf)
416
- cleaned += 1
417
- except Exception as e:
418
- logger.warning(f"Failed to cleanup {tf}: {e}")
419
- finally:
420
- try: self.temp_files.remove(tf)
421
- except Exception: pass
422
- if cleaned:
423
- logger.info(f"Cleaned {cleaned} temp paths")
424
-
425
- def aggressive_cleanup(self):
426
- logger.info("Aggressive cleanup...")
427
- gc.collect()
428
- if self.torch_available and self.cuda_available:
429
- try:
430
- import torch
431
- torch.cuda.empty_cache()
432
- torch.cuda.synchronize()
433
- except Exception:
434
- pass
435
- self.cleanup_temp_files()
436
- gc.collect()
437
-
438
- @contextmanager
439
- def mem_context(self, name="op"):
440
- stats = self.get_memory_stats()
441
- logger.info(f"Start {name} | CPU {stats.cpu_memory_mb:.1f}MB, GPU {stats.gpu_memory_mb:.1f}MB")
442
- try:
443
- yield self
444
- finally:
445
- self.aggressive_cleanup()
446
- stats = self.get_memory_stats()
447
- logger.info(f"End {name} | CPU {stats.cpu_memory_mb:.1f}MB, GPU {stats.gpu_memory_mb:.1f}MB")
448
-
449
- memory_manager = MemoryManager()
450
-
451
- # =============================================================================
452
- # CHAPTER 5: SYSTEM STATE
453
- # =============================================================================
454
- class SystemState:
455
- def __init__(self):
456
- self.torch_available = TORCH_AVAILABLE
457
- self.cuda_available = CUDA_AVAILABLE
458
- self.device = DEVICE
459
- self.sam2_ready = False
460
- self.matanyone_ready = False
461
- self.sam2_error = None
462
- self.matanyone_error = None
463
-
464
- def status_text(self) -> str:
465
- stats = memory_manager.get_memory_stats()
466
- return (
467
- "=== SYSTEM STATUS ===\n"
468
- f"PyTorch: {'✅' if self.torch_available else '❌'}\n"
469
- f"CUDA: {'✅' if self.cuda_available else '❌'}\n"
470
- f"Device: {self.device}\n"
471
- f"SAM2: {'✅' if self.sam2_ready else ('❌' if self.sam2_error else '⏳')}\n"
472
- f"MatAnyone: {'✅' if self.matanyone_ready else ('❌' if self.matanyone_error else '⏳')}\n\n"
473
- "=== MEMORY ===\n"
474
- f"CPU: {stats.cpu_percent:.1f}% ({stats.cpu_memory_mb:.1f} MB)\n"
475
- f"GPU: {stats.gpu_memory_mb:.1f} MB (Reserved {stats.gpu_memory_reserved_mb:.1f} MB)\n"
476
- f"Temp: {stats.temp_files_count} files ({stats.temp_files_size_mb:.1f} MB)\n\n"
477
- "=== CONFIGURATION ===\n"
478
- f"SAM2 Model: {SAM2_CHECKPOINT_NAME}\n"
479
- f"MatAnyone Model: {MATANYONE_CHECKPOINT_NAME}\n"
480
- f"Server: {SERVER_HOST}:{SERVER_PORT}\n"
481
- f"CUDA Memory: {CUDA_MEMORY_FRACTION:.1%}\n"
482
- )
483
-
484
- state = SystemState()
485
-
486
- # =============================================================================
487
- # CHAPTER 6: SAM2 HANDLER (CUDA-only) - Dynamic URLs
488
- # =============================================================================
489
- class SAM2Handler:
490
- def __init__(self):
491
- self.predictor = None
492
- self.initialized = False
493
-
494
- def initialize(self) -> bool:
495
- if not (TORCH_AVAILABLE and CUDA_AVAILABLE):
496
- state.sam2_error = "SAM2 requires CUDA"
497
- return False
498
-
499
- with memory_manager.mem_context("SAM2 init"):
500
- try:
501
- _reset_hydra()
502
- repo_path = ensure_repo("sam2", SAM2_REPO_URL)
503
- if not repo_path:
504
- state.sam2_error = "Clone failed"
505
- return False
506
-
507
- ckpt = CHECKPOINTS / SAM2_CHECKPOINT_NAME
508
- if not download_file(SAM2_CHECKPOINT_URL, ckpt, "SAM2 Large"):
509
- state.sam2_error = "SAM2 ckpt download failed"
510
- return False
511
-
512
- from hydra.core.global_hydra import GlobalHydra
513
- from hydra import initialize_config_dir
514
- from sam2.build_sam import build_sam2
515
- from sam2.sam2_image_predictor import SAM2ImagePredictor
516
-
517
- config_dir = (repo_path / "sam2" / "configs").as_posix()
518
- if GlobalHydra().is_initialized():
519
- GlobalHydra.instance().clear()
520
- initialize_config_dir(config_dir=config_dir, version_base=None)
521
-
522
- model = build_sam2(SAM2_CONFIG_NAME, str(ckpt), device="cuda")
523
- self.predictor = SAM2ImagePredictor(model)
524
-
525
- # Smoke test
526
- test = np.zeros((64, 64, 3), dtype=np.uint8)
527
- self.predictor.set_image(test)
528
- masks, scores, _ = self.predictor.predict(
529
- point_coords=np.array([[32, 32]]),
530
- point_labels=np.ones(1, dtype=np.int64),
531
- multimask_output=True,
532
- )
533
- ok = masks is not None and len(masks) > 0
534
- self.initialized = ok
535
- state.sam2_ready = ok
536
- if not ok:
537
- state.sam2_error = "SAM2 verify failed"
538
- return ok
539
-
540
- except Exception as e:
541
- state.sam2_error = f"SAM2 init error: {e}"
542
- return False
543
-
544
- def create_mask(self, image_rgb: np.ndarray) -> Optional[np.ndarray]:
545
- if not self.initialized:
546
- return None
547
- with memory_manager.mem_context("SAM2 mask"):
548
- try:
549
- self.predictor.set_image(image_rgb)
550
- h, w = image_rgb.shape[:2]
551
- strategies = [
552
- np.array([[w // 2, h // 2]]),
553
- np.array([[w // 2, h // 3]]),
554
- np.array([[w // 2, h // 3], [w // 2, (2 * h) // 3]]),
555
- ]
556
- best, best_score = None, -1.0
557
- for pc in strategies:
558
- masks, scores, _ = self.predictor.predict(
559
- point_coords=pc,
560
- point_labels=np.ones(len(pc), dtype=np.int64),
561
- multimask_output=True,
562
- )
563
- if masks is not None and len(masks) > 0:
564
- i = int(np.argmax(scores))
565
- sc = float(scores[i])
566
- if sc > best_score:
567
- best_score, best = sc, masks[i]
568
-
569
- if best is None:
570
- return None
571
-
572
- mask_u8 = (best * 255).astype(np.uint8)
573
- k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
574
- mask_clean = cv2.morphologyEx(mask_u8, cv2.MORPH_CLOSE, k)
575
- mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_OPEN, k)
576
- mask_clean = cv2.GaussianBlur(mask_clean, (3, 3), 1.0)
577
- return mask_clean
578
- except Exception as e:
579
- logger.error(f"SAM2 mask error: {e}")
580
- return None
581
-
582
- # =============================================================================
583
- # CHAPTER 7: MATANYONE HANDLER (runtime patch for group dim + safe shapes)
584
- # =============================================================================
585
- class MatAnyoneHandler:
586
- """
587
- MatAnyone loader + inference adapter with a runtime patch:
588
-
589
- - Robust import (supports both 'inference.*' and 'matanyone.inference.*')
590
- - Loads checkpoint via get_matanyone_model
591
- - Monkey-patches GroupDistributor.forward to align dims:
592
- x: [B,T,C,H,W] -> [B,T,1,C,H,W] before cat with g: [B,T,G,C,H,W]
593
- - First frame: pass soft prob (1,H,W) or (H,W) if needed
594
- - Subsequent frames: image only
595
- """
596
- def __init__(self):
597
- self.core = None
598
- self.initialized = False
599
-
600
- # ---------- helpers: tensor conversions ----------
601
- def _to_chw_float(self, img01: np.ndarray) -> "torch.Tensor":
602
- assert img01.ndim == 3 and img01.shape[2] == 3, f"Expected HxWx3, got {img01.shape}"
603
- t = torch.from_numpy(img01.transpose(2, 0, 1)).contiguous().float() # (3,H,W)
604
- t = torch.clamp(t, 0.0, 1.0)
605
- return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
606
-
607
- def _prob_hw_from_mask(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor":
608
- if mask_u8.shape[:2] != (h, w):
609
- mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST)
610
- prob = (mask_u8.astype(np.float32) / 255.0)
611
- t = torch.from_numpy(prob).contiguous().float() # (H,W)
612
- t = torch.clamp(t, 0.0, 1.0)
613
- return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
614
-
615
- def _prob_1hw_from_mask(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor":
616
- if mask_u8.shape[:2] != (h, w):
617
- mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST)
618
- prob = (mask_u8.astype(np.float32) / 255.0)[None, ...] # (1,H,W)
619
- t = torch.from_numpy(prob).contiguous().float()
620
- t = torch.clamp(t, 0.0, 1.0)
621
- return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
622
-
623
- def _alpha_to_u8_hw(self, alpha_like) -> np.ndarray:
624
- if isinstance(alpha_like, (list, tuple)) and len(alpha_like) > 1:
625
- alpha_like = alpha_like[1]
626
- if isinstance(alpha_like, torch.Tensor):
627
- t = alpha_like.detach().float().cpu()
628
- a = t.numpy()
629
- else:
630
- a = np.asarray(alpha_like, dtype=np.float32)
631
- a = np.squeeze(a)
632
- if a.ndim == 3 and a.shape[0] >= 1:
633
- a = a[0]
634
- if a.ndim != 2:
635
- raise ValueError(f"Alpha must be HxW; got {a.shape}")
636
- a = np.clip(a, 0.0, 1.0)
637
- return (a * 255.0 + 0.5).astype(np.uint8)
638
-
639
- # ---------- import + runtime patch ----------
640
- def _import_matanyone(self):
641
- """
642
- Try both import layouts (repo as top-level vs packaged).
643
- Returns (InferenceCore, get_matanyone_model, group_modules_module).
644
- """
645
- # Ensure repo on sys.path
646
- repo_path = ensure_repo("matanyone", MATANYONE_REPO_URL)
647
- if not repo_path:
648
- raise ImportError("MatAnyone repo clone failed")
649
- repo_str = str(repo_path)
650
- if repo_str not in sys.path:
651
- sys.path.insert(0, repo_str)
652
-
653
- # Try packaged import first
654
- try:
655
- from matanyone.inference.inference_core import InferenceCore as IC
656
- from matanyone.utils.get_default_model import get_matanyone_model as get_model
657
- from matanyone.model import group_modules as gm
658
- return IC, get_model, gm
659
- except Exception:
660
- pass
661
-
662
- # Try repo-root import (modules are at repo root)
663
- try:
664
- from inference.inference_core import InferenceCore as IC # type: ignore
665
- from utils.get_default_model import get_matanyone_model as get_model # type: ignore
666
- import model.group_modules as gm # type: ignore
667
- return IC, get_model, gm
668
- except Exception as e:
669
- raise ImportError(f"MatAnyone import failed from both paths: {e}")
670
-
671
- def _patch_group_distributor(self, gm_module) -> bool:
672
- """
673
- Patch GroupDistributor.forward so it promotes x to 6D with a singleton
674
- group dim (G=1) when needed, avoiding 5-vs-6D concat errors.
675
- """
676
- if not hasattr(gm_module, "GroupDistributor"):
677
- logger.warning("MatAnyone: GroupDistributor not found; skip patch")
678
- return False
679
-
680
- cls = gm_module.GroupDistributor
681
- if getattr(cls, "_bgx_patched", False):
682
- return True
683
-
684
- orig_forward = cls.forward
685
-
686
- def _wrapped_forward(self, x, g):
687
- # Normalize x to 6D: [B,T,G,C,H,W] with G=1 if missing
688
- def to_6d_for_x(t):
689
- while t.dim() < 5: # e.g., CHW -> 1,1,C,H,W
690
- t = t.unsqueeze(0)
691
- if t.dim() == 5: # [B,T,C,H,W] -> insert G at dim=2
692
- t = t.unsqueeze(2)
693
- return t
694
-
695
- def to_6d_for_g(t):
696
- while t.dim() < 6:
697
- t = t.unsqueeze(0)
698
- return t
699
-
700
- try:
701
- if x.dim() != 6:
702
- x = to_6d_for_x(x)
703
- if g.dim() != 6:
704
- g = to_6d_for_g(g)
705
- # Now both 6D, cat along dim=2 as repo expects
706
- return orig_forward(self, x, g)
707
- except Exception as e:
708
- # Last resort: broadcast x's group dim to 1 and retry
709
- try:
710
- if x.dim() == 6 and x.size(2) != 1:
711
- x = x.narrow(2, 0, 1) # force G=1
712
- return orig_forward(self, x, g)
713
- except Exception as e2:
714
- raise e2
715
-
716
- cls.forward = _wrapped_forward
717
- cls._bgx_patched = True
718
- logger.info("MatAnyone: GroupDistributor.forward patched for group dim alignment")
719
- return True
720
-
721
- # ---------- initialization ----------
722
- def initialize(self) -> bool:
723
- if not TORCH_AVAILABLE:
724
- state.matanyone_error = "PyTorch required"
725
- return False
726
- with memory_manager.mem_context("MatAnyone init"):
727
- try:
728
- IC, get_model, gm = self._import_matanyone()
729
-
730
- ckpt = CHECKPOINTS / MATANYONE_CHECKPOINT_NAME
731
- if not ckpt.exists():
732
- ok = download_file(MATANYONE_CHECKPOINT_URL, ckpt, "MatAnyone")
733
- if not ok:
734
- state.matanyone_error = "Checkpoint download failed"
735
- return False
736
-
737
- # Build network & core
738
- net = get_model(str(ckpt), device=DEVICE)
739
- self.core = IC(net)
740
-
741
- # Patch distributor at runtime
742
- self._patch_group_distributor(gm)
743
-
744
- self.initialized = True
745
- state.matanyone_ready = True
746
- logger.info("MatAnyone initialized and patched")
747
- return True
748
- except Exception as e:
749
- state.matanyone_error = f"MatAnyone init error: {e}"
750
- logger.error(f"MatAnyone init error: {e}")
751
- return False
752
-
753
- # ---------- safe call variants into InferenceCore ----------
754
- def _call_step_seed(self, img_chw: "torch.Tensor", prob_hw: "torch.Tensor", prob_1hw: "torch.Tensor"):
755
- """
756
- Try a few input shapes the core accepts, *without* adding fake time dims.
757
- Order:
758
- 1) (3,H,W) + (H,W)
759
- 2) (3,H,W) + (1,H,W)
760
- 3) (3,H,W) only (fallback)
761
- """
762
- last_err = None
763
- trials = [
764
- ("chw_hw" , lambda: self.core.step(img_chw, prob_hw)),
765
- ("chw_1hw", lambda: self.core.step(img_chw, prob_1hw)),
766
- ("chw" , lambda: self.core.step(img_chw)),
767
- ]
768
- for name, fn in trials:
769
- try:
770
- out = fn()
771
- logger.debug(f"MatAnyone.step seed variant ok: {name}")
772
- return out
773
- except Exception as e:
774
- last_err = e
775
- logger.debug(f"MatAnyone.step seed variant failed: {name} -> {e}")
776
- raise last_err or RuntimeError("MatAnyone step (seed, CHW-only) failed for all tried shapes")
777
-
778
- def _call_step_noseed(self, img_chw: "torch.Tensor"):
779
- try:
780
- return self.core.step(img_chw)
781
- except TypeError:
782
- # Some builds demand positional only
783
- return self.core.step(img_chw)
784
-
785
- # ---------- main video processing ----------
786
- def process_video(self, input_path: str, mask_path: str, output_path: str) -> str:
787
- if not self.initialized or self.core is None:
788
- raise RuntimeError("MatAnyone not initialized")
789
-
790
- out_dir = Path(output_path)
791
- out_dir.mkdir(parents=True, exist_ok=True)
792
- alpha_path = out_dir / "alpha.mp4"
793
-
794
- cap = cv2.VideoCapture(input_path)
795
- if not cap.isOpened():
796
- raise RuntimeError("Could not open input video")
797
-
798
- fps = cap.get(cv2.CAP_PROP_FPS) or 24.0
799
- w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
800
- h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
801
-
802
- seed_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
803
- if seed_mask is None:
804
- cap.release()
805
- raise RuntimeError("Seed mask read failed")
806
-
807
- prob_hw = self._prob_hw_from_mask(seed_mask, w, h) # (H,W)
808
- prob_1hw = self._prob_1hw_from_mask(seed_mask, w, h) # (1,H,W)
809
-
810
- tmp_dir = TEMP_DIR / f"ma_{int(time.time())}_{random.randint(1000,9999)}"
811
- tmp_dir.mkdir(parents=True, exist_ok=True)
812
- memory_manager.register_temp_file(str(tmp_dir))
813
-
814
- frame_idx = 0
815
-
816
- # First frame (with soft prob)
817
- ok, frame_bgr = cap.read()
818
- if not ok or frame_bgr is None:
819
- cap.release(); raise RuntimeError("Empty first frame")
820
- frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
821
- img_chw = self._to_chw_float(frame_rgb01)
822
-
823
- with torch.no_grad():
824
- out_prob = self._call_step_seed(img_chw, prob_hw, prob_1hw)
825
-
826
- alpha_u8 = self._alpha_to_u8_hw(out_prob)
827
- cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
828
- frame_idx += 1
829
-
830
- # Remaining frames (no mask)
831
- while True:
832
- ok, frame_bgr = cap.read()
833
- if not ok or frame_bgr is None:
834
- break
835
- frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
836
- img_chw = self._to_chw_float(frame_rgb01)
837
-
838
- with torch.no_grad():
839
- out_prob = self._call_step_noseed(img_chw)
840
-
841
- alpha_u8 = self._alpha_to_u8_hw(out_prob)
842
- cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
843
- frame_idx += 1
844
-
845
- cap.release()
846
-
847
- # Encode PNGs → mp4 (gray)
848
- list_file = tmp_dir / "list.txt"
849
- with open(list_file, "w") as f:
850
- for i in range(frame_idx):
851
- f.write(f"file '{(tmp_dir / f'{i:06d}.png').as_posix()}'\n")
852
-
853
- cmd = [
854
- "ffmpeg", "-y", "-hide_banner", "-loglevel", "error",
855
- "-f", "concat", "-safe", "0",
856
- "-r", f"{fps:.6f}",
857
- "-i", str(list_file),
858
- "-vf", f"format=gray,scale={w}:{h}:flags=area",
859
- "-pix_fmt", FFMPEG_PIXEL_FORMAT,
860
- "-c:v", "libx264", "-preset", FFMPEG_PRESET, "-crf", str(FFMPEG_DEFAULT_CRF),
861
- str(alpha_path),
862
- ]
863
- subprocess.run(cmd, check=True)
864
- return str(alpha_path)
865
-
866
-
867
- # =============================================================================
868
- # CHAPTER 8: AI BACKGROUNDS - Dynamic model names and URLs
869
- # =============================================================================
870
- def _maybe_enable_xformers(pipe):
871
- try:
872
- pipe.enable_xformers_memory_efficient_attention()
873
- except Exception:
874
- pass
875
-
876
- def _setup_memory_efficient_pipeline(pipe, require_gpu: bool):
877
- _maybe_enable_xformers(pipe)
878
- if not require_gpu:
879
- try:
880
- if hasattr(pipe, "enable_attention_slicing"):
881
- pipe.enable_attention_slicing("auto")
882
- if hasattr(pipe, "enable_model_cpu_offload"):
883
- pipe.enable_model_cpu_offload()
884
- if hasattr(pipe, "enable_sequential_cpu_offload"):
885
- pipe.enable_sequential_cpu_offload()
886
- except Exception:
887
- pass
888
-
889
- def generate_sdxl_background(width:int, height:int, prompt:str, steps:int=30, guidance:float=7.0,
890
- seed:Optional[int]=None, require_gpu:bool=False) -> str:
891
- if not TORCH_AVAILABLE:
892
- raise RuntimeError("PyTorch required for SDXL")
893
- with memory_manager.mem_context("SDXL background"):
894
- try:
895
- from diffusers import StableDiffusionXLPipeline
896
- except ImportError as e:
897
- raise RuntimeError("Install diffusers/transformers/accelerate") from e
898
-
899
- if require_gpu and not CUDA_AVAILABLE:
900
- raise RuntimeError("Force GPU enabled but CUDA not available")
901
-
902
- device = "cuda" if CUDA_AVAILABLE else "cpu"
903
- torch_dtype = torch.float16 if CUDA_AVAILABLE else torch.float32
904
-
905
- generator = torch.Generator(device=device)
906
- if seed is None:
907
- seed = random.randint(0, 2**31 - 1)
908
- generator.manual_seed(int(seed))
909
-
910
- pipe = StableDiffusionXLPipeline.from_pretrained(
911
- DEFAULT_DIFFUSION_MODELS["SDXL"],
912
- torch_dtype=torch_dtype,
913
- add_watermarker=False,
914
- ).to(device)
915
-
916
- _setup_memory_efficient_pipeline(pipe, require_gpu)
917
-
918
- enhanced = f"{prompt}, professional studio lighting, high detail, clean composition"
919
- img = pipe(
920
- prompt=enhanced,
921
- height=int(height),
922
- width=int(width),
923
- num_inference_steps=int(steps),
924
- guidance_scale=float(guidance),
925
- generator=generator
926
- ).images[0]
927
-
928
- out = TEMP_DIR / f"sdxl_bg_{int(time.time())}_{seed or 0:08d}.jpg"
929
- img.save(out, quality=95, optimize=True)
930
- memory_manager.register_temp_file(str(out))
931
- del pipe, img
932
- return str(out)
933
-
934
- def generate_playground_v25_background(width:int, height:int, prompt:str, steps:int=30, guidance:float=7.0,
935
- seed:Optional[int]=None, require_gpu:bool=False) -> str:
936
- if not TORCH_AVAILABLE:
937
- raise RuntimeError("PyTorch required for Playground v2.5")
938
- with memory_manager.mem_context("Playground v2.5 background"):
939
- try:
940
- from diffusers import DiffusionPipeline
941
- except ImportError as e:
942
- raise RuntimeError("Install diffusers/transformers/accelerate") from e
943
-
944
- if require_gpu and not CUDA_AVAILABLE:
945
- raise RuntimeError("Force GPU enabled but CUDA not available")
946
-
947
- device = "cuda" if CUDA_AVAILABLE else "cpu"
948
- torch_dtype = torch.float16 if CUDA_AVAILABLE else torch.float32
949
-
950
- generator = torch.Generator(device=device)
951
- if seed is None:
952
- seed = random.randint(0, 2**31 - 1)
953
- generator.manual_seed(int(seed))
954
-
955
- pipe = DiffusionPipeline.from_pretrained(DEFAULT_DIFFUSION_MODELS["Playground"], torch_dtype=torch_dtype).to(device)
956
- _setup_memory_efficient_pipeline(pipe, require_gpu)
957
-
958
- enhanced = f"{prompt}, professional quality, soft light, minimal distractions"
959
- img = pipe(
960
- prompt=enhanced,
961
- height=int(height),
962
- width=int(width),
963
- num_inference_steps=int(steps),
964
- guidance_scale=float(guidance),
965
- generator=generator
966
- ).images[0]
967
-
968
- out = TEMP_DIR / f"pg25_bg_{int(time.time())}_{seed or 0:08d}.jpg"
969
- img.save(out, quality=95, optimize=True)
970
- memory_manager.register_temp_file(str(out))
971
- del pipe, img
972
- return str(out)
973
-
974
- def generate_sd15_background(width:int, height:int, prompt:str, steps:int=25, guidance:float=7.5,
975
- seed:Optional[int]=None, require_gpu:bool=False) -> str:
976
- if not TORCH_AVAILABLE:
977
- raise RuntimeError("PyTorch required for SD 1.5")
978
- with memory_manager.mem_context("SD1.5 background"):
979
- try:
980
- from diffusers import StableDiffusionPipeline
981
- except ImportError as e:
982
- raise RuntimeError("Install diffusers/transformers/accelerate") from e
983
-
984
- if require_gpu and not CUDA_AVAILABLE:
985
- raise RuntimeError("Force GPU enabled but CUDA not available")
986
-
987
- device = "cuda" if CUDA_AVAILABLE else "cpu"
988
- torch_dtype = torch.float16 if CUDA_AVAILABLE else torch.float32
989
-
990
- generator = torch.Generator(device=device)
991
- if seed is None:
992
- seed = random.randint(0, 2**31 - 1)
993
- generator.manual_seed(int(seed))
994
-
995
- pipe = StableDiffusionPipeline.from_pretrained(
996
- DEFAULT_DIFFUSION_MODELS["SD15"],
997
- torch_dtype=torch_dtype,
998
- safety_checker=None,
999
- requires_safety_checker=False
1000
- ).to(device)
1001
-
1002
- _setup_memory_efficient_pipeline(pipe, require_gpu)
1003
-
1004
- enhanced = f"{prompt}, professional background, clean composition"
1005
- img = pipe(
1006
- prompt=enhanced,
1007
- height=int(height),
1008
- width=int(width),
1009
- num_inference_steps=int(steps),
1010
- guidance_scale=float(guidance),
1011
- generator=generator
1012
- ).images[0]
1013
-
1014
- out = TEMP_DIR / f"sd15_bg_{int(time.time())}_{seed or 0:08d}.jpg"
1015
- img.save(out, quality=95, optimize=True)
1016
- memory_manager.register_temp_file(str(out))
1017
- del pipe, img
1018
- return str(out)
1019
-
1020
- def generate_openai_background(width:int, height:int, prompt:str, api_key:str, model:str=None) -> str:
1021
- if not api_key or not isinstance(api_key, str) or len(api_key) < 10:
1022
- raise RuntimeError("Missing or invalid OpenAI API key")
1023
-
1024
- model = model or OPENAI_DEFAULT_MODEL
1025
-
1026
- with memory_manager.mem_context("OpenAI background"):
1027
- headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
1028
- body = {
1029
- "model": model,
1030
- "prompt": f"{prompt}, professional background, studio lighting, minimal distractions, high detail",
1031
- "size": OPENAI_DEFAULT_SIZE,
1032
- "n": 1,
1033
- "quality": OPENAI_DEFAULT_QUALITY
1034
- }
1035
- import requests
1036
- r = requests.post(OPENAI_BASE_URL, headers=headers, data=json.dumps(body), timeout=120)
1037
- if r.status_code != 200:
1038
- raise RuntimeError(f"OpenAI API error: {r.status_code} {r.text}")
1039
- data = r.json()
1040
- b64 = data["data"][0]["b64_json"]
1041
- raw = base64.b64decode(b64)
1042
- tmp_png = TEMP_DIR / f"openai_raw_{int(time.time())}_{random.randint(1000,9999)}.png"
1043
- with open(tmp_png, "wb") as f:
1044
- f.write(raw)
1045
- img = Image.open(tmp_png).convert("RGB").resize((int(width), int(height)), Image.LANCZOS)
1046
- out = TEMP_DIR / f"openai_bg_{int(time.time())}_{random.randint(1000,9999)}.jpg"
1047
- img.save(out, quality=95, optimize=True)
1048
- try: os.unlink(tmp_png)
1049
- except Exception: pass
1050
- memory_manager.register_temp_file(str(out))
1051
- return str(out)
1052
-
1053
- def generate_ai_background_router(width:int, height:int, prompt:str, model:str="SDXL",
1054
- steps:int=30, guidance:float=7.0, seed:Optional[int]=None,
1055
- openai_key:Optional[str]=None, require_gpu:bool=False) -> str:
1056
- try:
1057
- if model == "OpenAI (gpt-image-1)":
1058
- if not openai_key:
1059
- raise RuntimeError("OpenAI API key not provided")
1060
- return generate_openai_background(width, height, prompt, openai_key, OPENAI_DEFAULT_MODEL)
1061
- elif model == "Playground v2.5":
1062
- return generate_playground_v25_background(width, height, prompt, steps, guidance, seed, require_gpu)
1063
- elif model == "SDXL":
1064
- return generate_sdxl_background(width, height, prompt, steps, guidance, seed, require_gpu)
1065
- else:
1066
- return generate_sd15_background(width, height, prompt, steps, guidance, seed, require_gpu)
1067
- except Exception as e:
1068
- logger.warning(f"{model} generation failed: {e}; falling back to SD1.5/gradient")
1069
- try:
1070
- return generate_sd15_background(width, height, prompt, steps, guidance, seed, require_gpu=False)
1071
- except Exception:
1072
- grad = _make_vertical_gradient(width, height, (235, 240, 245), (210, 220, 230))
1073
- out = TEMP_DIR / f"bg_fallback_{int(time.time())}.jpg"
1074
- cv2.imwrite(str(out), grad)
1075
- memory_manager.register_temp_file(str(out))
1076
- return str(out)
1077
-
1078
- # =============================================================================
1079
- # CHAPTER 9: CHUNKED PROCESSOR - Configurable chunk size
1080
- # =============================================================================
1081
- class ChunkedVideoProcessor:
1082
- def __init__(self, chunk_size_frames: int = None):
1083
- self.chunk_size = int(chunk_size_frames or CHUNK_SIZE_FRAMES)
1084
-
1085
- def _extract_chunk(self, video_path: str, start_frame: int, end_frame: int, fps: float) -> str:
1086
- chunk_path = str(TEMP_DIR / f"chunk_{start_frame}_{end_frame}_{random.randint(1000,9999)}.mp4")
1087
- start_time = start_frame / fps
1088
- duration = max(0.001, (end_frame - start_frame) / fps)
1089
- cmd = [
1090
- "ffmpeg", "-y", "-hide_banner", "-loglevel", "error",
1091
- "-ss", f"{start_time:.6f}", "-i", video_path,
1092
- "-t", f"{duration:.6f}",
1093
- "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2",
1094
- "-c:v", "libx264", "-preset", "veryfast", "-crf", "20",
1095
- "-an", chunk_path
1096
- ]
1097
- subprocess.run(cmd, check=True)
1098
- return chunk_path
1099
-
1100
- def _merge_chunks(self, chunk_paths: List[str], fps: float, width: int, height: int) -> str:
1101
- if not chunk_paths:
1102
- raise ValueError("No chunks to merge")
1103
- if len(chunk_paths) == 1:
1104
- return chunk_paths[0]
1105
- concat_file = TEMP_DIR / f"concat_{random.randint(1000,9999)}.txt"
1106
- with open(concat_file, "w") as f:
1107
- for c in chunk_paths:
1108
- f.write(f"file '{c}'\n")
1109
- out = TEMP_DIR / f"merged_{random.randint(1000,9999)}.mp4"
1110
- cmd = ["ffmpeg", "-y", "-hide_banner", "-loglevel", "error",
1111
- "-f", "concat", "-safe", "0", "-i", str(concat_file),
1112
- "-c", "copy", str(out)]
1113
- subprocess.run(cmd, check=True)
1114
- return str(out)
1115
-
1116
- def process_video_chunks(self, video_path: str, processor_func, **kwargs) -> str:
1117
- cap = cv2.VideoCapture(video_path)
1118
- total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
1119
- fps = cap.get(cv2.CAP_PROP_FPS) or 24.0
1120
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
1121
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
1122
- cap.release()
1123
-
1124
- processed: List[str] = []
1125
- for start in range(0, total, self.chunk_size):
1126
- end = min(start + self.chunk_size, total)
1127
- with memory_manager.mem_context(f"chunk {start}-{end}"):
1128
- ch = self._extract_chunk(video_path, start, end, fps)
1129
- memory_manager.register_temp_file(ch)
1130
- out = processor_func(ch, **kwargs)
1131
- memory_manager.register_temp_file(out)
1132
- processed.append(out)
1133
- return self._merge_chunks(processed, fps, width, height)
1134
-
1135
- # =============================================================================
1136
- # CHAPTER 10: MAIN PIPELINE - Enhanced positioning (unchanged logic)
1137
- # =============================================================================
1138
- def process_video_main(
1139
- video_path: str,
1140
- background_path: Optional[str] = None,
1141
- trim_duration: Optional[float] = None,
1142
- crf: int = None,
1143
- preserve_audio_flag: bool = True,
1144
- placement: Optional[dict] = None,
1145
- use_chunked_processing: bool = False,
1146
- progress: gr.Progress = gr.Progress(track_tqdm=True),
1147
- ) -> Tuple[Optional[str], str]:
1148
-
1149
- messages: List[str] = []
1150
- crf = crf or FFMPEG_DEFAULT_CRF
1151
-
1152
- with memory_manager.mem_context("Pipeline"):
1153
- try:
1154
- progress(0, desc="Initializing models")
1155
- sam2 = SAM2Handler()
1156
- matanyone = MatAnyoneHandler()
1157
-
1158
- if not sam2.initialize():
1159
- return None, f"SAM2 init failed: {state.sam2_error}"
1160
- if not matanyone.initialize():
1161
- return None, f"MatAnyone init failed: {state.matanyone_error}"
1162
- messages.append("✅ SAM2 & MatAnyone initialized")
1163
-
1164
- progress(0.1, desc="Preparing video")
1165
- input_video = video_path
1166
-
1167
- # Optional trim
1168
- if trim_duration and float(trim_duration) > 0:
1169
- trimmed = TEMP_DIR / f"trimmed_{int(time.time())}_{random.randint(1000,9999)}.mp4"
1170
- memory_manager.register_temp_file(str(trimmed))
1171
- with VideoFileClip(video_path) as clip:
1172
- d = min(float(trim_duration), float(clip.duration or trim_duration))
1173
- sub = clip.subclip(0, d)
1174
- write_video_h264(sub, str(trimmed), crf=int(crf))
1175
- sub.close()
1176
- input_video = str(trimmed)
1177
- messages.append(f"✂️ Trimmed to {d:.1f}s")
1178
- else:
1179
- with VideoFileClip(video_path) as clip:
1180
- messages.append(f"🎞️ Full video: {clip.duration:.1f}s")
1181
-
1182
- progress(0.2, desc="Creating SAM2 mask")
1183
- cap = cv2.VideoCapture(input_video)
1184
- ret, first_frame = cap.read()
1185
- cap.release()
1186
- if not ret or first_frame is None:
1187
- return None, "Could not read video"
1188
- h, w = first_frame.shape[:2]
1189
- rgb0 = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
1190
- mask = sam2.create_mask(rgb0)
1191
- if mask is None:
1192
- return None, "SAM2 mask failed"
1193
-
1194
- mask_path = TEMP_DIR / f"mask_{int(time.time())}_{random.randint(1000,9999)}.png"
1195
- memory_manager.register_temp_file(str(mask_path))
1196
- cv2.imwrite(str(mask_path), mask)
1197
- messages.append("✅ Person mask created")
1198
-
1199
- progress(0.35, desc="Matting video")
1200
- if use_chunked_processing:
1201
- chunker = ChunkedVideoProcessor(chunk_size_frames=CHUNK_SIZE_FRAMES)
1202
- alpha_video = chunker.process_video_chunks(
1203
- input_video,
1204
- lambda chunk_path, **_k: matanyone.process_video(
1205
- input_path=chunk_path,
1206
- mask_path=str(mask_path),
1207
- output_path=str(TEMP_DIR / f"matanyone_chunk_{int(time.time())}_{random.randint(1000,9999)}")
1208
- )
1209
- )
1210
- memory_manager.register_temp_file(alpha_video)
1211
- else:
1212
- out_dir = TEMP_DIR / f"matanyone_out_{int(time.time())}_{random.randint(1000,9999)}"
1213
- out_dir.mkdir(parents=True, exist_ok=True)
1214
- memory_manager.register_temp_file(str(out_dir))
1215
- alpha_video = matanyone.process_video(
1216
- input_path=input_video,
1217
- mask_path=str(mask_path),
1218
- output_path=str(out_dir)
1219
- )
1220
-
1221
- if not alpha_video or not os.path.exists(alpha_video):
1222
- return None, "MatAnyone did not produce alpha video"
1223
- messages.append("✅ Alpha video generated")
1224
-
1225
- progress(0.55, desc="Preparing background")
1226
- original_clip = VideoFileClip(input_video)
1227
- alpha_clip = VideoFileClip(alpha_video)
1228
-
1229
- if background_path and os.path.exists(background_path):
1230
- messages.append("🖼️ Using background file")
1231
- bg_bgr = cv2.imread(background_path)
1232
- bg_bgr = cv2.resize(bg_bgr, (w, h))
1233
- bg_rgb = cv2.cvtColor(bg_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
1234
- else:
1235
- messages.append("🖼️ Using gradient background")
1236
- grad = _make_vertical_gradient(w, h, (200, 205, 215), (160, 170, 190))
1237
- bg_rgb = cv2.cvtColor(grad, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
1238
-
1239
- # FIXED: Enhanced placement parameters with validation and debugging
1240
- placement = placement or {}
1241
- px = max(0.0, min(1.0, float(placement.get("x", 0.5))))
1242
- py = max(0.0, min(1.0, float(placement.get("y", 0.75))))
1243
- ps = max(0.3, min(2.0, float(placement.get("scale", 1.0))))
1244
- feather_px = max(0, min(50, int(placement.get("feather", 3))))
1245
-
1246
- # Debug logging for placement parameters
1247
- logger.info(f"POSITIONING DEBUG: px={px:.3f}, py={py:.3f}, ps={ps:.3f}, feather={feather_px}")
1248
- logger.info(f"VIDEO DIMENSIONS: {w}x{h}")
1249
- logger.info(f"TARGET CENTER: ({int(px * w)}, {int(py * h)})")
1250
-
1251
- frame_count = 0
1252
- def composite_frame(get_frame, t):
1253
- nonlocal frame_count
1254
- frame_count += 1
1255
-
1256
- # Get original frame
1257
- frame = get_frame(t).astype(np.float32) / 255.0
1258
- hh, ww = frame.shape[:2]
1259
-
1260
- # FIXED: Better alpha temporal synchronization
1261
- alpha_duration = getattr(alpha_clip, 'duration', None)
1262
- if alpha_duration and alpha_duration > 0:
1263
- # Ensure we don't go beyond alpha video duration
1264
- alpha_t = min(t, alpha_duration - 0.01)
1265
- alpha_t = max(0.0, alpha_t)
1266
- else:
1267
- alpha_t = 0.0
1268
-
1269
- try:
1270
- a = alpha_clip.get_frame(alpha_t)
1271
- # Handle multi-channel alpha
1272
- if a.ndim == 3:
1273
- a = a[:, :, 0]
1274
- a = a.astype(np.float32) / 255.0
1275
-
1276
- # FIXED: Ensure alpha matches frame dimensions exactly
1277
- if a.shape != (hh, ww):
1278
- logger.warning(f"Alpha size mismatch: {a.shape} vs {(hh, ww)}, resizing...")
1279
- a = cv2.resize(a, (ww, hh), interpolation=cv2.INTER_LINEAR)
1280
-
1281
- except Exception as e:
1282
- logger.error(f"Alpha frame error at t={t:.3f}: {e}")
1283
- return (bg_rgb * 255).astype(np.uint8)
1284
-
1285
- # FIXED: Calculate scaled dimensions with better rounding
1286
- sw = max(1, round(ww * ps)) # Use round instead of int for better precision
1287
- sh = max(1, round(hh * ps))
1288
-
1289
- # FIXED: Scale both frame and alpha consistently
1290
- try:
1291
- fg_scaled = cv2.resize(frame, (sw, sh), interpolation=cv2.INTER_AREA if ps < 1.0 else cv2.INTER_LINEAR)
1292
- a_scaled = cv2.resize(a, (sw, sh), interpolation=cv2.INTER_AREA if ps < 1.0 else cv2.INTER_LINEAR)
1293
- except Exception as e:
1294
- logger.error(f"Scaling error: {e}")
1295
- return (bg_rgb * 255).astype(np.uint8)
1296
-
1297
- # Create canvases
1298
- fg_canvas = np.zeros_like(frame, dtype=np.float32)
1299
- a_canvas = np.zeros((hh, ww), dtype=np.float32)
1300
-
1301
- # FIXED: More precise center calculations
1302
- cx = round(px * ww)
1303
- cy = round(py * hh)
1304
-
1305
- # FIXED: Use floor division for consistent centering
1306
- x0 = cx - sw // 2
1307
- y0 = cy - sh // 2
1308
-
1309
- # Debug logging for first few frames
1310
- if frame_count <= 3:
1311
- logger.info(f"FRAME {frame_count}: scaled_size=({sw}, {sh}), center=({cx}, {cy}), top_left=({x0}, {y0})")
1312
-
1313
- # FIXED: Robust bounds checking with edge case handling
1314
- xs0 = max(0, x0)
1315
- ys0 = max(0, y0)
1316
- xs1 = min(ww, x0 + sw)
1317
- ys1 = min(hh, y0 + sh)
1318
-
1319
- # Check for valid placement region
1320
- if xs1 <= xs0 or ys1 <= ys0:
1321
- if frame_count <= 3:
1322
- logger.warning(f"Subject outside bounds: dest=({xs0},{ys0})-({xs1},{ys1})")
1323
- return (bg_rgb * 255).astype(np.uint8)
1324
-
1325
- # FIXED: Calculate source region with bounds validation
1326
- src_x0 = xs0 - x0 # Will be 0 if x0 >= 0, positive if x0 < 0
1327
- src_y0 = ys0 - y0 # Will be 0 if y0 >= 0, positive if y0 < 0
1328
- src_x1 = src_x0 + (xs1 - xs0)
1329
- src_y1 = src_y0 + (ys1 - ys0)
1330
-
1331
- # Validate source bounds
1332
- if (src_x1 > sw or src_y1 > sh or src_x0 < 0 or src_y0 < 0 or
1333
- src_x1 <= src_x0 or src_y1 <= src_y0):
1334
- if frame_count <= 3:
1335
- logger.error(f"Invalid source region: ({src_x0},{src_y0})-({src_x1},{src_y1}) for {sw}x{sh} scaled")
1336
- return (bg_rgb * 255).astype(np.uint8)
1337
-
1338
- # FIXED: Safe canvas placement with error handling
1339
- try:
1340
- fg_canvas[ys0:ys1, xs0:xs1, :] = fg_scaled[src_y0:src_y1, src_x0:src_x1, :]
1341
- a_canvas[ys0:ys1, xs0:xs1] = a_scaled[src_y0:src_y1, src_x0:src_x1]
1342
- except Exception as e:
1343
- logger.error(f"Canvas placement failed: {e}")
1344
- logger.error(f"Dest: [{ys0}:{ys1}, {xs0}:{xs1}], Src: [{src_y0}:{src_y1}, {src_x0}:{src_x1}]")
1345
- return (bg_rgb * 255).astype(np.uint8)
1346
-
1347
- # FIXED: Apply feathering with bounds checking
1348
- if feather_px > 0:
1349
- kernel_size = max(3, feather_px * 2 + 1)
1350
- if kernel_size % 2 == 0:
1351
- kernel_size += 1 # Ensure odd kernel size
1352
- try:
1353
- a_canvas = cv2.GaussianBlur(a_canvas, (kernel_size, kernel_size), feather_px / 3.0)
1354
- except Exception as e:
1355
- logger.warning(f"Feathering failed: {e}")
1356
-
1357
- # FIXED: Composite with proper alpha handling
1358
- a3 = np.expand_dims(a_canvas, axis=2) # More explicit than [:, :, None]
1359
- comp = a3 * fg_canvas + (1.0 - a3) * bg_rgb
1360
- result = np.clip(comp * 255, 0, 255).astype(np.uint8)
1361
-
1362
- return result
1363
-
1364
- progress(0.7, desc="Compositing")
1365
- final_clip = original_clip.fl(composite_frame)
1366
-
1367
- output_path = OUT_DIR / f"processed_{int(time.time())}_{random.randint(1000,9999)}.mp4"
1368
- temp_video_path = TEMP_DIR / f"temp_video_{int(time.time())}_{random.randint(1000,9999)}.mp4"
1369
- memory_manager.register_temp_file(str(temp_video_path))
1370
-
1371
- write_video_h264(final_clip, str(temp_video_path), crf=int(crf))
1372
- original_clip.close(); alpha_clip.close(); final_clip.close()
1373
-
1374
- progress(0.85, desc="Merging audio")
1375
- if preserve_audio_flag:
1376
- success = run_ffmpeg([
1377
- "-i", str(temp_video_path),
1378
- "-i", video_path,
1379
- "-map", "0:v:0",
1380
- "-map", "1:a:0?",
1381
- "-c:v", "copy",
1382
- "-c:a", FFMPEG_AUDIO_CODEC,
1383
- "-b:a", FFMPEG_AUDIO_BITRATE,
1384
- "-shortest",
1385
- str(output_path)
1386
- ], fail_ok=True)
1387
- if success:
1388
- messages.append("🔊 Original audio preserved")
1389
- else:
1390
- shutil.copy2(str(temp_video_path), str(output_path))
1391
- messages.append("⚠️ Audio merge failed, saved w/o audio")
1392
- else:
1393
- shutil.copy2(str(temp_video_path), str(output_path))
1394
- messages.append("🔇 Saved without audio")
1395
-
1396
- messages.append("✅ Done")
1397
- stats = memory_manager.get_memory_stats()
1398
- messages.append(f"📊 CPU {stats.cpu_memory_mb:.1f}MB, GPU {stats.gpu_memory_mb:.1f}MB")
1399
- messages.append(f"🎯 Processed {frame_count} frames with placement ({px:.2f}, {py:.2f}) @ {ps:.2f}x scale")
1400
- progress(1.0, desc="Done")
1401
- return str(output_path), "\n".join(messages)
1402
-
1403
- except Exception as e:
1404
- err = f"Processing failed: {str(e)}\n\n{traceback.format_exc()}"
1405
- return None, err
1406
-
1407
- # =============================================================================
1408
- # CHAPTER 11: GRADIO UI - Dynamic titles and configurations
1409
- # =============================================================================
1410
- def create_interface():
1411
- def diag():
1412
- return state.status_text()
1413
-
1414
- def cleanup():
1415
- memory_manager.aggressive_cleanup()
1416
- s = memory_manager.get_memory_stats()
1417
- return f"🧹 Cleanup\nCPU: {s.cpu_memory_mb:.1f}MB\nGPU: {s.gpu_memory_mb:.1f}MB\nTemp: {s.temp_files_count} files"
1418
-
1419
- def preload(ai_model, openai_key, force_gpu, progress=gr.Progress()):
1420
- try:
1421
- progress(0, desc="Preloading...")
1422
- msg = ""
1423
- if ai_model in ("SDXL", "Playground v2.5", "SD 1.5 (fallback)"):
1424
- try:
1425
- if ai_model == "SDXL":
1426
- _ = generate_sdxl_background(64, 64, "plain", steps=2, guidance=3.5, seed=42, require_gpu=bool(force_gpu))
1427
- elif ai_model == "Playground v2.5":
1428
- _ = generate_playground_v25_background(64, 64, "plain", steps=2, guidance=3.5, seed=42, require_gpu=bool(force_gpu))
1429
- else:
1430
- _ = generate_sd15_background(64, 64, "plain", steps=2, guidance=3.5, seed=42, require_gpu=bool(force_gpu))
1431
- msg += f"{ai_model} preloaded.\n"
1432
- except Exception as e:
1433
- msg += f"{ai_model} preload failed: {e}\n"
1434
-
1435
- _reset_hydra()
1436
- s, m = SAM2Handler(), MatAnyoneHandler()
1437
- ok_s = s.initialize()
1438
- _reset_hydra()
1439
- ok_m = m.initialize()
1440
- progress(1.0, desc="Preload complete")
1441
- return f"✅ Preload\n{msg}SAM2: {'ready' if ok_s else 'failed'}\nMatAnyone: {'ready' if ok_m else 'failed'}"
1442
- except Exception as e:
1443
- return f"❌ Preload error: {e}"
1444
-
1445
- def generate_background_safe(video_file, ai_prompt, ai_steps, ai_guidance, ai_seed,
1446
- ai_model, openai_key, force_gpu, progress=gr.Progress()):
1447
- if not video_file:
1448
- return None, "Upload a video first", gr.update(visible=False), None
1449
- with memory_manager.mem_context("Background generation"):
1450
- try:
1451
- video_path = video_file.name if hasattr(video_file, 'name') else str(video_file)
1452
- if not os.path.exists(video_path):
1453
- return None, "Video not found", gr.update(visible=False), None
1454
- cap = cv2.VideoCapture(video_path)
1455
- if not cap.isOpened():
1456
- return None, "Could not open video", gr.update(visible=False), None
1457
- ret, frame = cap.read()
1458
- cap.release()
1459
- if not ret or frame is None:
1460
- return None, "Could not read frame", gr.update(visible=False), None
1461
- h, w = int(frame.shape[0]), int(frame.shape[1])
1462
-
1463
- steps = max(1, min(50, int(ai_steps or 30)))
1464
- guidance = max(1.0, min(15.0, float(ai_guidance or 7.0)))
1465
- try:
1466
- seed_val = int(ai_seed) if ai_seed and str(ai_seed).strip() else None
1467
- except Exception:
1468
- seed_val = None
1469
-
1470
- progress(0.1, desc=f"Generating {ai_model}")
1471
- bg_path = generate_ai_background_router(
1472
- width=w, height=h, prompt=str(ai_prompt or "professional office background").strip(),
1473
- model=str(ai_model or "SDXL"), steps=steps, guidance=guidance,
1474
- seed=seed_val, openai_key=openai_key, require_gpu=bool(force_gpu)
1475
- )
1476
- progress(1.0, desc="Background ready")
1477
- if bg_path and os.path.exists(bg_path):
1478
- return bg_path, f"AI background generated with {ai_model}", gr.update(visible=True), bg_path
1479
- else:
1480
- return None, "No output file", gr.update(visible=False), None
1481
- except Exception as e:
1482
- logger.error(f"Background generation error: {e}")
1483
- return None, f"Background generation failed: {str(e)}", gr.update(visible=False), None
1484
-
1485
- def approve_background(bg_path):
1486
- try:
1487
- if not bg_path or not (isinstance(bg_path, str) and os.path.exists(bg_path)):
1488
- return None, "Generate a background first", gr.update(visible=False)
1489
- ext = os.path.splitext(bg_path)[1].lower() or ".jpg"
1490
- safe_name = f"approved_{int(time.time())}_{random.randint(1000,9999)}{ext}"
1491
- dest = BACKGROUND_DIR / safe_name
1492
- shutil.copy2(bg_path, dest)
1493
- return str(dest), f"✅ Background approved → {dest.name}", gr.update(visible=False)
1494
- except Exception as e:
1495
- return None, f"⚠️ Approve failed: {e}", gr.update(visible=False)
1496
-
1497
- css = """
1498
- .gradio-container { font-size: 16px !important; }
1499
- label { font-size: 18px !important; font-weight: 600 !important; color: #2d3748 !important; }
1500
- .process-button { font-size: 20px !important; font-weight: 700 !important; padding: 16px 28px !important; }
1501
- .memory-info { background: #f8fafc; border: 1px solid #e2e8f0; border-radius: 8px; padding: 12px; }
1502
- """
1503
-
1504
- with gr.Blocks(title=APP_TITLE, theme=gr.themes.Soft(), css=css) as interface:
1505
- gr.Markdown(f"# 🎬 {APP_TITLE}")
1506
- gr.Markdown("_SAM2 + MatAnyone + AI Backgrounds — with strict tensor shapes & memory management_")
1507
-
1508
- gr.HTML(f"""
1509
- <div class='memory-info'>
1510
- <strong>Device:</strong> {DEVICE} &nbsp;&nbsp;
1511
- <strong>PyTorch:</strong> {'✅' if TORCH_AVAILABLE else '❌'} &nbsp;&nbsp;
1512
- <strong>CUDA:</strong> {'✅' if CUDA_AVAILABLE else '❌'} &nbsp;&nbsp;
1513
- <strong>Host:</strong> {SERVER_HOST}:{SERVER_PORT}
1514
- </div>
1515
- """)
1516
-
1517
- with gr.Row():
1518
- with gr.Column(scale=1):
1519
- video_input = gr.Video(label="Input Video")
1520
-
1521
- gr.Markdown("### Background")
1522
- bg_method = gr.Radio(choices=["Upload Image", "Gradients", "AI Generated"],
1523
- value="AI Generated", label="Background Method")
1524
-
1525
- # Upload group (hidden by default)
1526
- with gr.Group(visible=False) as upload_group:
1527
- upload_img = gr.Image(label="Background Image", type="filepath")
1528
-
1529
- # Gradient group (hidden by default)
1530
- with gr.Group(visible=False) as gradient_group:
1531
- gradient_choice = gr.Dropdown(label="Gradient Style",
1532
- choices=list(GRADIENT_PRESETS.keys()),
1533
- value=list(GRADIENT_PRESETS.keys())[0])
1534
-
1535
- # AI group (visible by default)
1536
- with gr.Group(visible=True) as ai_group:
1537
- prompt_suggestions = gr.Dropdown(label="💡 Prompt Inspiration",
1538
- choices=AI_PROMPT_SUGGESTIONS,
1539
- value=AI_PROMPT_SUGGESTIONS[0])
1540
- ai_prompt = gr.Textbox(label="Background Description",
1541
- value="professional office background", lines=3)
1542
- ai_model = gr.Radio(["SDXL", "Playground v2.5", "SD 1.5 (fallback)", "OpenAI (gpt-image-1)"],
1543
- value="SDXL", label="AI Model")
1544
- with gr.Accordion("Connect services (optional)", open=False):
1545
- openai_api_key = gr.Textbox(label="OpenAI API Key", type="password",
1546
- placeholder="sk-... (kept only in this session)")
1547
- with gr.Row():
1548
- ai_steps = gr.Slider(10, 50, value=30, step=1, label="Quality (steps)")
1549
- ai_guidance = gr.Slider(1.0, 15.0, value=7.0, step=0.1, label="Guidance")
1550
- ai_seed = gr.Number(label="Seed (optional)", precision=0)
1551
- force_gpu_ai = gr.Checkbox(value=True, label="Force GPU for AI background")
1552
- preload_btn = gr.Button("📦 Preload Models")
1553
- preload_status = gr.Textbox(label="Preload Status", lines=4)
1554
- generate_bg_btn = gr.Button("Generate AI Background", variant="primary")
1555
- ai_generated_bg = gr.Image(label="Generated Background", type="filepath")
1556
- approve_bg_btn = gr.Button("✅ Approve Background", visible=False)
1557
- approved_background_path = gr.State(value=None)
1558
- last_generated_bg = gr.State(value=None)
1559
- ai_status = gr.Textbox(label="Generation Status", lines=2)
1560
-
1561
- gr.Markdown("### Processing")
1562
- with gr.Row():
1563
- trim_enabled = gr.Checkbox(label="Trim Video", value=False)
1564
- trim_seconds = gr.Number(label="Trim Duration (seconds)", value=5, precision=1)
1565
- with gr.Row():
1566
- crf_value = gr.Slider(0, 30, value=FFMPEG_DEFAULT_CRF, step=1, label="Quality (CRF - lower=better)")
1567
- audio_enabled = gr.Checkbox(label="Preserve Audio", value=True)
1568
- with gr.Row():
1569
- use_chunked = gr.Checkbox(label="Use Chunked Processing", value=False)
1570
-
1571
- gr.Markdown("### Subject Placement")
1572
- with gr.Row():
1573
- place_x = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Horizontal")
1574
- place_y = gr.Slider(0.0, 1.0, value=0.75, step=0.01, label="Vertical")
1575
- with gr.Row():
1576
- place_scale = gr.Slider(0.3, 2.0, value=1.0, step=0.01, label="Scale")
1577
- place_feather = gr.Slider(0, 15, value=3, step=1, label="Edge feather (px)")
1578
-
1579
- process_btn = gr.Button("🚀 Process Video", variant="primary", elem_classes=["process-button"])
1580
-
1581
- gr.Markdown("### System")
1582
- with gr.Row():
1583
- diagnostics_btn = gr.Button("📊 System Diagnostics")
1584
- cleanup_btn = gr.Button("🧹 Memory Cleanup")
1585
- diagnostics_output = gr.Textbox(label="System Status", lines=10)
1586
-
1587
- with gr.Column(scale=1):
1588
- output_video = gr.Video(label="Processed Video")
1589
- download_file = gr.File(label="Download Processed Video")
1590
- status_output = gr.Textbox(label="Processing Status", lines=20)
1591
-
1592
- # --- Wiring ---
1593
- def update_background_visibility(method):
1594
- return (
1595
- gr.update(visible=(method == "Upload Image")),
1596
- gr.update(visible=(method == "Gradients")),
1597
- gr.update(visible=(method == "AI Generated")),
1598
- )
1599
-
1600
- def update_prompt_from_suggestion(suggestion):
1601
- if suggestion == AI_PROMPT_SUGGESTIONS[0]: # "Custom (write your own)"
1602
- return gr.update(value="", placeholder="Describe the background you want...")
1603
- return gr.update(value=suggestion)
1604
-
1605
- bg_method.change(
1606
- update_background_visibility,
1607
- inputs=[bg_method],
1608
- outputs=[upload_group, gradient_group, ai_group]
1609
- )
1610
- prompt_suggestions.change(update_prompt_from_suggestion, inputs=[prompt_suggestions], outputs=[ai_prompt])
1611
-
1612
- preload_btn.click(preload,
1613
- inputs=[ai_model, openai_api_key, force_gpu_ai],
1614
- outputs=[preload_status],
1615
- show_progress=True
1616
- )
1617
-
1618
- generate_bg_btn.click(
1619
- generate_background_safe,
1620
- inputs=[video_input, ai_prompt, ai_steps, ai_guidance, ai_seed, ai_model, openai_api_key, force_gpu_ai],
1621
- outputs=[ai_generated_bg, ai_status, approve_bg_btn, last_generated_bg],
1622
- show_progress=True
1623
- )
1624
- approve_bg_btn.click(
1625
- approve_background,
1626
- inputs=[ai_generated_bg],
1627
- outputs=[approved_background_path, ai_status, approve_bg_btn]
1628
- )
1629
-
1630
- diagnostics_btn.click(diag, outputs=[diagnostics_output])
1631
- cleanup_btn.click(cleanup, outputs=[diagnostics_output])
1632
-
1633
- def process_video(
1634
- video_file,
1635
- bg_method,
1636
- upload_img,
1637
- gradient_choice,
1638
- approved_background_path,
1639
- last_generated_bg,
1640
- trim_enabled, trim_seconds, crf_value, audio_enabled,
1641
- use_chunked,
1642
- place_x, place_y, place_scale, place_feather,
1643
- progress=gr.Progress(track_tqdm=True),
1644
- ):
1645
- try:
1646
- if not video_file:
1647
- return None, None, "Please upload a video file"
1648
- video_path = video_file.name if hasattr(video_file, 'name') else str(video_file)
1649
-
1650
- # Resolve background
1651
- bg_path = None
1652
- try:
1653
- if bg_method == "Upload Image" and upload_img:
1654
- bg_path = upload_img if isinstance(upload_img, str) else getattr(upload_img, "name", None)
1655
- elif bg_method == "Gradients":
1656
- cap = cv2.VideoCapture(video_path)
1657
- ret, frame = cap.read(); cap.release()
1658
- if ret and frame is not None:
1659
- h, w = frame.shape[:2]
1660
- if gradient_choice in GRADIENT_PRESETS:
1661
- grad = _make_vertical_gradient(w, h, *GRADIENT_PRESETS[gradient_choice])
1662
- tmp_bg = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False, dir=TEMP_DIR).name
1663
- cv2.imwrite(tmp_bg, grad)
1664
- memory_manager.register_temp_file(tmp_bg)
1665
- bg_path = tmp_bg
1666
- else: # AI Generated
1667
- if approved_background_path:
1668
- bg_path = approved_background_path
1669
- elif last_generated_bg and isinstance(last_generated_bg, str) and os.path.exists(last_generated_bg):
1670
- bg_path = last_generated_bg
1671
- except Exception as e:
1672
- logger.error(f"Background setup error: {e}")
1673
- return None, None, f"Background setup failed: {str(e)}"
1674
-
1675
- result_path, status = process_video_main(
1676
- video_path=video_path,
1677
- background_path=bg_path,
1678
- trim_duration=float(trim_seconds) if (trim_enabled and float(trim_seconds) > 0) else None,
1679
- crf=int(crf_value),
1680
- preserve_audio_flag=bool(audio_enabled),
1681
- placement=dict(x=float(place_x), y=float(place_y), scale=float(place_scale), feather=int(place_feather)),
1682
- use_chunked_processing=bool(use_chunked),
1683
- progress=progress,
1684
- )
1685
-
1686
- if result_path and os.path.exists(result_path):
1687
- return result_path, result_path, f"✅ Success\n\n{status}"
1688
- else:
1689
- return None, None, f"❌ Failed\n\n{status or 'Unknown error'}"
1690
- except Exception as e:
1691
- tb = traceback.format_exc()
1692
- return None, None, f"❌ Crash: {e}\n\n{tb}"
1693
-
1694
- process_btn.click(
1695
- process_video,
1696
- inputs=[
1697
- video_input,
1698
- bg_method,
1699
- upload_img,
1700
- gradient_choice,
1701
- approved_background_path, last_generated_bg,
1702
- trim_enabled, trim_seconds, crf_value, audio_enabled,
1703
- use_chunked,
1704
- place_x, place_y, place_scale, place_feather,
1705
- ],
1706
- outputs=[output_video, download_file, status_output],
1707
- show_progress=True
1708
- )
1709
-
1710
- return interface
1711
-
1712
- # =============================================================================
1713
- # CHAPTER 12: MAIN - All dynamic configuration
1714
- # =============================================================================
1715
- def main():
1716
- logger.info(f"Starting {APP_TITLE}")
1717
- logger.info(f"Configuration: {SERVER_HOST}:{SERVER_PORT}, Share: {SERVER_SHARE}, InBrowser: {SERVER_INBROWSER}")
1718
- stats = memory_manager.get_memory_stats()
1719
- logger.info(f"Initial memory: CPU {stats.cpu_memory_mb:.1f}MB, GPU {stats.gpu_memory_mb:.1f}MB")
1720
-
1721
- interface = create_interface()
1722
- interface.queue(max_size=GRADIO_QUEUE_MAX_SIZE)
1723
-
1724
- try:
1725
- interface.launch(
1726
- server_name=SERVER_HOST,
1727
- server_port=SERVER_PORT,
1728
- share=SERVER_SHARE,
1729
- inbrowser=SERVER_INBROWSER,
1730
- show_error=True
1731
- )
1732
- finally:
1733
- logger.info("Shutting down - cleanup")
1734
- memory_manager.cleanup_temp_files()
1735
- memory_manager.aggressive_cleanup()
1736
-
1737
- if __name__ == "__main__":
1738
- main()