Davidsv commited on
Commit
1724828
·
verified ·
1 Parent(s): 6827f69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -34
app.py CHANGED
@@ -32,13 +32,13 @@ class BallTrackerLinkedIn:
32
  self.smooth_window = 5
33
  self.max_interpolate_gap = 30
34
 
35
- def process_batch(self, frames, progress=gr.Progress()):
36
  """Process un batch de frames pour le tracking"""
37
  positions = []
38
 
39
  for i, frame in enumerate(frames):
40
- if progress:
41
- progress((i + 1) / len(frames), desc=f"Detecting ball... {i+1}/{len(frames)}")
42
 
43
  self.frame_idx = i
44
  results = self.ball_model.track(
@@ -168,7 +168,7 @@ class VideoProcessorLinkedIn:
168
  cv2.circle(frame, (x, y), 4, (255, 255, 255), 1, cv2.LINE_AA)
169
 
170
  def process_video(self, video_path, player1_name="PLAYER 1", player2_name="PLAYER 2",
171
- max_duration=30, progress=gr.Progress()):
172
  """Traiter la vidéo et retourner la version annotée"""
173
 
174
  if video_path is None:
@@ -207,7 +207,7 @@ class VideoProcessorLinkedIn:
207
 
208
  # Phase 1: Tracking de la balle
209
  progress(0.1, desc="Tracking ball...")
210
- positions = self.tracker.process_batch(frames, progress)
211
 
212
  # Phase 2: Interpolation
213
  progress(0.4, desc="Interpolating missing positions...")
@@ -233,9 +233,8 @@ class VideoProcessorLinkedIn:
233
  trail_positions = []
234
 
235
  for frame_idx, (_, ball_pos) in enumerate(positions):
236
- if progress:
237
- progress(0.6 + 0.3 * (frame_idx / len(frames)),
238
- desc=f"Rendering... {frame_idx+1}/{len(frames)}")
239
 
240
  annotated = frames[frame_idx].copy()
241
 
@@ -318,32 +317,30 @@ class VideoProcessorLinkedIn:
318
  return None, f"❌ Erreur: {str(e)}"
319
 
320
 
321
- # Télécharger les modèles au démarrage
322
- def download_models():
323
- """Télécharge les modèles YOLO nécessaires"""
324
- logger.info("Downloading YOLO models...")
325
-
326
- # Modèles de base (téléchargés automatiquement par ultralytics)
327
- _ = YOLO('yolov8m.pt')
328
- _ = YOLO('yolov8m-pose.pt')
329
-
330
- logger.info("✅ Models ready!")
331
-
332
 
333
- # Initialiser le processeur
334
- logger.info("Initializing app...")
335
- download_models()
336
-
337
- # Note: Pour le modèle de balle de tennis personnalisé, vous devrez l'uploader
338
- # Pour l'instant, on utilise le modèle YOLO standard
339
- ball_model_path = 'yolov8m.pt' # Remplacer par le chemin de votre modèle custom
340
- processor = VideoProcessorLinkedIn(ball_model_path)
 
 
 
 
 
 
341
 
342
 
343
  # Interface Gradio
344
- def process_video_gradio(video, player1, player2, max_duration):
345
  """Wrapper pour Gradio"""
346
- return processor.process_video(video, player1, player2, max_duration)
 
347
 
348
 
349
  # Créer l'interface
@@ -415,8 +412,4 @@ with gr.Blocks(title="🎾 CourtSide-CV - Tennis Analysis", theme=gr.themes.Soft
415
 
416
  # Lancer l'application
417
  if __name__ == "__main__":
418
- demo.launch(
419
- server_name="0.0.0.0", # Écoute sur toutes les interfaces
420
- server_port=7860, # Port par défaut HF Spaces
421
- share=False # Pas besoin de share sur HF
422
- )
 
32
  self.smooth_window = 5
33
  self.max_interpolate_gap = 30
34
 
35
+ def process_batch(self, frames, progress_callback=None):
36
  """Process un batch de frames pour le tracking"""
37
  positions = []
38
 
39
  for i, frame in enumerate(frames):
40
+ if progress_callback:
41
+ progress_callback((i + 1) / len(frames), desc=f"Detecting ball... {i+1}/{len(frames)}")
42
 
43
  self.frame_idx = i
44
  results = self.ball_model.track(
 
168
  cv2.circle(frame, (x, y), 4, (255, 255, 255), 1, cv2.LINE_AA)
169
 
170
  def process_video(self, video_path, player1_name="PLAYER 1", player2_name="PLAYER 2",
171
+ max_duration=30, progress=gr.Progress(track_tqdm=True)):
172
  """Traiter la vidéo et retourner la version annotée"""
173
 
174
  if video_path is None:
 
207
 
208
  # Phase 1: Tracking de la balle
209
  progress(0.1, desc="Tracking ball...")
210
+ positions = self.tracker.process_batch(frames, progress_callback=progress)
211
 
212
  # Phase 2: Interpolation
213
  progress(0.4, desc="Interpolating missing positions...")
 
233
  trail_positions = []
234
 
235
  for frame_idx, (_, ball_pos) in enumerate(positions):
236
+ progress(0.6 + 0.3 * (frame_idx / len(frames)),
237
+ desc=f"Rendering... {frame_idx+1}/{len(frames)}")
 
238
 
239
  annotated = frames[frame_idx].copy()
240
 
 
317
  return None, f"❌ Erreur: {str(e)}"
318
 
319
 
320
+ # Variable globale pour le processeur (initialisé paresseusement)
321
+ processor = None
 
 
 
 
 
 
 
 
 
322
 
323
+ def get_processor():
324
+ """Initialise le processeur de manière paresseuse"""
325
+ global processor
326
+ if processor is None:
327
+ logger.info("Initializing processor...")
328
+ # Télécharger les modèles
329
+ logger.info("Downloading YOLO models...")
330
+ _ = YOLO('yolov8m.pt')
331
+ _ = YOLO('yolov8m-pose.pt')
332
+ logger.info("✅ Models ready!")
333
+
334
+ ball_model_path = 'yolov8m.pt'
335
+ processor = VideoProcessorLinkedIn(ball_model_path)
336
+ return processor
337
 
338
 
339
  # Interface Gradio
340
+ def process_video_gradio(video, player1, player2, max_duration, progress=gr.Progress(track_tqdm=True)):
341
  """Wrapper pour Gradio"""
342
+ proc = get_processor()
343
+ return proc.process_video(video, player1, player2, max_duration, progress)
344
 
345
 
346
  # Créer l'interface
 
412
 
413
  # Lancer l'application
414
  if __name__ == "__main__":
415
+ demo.launch()