MogensR commited on
Commit
d93f7a0
Β·
verified Β·
1 Parent(s): bc57987

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -24
app.py CHANGED
@@ -46,9 +46,9 @@ os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
46
  logging.basicConfig(level=logging.INFO)
47
  logger = logging.getLogger(__name__)
48
 
49
- # ============================================================================
50
  # GRADIO MONKEY PATCH (BUG FIX for gradio>=4.44.0)
51
- # ============================================================================
52
  try:
53
  import gradio_client.utils as gc_utils
54
  original_get_type = gc_utils.get_type
@@ -65,58 +65,82 @@ try:
65
  except (ImportError, AttributeError) as e:
66
  logger.warning(f"⚠️ Could not apply Gradio monkey patch: {e}")
67
 
68
- # ============================================================================
69
 
70
- # --------- Robust SAM2 loader for Hugging Face / local YAML configs ----------
71
- def load_sam2_predictor(device="cuda"):
72
  """
73
  Loads the SAM2 model and returns a SAM2ImagePredictor instance.
74
  - Tries to load 'sam2_hiera_large' first.
75
  - Falls back to 'sam2_hiera_tiny' if large cannot be loaded.
76
- - Assumes YAML configs are in ./configs/.
77
  """
78
  import hydra
79
  from omegaconf import OmegaConf
80
  import torch
81
- import os
82
  import logging
83
 
84
  logger = logging.getLogger("SAM2Loader")
85
- configs_dir = "./configs"
 
 
 
 
 
 
86
  tried = []
87
 
88
  def try_load(config_name, checkpoint_name):
89
  try:
90
- hydra.core.global_hydra.GlobalHydra.instance().clear()
91
- hydra.initialize(config_path=configs_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  cfg = hydra.compose(config_name=config_name)
93
 
 
 
 
94
  from sam2.build_sam import build_sam2
95
  from sam2.sam2_image_predictor import SAM2ImagePredictor
96
 
97
- checkpoint_path = os.path.expanduser(f"~/.cache/sam2/{checkpoint_name}")
98
- if not os.path.exists(checkpoint_path):
99
- raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
100
-
101
- logger.info(f"Trying {config_name} on {device} with checkpoint {checkpoint_path}")
102
- sam2_model = build_sam2(cfg, checkpoint_path, device=device)
103
  predictor = SAM2ImagePredictor(sam2_model)
104
  logger.info(f"βœ… Loaded {config_name} successfully on {device}")
105
  return predictor
106
  except Exception as e:
107
- tried.append(f"{config_name}: {e}")
108
- logger.warning(f"Failed to load {config_name}: {e}")
 
109
  return None
110
 
111
- # Try large model first, then tiny
112
  predictor = try_load("sam2_hiera_large.yaml", "sam2_hiera_large.pt")
113
  if predictor is None:
 
114
  predictor = try_load("sam2_hiera_tiny.yaml", "sam2_hiera_tiny.pt")
115
  if predictor:
116
  logger.warning("⚠️ Using Tiny model as fallback (less accurate, but faster and lighter).")
 
117
  if predictor is None:
118
- logger.error("❌ SAM2 loading failed for both large and tiny. Reasons: \n" + "\n".join(tried))
119
- raise RuntimeError("SAM2 config/weights not found or could not be loaded. Make sure YAML and checkpoint files are present.")
 
120
 
121
  return predictor
122
  # -------------------------------------------------------------------
@@ -128,7 +152,7 @@ models_loaded = False
128
  loading_lock = threading.Lock()
129
 
130
  # ------- Robust download_and_setup_models() using above loader --------
131
- def download_and_setup_models():
132
  """
133
  Download and setup models (SAM2 and MatAnyone), robust to Hugging Face Spaces and local dev.
134
  Uses local YAML config, falls back to Tiny if Large can't be loaded.
@@ -143,7 +167,7 @@ def download_and_setup_models():
143
 
144
  # --- Load SAM2 ---
145
  device = "cuda" if torch.cuda.is_available() else "cpu"
146
- sam2_predictor_local = load_sam2_predictor(device)
147
  sam2_predictor = sam2_predictor_local
148
 
149
  # --- Load MatAnyone (your original robust loader logic) ---
@@ -152,7 +176,7 @@ def download_and_setup_models():
152
  try:
153
  from huggingface_hub import hf_hub_download
154
  from matanyone import InferenceCore
155
- matanyone_model_local = InferenceCore("PeiqingYang/MatAnyone")
156
  matanyone_loaded = True
157
  logger.info("βœ… MatAnyone loaded via HuggingFace Hub")
158
  except Exception as e:
@@ -164,6 +188,7 @@ def download_and_setup_models():
164
  matanyone_model = matanyone_model_local
165
 
166
  models_loaded = True
 
167
  return "βœ… SAM2 + MatAnyone loaded successfully!"
168
  except Exception as e:
169
  logger.error(f"❌ Enhanced loading failed: {str(e)}")
@@ -171,6 +196,15 @@ def download_and_setup_models():
171
  return f"❌ Enhanced loading failed: {str(e)}"
172
  # ------------------------------------------------------------------------------
173
 
 
 
 
 
 
 
 
 
 
174
  def process_video_hq(video_path, background_choice, custom_background_path, progress=gr.Progress()):
175
  """TWO-STAGE High-quality video processing: Original β†’ Green Screen β†’ Final Background"""
176
  if not models_loaded:
 
46
  logging.basicConfig(level=logging.INFO)
47
  logger = logging.getLogger(__name__)
48
 
49
+ # ============================================================================ #
50
  # GRADIO MONKEY PATCH (BUG FIX for gradio>=4.44.0)
51
+ # ============================================================================ #
52
  try:
53
  import gradio_client.utils as gc_utils
54
  original_get_type = gc_utils.get_type
 
65
  except (ImportError, AttributeError) as e:
66
  logger.warning(f"⚠️ Could not apply Gradio monkey patch: {e}")
67
 
68
+ # ============================================================================ #
69
 
70
+ # --------- Robust SAM2 loader for Hugging Face / local YAML configs ---------- #
71
+ def load_sam2_predictor(device="cuda", progress=gr.Progress()):
72
  """
73
  Loads the SAM2 model and returns a SAM2ImagePredictor instance.
74
  - Tries to load 'sam2_hiera_large' first.
75
  - Falls back to 'sam2_hiera_tiny' if large cannot be loaded.
76
+ - Assumes YAML configs are in ./Configs/ (capital C), as required by upstream.
77
  """
78
  import hydra
79
  from omegaconf import OmegaConf
80
  import torch
 
81
  import logging
82
 
83
  logger = logging.getLogger("SAM2Loader")
84
+ configs_dir = os.path.abspath("Configs") # Capital C as in small app.py
85
+ logger.info(f"Looking for SAM2 configs in absolute path: {configs_dir}")
86
+
87
+ if not os.path.isdir(configs_dir):
88
+ logger.error(f"FATAL: Configs directory not found at '{configs_dir}'. Please ensure the 'Configs' folder is at the root of your repository.")
89
+ raise gr.Error(f"FATAL: SAM2 Configs directory not found. Check repository structure.")
90
+
91
  tried = []
92
 
93
  def try_load(config_name, checkpoint_name):
94
  try:
95
+ checkpoint_path = os.path.join("./checkpoints", checkpoint_name)
96
+ logger.info(f"Attempting to use checkpoint: {checkpoint_path}")
97
+
98
+ if not os.path.exists(checkpoint_path):
99
+ logger.info(f"Downloading {checkpoint_name} from Hugging Face Hub...")
100
+ progress(0.1, desc=f"Downloading {checkpoint_name}...")
101
+ from huggingface_hub import hf_hub_download
102
+ checkpoint_path = hf_hub_download(
103
+ repo_id=f"facebook/{config_name.replace('.yaml', '')}", # e.g. facebook/sam2_hiera_large
104
+ filename=checkpoint_name,
105
+ cache_dir="./checkpoints", # Download to a local checkpoints dir
106
+ local_dir_use_symlinks=False
107
+ )
108
+ logger.info(f"βœ… Download complete: {checkpoint_path}")
109
+
110
+ if hydra.core.global_hydra.GlobalHydra.instance().is_initialized():
111
+ hydra.core.global_hydra.GlobalHydra.instance().clear()
112
+
113
+ hydra.initialize(config_path=os.path.relpath(configs_dir), job_name=f"sam2_load_{int(time.time())}")
114
  cfg = hydra.compose(config_name=config_name)
115
 
116
+ logger.info(f"Trying to load {config_name} on {device} with checkpoint {checkpoint_path}")
117
+ progress(0.3, desc=f"Loading {config_name}...")
118
+
119
  from sam2.build_sam import build_sam2
120
  from sam2.sam2_image_predictor import SAM2ImagePredictor
121
 
122
+ sam2_model = build_sam2(cfg.model, checkpoint_path)
123
+ sam2_model.to(device)
 
 
 
 
124
  predictor = SAM2ImagePredictor(sam2_model)
125
  logger.info(f"βœ… Loaded {config_name} successfully on {device}")
126
  return predictor
127
  except Exception as e:
128
+ error_msg = f"Failed to load {config_name}: {e}\nTraceback: {traceback.format_exc()}"
129
+ tried.append(error_msg)
130
+ logger.warning(error_msg)
131
  return None
132
 
 
133
  predictor = try_load("sam2_hiera_large.yaml", "sam2_hiera_large.pt")
134
  if predictor is None:
135
+ logger.warning("Could not load large model, falling back to tiny model.")
136
  predictor = try_load("sam2_hiera_tiny.yaml", "sam2_hiera_tiny.pt")
137
  if predictor:
138
  logger.warning("⚠️ Using Tiny model as fallback (less accurate, but faster and lighter).")
139
+
140
  if predictor is None:
141
+ error_message = "SAM2 loading failed for both large and tiny. Reasons: \n" + "\n".join(tried)
142
+ logger.error(f"❌ {error_message}")
143
+ raise gr.Error(error_message)
144
 
145
  return predictor
146
  # -------------------------------------------------------------------
 
152
  loading_lock = threading.Lock()
153
 
154
  # ------- Robust download_and_setup_models() using above loader --------
155
+ def download_and_setup_models(progress=gr.Progress()):
156
  """
157
  Download and setup models (SAM2 and MatAnyone), robust to Hugging Face Spaces and local dev.
158
  Uses local YAML config, falls back to Tiny if Large can't be loaded.
 
167
 
168
  # --- Load SAM2 ---
169
  device = "cuda" if torch.cuda.is_available() else "cpu"
170
+ sam2_predictor_local = load_sam2_predictor(device, progress)
171
  sam2_predictor = sam2_predictor_local
172
 
173
  # --- Load MatAnyone (your original robust loader logic) ---
 
176
  try:
177
  from huggingface_hub import hf_hub_download
178
  from matanyone import InferenceCore
179
+ matanyone_model_local = InferenceCore("PeiqingYang/MatAnyone-v1.0", device=device)
180
  matanyone_loaded = True
181
  logger.info("βœ… MatAnyone loaded via HuggingFace Hub")
182
  except Exception as e:
 
188
  matanyone_model = matanyone_model_local
189
 
190
  models_loaded = True
191
+ logger.info("--- βœ… All models loaded successfully ---")
192
  return "βœ… SAM2 + MatAnyone loaded successfully!"
193
  except Exception as e:
194
  logger.error(f"❌ Enhanced loading failed: {str(e)}")
 
196
  return f"❌ Enhanced loading failed: {str(e)}"
197
  # ------------------------------------------------------------------------------
198
 
199
+ # [Now, everything below is unchanged from your previous full version.
200
+ # If the code gets cut off due to token limits, just say "continue" and I will resume from the exact point!]
201
+
202
+ # ... [Here comes the rest of your file: process_video_hq, create_interface, etc.]
203
+
204
+ # =======================================================================
205
+ # [START REST OF YOUR MAIN APP, UNCHANGED]
206
+ # =======================================================================
207
+
208
  def process_video_hq(video_path, background_choice, custom_background_path, progress=gr.Progress()):
209
  """TWO-STAGE High-quality video processing: Original β†’ Green Screen β†’ Final Background"""
210
  if not models_loaded: