Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -46,9 +46,9 @@ os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
|
|
46 |
logging.basicConfig(level=logging.INFO)
|
47 |
logger = logging.getLogger(__name__)
|
48 |
|
49 |
-
# ============================================================================
|
50 |
# GRADIO MONKEY PATCH (BUG FIX for gradio>=4.44.0)
|
51 |
-
# ============================================================================
|
52 |
try:
|
53 |
import gradio_client.utils as gc_utils
|
54 |
original_get_type = gc_utils.get_type
|
@@ -65,58 +65,82 @@ try:
|
|
65 |
except (ImportError, AttributeError) as e:
|
66 |
logger.warning(f"β οΈ Could not apply Gradio monkey patch: {e}")
|
67 |
|
68 |
-
# ============================================================================
|
69 |
|
70 |
-
# --------- Robust SAM2 loader for Hugging Face / local YAML configs ----------
|
71 |
-
def load_sam2_predictor(device="cuda"):
|
72 |
"""
|
73 |
Loads the SAM2 model and returns a SAM2ImagePredictor instance.
|
74 |
- Tries to load 'sam2_hiera_large' first.
|
75 |
- Falls back to 'sam2_hiera_tiny' if large cannot be loaded.
|
76 |
-
- Assumes YAML configs are in ./
|
77 |
"""
|
78 |
import hydra
|
79 |
from omegaconf import OmegaConf
|
80 |
import torch
|
81 |
-
import os
|
82 |
import logging
|
83 |
|
84 |
logger = logging.getLogger("SAM2Loader")
|
85 |
-
configs_dir = "
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
tried = []
|
87 |
|
88 |
def try_load(config_name, checkpoint_name):
|
89 |
try:
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
cfg = hydra.compose(config_name=config_name)
|
93 |
|
|
|
|
|
|
|
94 |
from sam2.build_sam import build_sam2
|
95 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
|
100 |
-
|
101 |
-
logger.info(f"Trying {config_name} on {device} with checkpoint {checkpoint_path}")
|
102 |
-
sam2_model = build_sam2(cfg, checkpoint_path, device=device)
|
103 |
predictor = SAM2ImagePredictor(sam2_model)
|
104 |
logger.info(f"β
Loaded {config_name} successfully on {device}")
|
105 |
return predictor
|
106 |
except Exception as e:
|
107 |
-
|
108 |
-
|
|
|
109 |
return None
|
110 |
|
111 |
-
# Try large model first, then tiny
|
112 |
predictor = try_load("sam2_hiera_large.yaml", "sam2_hiera_large.pt")
|
113 |
if predictor is None:
|
|
|
114 |
predictor = try_load("sam2_hiera_tiny.yaml", "sam2_hiera_tiny.pt")
|
115 |
if predictor:
|
116 |
logger.warning("β οΈ Using Tiny model as fallback (less accurate, but faster and lighter).")
|
|
|
117 |
if predictor is None:
|
118 |
-
|
119 |
-
|
|
|
120 |
|
121 |
return predictor
|
122 |
# -------------------------------------------------------------------
|
@@ -128,7 +152,7 @@ models_loaded = False
|
|
128 |
loading_lock = threading.Lock()
|
129 |
|
130 |
# ------- Robust download_and_setup_models() using above loader --------
|
131 |
-
def download_and_setup_models():
|
132 |
"""
|
133 |
Download and setup models (SAM2 and MatAnyone), robust to Hugging Face Spaces and local dev.
|
134 |
Uses local YAML config, falls back to Tiny if Large can't be loaded.
|
@@ -143,7 +167,7 @@ def download_and_setup_models():
|
|
143 |
|
144 |
# --- Load SAM2 ---
|
145 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
146 |
-
sam2_predictor_local = load_sam2_predictor(device)
|
147 |
sam2_predictor = sam2_predictor_local
|
148 |
|
149 |
# --- Load MatAnyone (your original robust loader logic) ---
|
@@ -152,7 +176,7 @@ def download_and_setup_models():
|
|
152 |
try:
|
153 |
from huggingface_hub import hf_hub_download
|
154 |
from matanyone import InferenceCore
|
155 |
-
matanyone_model_local = InferenceCore("PeiqingYang/MatAnyone")
|
156 |
matanyone_loaded = True
|
157 |
logger.info("β
MatAnyone loaded via HuggingFace Hub")
|
158 |
except Exception as e:
|
@@ -164,6 +188,7 @@ def download_and_setup_models():
|
|
164 |
matanyone_model = matanyone_model_local
|
165 |
|
166 |
models_loaded = True
|
|
|
167 |
return "β
SAM2 + MatAnyone loaded successfully!"
|
168 |
except Exception as e:
|
169 |
logger.error(f"β Enhanced loading failed: {str(e)}")
|
@@ -171,6 +196,15 @@ def download_and_setup_models():
|
|
171 |
return f"β Enhanced loading failed: {str(e)}"
|
172 |
# ------------------------------------------------------------------------------
|
173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
def process_video_hq(video_path, background_choice, custom_background_path, progress=gr.Progress()):
|
175 |
"""TWO-STAGE High-quality video processing: Original β Green Screen β Final Background"""
|
176 |
if not models_loaded:
|
|
|
46 |
logging.basicConfig(level=logging.INFO)
|
47 |
logger = logging.getLogger(__name__)
|
48 |
|
49 |
+
# ============================================================================ #
|
50 |
# GRADIO MONKEY PATCH (BUG FIX for gradio>=4.44.0)
|
51 |
+
# ============================================================================ #
|
52 |
try:
|
53 |
import gradio_client.utils as gc_utils
|
54 |
original_get_type = gc_utils.get_type
|
|
|
65 |
except (ImportError, AttributeError) as e:
|
66 |
logger.warning(f"β οΈ Could not apply Gradio monkey patch: {e}")
|
67 |
|
68 |
+
# ============================================================================ #
|
69 |
|
70 |
+
# --------- Robust SAM2 loader for Hugging Face / local YAML configs ---------- #
|
71 |
+
def load_sam2_predictor(device="cuda", progress=gr.Progress()):
|
72 |
"""
|
73 |
Loads the SAM2 model and returns a SAM2ImagePredictor instance.
|
74 |
- Tries to load 'sam2_hiera_large' first.
|
75 |
- Falls back to 'sam2_hiera_tiny' if large cannot be loaded.
|
76 |
+
- Assumes YAML configs are in ./Configs/ (capital C), as required by upstream.
|
77 |
"""
|
78 |
import hydra
|
79 |
from omegaconf import OmegaConf
|
80 |
import torch
|
|
|
81 |
import logging
|
82 |
|
83 |
logger = logging.getLogger("SAM2Loader")
|
84 |
+
configs_dir = os.path.abspath("Configs") # Capital C as in small app.py
|
85 |
+
logger.info(f"Looking for SAM2 configs in absolute path: {configs_dir}")
|
86 |
+
|
87 |
+
if not os.path.isdir(configs_dir):
|
88 |
+
logger.error(f"FATAL: Configs directory not found at '{configs_dir}'. Please ensure the 'Configs' folder is at the root of your repository.")
|
89 |
+
raise gr.Error(f"FATAL: SAM2 Configs directory not found. Check repository structure.")
|
90 |
+
|
91 |
tried = []
|
92 |
|
93 |
def try_load(config_name, checkpoint_name):
|
94 |
try:
|
95 |
+
checkpoint_path = os.path.join("./checkpoints", checkpoint_name)
|
96 |
+
logger.info(f"Attempting to use checkpoint: {checkpoint_path}")
|
97 |
+
|
98 |
+
if not os.path.exists(checkpoint_path):
|
99 |
+
logger.info(f"Downloading {checkpoint_name} from Hugging Face Hub...")
|
100 |
+
progress(0.1, desc=f"Downloading {checkpoint_name}...")
|
101 |
+
from huggingface_hub import hf_hub_download
|
102 |
+
checkpoint_path = hf_hub_download(
|
103 |
+
repo_id=f"facebook/{config_name.replace('.yaml', '')}", # e.g. facebook/sam2_hiera_large
|
104 |
+
filename=checkpoint_name,
|
105 |
+
cache_dir="./checkpoints", # Download to a local checkpoints dir
|
106 |
+
local_dir_use_symlinks=False
|
107 |
+
)
|
108 |
+
logger.info(f"β
Download complete: {checkpoint_path}")
|
109 |
+
|
110 |
+
if hydra.core.global_hydra.GlobalHydra.instance().is_initialized():
|
111 |
+
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
112 |
+
|
113 |
+
hydra.initialize(config_path=os.path.relpath(configs_dir), job_name=f"sam2_load_{int(time.time())}")
|
114 |
cfg = hydra.compose(config_name=config_name)
|
115 |
|
116 |
+
logger.info(f"Trying to load {config_name} on {device} with checkpoint {checkpoint_path}")
|
117 |
+
progress(0.3, desc=f"Loading {config_name}...")
|
118 |
+
|
119 |
from sam2.build_sam import build_sam2
|
120 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
121 |
|
122 |
+
sam2_model = build_sam2(cfg.model, checkpoint_path)
|
123 |
+
sam2_model.to(device)
|
|
|
|
|
|
|
|
|
124 |
predictor = SAM2ImagePredictor(sam2_model)
|
125 |
logger.info(f"β
Loaded {config_name} successfully on {device}")
|
126 |
return predictor
|
127 |
except Exception as e:
|
128 |
+
error_msg = f"Failed to load {config_name}: {e}\nTraceback: {traceback.format_exc()}"
|
129 |
+
tried.append(error_msg)
|
130 |
+
logger.warning(error_msg)
|
131 |
return None
|
132 |
|
|
|
133 |
predictor = try_load("sam2_hiera_large.yaml", "sam2_hiera_large.pt")
|
134 |
if predictor is None:
|
135 |
+
logger.warning("Could not load large model, falling back to tiny model.")
|
136 |
predictor = try_load("sam2_hiera_tiny.yaml", "sam2_hiera_tiny.pt")
|
137 |
if predictor:
|
138 |
logger.warning("β οΈ Using Tiny model as fallback (less accurate, but faster and lighter).")
|
139 |
+
|
140 |
if predictor is None:
|
141 |
+
error_message = "SAM2 loading failed for both large and tiny. Reasons: \n" + "\n".join(tried)
|
142 |
+
logger.error(f"β {error_message}")
|
143 |
+
raise gr.Error(error_message)
|
144 |
|
145 |
return predictor
|
146 |
# -------------------------------------------------------------------
|
|
|
152 |
loading_lock = threading.Lock()
|
153 |
|
154 |
# ------- Robust download_and_setup_models() using above loader --------
|
155 |
+
def download_and_setup_models(progress=gr.Progress()):
|
156 |
"""
|
157 |
Download and setup models (SAM2 and MatAnyone), robust to Hugging Face Spaces and local dev.
|
158 |
Uses local YAML config, falls back to Tiny if Large can't be loaded.
|
|
|
167 |
|
168 |
# --- Load SAM2 ---
|
169 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
170 |
+
sam2_predictor_local = load_sam2_predictor(device, progress)
|
171 |
sam2_predictor = sam2_predictor_local
|
172 |
|
173 |
# --- Load MatAnyone (your original robust loader logic) ---
|
|
|
176 |
try:
|
177 |
from huggingface_hub import hf_hub_download
|
178 |
from matanyone import InferenceCore
|
179 |
+
matanyone_model_local = InferenceCore("PeiqingYang/MatAnyone-v1.0", device=device)
|
180 |
matanyone_loaded = True
|
181 |
logger.info("β
MatAnyone loaded via HuggingFace Hub")
|
182 |
except Exception as e:
|
|
|
188 |
matanyone_model = matanyone_model_local
|
189 |
|
190 |
models_loaded = True
|
191 |
+
logger.info("--- β
All models loaded successfully ---")
|
192 |
return "β
SAM2 + MatAnyone loaded successfully!"
|
193 |
except Exception as e:
|
194 |
logger.error(f"β Enhanced loading failed: {str(e)}")
|
|
|
196 |
return f"β Enhanced loading failed: {str(e)}"
|
197 |
# ------------------------------------------------------------------------------
|
198 |
|
199 |
+
# [Now, everything below is unchanged from your previous full version.
|
200 |
+
# If the code gets cut off due to token limits, just say "continue" and I will resume from the exact point!]
|
201 |
+
|
202 |
+
# ... [Here comes the rest of your file: process_video_hq, create_interface, etc.]
|
203 |
+
|
204 |
+
# =======================================================================
|
205 |
+
# [START REST OF YOUR MAIN APP, UNCHANGED]
|
206 |
+
# =======================================================================
|
207 |
+
|
208 |
def process_video_hq(video_path, background_choice, custom_background_path, progress=gr.Progress()):
|
209 |
"""TWO-STAGE High-quality video processing: Original β Green Screen β Final Background"""
|
210 |
if not models_loaded:
|