derektan commited on
Commit
8b66bab
·
1 Parent(s): f952795

Provide indicator that TTA is being run

Browse files
Taxabind/TaxaBind/SatBind/clip_seg_tta.py CHANGED
@@ -150,6 +150,9 @@ class ClipSegTTA:
150
 
151
  self.clip_inference_time = 0.0
152
  self.tta_time = 0.0
 
 
 
153
 
154
 
155
  def load_data(self):
@@ -291,6 +294,8 @@ class ClipSegTTA:
291
  :param viz_heatmap: If True, perform visualization. If False, skip plotting.
292
  """
293
 
 
 
294
  # with enable_grad():
295
 
296
  ### Option 1: SAMPLE FROM DATASET
@@ -444,6 +449,9 @@ class ClipSegTTA:
444
  self.model_local.imo_encoder.to(self.device)
445
  self.model_local.bio_model.to(self.device)
446
 
 
 
 
447
  return self.heatmap
448
 
449
 
 
150
 
151
  self.clip_inference_time = 0.0
152
  self.tta_time = 0.0
153
+ # NOTE: integration with app.py on hf
154
+ self.executing_tta = False
155
+
156
 
157
 
158
  def load_data(self):
 
294
  :param viz_heatmap: If True, perform visualization. If False, skip plotting.
295
  """
296
 
297
+ # NOTE: integration with app.py on hf
298
+ self.executing_tta = True
299
  # with enable_grad():
300
 
301
  ### Option 1: SAMPLE FROM DATASET
 
449
  self.model_local.imo_encoder.to(self.device)
450
  self.model_local.bio_model.to(self.device)
451
 
452
+ # NOTE: integration with app.py on hf
453
+ self.executing_tta = False
454
+
455
  return self.heatmap
456
 
457
 
app.py CHANGED
@@ -54,6 +54,9 @@ def _stop_thread(thread: threading.Thread):
54
  _running_threads: list[threading.Thread] = []
55
  _running_threads_lock = threading.Lock()
56
 
 
 
 
57
  def _register_thread(th: threading.Thread):
58
  """Record a newly started worker thread so we can later cancel it."""
59
  with _running_threads_lock:
@@ -277,6 +280,7 @@ def process_search_tta(
277
  """Prepare directory, build planner, run an episode, record errors."""
278
  try:
279
  planner = build_planner(enable_tta, save_dir, clip_obj)
 
280
  planner.run_episode(0)
281
  except Exception as exc:
282
  # Mark that this planner crashed so UI can show an error status
@@ -352,15 +356,23 @@ def process_search_tta(
352
 
353
  # Determine status based on whether we already have a frame and whether
354
  # the corresponding thread is still alive.
355
- def _mk_status(last_frame, thread_alive, errored: bool):
356
  if errored:
357
  return "Error!"
358
  if last_frame is None:
359
  return "Initializing model…"
360
- return "Running…" if thread_alive else "Done."
361
-
362
- status_tta = _mk_status(last_tta, thread_tta.is_alive(), error_flags["tta"])
363
- status_no = _mk_status(last_no, thread_no.is_alive(), error_flags["no"])
 
 
 
 
 
 
 
 
364
 
365
  # Determine if we should reveal sliders (once corresponding thread has finished)
366
  show_slider_tta = (not thread_tta.is_alive()) and (last_tta is not None)
 
54
  _running_threads: list[threading.Thread] = []
55
  _running_threads_lock = threading.Lock()
56
 
57
+ # Map worker threads to their ClipSegTTA instance so UI can read executing_tta flag
58
+ _thread_clip_map: dict[threading.Thread, ClipSegTTA] = {}
59
+
60
  def _register_thread(th: threading.Thread):
61
  """Record a newly started worker thread so we can later cancel it."""
62
  with _running_threads_lock:
 
280
  """Prepare directory, build planner, run an episode, record errors."""
281
  try:
282
  planner = build_planner(enable_tta, save_dir, clip_obj)
283
+ _thread_clip_map[threading.current_thread()] = planner.clip_seg_tta
284
  planner.run_episode(0)
285
  except Exception as exc:
286
  # Mark that this planner crashed so UI can show an error status
 
356
 
357
  # Determine status based on whether we already have a frame and whether
358
  # the corresponding thread is still alive.
359
+ def _mk_status(last_frame, thread_alive, errored: bool, running_tta: bool=False):
360
  if errored:
361
  return "Error!"
362
  if last_frame is None:
363
  return "Initializing model…"
364
+ if not thread_alive:
365
+ return "Done."
366
+ return "Running TTA…" if running_tta else "Running Planner…"
367
+
368
+ exec_tta_flag = False
369
+ if thread_tta.is_alive():
370
+ clip_obj = _thread_clip_map.get(thread_tta)
371
+ if clip_obj is not None and getattr(clip_obj, "executing_tta", False):
372
+ exec_tta_flag = True
373
+
374
+ status_tta = _mk_status(last_tta, thread_tta.is_alive(), error_flags["tta"], exec_tta_flag)
375
+ status_no = _mk_status(last_no, thread_no.is_alive(), error_flags["no"], False)
376
 
377
  # Determine if we should reveal sliders (once corresponding thread has finished)
378
  show_slider_tta = (not thread_tta.is_alive()) and (last_tta is not None)