#!/usr/bin/env python """ CourtSide-CV - Tennis Analysis Space Hugging Face Gradio App """ import os import cv2 import gradio as gr import numpy as np from pathlib import Path from ultralytics import YOLO from collections import defaultdict import logging from scipy import interpolate import tempfile import subprocess logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class BallTrackerLinkedIn: """Tracker optimisé pour détection de balle de tennis""" def __init__(self, model_path): self.ball_model = YOLO(model_path) self.tracks = {} self.frame_idx = 0 self.all_positions = [] self.conf_thresh = 0.05 self.smooth_window = 5 self.max_interpolate_gap = 30 def process_batch(self, frames, progress_callback=None): """Process un batch de frames pour le tracking""" positions = [] for i, frame in enumerate(frames): if progress_callback: progress_callback((i + 1) / len(frames), desc=f"Detecting ball... {i+1}/{len(frames)}") self.frame_idx = i results = self.ball_model.track( source=frame, conf=self.conf_thresh, classes=[0], imgsz=640, iou=0.5, persist=True, verbose=False ) ball_pos = None if results[0].boxes is not None and len(results[0].boxes) > 0: best_idx = results[0].boxes.conf.argmax() x1, y1, x2, y2 = results[0].boxes.xyxy[best_idx].tolist() cx = (x1 + x2) / 2 cy = (y1 + y2) / 2 conf = float(results[0].boxes.conf[best_idx]) ball_pos = (cx, cy, conf) positions.append((i, ball_pos)) return positions def interpolate_missing(self, positions): """Interpoler les positions manquantes""" detected_frames = [] detected_x = [] detected_y = [] for frame_idx, pos in positions: if pos is not None: detected_frames.append(frame_idx) detected_x.append(pos[0]) detected_y.append(pos[1]) if len(detected_frames) < 2: return positions fx = interpolate.interp1d(detected_frames, detected_x, kind='linear', fill_value='extrapolate') fy = interpolate.interp1d(detected_frames, detected_y, kind='linear', fill_value='extrapolate') interpolated = [] for frame_idx, pos in positions: if pos is None: prev_detected = max([f for f in detected_frames if f < frame_idx], default=-999) next_detected = min([f for f in detected_frames if f > frame_idx], default=999) if (frame_idx - prev_detected <= self.max_interpolate_gap and next_detected - frame_idx <= self.max_interpolate_gap): ix = float(fx(frame_idx)) iy = float(fy(frame_idx)) interpolated.append((frame_idx, (ix, iy, 0.0))) else: interpolated.append((frame_idx, None)) else: interpolated.append((frame_idx, pos)) return interpolated def smooth_trajectory(self, positions): """Lisser la trajectoire avec filtre médian""" smoothed = [] for i, (frame_idx, pos) in enumerate(positions): if pos is None: smoothed.append((frame_idx, None)) continue window_start = max(0, i - self.smooth_window // 2) window_end = min(len(positions), i + self.smooth_window // 2 + 1) window_x = [] window_y = [] for j in range(window_start, window_end): if positions[j][1] is not None: window_x.append(positions[j][1][0]) window_y.append(positions[j][1][1]) if window_x: smooth_x = np.median(window_x) smooth_y = np.median(window_y) conf = pos[2] if len(pos) > 2 else 0.0 smoothed.append((frame_idx, (smooth_x, smooth_y, conf))) else: smoothed.append((frame_idx, pos)) return smoothed class VideoProcessorLinkedIn: """Processeur vidéo pour Gradio""" def __init__(self, ball_model_path): self.tracker = BallTrackerLinkedIn(ball_model_path) self.person_model = YOLO('yolov8m.pt') self.pose_model = YOLO('yolov8m-pose.pt') self.skeleton_connections = [ (5, 6), (5, 7), (7, 9), (6, 8), (8, 10), (5, 11), (6, 12), (11, 12), (11, 13), (13, 15), (12, 14), (14, 16), (0, 1), (0, 2), (1, 3), (2, 4) ] def draw_skeleton(self, frame, keypoints, conf_threshold=0.5): """Dessine le squelette sur la frame""" joint_color = (0, 255, 0) bone_color = (0, 255, 255) for connection in self.skeleton_connections: kp1_idx, kp2_idx = connection if kp1_idx < len(keypoints) and kp2_idx < len(keypoints): kp1 = keypoints[kp1_idx] kp2 = keypoints[kp2_idx] if len(kp1) > 2 and len(kp2) > 2: if kp1[2] > conf_threshold and kp2[2] > conf_threshold: pt1 = (int(kp1[0]), int(kp1[1])) pt2 = (int(kp2[0]), int(kp2[1])) cv2.line(frame, pt1, pt2, bone_color, 2, cv2.LINE_AA) for keypoint in keypoints: if len(keypoint) > 2 and keypoint[2] > conf_threshold: x, y = int(keypoint[0]), int(keypoint[1]) cv2.circle(frame, (x, y), 4, joint_color, -1, cv2.LINE_AA) cv2.circle(frame, (x, y), 4, (255, 255, 255), 1, cv2.LINE_AA) def process_video(self, video_path, player1_name="PLAYER 1", player2_name="PLAYER 2", max_duration=30, progress=gr.Progress(track_tqdm=True)): """Traiter la vidéo et retourner la version annotée""" if video_path is None: return None, "❌ Veuillez uploader une vidéo" try: logger.info(f"Processing video: {video_path}") # Ouvrir la vidéo cap = cv2.VideoCapture(video_path) if not cap.isOpened(): return None, "❌ Impossible d'ouvrir la vidéo" fps = int(cap.get(cv2.CAP_PROP_FPS)) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Limiter la durée max_frames = min(total_frames, int(fps * max_duration)) # Lire toutes les frames progress(0, desc="Loading video...") frames = [] for i in range(max_frames): ret, frame = cap.read() if not ret: break frames.append(frame) cap.release() if len(frames) == 0: return None, "❌ Aucune frame lue" logger.info(f"Loaded {len(frames)} frames ({width}x{height} @ {fps}fps)") # Phase 1: Tracking de la balle progress(0.1, desc="Tracking ball...") positions = self.tracker.process_batch(frames, progress_callback=progress) # Phase 2: Interpolation progress(0.4, desc="Interpolating missing positions...") positions = self.tracker.interpolate_missing(positions) # Phase 3: Lissage progress(0.5, desc="Smoothing trajectory...") positions = self.tracker.smooth_trajectory(positions) # Stats detected = sum(1 for _, p in positions if p is not None) coverage = (detected / len(positions)) * 100 # Phase 4: Rendu vidéo progress(0.6, desc="Rendering annotated video...") # Créer fichier de sortie temporaire temp_output = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(temp_output, fourcc, fps, (width, height)) trail_length = 15 trail_positions = [] for frame_idx, (_, ball_pos) in enumerate(positions): progress(0.6 + 0.3 * (frame_idx / len(frames)), desc=f"Rendering... {frame_idx+1}/{len(frames)}") annotated = frames[frame_idx].copy() # Dessiner la balle et sa trajectoire if ball_pos is not None: x, y, conf = ball_pos trail_positions.append((int(x), int(y))) if len(trail_positions) > trail_length: trail_positions.pop(0) # Trail for i in range(1, len(trail_positions)): alpha = i / len(trail_positions) thickness = int(2 + alpha * 2) cv2.line(annotated, trail_positions[i-1], trail_positions[i], (0, 255, 255), thickness, cv2.LINE_AA) # Balle radius = 8 cv2.circle(annotated, (int(x), int(y)), radius + 3, (0, 255, 255), -1, cv2.LINE_AA) cv2.circle(annotated, (int(x), int(y)), radius, (0, 255, 0), -1, cv2.LINE_AA) cv2.circle(annotated, (int(x), int(y)), radius, (255, 255, 255), 2, cv2.LINE_AA) # Détection de pose pose_results = self.pose_model(frames[frame_idx], conf=0.3, verbose=False) if pose_results[0].keypoints is not None: for keypoints in pose_results[0].keypoints.data[:2]: keypoints_np = keypoints.cpu().numpy() keypoints_with_conf = [[kp[0], kp[1], kp[2]] for kp in keypoints_np] self.draw_skeleton(annotated, keypoints_with_conf, conf_threshold=0.3) # Overlay cv2.rectangle(annotated, (0, height-45), (width, height), (0, 0, 0), -1) cv2.putText(annotated, "CourtSide-CV", (15, height-15), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2, cv2.LINE_AA) if ball_pos is not None: status = "TRACKING" if conf > 0.1 else "PREDICTED" color_status = (0, 255, 255) if conf > 0.1 else (255, 200, 0) cv2.putText(annotated, f"Ball: {status}", (width//2 - 60, height-15), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color_status, 2, cv2.LINE_AA) out.write(annotated) out.release() # Conversion finale en H.264 pour compatibilité progress(0.95, desc="Finalizing video...") final_output = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name cmd = [ 'ffmpeg', '-i', temp_output, '-c:v', 'libx264', '-preset', 'fast', '-crf', '22', '-pix_fmt', 'yuv420p', '-movflags', '+faststart', final_output, '-y', '-loglevel', 'error' ] subprocess.run(cmd, check=True) os.remove(temp_output) message = f""" ✅ **Vidéo traitée avec succès!** 📊 **Statistiques:** - Frames traitées: {len(frames)} - Couverture balle: {coverage:.1f}% - Résolution: {width}x{height} - FPS: {fps} 🎾 Prêt pour LinkedIn! """ logger.info(f"✅ Processing complete: {final_output}") return final_output, message except Exception as e: logger.error(f"Error processing video: {e}", exc_info=True) return None, f"❌ Erreur: {str(e)}" # Variable globale pour le processeur (initialisé paresseusement) processor = None def get_processor(): """Initialise le processeur de manière paresseuse""" global processor if processor is None: logger.info("Initializing processor...") # Télécharger les modèles logger.info("Downloading YOLO models...") _ = YOLO('yolov8m.pt') _ = YOLO('yolov8m-pose.pt') logger.info("✅ Models ready!") ball_model_path = 'yolov8m.pt' processor = VideoProcessorLinkedIn(ball_model_path) return processor # Interface Gradio def process_video_gradio(video, player1, player2, max_duration, progress=gr.Progress(track_tqdm=True)): """Wrapper pour Gradio""" proc = get_processor() return proc.process_video(video, player1, player2, max_duration, progress) # Créer l'interface with gr.Blocks(title="🎾 CourtSide-CV - Tennis Analysis", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🎾 CourtSide-CV - Tennis Analysis Analysez vos matchs de tennis avec l'IA ! Cette application utilise la vision par ordinateur pour : - 🎯 **Tracker la balle** en temps réel avec interpolation intelligente - 🤸 **Détecter la pose** des joueurs avec visualisation du squelette - 📊 **Analyser les trajectoires** avec lissage avancé --- """) with gr.Row(): with gr.Column(): video_input = gr.Video(label="📹 Uploadez votre vidéo de tennis") with gr.Row(): player1_input = gr.Textbox( label="👤 Nom Joueur 1 (gauche)", value="PLAYER 1", max_lines=1 ) player2_input = gr.Textbox( label="👤 Nom Joueur 2 (droite)", value="PLAYER 2", max_lines=1 ) max_duration_input = gr.Slider( minimum=5, maximum=60, value=30, step=5, label="⏱️ Durée maximale (secondes)", info="Pour des raisons de performance, limitez la durée" ) submit_btn = gr.Button("🚀 Analyser la vidéo", variant="primary", size="lg") with gr.Column(): video_output = gr.Video(label="🎬 Vidéo annotée") status_output = gr.Markdown(label="📊 Résultats") gr.Markdown(""" --- ### 💡 Conseils - Utilisez des vidéos de **bonne qualité** pour de meilleurs résultats - La **balle doit être visible** dans la majorité des frames - Pour de meilleures performances, limitez à **30 secondes** ### 🔧 Technologies - **YOLOv8** pour la détection d'objets et de poses - **ByteTrack** pour le suivi d'objets - **OpenCV** pour le traitement vidéo - **Scipy** pour l'interpolation --- Créé avec ❤️ par CourtSide-CV """) submit_btn.click( fn=process_video_gradio, inputs=[video_input, player1_input, player2_input, max_duration_input], outputs=[video_output, status_output] ) # Lancer l'application if __name__ == "__main__": demo.launch()