Spaces:
Runtime error
Runtime error
| import sys | |
| import os | |
| sys.path.append(os.path.abspath(os.path.dirname(__file__))) | |
| from utils import read_video, save_video | |
| from trackers import Tracker | |
| import cv2 | |
| import numpy as np | |
| from team_assigner import TeamAssigner | |
| from player_ball_assigner import PlayerBallAssigner | |
| from camera_movement_estimator import CameraMovementEstimator | |
| from view_transformer import ViewTransformer | |
| from speed_and_distance_estimator import SpeedAndDistance_Estimator | |
| # RAG + GPT 관련 모듈 | |
| from commentary_ai.generator.frame_analyzer import analyze_frame | |
| from commentary_ai.generator.retriever import load_vector_store, search_similar_sentences | |
| from commentary_ai.generator.openai_captioner import generate_caption | |
| def main(): | |
| FPS = 30 # 영상 FPS | |
| vector_store_path = os.path.abspath( | |
| os.path.join("commentary_ai", "generator", "vector_store.pkl") | |
| ) | |
| print("벡터 저장소 절대 경로:", vector_store_path) | |
| # 벡터 저장소 불러오기 | |
| load_vector_store(vector_store_path) | |
| video_frames = read_video('input_videos/sample_1.mp4') | |
| tracker = Tracker('models/best.pt') | |
| tracks = tracker.get_object_tracks( | |
| video_frames, | |
| read_from_stub=True, | |
| stub_path='stubs/track_stubs.pkl' | |
| ) | |
| tracker.add_positions_to_tracks(tracks) | |
| camera_movement_estimator = CameraMovementEstimator(video_frames[0]) | |
| camera_movement_per_frame = camera_movement_estimator.get_camera_movement( | |
| video_frames, | |
| read_from_stub=True, | |
| stub_path='stubs/camera_movement_stub.pkl' | |
| ) | |
| camera_movement_estimator.add_adjust_positions_to_tracks(tracks, camera_movement_per_frame) | |
| view_transformer = ViewTransformer() | |
| view_transformer.add_transformed_position_to_tracks(tracks) | |
| tracks["ball"] = tracker.interpolate_ball_positions(tracks["ball"]) | |
| speed_and_distance_estimator = SpeedAndDistance_Estimator() | |
| speed_and_distance_estimator.add_speed_and_distance_to_tracks(tracks) | |
| team_assigner = TeamAssigner() | |
| team_assigner.assign_team_color(video_frames[0], tracks['players'][0]) | |
| for frame_num, player_track in enumerate(tracks['players']): | |
| for player_id, track in player_track.items(): | |
| team = team_assigner.get_player_team(video_frames[frame_num], track['bbox'], player_id) | |
| tracks['players'][frame_num][player_id]['team'] = team | |
| tracks['players'][frame_num][player_id]['team_color'] = team_assigner.team_colors[team] | |
| player_assigner = PlayerBallAssigner() | |
| team_ball_control = [] | |
| previous_player_with_ball = -1 | |
| previous_team_with_ball = None | |
| subtitle_data = [] | |
| event_data = [] | |
| for frame_num, player_track in enumerate(tracks['players']): | |
| ball_bbox = tracks['ball'][frame_num][1]['bbox'] | |
| ball_speed = tracks['ball'][frame_num].get('speed', 0) | |
| assigned_player = player_assigner.assign_ball_to_player(player_track, ball_bbox) | |
| if assigned_player != -1: | |
| tracks['players'][frame_num][assigned_player]['has_ball'] = True | |
| current_team_with_ball = tracks['players'][frame_num][assigned_player]['team'] | |
| tracker.update_ball_owner(assigned_player, current_team_with_ball) | |
| team_ball_control.append(current_team_with_ball) | |
| else: | |
| current_team_with_ball = previous_team_with_ball | |
| team_ball_control.append(current_team_with_ball) | |
| event_texts = [] | |
| if previous_player_with_ball != -1 and assigned_player != previous_player_with_ball: | |
| if assigned_player != -1: | |
| event_texts.append(f"패스 성공! 플레이어 {previous_player_with_ball} ➡ 플레이어 {assigned_player}") | |
| elif assigned_player != -1: | |
| speed = tracks['players'][frame_num][assigned_player].get('speed', 0) | |
| if speed > 1.5: | |
| event_texts.append(f"플레이어 {assigned_player}이 드리블 중입니다.") | |
| if previous_team_with_ball is not None and current_team_with_ball != previous_team_with_ball: | |
| event_texts.append("태클 성공! 상대 팀이 볼을 차단했습니다.") | |
| if ball_speed > 8: | |
| event_texts.append("슛! 볼이 빠른 속도로 움직입니다.") | |
| goal_area = ((100, 50), (200, 100)) # 골대 영역 예시 | |
| if goal_area[0][0] < ball_bbox[0] < goal_area[1][0] and goal_area[0][1] < ball_bbox[1] < goal_area[1][1]: | |
| event_texts.append("골! 볼이 골대에 들어갔습니다!") | |
| event_text = "\n".join(event_texts) | |
| event_data.append(event_text) | |
| # 3초에 한 번만 자막 생성 | |
| if frame_num % (FPS * 2) != 0: | |
| if subtitle_data: | |
| subtitle_data.append(subtitle_data[-1]) # 이전 자막 유지 | |
| else: | |
| subtitle_data.append("") | |
| continue | |
| if assigned_player != -1: | |
| speed = tracks['players'][frame_num][assigned_player].get('speed', 0) | |
| else: | |
| speed = 0 | |
| frame_info = { | |
| "frame_num": frame_num, | |
| "assigned_player": assigned_player, | |
| "player_speed": speed, | |
| "ball_speed": ball_speed, | |
| "team_with_ball": current_team_with_ball, | |
| "ball_position": ball_bbox, | |
| "events": event_texts | |
| } | |
| frame_description = analyze_frame(frame_info) | |
| retrieved_examples = search_similar_sentences(frame_description, top_k=5) | |
| subtitle_text = generate_caption(frame_description, retrieved_examples) | |
| subtitle_data.append(subtitle_text) | |
| previous_player_with_ball = assigned_player | |
| previous_team_with_ball = current_team_with_ball | |
| output_video_frames = tracker.draw_annotations(video_frames, tracks, team_ball_control, subtitle_data, event_data) | |
| save_video(output_video_frames, 'output_videos/output_video.avi') | |
| if __name__ == '__main__': | |
| main() | |