maxhuber commited on
Commit
9ad2f73
·
1 Parent(s): d893f57

Added error handling for video loading, added output flagging

Browse files
Files changed (2) hide show
  1. app.py +12 -1
  2. helpers.py +39 -29
app.py CHANGED
@@ -10,6 +10,8 @@ theme = gr.themes.Default(
10
  font=[gr.themes.GoogleFont("IBM Plex Mono"), "system-ui"]
11
  )
12
 
 
 
13
  with gr.Blocks(theme=theme) as demo:
14
  # DEFINE COMPONENTS
15
 
@@ -49,11 +51,20 @@ with gr.Blocks(theme=theme) as demo:
49
  visible=False
50
  )
51
 
 
 
 
52
  # DEFINE FUNCTIONS
53
  # Load video from URL, display sample frames, and enable prediction button
54
  loadVideoBtn.click(fn=load_video_from_url, inputs=[urlInput], outputs=[videoTitle, sampleFrames, predVideoBtn, predOutput])
55
 
56
  # Generate video prediction
57
- predVideoBtn.click(fn=detect_deepfake, outputs=[predOutput])
 
 
 
 
 
 
58
 
59
  demo.launch()
 
10
  font=[gr.themes.GoogleFont("IBM Plex Mono"), "system-ui"]
11
  )
12
 
13
+ callback = gr.CSVLogger()
14
+
15
  with gr.Blocks(theme=theme) as demo:
16
  # DEFINE COMPONENTS
17
 
 
51
  visible=False
52
  )
53
 
54
+ # Button for flagging the output
55
+ flagBtn = gr.Button(value="Flag Output", visible=False)
56
+
57
  # DEFINE FUNCTIONS
58
  # Load video from URL, display sample frames, and enable prediction button
59
  loadVideoBtn.click(fn=load_video_from_url, inputs=[urlInput], outputs=[videoTitle, sampleFrames, predVideoBtn, predOutput])
60
 
61
  # Generate video prediction
62
+ predVideoBtn.click(fn=detect_deepfake, outputs=[predOutput, flagBtn])
63
+
64
+ # Define flag callback
65
+ callback.setup([urlInput], "flagged_data_points")
66
+
67
+ # Flag output
68
+ flagBtn.click(fn=lambda *args: callback.flag(args), inputs=[urlInput], outputs=None)
69
 
70
  demo.launch()
helpers.py CHANGED
@@ -9,40 +9,47 @@ import pickle
9
 
10
 
11
  def load_video_from_url(youtube_url):
12
- # DOWNLOAD THE VIDEO USING THE GIVEN URL
13
- yt = YouTube(youtube_url)
14
- yt_stream = yt.streams.filter(file_extension='mp4').first()
15
- title = yt_stream.title
16
- src = yt_stream.download()
17
- capture = cv2.VideoCapture(src)
18
-
19
- # SAMPLE FRAMES FROM VIDEO FILE
20
- sampled_frames = sample_frames_from_video_file(capture)
21
-
22
- # PICK EXAMPLE FRAME FROM THE MIDDLE OF THE SAMPLED FRAMES
23
- example_frames = [
24
- sampled_frames[len(sampled_frames) // 4],
25
- sampled_frames[len(sampled_frames) // 2],
26
- sampled_frames[3 * len(sampled_frames) // 4],
27
- ]
28
-
29
- # DELETE VIDEO FILE
30
- if os.path.exists(src):
31
- os.remove(src)
32
-
33
- # CONVERT SAMPLED FRAMES TO TENSOR
34
- frames_tensor = tf.expand_dims(tf.convert_to_tensor(sampled_frames, dtype=tf.float32), axis=0)
35
-
36
- # SAVE TENSOR TO FILE
37
- pickle.dump(frames_tensor, open("frames_tf.pkl", "wb"))
 
 
 
 
 
 
 
38
 
39
  # Define visible prediction components to show upon video loaded
40
- predVideoBtn = gr.Button(value="Classify Video", visible=True)
41
 
42
  predOutput = gr.Label(
43
  label="DETECTED LABEL (AND CONFIDENCE LEVEL)",
44
  num_top_classes=2,
45
- visible=True
46
  )
47
 
48
  return title, example_frames, predVideoBtn, predOutput
@@ -66,8 +73,11 @@ def detect_deepfake():
66
  fake_confidence = 1 - real_confidence
67
  confidence_dict = {"FAKE": fake_confidence, "REAL": real_confidence}
68
 
 
 
 
69
  # RETURN THE OUTPUT LABEL AND EXAMPLE FRAMES
70
- return confidence_dict
71
 
72
 
73
  def sample_frames_from_video_file(capture, sample_count=10, frames_per_sample=10, frame_step=10,
 
9
 
10
 
11
  def load_video_from_url(youtube_url):
12
+ visible = True
13
+ try:
14
+ # DOWNLOAD THE VIDEO USING THE GIVEN URL
15
+ yt = YouTube(youtube_url)
16
+ yt_stream = yt.streams.filter(file_extension='mp4').first()
17
+ title = yt_stream.title
18
+ src = yt_stream.download()
19
+ capture = cv2.VideoCapture(src)
20
+
21
+ # SAMPLE FRAMES FROM VIDEO FILE
22
+ sampled_frames = sample_frames_from_video_file(capture)
23
+
24
+ # PICK EXAMPLE FRAME FROM THE MIDDLE OF THE SAMPLED FRAMES
25
+ example_frames = [
26
+ sampled_frames[len(sampled_frames) // 4],
27
+ sampled_frames[len(sampled_frames) // 2],
28
+ sampled_frames[3 * len(sampled_frames) // 4],
29
+ ]
30
+
31
+ # DELETE VIDEO FILE
32
+ if os.path.exists(src):
33
+ os.remove(src)
34
+
35
+ # CONVERT SAMPLED FRAMES TO TENSOR
36
+ frames_tensor = tf.expand_dims(tf.convert_to_tensor(sampled_frames, dtype=tf.float32), axis=0)
37
+
38
+ # SAVE TENSOR TO FILE
39
+ pickle.dump(frames_tensor, open("frames_tf.pkl", "wb"))
40
+
41
+ except Exception as e:
42
+ title = "Error while loading video: " + str(e)
43
+ visible = False
44
+ example_frames = [np.zeros((256, 256, 3)) for _ in range(3)]
45
 
46
  # Define visible prediction components to show upon video loaded
47
+ predVideoBtn = gr.Button(value="Classify Video", visible=visible)
48
 
49
  predOutput = gr.Label(
50
  label="DETECTED LABEL (AND CONFIDENCE LEVEL)",
51
  num_top_classes=2,
52
+ visible=visible
53
  )
54
 
55
  return title, example_frames, predVideoBtn, predOutput
 
73
  fake_confidence = 1 - real_confidence
74
  confidence_dict = {"FAKE": fake_confidence, "REAL": real_confidence}
75
 
76
+ # MAKE FLAG BUTTON VISIBLE
77
+ flagBtn = gr.Button(value="Flag Output", visible=True)
78
+
79
  # RETURN THE OUTPUT LABEL AND EXAMPLE FRAMES
80
+ return confidence_dict, flagBtn
81
 
82
 
83
  def sample_frames_from_video_file(capture, sample_count=10, frames_per_sample=10, frame_step=10,