Abhishek Gola commited on
Commit
89138dc
·
1 Parent(s): 71905ee

Added vit tracker to opencv spaces

Browse files
Files changed (4) hide show
  1. README.md +6 -0
  2. app.py +185 -0
  3. requirements.txt +4 -0
  4. vittrack.py +39 -0
README.md CHANGED
@@ -7,6 +7,12 @@ sdk: gradio
7
  sdk_version: 5.34.2
8
  app_file: app.py
9
  pinned: false
 
 
 
 
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
7
  sdk_version: 5.34.2
8
  app_file: app.py
9
  pinned: false
10
+ short_description: Object tracking with ViTtracker using OpenCV
11
+ tags:
12
+ - opencv
13
+ - object-tracking
14
+ - vit
15
+ - vittracker
16
  ---
17
 
18
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2 as cv
2
+ import numpy as np
3
+ import gradio as gr
4
+ from vittrack import VitTrack
5
+ from huggingface_hub import hf_hub_download
6
+ import os
7
+ import tempfile
8
+
9
+ # Download ONNX model at startup
10
+ MODEL_PATH = hf_hub_download(
11
+ repo_id="opencv/object_tracking_vittrack",
12
+ filename="object_tracking_vittrack_2023sep.onnx"
13
+ )
14
+
15
+ backend_id = cv.dnn.DNN_BACKEND_OPENCV
16
+ target_id = cv.dnn.DNN_TARGET_CPU
17
+
18
+ # Global state
19
+ state = {
20
+ "points": [],
21
+ "bbox": None,
22
+ "video_path": None,
23
+ "first_frame": None
24
+ }
25
+
26
+ def load_first_frame(video_path):
27
+ """Load video, grab first frame, reset state."""
28
+ state["video_path"] = video_path
29
+ cap = cv.VideoCapture(video_path)
30
+ has_frame, frame = cap.read()
31
+ cap.release()
32
+ if not has_frame:
33
+ return None
34
+ state["first_frame"] = frame.copy()
35
+ state["points"].clear()
36
+ state["bbox"] = None
37
+ return cv.cvtColor(frame, cv.COLOR_BGR2RGB)
38
+
39
+
40
+ def select_point(img, evt: gr.SelectData):
41
+ """Accumulate up to 4 clicks, draw polygon + bounding box."""
42
+ if state["first_frame"] is None:
43
+ return None
44
+
45
+ x, y = int(evt.index[0]), int(evt.index[1])
46
+ if len(state["points"]) < 4:
47
+ state["points"].append((x, y))
48
+
49
+ vis = state["first_frame"].copy()
50
+ # draw each point
51
+ for pt in state["points"]:
52
+ cv.circle(vis, pt, 5, (0, 255, 0), -1)
53
+ # draw connecting polygon
54
+ if len(state["points"]) > 1:
55
+ pts = np.array(state["points"], dtype=np.int32)
56
+ cv.polylines(vis, [pts], isClosed=False, color=(255, 255, 0), thickness=2)
57
+
58
+ # once we have exactly 4, compute & draw bounding rect
59
+ if len(state["points"]) == 4:
60
+ pts = np.array(state["points"], dtype=np.int32)
61
+ x0, y0, w, h = cv.boundingRect(pts)
62
+ state["bbox"] = (x0, y0, w, h)
63
+ cv.rectangle(vis, (x0, y0), (x0 + w, y0 + h), (0, 0, 255), 2)
64
+
65
+ return cv.cvtColor(vis, cv.COLOR_BGR2RGB)
66
+
67
+
68
+ def clear_points():
69
+ """Reset selected points only."""
70
+ state["points"].clear()
71
+ state["bbox"] = None
72
+ if state["first_frame"] is None:
73
+ return None
74
+ return cv.cvtColor(state["first_frame"], cv.COLOR_BGR2RGB)
75
+
76
+
77
+ def clear_all():
78
+ """Reset everything."""
79
+ state["points"].clear()
80
+ state["bbox"] = None
81
+ state["video_path"] = None
82
+ state["first_frame"] = None
83
+ return None, None, None
84
+
85
+
86
+ def track_video():
87
+ """Init VitTrack and process entire video, return output path."""
88
+ if state["video_path"] is None or state["bbox"] is None:
89
+ return None
90
+
91
+ # instantiate VitTrack
92
+ model = VitTrack(
93
+ model_path=MODEL_PATH,
94
+ backend_id=backend_id,
95
+ target_id= target_id
96
+ )
97
+
98
+ cap = cv.VideoCapture(state["video_path"])
99
+ fps = cap.get(cv.CAP_PROP_FPS)
100
+ w = int(cap.get(cv.CAP_PROP_FRAME_WIDTH))
101
+ h = int(cap.get(cv.CAP_PROP_FRAME_HEIGHT))
102
+
103
+ # prepare temporary output file
104
+ tmpdir = tempfile.gettempdir()
105
+ out_path = os.path.join(tmpdir, "vittrack_output.mp4")
106
+ writer = cv.VideoWriter(
107
+ out_path,
108
+ cv.VideoWriter_fourcc(*"mp4v"),
109
+ fps,
110
+ (w, h)
111
+ )
112
+
113
+ # read & init on first frame
114
+ _, first_frame = cap.read()
115
+ model.init(first_frame, state["bbox"])
116
+
117
+ tm = cv.TickMeter()
118
+ while True:
119
+ has_frame, frame = cap.read()
120
+ if not has_frame:
121
+ break
122
+ tm.start()
123
+ isLocated, bbox, score = model.infer(frame)
124
+ tm.stop()
125
+
126
+ vis = frame.copy()
127
+ # overlay FPS
128
+ cv.putText(vis, f"FPS:{tm.getFPS():.2f}", (w//4, 30),
129
+ cv.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
130
+ # draw tracking box or loss message
131
+ if isLocated and score >= 0.3:
132
+ x, y, w_, h_ = bbox
133
+ cv.rectangle(vis, (x, y), (x + w_, y + h_), (0, 255, 0), 2)
134
+ cv.putText(vis, f"{score:.2f}", (x, y - 10),
135
+ cv.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
136
+ else:
137
+ cv.putText(vis, "Target lost!",
138
+ (w // 2, h//4),
139
+ cv.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 3)
140
+
141
+ writer.write(vis)
142
+ tm.reset()
143
+
144
+ cap.release()
145
+ writer.release()
146
+ return out_path
147
+
148
+
149
+ with gr.Blocks() as demo:
150
+ gr.Markdown("## VitTrack: Interactive Video Object Tracking")
151
+ gr.Markdown(
152
+ """
153
+ **How to use this tool:**
154
+
155
+ 1. **Upload a video** file (e.g., `.mp4` or `.avi`).
156
+ 2. The **first frame** of the video will appear.
157
+ 3. **Click exactly 4 points** on the object you want to track. These points should outline the object as closely as possible.
158
+ 4. A **bounding box** will be drawn around the selected region automatically.
159
+ 5. Click the **Track** button to start object tracking across the entire video.
160
+ 6. The output video with tracking overlay will appear below.
161
+
162
+ You can also use:
163
+ - 🧹 **Clear Points** to reset the 4-point selection on the first frame.
164
+ - 🔄 **Clear All** to reset the uploaded video, frame, and selections.
165
+ """
166
+ )
167
+
168
+ with gr.Row():
169
+ video_in = gr.File(label="Upload Video", file_types=[".mp4", ".avi"])
170
+ first_frame = gr.Image(label="First Frame", interactive=True)
171
+ output_video = gr.Video(label="Tracking Result")
172
+
173
+ with gr.Row():
174
+ track_btn = gr.Button("Track", variant="primary")
175
+ clear_pts_btn = gr.Button("Clear Points")
176
+ clear_all_btn = gr.Button("Clear All")
177
+
178
+ video_in.change(fn=load_first_frame, inputs=video_in, outputs=first_frame)
179
+ first_frame.select(fn=select_point, inputs=first_frame, outputs=first_frame)
180
+ clear_pts_btn.click(fn=clear_points, outputs=first_frame)
181
+ clear_all_btn.click(fn=clear_all, outputs=[video_in, first_frame, output_video])
182
+ track_btn.click(fn=track_video, outputs=output_video)
183
+
184
+ if __name__ == "__main__":
185
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ opencv-python
2
+ gradio
3
+ numpy
4
+ huggingface_hub
vittrack.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is part of OpenCV Zoo project.
2
+ # It is subject to the license terms in the LICENSE file found in the same directory.
3
+
4
+ import numpy as np
5
+ import cv2 as cv
6
+
7
+ class VitTrack:
8
+ def __init__(self, model_path, backend_id=0, target_id=0):
9
+ self.model_path = model_path
10
+ self.backend_id = backend_id
11
+ self.target_id = target_id
12
+
13
+ self.params = cv.TrackerVit_Params()
14
+ self.params.net = self.model_path
15
+ self.params.backend = self.backend_id
16
+ self.params.target = self.target_id
17
+
18
+ self.model = cv.TrackerVit_create(self.params)
19
+
20
+ @property
21
+ def name(self):
22
+ return self.__class__.__name__
23
+
24
+ def setBackendAndTarget(self, backend_id, target_id):
25
+ self.backend_id = backend_id
26
+ self.target_id = target_id
27
+
28
+ self.params.backend = self.backend_id
29
+ self.params.target = self.target_id
30
+
31
+ self.model = cv.TrackerVit_create(self.params)
32
+
33
+ def init(self, image, roi):
34
+ self.model.init(image, roi)
35
+
36
+ def infer(self, image):
37
+ is_located, bbox = self.model.update(image)
38
+ score = self.model.getTrackingScore()
39
+ return is_located, bbox, score