Myogyi commited on
Commit
135c1ed
·
1 Parent(s): 47a7997

Update interfacetest2.py

Browse files
Files changed (1) hide show
  1. interfacetest2.py +116 -47
interfacetest2.py CHANGED
@@ -7,27 +7,19 @@ import torch
7
  import torch.backends.cudnn as cudnn
8
  from numpy import random
9
  import numpy as np
10
- from models.experimental import attempt_load
11
- from utils.datasets import LoadImages
12
- from utils.general import check_img_size, non_max_suppression, scale_coords, set_logging, increment_path
13
- from utils.plots import plot_one_box
14
- from utils.torch_utils import select_device, time_synchronized
15
- import gradio as gr
16
  import ffmpeg
 
17
  from fastapi import FastAPI
18
  import uvicorn
 
 
 
 
 
 
 
19
 
20
- def convert_to_h264(input_path):
21
- output_path = str(Path(input_path).with_suffix('')) + "_h264.mp4"
22
- try:
23
- stream = ffmpeg.input(input_path)
24
- stream = ffmpeg.output(stream, output_path, vcodec='libx264', acodec='aac', format='mp4', pix_fmt='yuv420p')
25
- ffmpeg.run(stream, overwrite_output=True)
26
- return output_path
27
- except ffmpeg.Error as e:
28
- print(f"FFmpeg conversion error: {e.stderr.decode()}")
29
- return input_path
30
-
31
  def compute_iou(box1, box2):
32
  x1, y1, x2, y2 = box1
33
  x1_, y1_, x2_, y2_ = box2
@@ -43,35 +35,63 @@ def compute_iou(box1, box2):
43
  union_area = box1_area + box2_area - inter_area
44
  return inter_area / union_area if union_area != 0 else 0.0
45
 
46
- def is_scanner_moving(prev_centroids, curr_box, scanner_id, threshold=5.0):
 
47
  x1, y1, x2, y2 = curr_box
48
  curr_centroid = ((x1 + x2) / 2, (y1 + y2) / 2)
49
  if scanner_id in prev_centroids:
50
  prev_x, prev_y = prev_centroids[scanner_id]
51
  distance = np.sqrt((curr_centroid[0] - prev_x)**2 + (curr_centroid[1] - prev_y)**2)
52
  return distance > threshold
53
- return False
54
 
55
- def detect_video(video_path, weights, conf_thres=0.25, iou_thres=0.45, img_size=640, device='', save_dir='runs/detect/exp'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  save_dir = Path(increment_path(Path(save_dir), exist_ok=True))
57
  save_dir.mkdir(parents=True, exist_ok=True)
58
-
 
59
  set_logging()
60
  device = select_device(device)
61
  half = device.type != 'cpu'
 
 
62
  model = attempt_load(weights, map_location=device)
63
  stride = int(model.stride.max())
64
  imgsz = check_img_size(img_size, s=stride)
 
 
 
65
  if half:
66
  model.half()
67
 
 
68
  dataset = LoadImages(video_path, img_size=imgsz, stride=stride)
 
 
69
  names = model.module.names if hasattr(model, 'module') else model.names
70
  colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
71
 
 
72
  vid_path, vid_writer = None, None
73
  prev_centroids = {}
74
  scanner_id_counter = 0
 
 
 
75
 
76
  for path, img, im0s, vid_cap in dataset:
77
  img = torch.from_numpy(img).to(device)
@@ -80,10 +100,22 @@ def detect_video(video_path, weights, conf_thres=0.25, iou_thres=0.45, img_size=
80
  if img.ndimension() == 3:
81
  img = img.unsqueeze(0)
82
 
 
 
 
 
 
 
 
 
 
83
  with torch.no_grad():
84
- pred = model(img)[0]
 
 
85
  pred = non_max_suppression(pred, conf_thres, iou_thres)
86
 
 
87
  for i, det in enumerate(pred):
88
  p = Path(path)
89
  save_path = str(save_dir / p.name.replace('.mp4', '_output.mp4'))
@@ -94,6 +126,7 @@ def detect_video(video_path, weights, conf_thres=0.25, iou_thres=0.45, img_size=
94
  item_boxes, scanner_data, phone_boxes = [], [], []
95
  curr_scanner_boxes = []
96
 
 
97
  for *xyxy, conf, cls in det:
98
  x1, y1, x2, y2 = map(int, xyxy)
99
  class_name = names[int(cls)]
@@ -106,21 +139,27 @@ def detect_video(video_path, weights, conf_thres=0.25, iou_thres=0.45, img_size=
106
  curr_scanner_boxes.append([x1, y1, x2, y2])
107
  plot_one_box(xyxy, im0, label=class_name, color=color, line_thickness=2)
108
 
 
109
  new_prev_centroids = {}
110
  if prev_centroids and curr_scanner_boxes:
111
  for curr_box in curr_scanner_boxes:
112
  curr_centroid = ((curr_box[0] + curr_box[2]) / 2, (curr_box[1] + curr_box[3]) / 2)
113
- best_match_id = min(prev_centroids.keys(),
114
- key=lambda k: np.sqrt((curr_centroid[0] - prev_centroids[k][0])**2 +
115
- (curr_centroid[1] - prev_centroids[k][1])**2),
116
- default=None)
117
- if best_match_id is not None and np.sqrt((curr_centroid[0] - prev_centroids[best_match_id][0])**2 +
118
- (curr_centroid[1] - prev_centroids[best_match_id][1])**2) < 50:
119
- scanner_id = best_match_id
 
 
 
 
 
120
  else:
121
  scanner_id = scanner_id_counter
122
  scanner_id_counter += 1
123
- is_moving = is_scanner_moving(prev_centroids, curr_box, scanner_id)
124
  movement_status = "Scanning" if is_moving else "Idle"
125
  scanner_data.append([curr_box, movement_status, scanner_id])
126
  new_prev_centroids[scanner_id] = curr_centroid
@@ -135,26 +174,35 @@ def detect_video(video_path, weights, conf_thres=0.25, iou_thres=0.45, img_size=
135
 
136
  prev_centroids = new_prev_centroids
137
 
 
138
  for scanner_box, movement_status, scanner_id in scanner_data:
139
  x1, y1, x2, y2 = scanner_box
140
  label = f"scanner {movement_status} (ID: {scanner_id})"
141
  plot_one_box([x1, y1, x2, y2], im0, label=label, color=colors[names.index("scanner")], line_thickness=2)
142
 
143
- product_scanning_status = ""
144
- payment_scanning_status = ""
145
- for scanner_box, movement_status, _ in scanner_data:
146
- for item_box in item_boxes:
147
- if movement_status == "Scanning" and compute_iou(scanner_box, item_box) > 0.1:
148
- product_scanning_status = "Product scanning is finished"
149
- for phone_box in phone_boxes:
150
- if movement_status == "Scanning" and compute_iou(scanner_box, phone_box) > 0.1:
151
- payment_scanning_status = "Payment scanning is finished"
152
-
153
- if product_scanning_status:
154
- cv2.putText(im0, product_scanning_status, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.9, colors[names.index("scanner")], 2)
155
- if payment_scanning_status:
156
- cv2.putText(im0, payment_scanning_status, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.9, colors[names.index("scanner")], 2)
157
 
 
 
 
 
 
 
 
 
 
158
  if vid_path != save_path:
159
  vid_path = save_path
160
  if isinstance(vid_writer, cv2.VideoWriter):
@@ -164,27 +212,47 @@ def detect_video(video_path, weights, conf_thres=0.25, iou_thres=0.45, img_size=
164
  vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
165
  vid_writer.write(im0)
166
 
 
167
  if isinstance(vid_writer, cv2.VideoWriter):
168
  vid_writer.release()
169
 
 
170
  output_h264 = str(Path(save_path).with_name(f"{Path(save_path).stem}_h264.mp4"))
171
  try:
172
  stream = ffmpeg.input(save_path)
173
  stream = ffmpeg.output(stream, output_h264, vcodec='libx264', acodec='aac', format='mp4', pix_fmt='yuv420p')
174
- ffmpeg.run(stream, overwrite_output=True)
175
  os.remove(save_path)
176
  return output_h264
177
  except ffmpeg.Error as e:
178
- print(f"FFmpeg error: {e.stderr.decode()}")
 
179
  return save_path
180
 
 
181
  def gradio_interface(video, conf_thres, iou_thres):
182
  weights = "/home/myominhtet/Desktop/deepsortfromscratch/yolov7/best.pt"
183
  img_size = 640
184
- video = convert_to_h264(video)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  output_video = detect_video(video, weights, conf_thres, iou_thres, img_size)
 
186
  return output_video if output_video else "Error processing video."
187
 
 
188
  interface = gr.Interface(
189
  fn=gradio_interface,
190
  inputs=[
@@ -197,6 +265,7 @@ interface = gr.Interface(
197
  description="Upload a video to run YOLO detection with custom parameters."
198
  )
199
 
 
200
  app = FastAPI()
201
  app = gr.mount_gradio_app(app, interface, path="/")
202
 
 
7
  import torch.backends.cudnn as cudnn
8
  from numpy import random
9
  import numpy as np
 
 
 
 
 
 
10
  import ffmpeg
11
+ import gradio as gr
12
  from fastapi import FastAPI
13
  import uvicorn
14
+ import shutil
15
+ from models.experimental import attempt_load
16
+ from utils.datasets import LoadStreams, LoadImages
17
+ from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, \
18
+ scale_coords, strip_optimizer, set_logging, increment_path
19
+ from utils.plots import plot_one_box
20
+ from utils.torch_utils import select_device, time_synchronized, TracedModel
21
 
22
+ # Function to compute IoU between two boxes
 
 
 
 
 
 
 
 
 
 
23
  def compute_iou(box1, box2):
24
  x1, y1, x2, y2 = box1
25
  x1_, y1_, x2_, y2_ = box2
 
35
  union_area = box1_area + box2_area - inter_area
36
  return inter_area / union_area if union_area != 0 else 0.0
37
 
38
+ # Function to check if a scanner is moving based on centroid displacement
39
+ def is_scanner_moving(prev_centroids, curr_box, scanner_id, threshold=2.0):
40
  x1, y1, x2, y2 = curr_box
41
  curr_centroid = ((x1 + x2) / 2, (y1 + y2) / 2)
42
  if scanner_id in prev_centroids:
43
  prev_x, prev_y = prev_centroids[scanner_id]
44
  distance = np.sqrt((curr_centroid[0] - prev_x)**2 + (curr_centroid[1] - prev_y)**2)
45
  return distance > threshold
46
+ return False # Default to "not moving" if no previous centroid exists
47
 
48
+ # Function to convert video to H.264 format
49
+ def convert_to_h264(input_path):
50
+ output_path = str(Path(input_path).with_suffix('')) + "_h264.mp4"
51
+ try:
52
+ stream = ffmpeg.input(input_path)
53
+ stream = ffmpeg.output(stream, output_path, vcodec='libx264', acodec='aac', format='mp4', pix_fmt='yuv420p')
54
+ ffmpeg.run(stream, cmd='/usr/bin/ffmpeg', overwrite_output=True)
55
+ return output_path
56
+ except ffmpeg.Error as e:
57
+ stderr = e.stderr.decode('utf-8') if e.stderr else "Unknown FFmpeg error"
58
+ print(f"FFmpeg error: {stderr}")
59
+ return input_path
60
+
61
+ # Detection function adapted from the second script
62
+ def detect_video(video_path, weights, conf_thres=0.25, iou_thres=0.45, img_size=640, device='', save_dir='runs/detect/exp', trace=False):
63
  save_dir = Path(increment_path(Path(save_dir), exist_ok=True))
64
  save_dir.mkdir(parents=True, exist_ok=True)
65
+
66
+ # Initialize
67
  set_logging()
68
  device = select_device(device)
69
  half = device.type != 'cpu'
70
+
71
+ # Load model
72
  model = attempt_load(weights, map_location=device)
73
  stride = int(model.stride.max())
74
  imgsz = check_img_size(img_size, s=stride)
75
+
76
+ if trace:
77
+ model = TracedModel(model, device, img_size)
78
  if half:
79
  model.half()
80
 
81
+ # Set Dataloader
82
  dataset = LoadImages(video_path, img_size=imgsz, stride=stride)
83
+
84
+ # Get names and colors
85
  names = model.module.names if hasattr(model, 'module') else model.names
86
  colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
87
 
88
+ # Initialize variables
89
  vid_path, vid_writer = None, None
90
  prev_centroids = {}
91
  scanner_id_counter = 0
92
+ product_scanning_status_global = ""
93
+ payment_scanning_status_global = ""
94
+ old_img_b, old_img_h, old_img_w = 0, 0, 0
95
 
96
  for path, img, im0s, vid_cap in dataset:
97
  img = torch.from_numpy(img).to(device)
 
100
  if img.ndimension() == 3:
101
  img = img.unsqueeze(0)
102
 
103
+ # Warmup
104
+ if device.type != 'cpu' and (old_img_b != img.shape[0] or old_img_h != img.shape[2] or old_img_w != img.shape[3]):
105
+ old_img_b = img.shape[0]
106
+ old_img_h = img.shape[2]
107
+ old_img_w = img.shape[3]
108
+ for _ in range(3):
109
+ model(img)[0]
110
+
111
+ # Inference
112
  with torch.no_grad():
113
+ pred = model(img, augment=False)[0]
114
+
115
+ # Apply NMS
116
  pred = non_max_suppression(pred, conf_thres, iou_thres)
117
 
118
+ # Process detections
119
  for i, det in enumerate(pred):
120
  p = Path(path)
121
  save_path = str(save_dir / p.name.replace('.mp4', '_output.mp4'))
 
126
  item_boxes, scanner_data, phone_boxes = [], [], []
127
  curr_scanner_boxes = []
128
 
129
+ # Process each detection
130
  for *xyxy, conf, cls in det:
131
  x1, y1, x2, y2 = map(int, xyxy)
132
  class_name = names[int(cls)]
 
139
  curr_scanner_boxes.append([x1, y1, x2, y2])
140
  plot_one_box(xyxy, im0, label=class_name, color=color, line_thickness=2)
141
 
142
+ # Match scanner boxes with previous frames
143
  new_prev_centroids = {}
144
  if prev_centroids and curr_scanner_boxes:
145
  for curr_box in curr_scanner_boxes:
146
  curr_centroid = ((curr_box[0] + curr_box[2]) / 2, (curr_box[1] + curr_box[3]) / 2)
147
+ best_match_id = min(prev_centroids.keys(),
148
+ key=lambda k: np.sqrt((curr_centroid[0] - prev_centroids[k][0])**2 +
149
+ (curr_centroid[1] - prev_centroids[k][1])**2),
150
+ default=None)
151
+ if best_match_id is not None:
152
+ distance = np.sqrt((curr_centroid[0] - prev_centroids[best_match_id][0])**2 +
153
+ (curr_centroid[1] - prev_centroids[best_match_id][1])**2)
154
+ if distance < 50:
155
+ scanner_id = best_match_id
156
+ else:
157
+ scanner_id = scanner_id_counter
158
+ scanner_id_counter += 1
159
  else:
160
  scanner_id = scanner_id_counter
161
  scanner_id_counter += 1
162
+ is_moving = is_scanner_moving(prev_centroids, curr_box, scanner_id, threshold=2.0)
163
  movement_status = "Scanning" if is_moving else "Idle"
164
  scanner_data.append([curr_box, movement_status, scanner_id])
165
  new_prev_centroids[scanner_id] = curr_centroid
 
174
 
175
  prev_centroids = new_prev_centroids
176
 
177
+ # Redraw scanner boxes with movement status
178
  for scanner_box, movement_status, scanner_id in scanner_data:
179
  x1, y1, x2, y2 = scanner_box
180
  label = f"scanner {movement_status} (ID: {scanner_id})"
181
  plot_one_box([x1, y1, x2, y2], im0, label=label, color=colors[names.index("scanner")], line_thickness=2)
182
 
183
+ # Check for overlaps only if scanning status hasn't been set
184
+ if not product_scanning_status_global:
185
+ for item_box in item_boxes:
186
+ iou = compute_iou(scanner_box, item_box)
187
+ if movement_status == "Scanning" and iou > 0.02:
188
+ product_scanning_status_global = "Product scanning is finished"
189
+ print(f"Product scanning finished at frame {i}")
190
+ if not payment_scanning_status_global:
191
+ for phone_box in phone_boxes:
192
+ iou = compute_iou(scanner_box, phone_box)
193
+ if movement_status == "Scanning" and iou > 0.02:
194
+ payment_scanning_status_global = "Payment scanning is finished"
195
+ print(f"Payment scanning finished at frame {i}")
 
196
 
197
+ # Display persistent labels
198
+ if product_scanning_status_global:
199
+ cv2.putText(im0, product_scanning_status_global, (10, 30),
200
+ cv2.FONT_HERSHEY_SIMPLEX, 0.9, colors[names.index("scanner")], 2)
201
+ if payment_scanning_status_global:
202
+ cv2.putText(im0, payment_scanning_status_global, (10, 60),
203
+ cv2.FONT_HERSHEY_SIMPLEX, 0.9, colors[names.index("scanner")], 2)
204
+
205
+ # Write frame to video
206
  if vid_path != save_path:
207
  vid_path = save_path
208
  if isinstance(vid_writer, cv2.VideoWriter):
 
212
  vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
213
  vid_writer.write(im0)
214
 
215
+ # Cleanup
216
  if isinstance(vid_writer, cv2.VideoWriter):
217
  vid_writer.release()
218
 
219
+ # Convert to H.264
220
  output_h264 = str(Path(save_path).with_name(f"{Path(save_path).stem}_h264.mp4"))
221
  try:
222
  stream = ffmpeg.input(save_path)
223
  stream = ffmpeg.output(stream, output_h264, vcodec='libx264', acodec='aac', format='mp4', pix_fmt='yuv420p')
224
+ ffmpeg.run(stream, cmd='/usr/bin/ffmpeg', overwrite_output=True)
225
  os.remove(save_path)
226
  return output_h264
227
  except ffmpeg.Error as e:
228
+ stderr = e.stderr.decode('utf-8') if e.stderr else "Unknown FFmpeg error"
229
+ print(f"FFmpeg error: {stderr}")
230
  return save_path
231
 
232
+ # Gradio interface
233
  def gradio_interface(video, conf_thres, iou_thres):
234
  weights = "/home/myominhtet/Desktop/deepsortfromscratch/yolov7/best.pt"
235
  img_size = 640
236
+
237
+ # Create a stable directory for video files
238
+ stable_dir = "/home/myominhtet/Desktop/deepsortfromscratch/videos"
239
+ os.makedirs(stable_dir, exist_ok=True)
240
+
241
+ # Copy the uploaded video to a stable path
242
+ stable_path = os.path.join(stable_dir, f"input_{Path(video).name}")
243
+ shutil.copy(video, stable_path)
244
+ print(f"Copied video to: {stable_path}")
245
+
246
+ # Verify the copied file
247
+ print(f"Stable path exists: {os.path.exists(stable_path)}")
248
+ print(f"Stable path readable: {os.access(stable_path, os.R_OK)}")
249
+
250
+ video = convert_to_h264(stable_path)
251
  output_video = detect_video(video, weights, conf_thres, iou_thres, img_size)
252
+
253
  return output_video if output_video else "Error processing video."
254
 
255
+ # Set up Gradio interface
256
  interface = gr.Interface(
257
  fn=gradio_interface,
258
  inputs=[
 
265
  description="Upload a video to run YOLO detection with custom parameters."
266
  )
267
 
268
+ # Set up FastAPI app
269
  app = FastAPI()
270
  app = gr.mount_gradio_app(app, interface, path="/")
271