MogensR commited on
Commit
22a6aa0
·
verified ·
1 Parent(s): 9661d53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -425
app.py CHANGED
@@ -1,254 +1,239 @@
1
  #!/usr/bin/env python3
2
  # ========================= PRE-IMPORT ENV GUARDS =========================
3
  import os
4
- # Remove invalid OMP setting or tame thread counts BEFORE importing numpy/cv2/torch
5
- os.environ.pop("OMP_NUM_THREADS", None) # or set "1"
6
  os.environ.setdefault("MKL_NUM_THREADS", "1")
7
  os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
8
  os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
9
- # Optional CUDA allocator tuning
10
  os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:1024")
11
  os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "0")
12
  # ========================================================================
13
 
14
  """
15
- High-Quality Video Background Replacement - MAIN APPLICATION
16
- Upload video Choose professional background Replace with cinema quality
17
- Features: SAM2 + MatAnyone with multi-fallback loading, professional backgrounds,
18
- cinema-quality processing, lazy loading, and enhanced stability
19
  """
20
 
21
  import sys
22
- import tempfile
23
  import cv2
24
  import numpy as np
25
  from pathlib import Path
26
- import gradio as gr
27
  import torch
28
- import requests
29
- from PIL import Image, ImageDraw, ImageFilter, ImageEnhance
30
- import json
31
  import traceback
32
  import time
33
  import shutil
34
  import gc
35
  import threading
36
- import queue
37
- from typing import Optional, Tuple, Dict, Any
38
  import logging
39
- import warnings
40
 
41
- # Import your utilities
42
- from utilities import * # must provide required helpers & PROFESSIONAL_BACKGROUNDS
43
 
44
- warnings.filterwarnings("ignore")
45
  logging.basicConfig(level=logging.INFO)
46
  logger = logging.getLogger(__name__)
47
 
48
  # ============================================================================ #
49
- # GRADIO MONKEY PATCH (BUG FIX for gradio>=4.44.0)
50
  # ============================================================================ #
51
- try:
52
- import gradio_client.utils as gc_utils
53
- original_get_type = gc_utils.get_type
54
- def patched_get_type(schema):
55
- if not isinstance(schema, dict):
56
- if isinstance(schema, bool):
57
- return "boolean"
58
- return "string"
59
- return original_get_type(schema)
60
- gc_utils.get_type = patched_get_type
61
- logger.info("Applied Gradio schema validation monkey patch.")
62
- except (ImportError, AttributeError) as e:
63
- logger.warning(f"Could not apply Gradio monkey patch: {e}")
64
 
65
- # ============================================================================ #
66
- # SAM2 LOADER (Hydra search path; pass STRING config name to build_sam2)
67
- # ============================================================================ #
68
- def load_sam2_predictor(device: str = "cuda", progress: Optional[gr.Progress] = None):
69
- """Loads SAM2 and returns SAM2ImagePredictor. Uses STRING config name for build_sam2."""
70
- import hydra
71
-
72
- sam_logger = logging.getLogger("SAM2Loader")
73
- configs_dir = os.path.abspath("Configs")
74
- sam_logger.info(f"Looking for SAM2 configs in absolute path: {configs_dir}")
75
-
76
- if not os.path.isdir(configs_dir):
77
- raise gr.Error(f"FATAL: SAM2 Configs directory not found at '{configs_dir}'")
78
-
79
- def _maybe_progress(pct: float, desc: str):
80
- if progress is not None:
81
- try: progress(pct, desc=desc)
82
- except Exception: pass
83
-
84
- def try_load(config_name_with_yaml: str, checkpoint_name: str):
85
- try:
86
- checkpoint_path = os.path.join("./checkpoints", checkpoint_name)
87
- sam_logger.info(f"Attempting to use checkpoint: {checkpoint_path}")
88
-
89
- if not os.path.exists(checkpoint_path):
90
- sam_logger.info(f"Downloading {checkpoint_name} from Hugging Face Hub...")
91
- _maybe_progress(0.1, f"Downloading {checkpoint_name}...")
92
- from huggingface_hub import hf_hub_download
93
- repo = f"facebook/{config_name_with_yaml.replace('.yaml','')}"
94
- checkpoint_path = hf_hub_download(
95
- repo_id=repo,
96
- filename=checkpoint_name,
97
- cache_dir="./checkpoints",
98
- local_dir_use_symlinks=False
99
- )
100
- sam_logger.info(f"Download complete: {checkpoint_path}")
101
-
102
- # Reset & init Hydra so its repo includes ./Configs
103
- if hydra.core.global_hydra.GlobalHydra.instance().is_initialized():
104
- hydra.core.global_hydra.GlobalHydra.instance().clear()
105
- hydra.initialize(
106
- version_base=None,
107
- config_path=os.path.relpath(configs_dir),
108
- job_name=f"sam2_load_{int(time.time())}"
109
- )
110
 
111
- # Pass STRING config name to build_sam2
112
- config_name = config_name_with_yaml.replace(".yaml", "")
113
-
114
- from sam2.build_sam import build_sam2
115
- from sam2.sam2_image_predictor import SAM2ImagePredictor
116
-
117
- sam_logger.info(f"Trying to load {config_name_with_yaml} on {device} with checkpoint {checkpoint_path}")
118
- _maybe_progress(0.3, f"Loading {config_name_with_yaml}...")
119
-
120
- sam2_model = build_sam2(config_name, checkpoint_path)
121
- sam2_model.to(device)
122
- predictor = SAM2ImagePredictor(sam2_model)
123
- sam_logger.info(f"Loaded {config_name_with_yaml} successfully on {device}")
124
- return predictor
 
125
 
126
- except Exception as e:
127
- err = f"Failed to load {config_name_with_yaml}: {e}\nTraceback: {traceback.format_exc()}"
128
- sam_logger.warning(err)
 
129
  return None
130
-
131
- predictor = try_load("sam2_hiera_large.yaml", "sam2_hiera_large.pt")
132
- if predictor is None:
133
- raise gr.Error("SAM2 loading failed for large model. Check configs/checkpoint.")
134
- return predictor
 
 
 
 
 
 
 
 
 
135
 
136
  # ============================================================================ #
137
- # MatAnyone LOADER (simple Hugging Face approach)
138
  # ============================================================================ #
139
- def load_matanyone(device: str):
140
- """
141
- Load MatAnyone using the simple Hugging Face approach as documented
142
- """
143
- ma_logger = logging.getLogger("MatAnyoneLoader")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  try:
146
- # Use the official approach from MatAnyone documentation
147
  from matanyone import InferenceCore
148
- ma_logger.info("MatAnyone package found, creating InferenceCore...")
149
-
150
- # Use the Hugging Face model as specified in the documentation
151
  processor = InferenceCore("PeiqingYang/MatAnyone")
152
- ma_logger.info("MatAnyone loaded successfully via Hugging Face")
 
 
 
 
153
  return processor
154
 
155
- except ImportError as e:
156
- ma_logger.error(f"MatAnyone package not found: {e}")
157
- raise RuntimeError(f"MatAnyone package not installed. Please install with: pip install git+https://github.com/pq-yang/MatAnyone")
158
  except Exception as e:
159
- ma_logger.error(f"Failed to create MatAnyone InferenceCore: {e}")
160
- raise RuntimeError(f"MatAnyone initialization failed: {e}")
161
 
162
  # ============================================================================ #
163
- # GLOBALS & MODEL SETUP
164
  # ============================================================================ #
165
  sam2_predictor = None
166
  matanyone_model = None
167
  models_loaded = False
168
  loading_lock = threading.Lock()
169
 
170
- def download_and_setup_models(progress: Optional[gr.Progress] = None):
171
- """Download and setup models. BOTH SAM2 and MatAnyone are REQUIRED."""
172
  global sam2_predictor, matanyone_model, models_loaded
173
 
174
  with loading_lock:
175
  if models_loaded:
176
- return "SAM2 + MatAnyone already loaded"
177
 
178
  try:
179
- logger.info("Starting ENHANCED model loading...")
180
  device = "cuda" if torch.cuda.is_available() else "cpu"
181
-
182
- # --- Load SAM2 (required) ---
183
- local_sam2 = load_sam2_predictor(device=device, progress=progress)
184
- sam2_predictor = local_sam2
185
-
186
- # --- Load MatAnyone (required) ---
187
- local_matanyone = load_matanyone(device)
188
- matanyone_model = local_matanyone
189
 
190
  models_loaded = True
191
- logger.info("--- All models loaded successfully (SAM2 + MatAnyone) ---")
192
- return "SAM2 + MatAnyone loaded successfully!"
 
 
 
 
193
  except Exception as e:
194
- logger.error(f"Enhanced loading failed: {str(e)}")
195
- logger.error(f"Full traceback: {traceback.format_exc()}")
196
- return f"Enhanced loading failed: {str(e)}"
197
 
198
  # ============================================================================ #
199
- # TWO-STAGE PROCESSING PIPELINE (uses your utilities' segmentation/compositing)
200
  # ============================================================================ #
201
- def create_green_screen_background(frame):
202
- """Create a pure green screen background for the frame"""
203
- return np.full_like(frame, (0, 255, 0), dtype=np.uint8)
204
-
205
- def process_video_hq(video_path, background_choice, custom_background_path, progress: Optional[gr.Progress] = None):
206
- """SINGLE-STAGE High-quality video processing: Original → Final Background"""
207
  if not models_loaded:
208
- return None, "Models not loaded. Click 'Load Models' first."
209
  if not video_path:
210
  return None, "No video file provided."
211
 
212
  def _prog(pct: float, desc: str):
213
- if progress is not None:
214
- try: progress(pct, desc=desc)
215
- except Exception: pass
216
 
217
  try:
218
- _prog(0.0, "Initializing SINGLE-STAGE processing...")
219
 
220
  if not os.path.exists(video_path):
221
  return None, f"Video file not found: {video_path}"
222
 
223
  cap = cv2.VideoCapture(video_path)
224
  if not cap.isOpened():
225
- return None, "Could not open video file. Please check the format."
226
 
227
  fps = cap.get(cv2.CAP_PROP_FPS)
228
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
229
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
230
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
231
- logger.info(f"Video properties: {frame_width}x{frame_height}, {fps}fps, {total_frames} frames")
232
 
233
  if total_frames == 0:
234
- return None, "Video appears to be empty or corrupted."
235
 
236
- # Prepare final background
237
  background = None
238
  background_name = ""
239
 
240
  if background_choice == "custom" and custom_background_path:
241
  background = cv2.imread(custom_background_path)
242
  if background is None:
243
- return None, "Could not read custom background image. Please check the file format."
244
  background_name = "Custom Image"
245
- logger.info("Using custom background image")
246
  else:
247
  if background_choice in PROFESSIONAL_BACKGROUNDS:
248
  bg_config = PROFESSIONAL_BACKGROUNDS[background_choice]
249
  background = create_professional_background(bg_config, frame_width, frame_height)
250
  background_name = bg_config["name"]
251
- logger.info(f"Using professional background: {background_name}")
252
  else:
253
  return None, f"Invalid background selection: {background_choice}"
254
 
@@ -258,43 +243,44 @@ def process_video_hq(video_path, background_choice, custom_background_path, prog
258
  timestamp = int(time.time())
259
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
260
 
261
- # SINGLE-STAGE: Original → Final Background
262
- _prog(0.1, f"SINGLE-STAGE: Replacing background with {background_name}...")
263
- final_path = f"/tmp/final_output_{timestamp}.mp4"
264
  final_writer = cv2.VideoWriter(final_path, fourcc, fps, (frame_width, frame_height))
 
265
  if not final_writer.isOpened():
266
  return None, "Could not create output video file."
267
 
268
  frame_count = 0
269
- keyframe_interval = 3 # Process MatAnyone every 3rd frame
270
  last_refined_mask = None
271
 
272
  while True:
273
  ret, frame = cap.read()
274
  if not ret:
275
  break
 
276
  try:
277
- _prog(0.1 + (frame_count / max(1, total_frames)) * 0.8, f"Processing frame {frame_count + 1}/{total_frames}")
 
278
 
279
- # Always run SAM2 segmentation
280
  mask = segment_person_hq(frame, sam2_predictor)
281
 
282
- # Run MatAnyone refinement only on keyframes
283
  if (frame_count % keyframe_interval == 0) or (last_refined_mask is None):
284
  refined_mask = refine_mask_hq(frame, mask, matanyone_model)
285
  last_refined_mask = refined_mask.copy()
286
- logger.info(f"MatAnyone refinement on frame {frame_count}")
287
  else:
288
- # Use SAM2 mask directly for intermediate frames
289
  refined_mask = mask
290
 
291
- # Direct background replacement
292
  result_frame = replace_background_hq(frame, refined_mask, background)
293
  final_writer.write(result_frame)
294
 
295
  except Exception as e:
296
  logger.warning(f"Error processing frame {frame_count}: {e}")
297
  final_writer.write(frame)
 
298
  frame_count += 1
299
  if frame_count % 100 == 0:
300
  gc.collect()
@@ -305,311 +291,92 @@ def process_video_hq(video_path, background_choice, custom_background_path, prog
305
  cap.release()
306
 
307
  if frame_count == 0:
308
- return None, "No frames were processed successfully."
309
 
310
- _prog(0.9, "Adding high-quality audio...")
311
- final_output = f"/tmp/final_output_hq_{timestamp}.mp4"
 
312
  try:
313
  audio_cmd = (
314
  f'ffmpeg -y -i "{final_path}" -i "{video_path}" '
315
- f'-c:v libx264 -crf 18 -preset medium -profile:v high -level:v 4.0 '
316
  f'-c:a aac -b:a 192k -ac 2 -ar 48000 '
317
  f'-map 0:v:0 -map 1:a:0? -shortest "{final_output}"'
318
  )
319
  result = os.system(audio_cmd)
320
  if result != 0 or not os.path.exists(final_output):
321
- logger.warning("Audio merging failed, using video without audio")
322
  shutil.copy2(final_path, final_output)
323
  except Exception as e:
324
- logger.warning(f"Audio processing error: {e}, using video without audio")
325
- try: shutil.copy2(final_path, final_output)
326
- except Exception as e2:
327
- logger.error(f"Failed to copy video file: {e2}")
328
- return None, f"Failed to finalize video: {str(e2)}"
329
 
330
- # Save to MyAvatar/My Videos directory
331
  try:
332
  myavatar_path = "/tmp/MyAvatar/My_Videos/"
333
  os.makedirs(myavatar_path, exist_ok=True)
334
- saved_filename = f"single_stage_bg_replaced_{timestamp}.mp4"
335
  saved_path = os.path.join(myavatar_path, saved_filename)
336
  shutil.copy2(final_output, saved_path)
337
- logger.info(f"Video saved to: {saved_path}")
338
  except Exception as e:
339
- logger.warning(f"Could not save to MyAvatar directory: {e}")
340
  saved_filename = os.path.basename(final_output)
341
 
 
342
  try:
343
  if os.path.exists(final_path):
344
  os.remove(final_path)
345
- except Exception:
346
  pass
347
 
348
- _prog(1.0, "SINGLE-STAGE processing complete!")
 
349
  success_message = (
350
- f"SINGLE-STAGE Success!\n"
351
- f"Direct background replacement: {background_name}\n"
352
  f"Processed: {frame_count} frames\n"
353
- f"Saved: MyAvatar/My Videos/{saved_filename}\n"
354
- f"Quality: Cinema-grade with SAM2 + MatAnyone\n"
355
- f"Method: Optimized single-stage processing"
356
  )
 
357
  return final_output, success_message
358
 
359
  except Exception as e:
360
- logger.error(f"Video processing error: {traceback.format_exc()}")
361
- return None, f"SINGLE-STAGE Processing Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
362
 
363
  # ============================================================================ #
364
- # GRADIO UI
365
- # ============================================================================ #
366
- def create_interface():
367
- def extract_video_path(v):
368
- if isinstance(v, (tuple, list)) and len(v) > 0:
369
- return v[0]
370
- return v
371
-
372
- with gr.Blocks(
373
- title="ENHANCED High-Quality Video Background Replacement",
374
- theme=gr.themes.Soft(),
375
- css="""
376
- .gradio-container { max-width: 1200px !important; }
377
- .progress-bar { background: linear-gradient(90deg, #3498db, #2ecc71) !important; }
378
- """
379
- ) as demo:
380
- gr.Markdown("# Cinema-Quality Video Background Replacement")
381
- gr.Markdown("**Upload a video → Choose a background → Get professional results with AI**")
382
- gr.Markdown("*Powered by SAM2 + MatAnyone with multi-fallback loading for maximum reliability*")
383
- gr.Markdown("---")
384
-
385
- with gr.Row():
386
- with gr.Column(scale=1):
387
- gr.Markdown("### Step 1: Upload Your Video")
388
- gr.Markdown("*Supports MP4, MOV, AVI, and other common formats*")
389
- video_input = gr.Video(label="Drop your video here", height=300)
390
-
391
- gr.Markdown("### Step 2: Choose Background Method")
392
- gr.Markdown("*Select your preferred background creation method*")
393
- background_method = gr.Radio(
394
- choices=["upload", "professional", "colors", "ai"],
395
- value="professional",
396
- label="Background Method"
397
- )
398
- gr.Markdown(
399
- "- **upload** = Upload Image \n"
400
- "- **professional** = Professional Presets \n"
401
- "- **colors** = Colors/Gradients \n"
402
- "- **ai** = AI Generated"
403
- )
404
-
405
- with gr.Group(visible=False) as upload_group:
406
- gr.Markdown("**Upload Your Background Image**")
407
- custom_background = gr.Image(label="Drop your background image here", type="filepath")
408
-
409
- with gr.Group(visible=True) as professional_group:
410
- gr.Markdown("**Professional Background Presets**")
411
- professional_choice = gr.Dropdown(
412
- choices=list(PROFESSIONAL_BACKGROUNDS.keys()),
413
- value="office_modern",
414
- label="Select Professional Background"
415
- )
416
-
417
- with gr.Group(visible=False) as colors_group:
418
- gr.Markdown("**Custom Colors & Gradients**")
419
- gradient_type = gr.Dropdown(
420
- choices=["solid", "vertical", "horizontal", "diagonal", "radial", "soft_radial"],
421
- value="vertical",
422
- label="Gradient Type"
423
- )
424
- with gr.Row():
425
- color1 = gr.ColorPicker(label="Color 1", value="#3498db")
426
- color2 = gr.ColorPicker(label="Color 2", value="#2ecc71")
427
- with gr.Row():
428
- color3 = gr.ColorPicker(label="Color 3", value="#e74c3c")
429
- use_third_color = gr.Checkbox(label="Use 3rd color", value=False)
430
-
431
- with gr.Group(visible=False) as ai_group:
432
- gr.Markdown("**AI Generated Background**")
433
- ai_prompt = gr.Textbox(
434
- label="Describe your background",
435
- placeholder="e.g., 'modern office with plants', 'sunset over mountains', 'abstract tech pattern'",
436
- lines=2
437
- )
438
- ai_style = gr.Dropdown(
439
- choices=["photorealistic", "artistic", "abstract", "minimalist", "corporate", "nature"],
440
- value="photorealistic",
441
- label="Style"
442
- )
443
- with gr.Row():
444
- generate_ai_btn = gr.Button("Generate Background", variant="secondary")
445
- ai_generated_image = gr.Image(label="Generated Background", type="filepath", visible=False)
446
-
447
- def switch_background_method(method):
448
- return (
449
- gr.update(visible=(method == "upload")),
450
- gr.update(visible=(method == "professional")),
451
- gr.update(visible=(method == "colors")),
452
- gr.update(visible=(method == "ai"))
453
- )
454
- background_method.change(
455
- fn=switch_background_method,
456
- inputs=background_method,
457
- outputs=[upload_group, professional_group, colors_group, ai_group]
458
- )
459
-
460
- gr.Markdown("### Processing Controls")
461
- gr.Markdown("*First load the AI models, then process your video*")
462
- with gr.Row():
463
- load_models_btn = gr.Button("Step 1: Load AI Models", variant="secondary")
464
- process_btn = gr.Button("Step 2: Process Video", variant="primary")
465
-
466
- status_text = gr.Textbox(label="System Status", value=get_model_status(), interactive=False, lines=3)
467
-
468
- with gr.Column(scale=1):
469
- gr.Markdown("### Your Results")
470
- gr.Markdown("*Processed video will appear here after Step 2*")
471
- video_output = gr.Video(label="Your Processed Video", height=400)
472
- result_text = gr.Textbox(
473
- label="Processing Results",
474
- interactive=False,
475
- lines=6,
476
- placeholder="Processing status and results will appear here..."
477
- )
478
-
479
- gr.Markdown("### Professional Backgrounds Available")
480
- bg_preview_html = """
481
- <div style='display: grid; grid-template-columns: repeat(3, 1fr); gap: 8px; padding: 10px; max-height: 400px; overflow-y: auto; border: 1px solid #ddd; border-radius: 8px;'>
482
- """
483
- for key, config in PROFESSIONAL_BACKGROUNDS.items():
484
- colors = config["colors"]
485
- gradient = f"linear-gradient(45deg, {colors[0]}, {colors[-1]})" if len(colors) >= 2 else colors[0]
486
- bg_preview_html += f"""
487
- <div style='padding: 12px 8px; border: 1px solid #ddd; border-radius: 6px; text-align: center; background: {gradient};
488
- min-height: 60px; display: flex; align-items: center; justify-content: center;'>
489
- <div>
490
- <strong style='color: white; text-shadow: 1px 1px 2px rgba(0,0,0,0.8); font-size: 12px; display: block;'>{config["name"]}</strong>
491
- <small style='color: rgba(255,255,255,0.9); text-shadow: 1px 1px 1px rgba(0,0,0,0.6); font-size: 10px;'>{config.get("description", "")[:30]}...</small>
492
- </div>
493
- </div>
494
- """
495
- bg_preview_html += "</div>"
496
- gr.HTML(bg_preview_html)
497
-
498
- def generate_ai_background(prompt, style):
499
- if not prompt or not prompt.strip():
500
- return None, "Please enter a prompt"
501
- try:
502
- bg_image = create_procedural_background(prompt, style, 1920, 1080)
503
- if bg_image is not None:
504
- with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
505
- cv2.imwrite(tmp.name, bg_image)
506
- return tmp.name, f"Background generated: {prompt[:50]}..."
507
- return None, "Generation failed, try different prompt"
508
- except Exception as e:
509
- logger.error(f"AI generation error: {e}")
510
- return None, f"Generation error: {str(e)}"
511
-
512
- def process_video_enhanced(
513
- video_path, bg_method, custom_img, prof_choice, grad_type,
514
- color1, color2, color3, use_third, ai_prompt, ai_style, ai_img,
515
- progress: Optional[gr.Progress] = None
516
- ):
517
- if not models_loaded:
518
- return None, "Models not loaded. Click 'Load Models' first."
519
- if not video_path:
520
- return None, "No video file provided."
521
- try:
522
- if bg_method == "upload":
523
- if custom_img and os.path.exists(custom_img):
524
- return process_video_hq(video_path, "custom", custom_img, progress)
525
- return None, "No image uploaded. Please upload a background image."
526
- elif bg_method == "professional":
527
- if prof_choice and prof_choice in PROFESSIONAL_BACKGROUNDS:
528
- return process_video_hq(video_path, prof_choice, None, progress)
529
- return None, f"Invalid professional background: {prof_choice}"
530
- elif bg_method == "colors":
531
- try:
532
- colors = [color1 or "#3498db", color2 or "#2ecc71"]
533
- if use_third and color3:
534
- colors.append(color3)
535
- bg_config = {
536
- "type": "gradient" if grad_type != "solid" else "color",
537
- "colors": colors if grad_type != "solid" else [colors[0]],
538
- "direction": grad_type if grad_type != "solid" else "vertical"
539
- }
540
- gradient_bg = create_professional_background(bg_config, 1920, 1080)
541
- temp_path = f"/tmp/gradient_{int(time.time())}.png"
542
- cv2.imwrite(temp_path, gradient_bg)
543
- return process_video_hq(video_path, "custom", temp_path, progress)
544
- except Exception as e:
545
- return None, f"Error creating gradient: {str(e)}"
546
- elif bg_method == "ai":
547
- if ai_img and os.path.exists(ai_img):
548
- return process_video_hq(video_path, "custom", ai_img, progress)
549
- return None, "No AI background generated. Click 'Generate Background' first."
550
- else:
551
- return None, f"Unknown background method: {bg_method}"
552
- except Exception as e:
553
- logger.error(f"Enhanced processing error: {e}")
554
- return None, f"Processing error: {str(e)}"
555
-
556
- load_models_btn.click(fn=download_and_setup_models, outputs=status_text)
557
- generate_ai_btn.click(fn=generate_ai_background, inputs=[ai_prompt, ai_style], outputs=[ai_generated_image, status_text])
558
- process_btn.click(
559
- fn=process_video_enhanced,
560
- inputs=[video_input, background_method, custom_background, professional_choice,
561
- gradient_type, color1, color2, color3, use_third_color,
562
- ai_prompt, ai_style, ai_generated_image],
563
- outputs=[video_output, result_text]
564
- )
565
-
566
- with gr.Accordion("ENHANCED Quality & Features", open=False):
567
- gr.Markdown("""
568
- ### TWO-STAGE Cinema-Quality Features:
569
- **Stage 1**: Original → Green Screen (SAM2 + MatAnyone)
570
- **Stage 2**: Green Screen → Final Background (professional chroma key)
571
- **Quality**: Edge feathering, gamma correction, mask cleanup, H.264 CRF 18, AAC 192kbps.
572
- """)
573
-
574
- gr.Markdown("---")
575
- gr.Markdown("*Cinema-Quality Video Background Replacement — TWO-STAGE pipeline*")
576
-
577
- return demo
578
-
579
- # ============================================================================ #
580
- # MAIN
581
  # ============================================================================ #
582
  def main():
583
  try:
584
- print(f"===== Application Startup at {time.strftime('%Y-%m-%d %H:%M:%S')} =====\n")
585
- print("Cinema-Quality Video Background Replacement")
586
- print("=" * 50)
 
 
 
587
  os.makedirs("/tmp/MyAvatar/My_Videos/", exist_ok=True)
588
- os.makedirs(os.path.expanduser("~/.cache/sam2"), exist_ok=True)
589
-
590
- print("Features:")
591
- print(" • SAM2 + MatAnyone AI models")
592
- print(" • TWO-STAGE processing (Original → Green Screen → Final)")
593
- print(" • 4 background methods (Upload/Professional/Colors/AI)")
594
- print(" • Multi-fallback loading system")
595
- print(" • Cinema-quality processing")
596
- print(" • Enhanced stability & error handling")
597
- print("=" * 50)
598
-
599
- logger.info("Creating Gradio interface...")
600
  demo = create_interface()
601
 
602
- logger.info("Launching application...")
603
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True, show_error=True)
604
 
605
- except KeyboardInterrupt:
606
- logger.info("Application stopped by user")
607
- print("\nApplication stopped by user")
608
  except Exception as e:
609
- logger.error(f"Application failed to start: {e}")
610
- logger.error(f"Full traceback: {traceback.format_exc()}")
611
- print(f"Application failed to start: {e}")
612
- print("Check logs for detailed error information.")
613
 
614
  if __name__ == "__main__":
615
  main()
 
1
  #!/usr/bin/env python3
2
  # ========================= PRE-IMPORT ENV GUARDS =========================
3
  import os
4
+ os.environ.pop("OMP_NUM_THREADS", None)
 
5
  os.environ.setdefault("MKL_NUM_THREADS", "1")
6
  os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
7
  os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
 
8
  os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:1024")
9
  os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "0")
10
  # ========================================================================
11
 
12
  """
13
+ CORE VIDEO PROCESSING - Fast startup with UI separation
14
+ SAM2 + MatAnyone processing core with persistent model caching
 
 
15
  """
16
 
17
  import sys
 
18
  import cv2
19
  import numpy as np
20
  from pathlib import Path
 
21
  import torch
 
 
 
22
  import traceback
23
  import time
24
  import shutil
25
  import gc
26
  import threading
27
+ import pickle
28
+ from typing import Optional
29
  import logging
30
+ from huggingface_hub import hf_hub_download
31
 
32
+ # Import utilities
33
+ from utilities import *
34
 
 
35
  logging.basicConfig(level=logging.INFO)
36
  logger = logging.getLogger(__name__)
37
 
38
  # ============================================================================ #
39
+ # FAST RESTART MODEL CACHING SYSTEM
40
  # ============================================================================ #
41
+ CACHE_DIR = Path("/tmp/persistent_models")
42
+ CACHE_DIR.mkdir(exist_ok=True, parents=True)
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ def get_cache_path(model_name: str) -> Path:
45
+ return CACHE_DIR / f"{model_name}_cached.pkl"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ def save_model_to_cache(model, model_name: str):
48
+ try:
49
+ cache_path = get_cache_path(model_name)
50
+ if hasattr(model, 'model') and hasattr(model.model, 'to'):
51
+ model.model.to('cpu')
52
+ elif hasattr(model, 'to'):
53
+ model.to('cpu')
54
+
55
+ with open(cache_path, 'wb') as f:
56
+ pickle.dump(model, f)
57
+ logger.info(f"Model {model_name} cached successfully")
58
+ return True
59
+ except Exception as e:
60
+ logger.warning(f"Failed to cache {model_name}: {e}")
61
+ return False
62
 
63
+ def load_model_from_cache(model_name: str, device: str):
64
+ try:
65
+ cache_path = get_cache_path(model_name)
66
+ if not cache_path.exists():
67
  return None
68
+
69
+ with open(cache_path, 'rb') as f:
70
+ model = pickle.load(f)
71
+
72
+ if hasattr(model, 'model') and hasattr(model.model, 'to'):
73
+ model.model.to(device)
74
+ elif hasattr(model, 'to'):
75
+ model.to(device)
76
+
77
+ logger.info(f"Model {model_name} loaded from cache")
78
+ return model
79
+ except Exception as e:
80
+ logger.warning(f"Failed to load {model_name} from cache: {e}")
81
+ return None
82
 
83
  # ============================================================================ #
84
+ # FAST SAM2 LOADER
85
  # ============================================================================ #
86
+ def load_sam2_predictor_fast(device: str = "cuda", progress_callback=None):
87
+ def _prog(pct: float, desc: str):
88
+ if progress_callback:
89
+ progress_callback(pct, desc)
90
+
91
+ # Try cache first
92
+ _prog(0.1, "Checking SAM2 cache...")
93
+ cached_predictor = load_model_from_cache("sam2_predictor", device)
94
+ if cached_predictor is not None:
95
+ _prog(1.0, "SAM2 loaded from cache!")
96
+ return cached_predictor
97
+
98
+ # Load fresh
99
+ _prog(0.2, "Loading SAM2 fresh...")
100
+ try:
101
+ checkpoint_path = hf_hub_download(
102
+ repo_id="facebook/sam2-hiera-large",
103
+ filename="sam2_hiera_large.pt",
104
+ cache_dir=str(CACHE_DIR / "sam2_checkpoint")
105
+ )
106
+ _prog(0.6, "Building SAM2...")
107
+
108
+ from sam2.build_sam import build_sam2
109
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
110
+
111
+ sam2_model = build_sam2("sam2_hiera_l.yaml", checkpoint_path)
112
+ sam2_model.to(device)
113
+ predictor = SAM2ImagePredictor(sam2_model)
114
+
115
+ _prog(0.9, "Caching SAM2...")
116
+ save_model_to_cache(predictor, "sam2_predictor")
117
+ predictor.model.to(device)
118
+
119
+ _prog(1.0, "SAM2 ready!")
120
+ return predictor
121
+
122
+ except Exception as e:
123
+ logger.error(f"SAM2 loading failed: {e}")
124
+ raise
125
 
126
+ # ============================================================================ #
127
+ # FAST MATANYONE LOADER
128
+ # ============================================================================ #
129
+ def load_matanyone_fast(progress_callback=None):
130
+ def _prog(pct: float, desc: str):
131
+ if progress_callback:
132
+ progress_callback(pct, desc)
133
+
134
+ # Try cache first
135
+ _prog(0.1, "Checking MatAnyone cache...")
136
+ cached_processor = load_model_from_cache("matanyone", "cpu")
137
+ if cached_processor is not None:
138
+ _prog(1.0, "MatAnyone loaded from cache!")
139
+ return cached_processor
140
+
141
+ # Load fresh
142
+ _prog(0.3, "Loading MatAnyone fresh...")
143
  try:
 
144
  from matanyone import InferenceCore
 
 
 
145
  processor = InferenceCore("PeiqingYang/MatAnyone")
146
+
147
+ _prog(0.8, "Caching MatAnyone...")
148
+ save_model_to_cache(processor, "matanyone")
149
+
150
+ _prog(1.0, "MatAnyone ready!")
151
  return processor
152
 
 
 
 
153
  except Exception as e:
154
+ logger.error(f"MatAnyone loading failed: {e}")
155
+ raise
156
 
157
  # ============================================================================ #
158
+ # GLOBAL MODEL STATE
159
  # ============================================================================ #
160
  sam2_predictor = None
161
  matanyone_model = None
162
  models_loaded = False
163
  loading_lock = threading.Lock()
164
 
165
+ def load_models_fast(progress_callback=None):
166
+ """Fast model loading with caching"""
167
  global sam2_predictor, matanyone_model, models_loaded
168
 
169
  with loading_lock:
170
  if models_loaded:
171
+ return "Models already loaded"
172
 
173
  try:
174
+ start_time = time.time()
175
  device = "cuda" if torch.cuda.is_available() else "cpu"
176
+
177
+ sam2_predictor = load_sam2_predictor_fast(device=device, progress_callback=progress_callback)
178
+ matanyone_model = load_matanyone_fast(progress_callback=progress_callback)
 
 
 
 
 
179
 
180
  models_loaded = True
181
+ load_time = time.time() - start_time
182
+
183
+ message = f"SAM2 + MatAnyone loaded in {load_time:.1f}s!"
184
+ logger.info(message)
185
+ return message
186
+
187
  except Exception as e:
188
+ logger.error(f"Model loading failed: {str(e)}")
189
+ return f"Model loading failed: {str(e)}"
 
190
 
191
  # ============================================================================ #
192
+ # CORE VIDEO PROCESSING
193
  # ============================================================================ #
194
+ def process_video_core(video_path, background_choice, custom_background_path, progress_callback=None):
195
+ """Core video processing function"""
 
 
 
 
196
  if not models_loaded:
197
+ return None, "Models not loaded. Call load_models_fast() first."
198
  if not video_path:
199
  return None, "No video file provided."
200
 
201
  def _prog(pct: float, desc: str):
202
+ if progress_callback:
203
+ progress_callback(pct, desc)
 
204
 
205
  try:
206
+ _prog(0.0, "Starting processing...")
207
 
208
  if not os.path.exists(video_path):
209
  return None, f"Video file not found: {video_path}"
210
 
211
  cap = cv2.VideoCapture(video_path)
212
  if not cap.isOpened():
213
+ return None, "Could not open video file."
214
 
215
  fps = cap.get(cv2.CAP_PROP_FPS)
216
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
217
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
218
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
219
 
220
  if total_frames == 0:
221
+ return None, "Video appears to be empty."
222
 
223
+ # Prepare background
224
  background = None
225
  background_name = ""
226
 
227
  if background_choice == "custom" and custom_background_path:
228
  background = cv2.imread(custom_background_path)
229
  if background is None:
230
+ return None, "Could not read custom background image."
231
  background_name = "Custom Image"
 
232
  else:
233
  if background_choice in PROFESSIONAL_BACKGROUNDS:
234
  bg_config = PROFESSIONAL_BACKGROUNDS[background_choice]
235
  background = create_professional_background(bg_config, frame_width, frame_height)
236
  background_name = bg_config["name"]
 
237
  else:
238
  return None, f"Invalid background selection: {background_choice}"
239
 
 
243
  timestamp = int(time.time())
244
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
245
 
246
+ _prog(0.1, f"Processing with {background_name}...")
247
+ final_path = f"/tmp/output_{timestamp}.mp4"
 
248
  final_writer = cv2.VideoWriter(final_path, fourcc, fps, (frame_width, frame_height))
249
+
250
  if not final_writer.isOpened():
251
  return None, "Could not create output video file."
252
 
253
  frame_count = 0
254
+ keyframe_interval = 3 # MatAnyone every 3rd frame
255
  last_refined_mask = None
256
 
257
  while True:
258
  ret, frame = cap.read()
259
  if not ret:
260
  break
261
+
262
  try:
263
+ _prog(0.1 + (frame_count / max(1, total_frames)) * 0.8,
264
+ f"Frame {frame_count + 1}/{total_frames}")
265
 
266
+ # SAM2 segmentation
267
  mask = segment_person_hq(frame, sam2_predictor)
268
 
269
+ # MatAnyone refinement on keyframes
270
  if (frame_count % keyframe_interval == 0) or (last_refined_mask is None):
271
  refined_mask = refine_mask_hq(frame, mask, matanyone_model)
272
  last_refined_mask = refined_mask.copy()
 
273
  else:
 
274
  refined_mask = mask
275
 
276
+ # Background replacement
277
  result_frame = replace_background_hq(frame, refined_mask, background)
278
  final_writer.write(result_frame)
279
 
280
  except Exception as e:
281
  logger.warning(f"Error processing frame {frame_count}: {e}")
282
  final_writer.write(frame)
283
+
284
  frame_count += 1
285
  if frame_count % 100 == 0:
286
  gc.collect()
 
291
  cap.release()
292
 
293
  if frame_count == 0:
294
+ return None, "No frames were processed."
295
 
296
+ _prog(0.9, "Adding audio...")
297
+ final_output = f"/tmp/final_{timestamp}.mp4"
298
+
299
  try:
300
  audio_cmd = (
301
  f'ffmpeg -y -i "{final_path}" -i "{video_path}" '
302
+ f'-c:v libx264 -crf 18 -preset medium '
303
  f'-c:a aac -b:a 192k -ac 2 -ar 48000 '
304
  f'-map 0:v:0 -map 1:a:0? -shortest "{final_output}"'
305
  )
306
  result = os.system(audio_cmd)
307
  if result != 0 or not os.path.exists(final_output):
 
308
  shutil.copy2(final_path, final_output)
309
  except Exception as e:
310
+ logger.warning(f"Audio processing error: {e}")
311
+ shutil.copy2(final_path, final_output)
 
 
 
312
 
313
+ # Save to MyAvatar directory
314
  try:
315
  myavatar_path = "/tmp/MyAvatar/My_Videos/"
316
  os.makedirs(myavatar_path, exist_ok=True)
317
+ saved_filename = f"bg_replaced_{timestamp}.mp4"
318
  saved_path = os.path.join(myavatar_path, saved_filename)
319
  shutil.copy2(final_output, saved_path)
 
320
  except Exception as e:
321
+ logger.warning(f"Could not save to MyAvatar: {e}")
322
  saved_filename = os.path.basename(final_output)
323
 
324
+ # Cleanup
325
  try:
326
  if os.path.exists(final_path):
327
  os.remove(final_path)
328
+ except:
329
  pass
330
 
331
+ _prog(1.0, "Processing complete!")
332
+
333
  success_message = (
334
+ f"Success!\n"
335
+ f"Background: {background_name}\n"
336
  f"Processed: {frame_count} frames\n"
337
+ f"Saved: {saved_filename}\n"
338
+ f"Quality: SAM2 + MatAnyone"
 
339
  )
340
+
341
  return final_output, success_message
342
 
343
  except Exception as e:
344
+ logger.error(f"Processing error: {traceback.format_exc()}")
345
+ return None, f"Processing Error: {str(e)}"
346
+
347
+ def get_cache_status():
348
+ """Get current cache status"""
349
+ sam2_cached = get_cache_path("sam2_predictor").exists()
350
+ matanyone_cached = get_cache_path("matanyone").exists()
351
+ return {
352
+ "sam2_cached": sam2_cached,
353
+ "matanyone_cached": matanyone_cached,
354
+ "cache_dir": str(CACHE_DIR)
355
+ }
356
 
357
  # ============================================================================ #
358
+ # MAIN - IMPORT UI COMPONENTS ONLY WHEN NEEDED
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  # ============================================================================ #
360
  def main():
361
  try:
362
+ print("===== FAST STARTUP CORE =====")
363
+ print("Loading UI components...")
364
+
365
+ # Import UI components only when needed
366
+ from ui_components import create_interface
367
+
368
  os.makedirs("/tmp/MyAvatar/My_Videos/", exist_ok=True)
369
+ CACHE_DIR.mkdir(exist_ok=True, parents=True)
370
+
371
+ print("Creating interface...")
 
 
 
 
 
 
 
 
 
372
  demo = create_interface()
373
 
374
+ print("Launching...")
375
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True, show_error=True)
376
 
 
 
 
377
  except Exception as e:
378
+ logger.error(f"Startup failed: {e}")
379
+ print(f"Startup failed: {e}")
 
 
380
 
381
  if __name__ == "__main__":
382
  main()