MogensR's picture
Update app.py
22a6aa0 verified
#!/usr/bin/env python3
# ========================= PRE-IMPORT ENV GUARDS =========================
import os
os.environ.pop("OMP_NUM_THREADS", None)
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:1024")
os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "0")
# ========================================================================
"""
CORE VIDEO PROCESSING - Fast startup with UI separation
SAM2 + MatAnyone processing core with persistent model caching
"""
import sys
import cv2
import numpy as np
from pathlib import Path
import torch
import traceback
import time
import shutil
import gc
import threading
import pickle
from typing import Optional
import logging
from huggingface_hub import hf_hub_download
# Import utilities
from utilities import *
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ============================================================================ #
# FAST RESTART MODEL CACHING SYSTEM
# ============================================================================ #
CACHE_DIR = Path("/tmp/persistent_models")
CACHE_DIR.mkdir(exist_ok=True, parents=True)
def get_cache_path(model_name: str) -> Path:
return CACHE_DIR / f"{model_name}_cached.pkl"
def save_model_to_cache(model, model_name: str):
try:
cache_path = get_cache_path(model_name)
if hasattr(model, 'model') and hasattr(model.model, 'to'):
model.model.to('cpu')
elif hasattr(model, 'to'):
model.to('cpu')
with open(cache_path, 'wb') as f:
pickle.dump(model, f)
logger.info(f"Model {model_name} cached successfully")
return True
except Exception as e:
logger.warning(f"Failed to cache {model_name}: {e}")
return False
def load_model_from_cache(model_name: str, device: str):
try:
cache_path = get_cache_path(model_name)
if not cache_path.exists():
return None
with open(cache_path, 'rb') as f:
model = pickle.load(f)
if hasattr(model, 'model') and hasattr(model.model, 'to'):
model.model.to(device)
elif hasattr(model, 'to'):
model.to(device)
logger.info(f"Model {model_name} loaded from cache")
return model
except Exception as e:
logger.warning(f"Failed to load {model_name} from cache: {e}")
return None
# ============================================================================ #
# FAST SAM2 LOADER
# ============================================================================ #
def load_sam2_predictor_fast(device: str = "cuda", progress_callback=None):
def _prog(pct: float, desc: str):
if progress_callback:
progress_callback(pct, desc)
# Try cache first
_prog(0.1, "Checking SAM2 cache...")
cached_predictor = load_model_from_cache("sam2_predictor", device)
if cached_predictor is not None:
_prog(1.0, "SAM2 loaded from cache!")
return cached_predictor
# Load fresh
_prog(0.2, "Loading SAM2 fresh...")
try:
checkpoint_path = hf_hub_download(
repo_id="facebook/sam2-hiera-large",
filename="sam2_hiera_large.pt",
cache_dir=str(CACHE_DIR / "sam2_checkpoint")
)
_prog(0.6, "Building SAM2...")
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
sam2_model = build_sam2("sam2_hiera_l.yaml", checkpoint_path)
sam2_model.to(device)
predictor = SAM2ImagePredictor(sam2_model)
_prog(0.9, "Caching SAM2...")
save_model_to_cache(predictor, "sam2_predictor")
predictor.model.to(device)
_prog(1.0, "SAM2 ready!")
return predictor
except Exception as e:
logger.error(f"SAM2 loading failed: {e}")
raise
# ============================================================================ #
# FAST MATANYONE LOADER
# ============================================================================ #
def load_matanyone_fast(progress_callback=None):
def _prog(pct: float, desc: str):
if progress_callback:
progress_callback(pct, desc)
# Try cache first
_prog(0.1, "Checking MatAnyone cache...")
cached_processor = load_model_from_cache("matanyone", "cpu")
if cached_processor is not None:
_prog(1.0, "MatAnyone loaded from cache!")
return cached_processor
# Load fresh
_prog(0.3, "Loading MatAnyone fresh...")
try:
from matanyone import InferenceCore
processor = InferenceCore("PeiqingYang/MatAnyone")
_prog(0.8, "Caching MatAnyone...")
save_model_to_cache(processor, "matanyone")
_prog(1.0, "MatAnyone ready!")
return processor
except Exception as e:
logger.error(f"MatAnyone loading failed: {e}")
raise
# ============================================================================ #
# GLOBAL MODEL STATE
# ============================================================================ #
sam2_predictor = None
matanyone_model = None
models_loaded = False
loading_lock = threading.Lock()
def load_models_fast(progress_callback=None):
"""Fast model loading with caching"""
global sam2_predictor, matanyone_model, models_loaded
with loading_lock:
if models_loaded:
return "Models already loaded"
try:
start_time = time.time()
device = "cuda" if torch.cuda.is_available() else "cpu"
sam2_predictor = load_sam2_predictor_fast(device=device, progress_callback=progress_callback)
matanyone_model = load_matanyone_fast(progress_callback=progress_callback)
models_loaded = True
load_time = time.time() - start_time
message = f"SAM2 + MatAnyone loaded in {load_time:.1f}s!"
logger.info(message)
return message
except Exception as e:
logger.error(f"Model loading failed: {str(e)}")
return f"Model loading failed: {str(e)}"
# ============================================================================ #
# CORE VIDEO PROCESSING
# ============================================================================ #
def process_video_core(video_path, background_choice, custom_background_path, progress_callback=None):
"""Core video processing function"""
if not models_loaded:
return None, "Models not loaded. Call load_models_fast() first."
if not video_path:
return None, "No video file provided."
def _prog(pct: float, desc: str):
if progress_callback:
progress_callback(pct, desc)
try:
_prog(0.0, "Starting processing...")
if not os.path.exists(video_path):
return None, f"Video file not found: {video_path}"
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None, "Could not open video file."
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
if total_frames == 0:
return None, "Video appears to be empty."
# Prepare background
background = None
background_name = ""
if background_choice == "custom" and custom_background_path:
background = cv2.imread(custom_background_path)
if background is None:
return None, "Could not read custom background image."
background_name = "Custom Image"
else:
if background_choice in PROFESSIONAL_BACKGROUNDS:
bg_config = PROFESSIONAL_BACKGROUNDS[background_choice]
background = create_professional_background(bg_config, frame_width, frame_height)
background_name = bg_config["name"]
else:
return None, f"Invalid background selection: {background_choice}"
if background is None:
return None, "Failed to create background."
timestamp = int(time.time())
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
_prog(0.1, f"Processing with {background_name}...")
final_path = f"/tmp/output_{timestamp}.mp4"
final_writer = cv2.VideoWriter(final_path, fourcc, fps, (frame_width, frame_height))
if not final_writer.isOpened():
return None, "Could not create output video file."
frame_count = 0
keyframe_interval = 3 # MatAnyone every 3rd frame
last_refined_mask = None
while True:
ret, frame = cap.read()
if not ret:
break
try:
_prog(0.1 + (frame_count / max(1, total_frames)) * 0.8,
f"Frame {frame_count + 1}/{total_frames}")
# SAM2 segmentation
mask = segment_person_hq(frame, sam2_predictor)
# MatAnyone refinement on keyframes
if (frame_count % keyframe_interval == 0) or (last_refined_mask is None):
refined_mask = refine_mask_hq(frame, mask, matanyone_model)
last_refined_mask = refined_mask.copy()
else:
refined_mask = mask
# Background replacement
result_frame = replace_background_hq(frame, refined_mask, background)
final_writer.write(result_frame)
except Exception as e:
logger.warning(f"Error processing frame {frame_count}: {e}")
final_writer.write(frame)
frame_count += 1
if frame_count % 100 == 0:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
final_writer.release()
cap.release()
if frame_count == 0:
return None, "No frames were processed."
_prog(0.9, "Adding audio...")
final_output = f"/tmp/final_{timestamp}.mp4"
try:
audio_cmd = (
f'ffmpeg -y -i "{final_path}" -i "{video_path}" '
f'-c:v libx264 -crf 18 -preset medium '
f'-c:a aac -b:a 192k -ac 2 -ar 48000 '
f'-map 0:v:0 -map 1:a:0? -shortest "{final_output}"'
)
result = os.system(audio_cmd)
if result != 0 or not os.path.exists(final_output):
shutil.copy2(final_path, final_output)
except Exception as e:
logger.warning(f"Audio processing error: {e}")
shutil.copy2(final_path, final_output)
# Save to MyAvatar directory
try:
myavatar_path = "/tmp/MyAvatar/My_Videos/"
os.makedirs(myavatar_path, exist_ok=True)
saved_filename = f"bg_replaced_{timestamp}.mp4"
saved_path = os.path.join(myavatar_path, saved_filename)
shutil.copy2(final_output, saved_path)
except Exception as e:
logger.warning(f"Could not save to MyAvatar: {e}")
saved_filename = os.path.basename(final_output)
# Cleanup
try:
if os.path.exists(final_path):
os.remove(final_path)
except:
pass
_prog(1.0, "Processing complete!")
success_message = (
f"Success!\n"
f"Background: {background_name}\n"
f"Processed: {frame_count} frames\n"
f"Saved: {saved_filename}\n"
f"Quality: SAM2 + MatAnyone"
)
return final_output, success_message
except Exception as e:
logger.error(f"Processing error: {traceback.format_exc()}")
return None, f"Processing Error: {str(e)}"
def get_cache_status():
"""Get current cache status"""
sam2_cached = get_cache_path("sam2_predictor").exists()
matanyone_cached = get_cache_path("matanyone").exists()
return {
"sam2_cached": sam2_cached,
"matanyone_cached": matanyone_cached,
"cache_dir": str(CACHE_DIR)
}
# ============================================================================ #
# MAIN - IMPORT UI COMPONENTS ONLY WHEN NEEDED
# ============================================================================ #
def main():
try:
print("===== FAST STARTUP CORE =====")
print("Loading UI components...")
# Import UI components only when needed
from ui_components import create_interface
os.makedirs("/tmp/MyAvatar/My_Videos/", exist_ok=True)
CACHE_DIR.mkdir(exist_ok=True, parents=True)
print("Creating interface...")
demo = create_interface()
print("Launching...")
demo.launch(server_name="0.0.0.0", server_port=7860, share=True, show_error=True)
except Exception as e:
logger.error(f"Startup failed: {e}")
print(f"Startup failed: {e}")
if __name__ == "__main__":
main()