Spaces:
Running
on
Zero
Running
on
Zero
derektan
commited on
Commit
·
8b66bab
1
Parent(s):
f952795
Provide indicator that TTA is being run
Browse files- Taxabind/TaxaBind/SatBind/clip_seg_tta.py +8 -0
- app.py +17 -5
Taxabind/TaxaBind/SatBind/clip_seg_tta.py
CHANGED
@@ -150,6 +150,9 @@ class ClipSegTTA:
|
|
150 |
|
151 |
self.clip_inference_time = 0.0
|
152 |
self.tta_time = 0.0
|
|
|
|
|
|
|
153 |
|
154 |
|
155 |
def load_data(self):
|
@@ -291,6 +294,8 @@ class ClipSegTTA:
|
|
291 |
:param viz_heatmap: If True, perform visualization. If False, skip plotting.
|
292 |
"""
|
293 |
|
|
|
|
|
294 |
# with enable_grad():
|
295 |
|
296 |
### Option 1: SAMPLE FROM DATASET
|
@@ -444,6 +449,9 @@ class ClipSegTTA:
|
|
444 |
self.model_local.imo_encoder.to(self.device)
|
445 |
self.model_local.bio_model.to(self.device)
|
446 |
|
|
|
|
|
|
|
447 |
return self.heatmap
|
448 |
|
449 |
|
|
|
150 |
|
151 |
self.clip_inference_time = 0.0
|
152 |
self.tta_time = 0.0
|
153 |
+
# NOTE: integration with app.py on hf
|
154 |
+
self.executing_tta = False
|
155 |
+
|
156 |
|
157 |
|
158 |
def load_data(self):
|
|
|
294 |
:param viz_heatmap: If True, perform visualization. If False, skip plotting.
|
295 |
"""
|
296 |
|
297 |
+
# NOTE: integration with app.py on hf
|
298 |
+
self.executing_tta = True
|
299 |
# with enable_grad():
|
300 |
|
301 |
### Option 1: SAMPLE FROM DATASET
|
|
|
449 |
self.model_local.imo_encoder.to(self.device)
|
450 |
self.model_local.bio_model.to(self.device)
|
451 |
|
452 |
+
# NOTE: integration with app.py on hf
|
453 |
+
self.executing_tta = False
|
454 |
+
|
455 |
return self.heatmap
|
456 |
|
457 |
|
app.py
CHANGED
@@ -54,6 +54,9 @@ def _stop_thread(thread: threading.Thread):
|
|
54 |
_running_threads: list[threading.Thread] = []
|
55 |
_running_threads_lock = threading.Lock()
|
56 |
|
|
|
|
|
|
|
57 |
def _register_thread(th: threading.Thread):
|
58 |
"""Record a newly started worker thread so we can later cancel it."""
|
59 |
with _running_threads_lock:
|
@@ -277,6 +280,7 @@ def process_search_tta(
|
|
277 |
"""Prepare directory, build planner, run an episode, record errors."""
|
278 |
try:
|
279 |
planner = build_planner(enable_tta, save_dir, clip_obj)
|
|
|
280 |
planner.run_episode(0)
|
281 |
except Exception as exc:
|
282 |
# Mark that this planner crashed so UI can show an error status
|
@@ -352,15 +356,23 @@ def process_search_tta(
|
|
352 |
|
353 |
# Determine status based on whether we already have a frame and whether
|
354 |
# the corresponding thread is still alive.
|
355 |
-
def _mk_status(last_frame, thread_alive, errored: bool):
|
356 |
if errored:
|
357 |
return "Error!"
|
358 |
if last_frame is None:
|
359 |
return "Initializing model…"
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
|
365 |
# Determine if we should reveal sliders (once corresponding thread has finished)
|
366 |
show_slider_tta = (not thread_tta.is_alive()) and (last_tta is not None)
|
|
|
54 |
_running_threads: list[threading.Thread] = []
|
55 |
_running_threads_lock = threading.Lock()
|
56 |
|
57 |
+
# Map worker threads to their ClipSegTTA instance so UI can read executing_tta flag
|
58 |
+
_thread_clip_map: dict[threading.Thread, ClipSegTTA] = {}
|
59 |
+
|
60 |
def _register_thread(th: threading.Thread):
|
61 |
"""Record a newly started worker thread so we can later cancel it."""
|
62 |
with _running_threads_lock:
|
|
|
280 |
"""Prepare directory, build planner, run an episode, record errors."""
|
281 |
try:
|
282 |
planner = build_planner(enable_tta, save_dir, clip_obj)
|
283 |
+
_thread_clip_map[threading.current_thread()] = planner.clip_seg_tta
|
284 |
planner.run_episode(0)
|
285 |
except Exception as exc:
|
286 |
# Mark that this planner crashed so UI can show an error status
|
|
|
356 |
|
357 |
# Determine status based on whether we already have a frame and whether
|
358 |
# the corresponding thread is still alive.
|
359 |
+
def _mk_status(last_frame, thread_alive, errored: bool, running_tta: bool=False):
|
360 |
if errored:
|
361 |
return "Error!"
|
362 |
if last_frame is None:
|
363 |
return "Initializing model…"
|
364 |
+
if not thread_alive:
|
365 |
+
return "Done."
|
366 |
+
return "Running TTA…" if running_tta else "Running Planner…"
|
367 |
+
|
368 |
+
exec_tta_flag = False
|
369 |
+
if thread_tta.is_alive():
|
370 |
+
clip_obj = _thread_clip_map.get(thread_tta)
|
371 |
+
if clip_obj is not None and getattr(clip_obj, "executing_tta", False):
|
372 |
+
exec_tta_flag = True
|
373 |
+
|
374 |
+
status_tta = _mk_status(last_tta, thread_tta.is_alive(), error_flags["tta"], exec_tta_flag)
|
375 |
+
status_no = _mk_status(last_no, thread_no.is_alive(), error_flags["no"], False)
|
376 |
|
377 |
# Determine if we should reveal sliders (once corresponding thread has finished)
|
378 |
show_slider_tta = (not thread_tta.is_alive()) and (last_tta is not None)
|