Update app.py
Browse files
app.py
CHANGED
|
@@ -32,13 +32,13 @@ class BallTrackerLinkedIn:
|
|
| 32 |
self.smooth_window = 5
|
| 33 |
self.max_interpolate_gap = 30
|
| 34 |
|
| 35 |
-
def process_batch(self, frames,
|
| 36 |
"""Process un batch de frames pour le tracking"""
|
| 37 |
positions = []
|
| 38 |
|
| 39 |
for i, frame in enumerate(frames):
|
| 40 |
-
if
|
| 41 |
-
|
| 42 |
|
| 43 |
self.frame_idx = i
|
| 44 |
results = self.ball_model.track(
|
|
@@ -168,7 +168,7 @@ class VideoProcessorLinkedIn:
|
|
| 168 |
cv2.circle(frame, (x, y), 4, (255, 255, 255), 1, cv2.LINE_AA)
|
| 169 |
|
| 170 |
def process_video(self, video_path, player1_name="PLAYER 1", player2_name="PLAYER 2",
|
| 171 |
-
max_duration=30, progress=gr.Progress()):
|
| 172 |
"""Traiter la vidéo et retourner la version annotée"""
|
| 173 |
|
| 174 |
if video_path is None:
|
|
@@ -207,7 +207,7 @@ class VideoProcessorLinkedIn:
|
|
| 207 |
|
| 208 |
# Phase 1: Tracking de la balle
|
| 209 |
progress(0.1, desc="Tracking ball...")
|
| 210 |
-
positions = self.tracker.process_batch(frames, progress)
|
| 211 |
|
| 212 |
# Phase 2: Interpolation
|
| 213 |
progress(0.4, desc="Interpolating missing positions...")
|
|
@@ -233,9 +233,8 @@ class VideoProcessorLinkedIn:
|
|
| 233 |
trail_positions = []
|
| 234 |
|
| 235 |
for frame_idx, (_, ball_pos) in enumerate(positions):
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
desc=f"Rendering... {frame_idx+1}/{len(frames)}")
|
| 239 |
|
| 240 |
annotated = frames[frame_idx].copy()
|
| 241 |
|
|
@@ -318,32 +317,30 @@ class VideoProcessorLinkedIn:
|
|
| 318 |
return None, f"❌ Erreur: {str(e)}"
|
| 319 |
|
| 320 |
|
| 321 |
-
#
|
| 322 |
-
|
| 323 |
-
"""Télécharge les modèles YOLO nécessaires"""
|
| 324 |
-
logger.info("Downloading YOLO models...")
|
| 325 |
-
|
| 326 |
-
# Modèles de base (téléchargés automatiquement par ultralytics)
|
| 327 |
-
_ = YOLO('yolov8m.pt')
|
| 328 |
-
_ = YOLO('yolov8m-pose.pt')
|
| 329 |
-
|
| 330 |
-
logger.info("✅ Models ready!")
|
| 331 |
-
|
| 332 |
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
#
|
| 339 |
-
|
| 340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
|
| 342 |
|
| 343 |
# Interface Gradio
|
| 344 |
-
def process_video_gradio(video, player1, player2, max_duration):
|
| 345 |
"""Wrapper pour Gradio"""
|
| 346 |
-
|
|
|
|
| 347 |
|
| 348 |
|
| 349 |
# Créer l'interface
|
|
@@ -415,8 +412,4 @@ with gr.Blocks(title="🎾 CourtSide-CV - Tennis Analysis", theme=gr.themes.Soft
|
|
| 415 |
|
| 416 |
# Lancer l'application
|
| 417 |
if __name__ == "__main__":
|
| 418 |
-
demo.launch(
|
| 419 |
-
server_name="0.0.0.0", # Écoute sur toutes les interfaces
|
| 420 |
-
server_port=7860, # Port par défaut HF Spaces
|
| 421 |
-
share=False # Pas besoin de share sur HF
|
| 422 |
-
)
|
|
|
|
| 32 |
self.smooth_window = 5
|
| 33 |
self.max_interpolate_gap = 30
|
| 34 |
|
| 35 |
+
def process_batch(self, frames, progress_callback=None):
|
| 36 |
"""Process un batch de frames pour le tracking"""
|
| 37 |
positions = []
|
| 38 |
|
| 39 |
for i, frame in enumerate(frames):
|
| 40 |
+
if progress_callback:
|
| 41 |
+
progress_callback((i + 1) / len(frames), desc=f"Detecting ball... {i+1}/{len(frames)}")
|
| 42 |
|
| 43 |
self.frame_idx = i
|
| 44 |
results = self.ball_model.track(
|
|
|
|
| 168 |
cv2.circle(frame, (x, y), 4, (255, 255, 255), 1, cv2.LINE_AA)
|
| 169 |
|
| 170 |
def process_video(self, video_path, player1_name="PLAYER 1", player2_name="PLAYER 2",
|
| 171 |
+
max_duration=30, progress=gr.Progress(track_tqdm=True)):
|
| 172 |
"""Traiter la vidéo et retourner la version annotée"""
|
| 173 |
|
| 174 |
if video_path is None:
|
|
|
|
| 207 |
|
| 208 |
# Phase 1: Tracking de la balle
|
| 209 |
progress(0.1, desc="Tracking ball...")
|
| 210 |
+
positions = self.tracker.process_batch(frames, progress_callback=progress)
|
| 211 |
|
| 212 |
# Phase 2: Interpolation
|
| 213 |
progress(0.4, desc="Interpolating missing positions...")
|
|
|
|
| 233 |
trail_positions = []
|
| 234 |
|
| 235 |
for frame_idx, (_, ball_pos) in enumerate(positions):
|
| 236 |
+
progress(0.6 + 0.3 * (frame_idx / len(frames)),
|
| 237 |
+
desc=f"Rendering... {frame_idx+1}/{len(frames)}")
|
|
|
|
| 238 |
|
| 239 |
annotated = frames[frame_idx].copy()
|
| 240 |
|
|
|
|
| 317 |
return None, f"❌ Erreur: {str(e)}"
|
| 318 |
|
| 319 |
|
| 320 |
+
# Variable globale pour le processeur (initialisé paresseusement)
|
| 321 |
+
processor = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
|
| 323 |
+
def get_processor():
|
| 324 |
+
"""Initialise le processeur de manière paresseuse"""
|
| 325 |
+
global processor
|
| 326 |
+
if processor is None:
|
| 327 |
+
logger.info("Initializing processor...")
|
| 328 |
+
# Télécharger les modèles
|
| 329 |
+
logger.info("Downloading YOLO models...")
|
| 330 |
+
_ = YOLO('yolov8m.pt')
|
| 331 |
+
_ = YOLO('yolov8m-pose.pt')
|
| 332 |
+
logger.info("✅ Models ready!")
|
| 333 |
+
|
| 334 |
+
ball_model_path = 'yolov8m.pt'
|
| 335 |
+
processor = VideoProcessorLinkedIn(ball_model_path)
|
| 336 |
+
return processor
|
| 337 |
|
| 338 |
|
| 339 |
# Interface Gradio
|
| 340 |
+
def process_video_gradio(video, player1, player2, max_duration, progress=gr.Progress(track_tqdm=True)):
|
| 341 |
"""Wrapper pour Gradio"""
|
| 342 |
+
proc = get_processor()
|
| 343 |
+
return proc.process_video(video, player1, player2, max_duration, progress)
|
| 344 |
|
| 345 |
|
| 346 |
# Créer l'interface
|
|
|
|
| 412 |
|
| 413 |
# Lancer l'application
|
| 414 |
if __name__ == "__main__":
|
| 415 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|