yugangee's picture
add main.py only
99bfd8f
raw
history blame
5.92 kB
import sys
import os
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
from utils import read_video, save_video
from trackers import Tracker
import cv2
import numpy as np
from team_assigner import TeamAssigner
from player_ball_assigner import PlayerBallAssigner
from camera_movement_estimator import CameraMovementEstimator
from view_transformer import ViewTransformer
from speed_and_distance_estimator import SpeedAndDistance_Estimator
# RAG + GPT 관련 모듈
from commentary_ai.generator.frame_analyzer import analyze_frame
from commentary_ai.generator.retriever import load_vector_store, search_similar_sentences
from commentary_ai.generator.openai_captioner import generate_caption
def main():
FPS = 30 # 영상 FPS
vector_store_path = os.path.abspath(
os.path.join("commentary_ai", "generator", "vector_store.pkl")
)
print("벡터 저장소 절대 경로:", vector_store_path)
# 벡터 저장소 불러오기
load_vector_store(vector_store_path)
video_frames = read_video('input_videos/sample_1.mp4')
tracker = Tracker('models/best.pt')
tracks = tracker.get_object_tracks(
video_frames,
read_from_stub=True,
stub_path='stubs/track_stubs.pkl'
)
tracker.add_positions_to_tracks(tracks)
camera_movement_estimator = CameraMovementEstimator(video_frames[0])
camera_movement_per_frame = camera_movement_estimator.get_camera_movement(
video_frames,
read_from_stub=True,
stub_path='stubs/camera_movement_stub.pkl'
)
camera_movement_estimator.add_adjust_positions_to_tracks(tracks, camera_movement_per_frame)
view_transformer = ViewTransformer()
view_transformer.add_transformed_position_to_tracks(tracks)
tracks["ball"] = tracker.interpolate_ball_positions(tracks["ball"])
speed_and_distance_estimator = SpeedAndDistance_Estimator()
speed_and_distance_estimator.add_speed_and_distance_to_tracks(tracks)
team_assigner = TeamAssigner()
team_assigner.assign_team_color(video_frames[0], tracks['players'][0])
for frame_num, player_track in enumerate(tracks['players']):
for player_id, track in player_track.items():
team = team_assigner.get_player_team(video_frames[frame_num], track['bbox'], player_id)
tracks['players'][frame_num][player_id]['team'] = team
tracks['players'][frame_num][player_id]['team_color'] = team_assigner.team_colors[team]
player_assigner = PlayerBallAssigner()
team_ball_control = []
previous_player_with_ball = -1
previous_team_with_ball = None
subtitle_data = []
event_data = []
for frame_num, player_track in enumerate(tracks['players']):
ball_bbox = tracks['ball'][frame_num][1]['bbox']
ball_speed = tracks['ball'][frame_num].get('speed', 0)
assigned_player = player_assigner.assign_ball_to_player(player_track, ball_bbox)
if assigned_player != -1:
tracks['players'][frame_num][assigned_player]['has_ball'] = True
current_team_with_ball = tracks['players'][frame_num][assigned_player]['team']
tracker.update_ball_owner(assigned_player, current_team_with_ball)
team_ball_control.append(current_team_with_ball)
else:
current_team_with_ball = previous_team_with_ball
team_ball_control.append(current_team_with_ball)
event_texts = []
if previous_player_with_ball != -1 and assigned_player != previous_player_with_ball:
if assigned_player != -1:
event_texts.append(f"패스 성공! 플레이어 {previous_player_with_ball} ➡ 플레이어 {assigned_player}")
elif assigned_player != -1:
speed = tracks['players'][frame_num][assigned_player].get('speed', 0)
if speed > 1.5:
event_texts.append(f"플레이어 {assigned_player}이 드리블 중입니다.")
if previous_team_with_ball is not None and current_team_with_ball != previous_team_with_ball:
event_texts.append("태클 성공! 상대 팀이 볼을 차단했습니다.")
if ball_speed > 8:
event_texts.append("슛! 볼이 빠른 속도로 움직입니다.")
goal_area = ((100, 50), (200, 100)) # 골대 영역 예시
if goal_area[0][0] < ball_bbox[0] < goal_area[1][0] and goal_area[0][1] < ball_bbox[1] < goal_area[1][1]:
event_texts.append("골! 볼이 골대에 들어갔습니다!")
event_text = "\n".join(event_texts)
event_data.append(event_text)
# 3초에 한 번만 자막 생성
if frame_num % (FPS * 2) != 0:
if subtitle_data:
subtitle_data.append(subtitle_data[-1]) # 이전 자막 유지
else:
subtitle_data.append("")
continue
if assigned_player != -1:
speed = tracks['players'][frame_num][assigned_player].get('speed', 0)
else:
speed = 0
frame_info = {
"frame_num": frame_num,
"assigned_player": assigned_player,
"player_speed": speed,
"ball_speed": ball_speed,
"team_with_ball": current_team_with_ball,
"ball_position": ball_bbox,
"events": event_texts
}
frame_description = analyze_frame(frame_info)
retrieved_examples = search_similar_sentences(frame_description, top_k=5)
subtitle_text = generate_caption(frame_description, retrieved_examples)
subtitle_data.append(subtitle_text)
previous_player_with_ball = assigned_player
previous_team_with_ball = current_team_with_ball
output_video_frames = tracker.draw_annotations(video_frames, tracks, team_ball_control, subtitle_data, event_data)
save_video(output_video_frames, 'output_videos/output_video.avi')
if __name__ == '__main__':
main()