Spaces:
Building
on
T4
Building
on
T4
#!/usr/bin/env python3 | |
# ========================= PRE-IMPORT ENV GUARDS ========================= | |
import os | |
os.environ.pop("OMP_NUM_THREADS", None) | |
os.environ.setdefault("MKL_NUM_THREADS", "1") | |
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1") | |
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1") | |
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:1024") | |
os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "0") | |
# ======================================================================== | |
""" | |
CORE VIDEO PROCESSING - Fast startup with UI separation | |
SAM2 + MatAnyone processing core with persistent model caching | |
""" | |
import sys | |
import cv2 | |
import numpy as np | |
from pathlib import Path | |
import torch | |
import traceback | |
import time | |
import shutil | |
import gc | |
import threading | |
import pickle | |
from typing import Optional | |
import logging | |
from huggingface_hub import hf_hub_download | |
# Import utilities | |
from utilities import * | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# ============================================================================ # | |
# FAST RESTART MODEL CACHING SYSTEM | |
# ============================================================================ # | |
CACHE_DIR = Path("/tmp/persistent_models") | |
CACHE_DIR.mkdir(exist_ok=True, parents=True) | |
def get_cache_path(model_name: str) -> Path: | |
return CACHE_DIR / f"{model_name}_cached.pkl" | |
def save_model_to_cache(model, model_name: str): | |
try: | |
cache_path = get_cache_path(model_name) | |
if hasattr(model, 'model') and hasattr(model.model, 'to'): | |
model.model.to('cpu') | |
elif hasattr(model, 'to'): | |
model.to('cpu') | |
with open(cache_path, 'wb') as f: | |
pickle.dump(model, f) | |
logger.info(f"Model {model_name} cached successfully") | |
return True | |
except Exception as e: | |
logger.warning(f"Failed to cache {model_name}: {e}") | |
return False | |
def load_model_from_cache(model_name: str, device: str): | |
try: | |
cache_path = get_cache_path(model_name) | |
if not cache_path.exists(): | |
return None | |
with open(cache_path, 'rb') as f: | |
model = pickle.load(f) | |
if hasattr(model, 'model') and hasattr(model.model, 'to'): | |
model.model.to(device) | |
elif hasattr(model, 'to'): | |
model.to(device) | |
logger.info(f"Model {model_name} loaded from cache") | |
return model | |
except Exception as e: | |
logger.warning(f"Failed to load {model_name} from cache: {e}") | |
return None | |
# ============================================================================ # | |
# FAST SAM2 LOADER | |
# ============================================================================ # | |
def load_sam2_predictor_fast(device: str = "cuda", progress_callback=None): | |
def _prog(pct: float, desc: str): | |
if progress_callback: | |
progress_callback(pct, desc) | |
# Try cache first | |
_prog(0.1, "Checking SAM2 cache...") | |
cached_predictor = load_model_from_cache("sam2_predictor", device) | |
if cached_predictor is not None: | |
_prog(1.0, "SAM2 loaded from cache!") | |
return cached_predictor | |
# Load fresh | |
_prog(0.2, "Loading SAM2 fresh...") | |
try: | |
checkpoint_path = hf_hub_download( | |
repo_id="facebook/sam2-hiera-large", | |
filename="sam2_hiera_large.pt", | |
cache_dir=str(CACHE_DIR / "sam2_checkpoint") | |
) | |
_prog(0.6, "Building SAM2...") | |
from sam2.build_sam import build_sam2 | |
from sam2.sam2_image_predictor import SAM2ImagePredictor | |
sam2_model = build_sam2("sam2_hiera_l.yaml", checkpoint_path) | |
sam2_model.to(device) | |
predictor = SAM2ImagePredictor(sam2_model) | |
_prog(0.9, "Caching SAM2...") | |
save_model_to_cache(predictor, "sam2_predictor") | |
predictor.model.to(device) | |
_prog(1.0, "SAM2 ready!") | |
return predictor | |
except Exception as e: | |
logger.error(f"SAM2 loading failed: {e}") | |
raise | |
# ============================================================================ # | |
# FAST MATANYONE LOADER | |
# ============================================================================ # | |
def load_matanyone_fast(progress_callback=None): | |
def _prog(pct: float, desc: str): | |
if progress_callback: | |
progress_callback(pct, desc) | |
# Try cache first | |
_prog(0.1, "Checking MatAnyone cache...") | |
cached_processor = load_model_from_cache("matanyone", "cpu") | |
if cached_processor is not None: | |
_prog(1.0, "MatAnyone loaded from cache!") | |
return cached_processor | |
# Load fresh | |
_prog(0.3, "Loading MatAnyone fresh...") | |
try: | |
from matanyone import InferenceCore | |
processor = InferenceCore("PeiqingYang/MatAnyone") | |
_prog(0.8, "Caching MatAnyone...") | |
save_model_to_cache(processor, "matanyone") | |
_prog(1.0, "MatAnyone ready!") | |
return processor | |
except Exception as e: | |
logger.error(f"MatAnyone loading failed: {e}") | |
raise | |
# ============================================================================ # | |
# GLOBAL MODEL STATE | |
# ============================================================================ # | |
sam2_predictor = None | |
matanyone_model = None | |
models_loaded = False | |
loading_lock = threading.Lock() | |
def load_models_fast(progress_callback=None): | |
"""Fast model loading with caching""" | |
global sam2_predictor, matanyone_model, models_loaded | |
with loading_lock: | |
if models_loaded: | |
return "Models already loaded" | |
try: | |
start_time = time.time() | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
sam2_predictor = load_sam2_predictor_fast(device=device, progress_callback=progress_callback) | |
matanyone_model = load_matanyone_fast(progress_callback=progress_callback) | |
models_loaded = True | |
load_time = time.time() - start_time | |
message = f"SAM2 + MatAnyone loaded in {load_time:.1f}s!" | |
logger.info(message) | |
return message | |
except Exception as e: | |
logger.error(f"Model loading failed: {str(e)}") | |
return f"Model loading failed: {str(e)}" | |
# ============================================================================ # | |
# CORE VIDEO PROCESSING | |
# ============================================================================ # | |
def process_video_core(video_path, background_choice, custom_background_path, progress_callback=None): | |
"""Core video processing function""" | |
if not models_loaded: | |
return None, "Models not loaded. Call load_models_fast() first." | |
if not video_path: | |
return None, "No video file provided." | |
def _prog(pct: float, desc: str): | |
if progress_callback: | |
progress_callback(pct, desc) | |
try: | |
_prog(0.0, "Starting processing...") | |
if not os.path.exists(video_path): | |
return None, f"Video file not found: {video_path}" | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
return None, "Could not open video file." | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
if total_frames == 0: | |
return None, "Video appears to be empty." | |
# Prepare background | |
background = None | |
background_name = "" | |
if background_choice == "custom" and custom_background_path: | |
background = cv2.imread(custom_background_path) | |
if background is None: | |
return None, "Could not read custom background image." | |
background_name = "Custom Image" | |
else: | |
if background_choice in PROFESSIONAL_BACKGROUNDS: | |
bg_config = PROFESSIONAL_BACKGROUNDS[background_choice] | |
background = create_professional_background(bg_config, frame_width, frame_height) | |
background_name = bg_config["name"] | |
else: | |
return None, f"Invalid background selection: {background_choice}" | |
if background is None: | |
return None, "Failed to create background." | |
timestamp = int(time.time()) | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
_prog(0.1, f"Processing with {background_name}...") | |
final_path = f"/tmp/output_{timestamp}.mp4" | |
final_writer = cv2.VideoWriter(final_path, fourcc, fps, (frame_width, frame_height)) | |
if not final_writer.isOpened(): | |
return None, "Could not create output video file." | |
frame_count = 0 | |
keyframe_interval = 3 # MatAnyone every 3rd frame | |
last_refined_mask = None | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
try: | |
_prog(0.1 + (frame_count / max(1, total_frames)) * 0.8, | |
f"Frame {frame_count + 1}/{total_frames}") | |
# SAM2 segmentation | |
mask = segment_person_hq(frame, sam2_predictor) | |
# MatAnyone refinement on keyframes | |
if (frame_count % keyframe_interval == 0) or (last_refined_mask is None): | |
refined_mask = refine_mask_hq(frame, mask, matanyone_model) | |
last_refined_mask = refined_mask.copy() | |
else: | |
refined_mask = mask | |
# Background replacement | |
result_frame = replace_background_hq(frame, refined_mask, background) | |
final_writer.write(result_frame) | |
except Exception as e: | |
logger.warning(f"Error processing frame {frame_count}: {e}") | |
final_writer.write(frame) | |
frame_count += 1 | |
if frame_count % 100 == 0: | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
final_writer.release() | |
cap.release() | |
if frame_count == 0: | |
return None, "No frames were processed." | |
_prog(0.9, "Adding audio...") | |
final_output = f"/tmp/final_{timestamp}.mp4" | |
try: | |
audio_cmd = ( | |
f'ffmpeg -y -i "{final_path}" -i "{video_path}" ' | |
f'-c:v libx264 -crf 18 -preset medium ' | |
f'-c:a aac -b:a 192k -ac 2 -ar 48000 ' | |
f'-map 0:v:0 -map 1:a:0? -shortest "{final_output}"' | |
) | |
result = os.system(audio_cmd) | |
if result != 0 or not os.path.exists(final_output): | |
shutil.copy2(final_path, final_output) | |
except Exception as e: | |
logger.warning(f"Audio processing error: {e}") | |
shutil.copy2(final_path, final_output) | |
# Save to MyAvatar directory | |
try: | |
myavatar_path = "/tmp/MyAvatar/My_Videos/" | |
os.makedirs(myavatar_path, exist_ok=True) | |
saved_filename = f"bg_replaced_{timestamp}.mp4" | |
saved_path = os.path.join(myavatar_path, saved_filename) | |
shutil.copy2(final_output, saved_path) | |
except Exception as e: | |
logger.warning(f"Could not save to MyAvatar: {e}") | |
saved_filename = os.path.basename(final_output) | |
# Cleanup | |
try: | |
if os.path.exists(final_path): | |
os.remove(final_path) | |
except: | |
pass | |
_prog(1.0, "Processing complete!") | |
success_message = ( | |
f"Success!\n" | |
f"Background: {background_name}\n" | |
f"Processed: {frame_count} frames\n" | |
f"Saved: {saved_filename}\n" | |
f"Quality: SAM2 + MatAnyone" | |
) | |
return final_output, success_message | |
except Exception as e: | |
logger.error(f"Processing error: {traceback.format_exc()}") | |
return None, f"Processing Error: {str(e)}" | |
def get_cache_status(): | |
"""Get current cache status""" | |
sam2_cached = get_cache_path("sam2_predictor").exists() | |
matanyone_cached = get_cache_path("matanyone").exists() | |
return { | |
"sam2_cached": sam2_cached, | |
"matanyone_cached": matanyone_cached, | |
"cache_dir": str(CACHE_DIR) | |
} | |
# ============================================================================ # | |
# MAIN - IMPORT UI COMPONENTS ONLY WHEN NEEDED | |
# ============================================================================ # | |
def main(): | |
try: | |
print("===== FAST STARTUP CORE =====") | |
print("Loading UI components...") | |
# Import UI components only when needed | |
from ui_components import create_interface | |
os.makedirs("/tmp/MyAvatar/My_Videos/", exist_ok=True) | |
CACHE_DIR.mkdir(exist_ok=True, parents=True) | |
print("Creating interface...") | |
demo = create_interface() | |
print("Launching...") | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True, show_error=True) | |
except Exception as e: | |
logger.error(f"Startup failed: {e}") | |
print(f"Startup failed: {e}") | |
if __name__ == "__main__": | |
main() |