SuriRaja commited on
Commit
760aaa1
·
verified ·
1 Parent(s): 45209ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -48
app.py CHANGED
@@ -2,21 +2,17 @@ import torch
2
  import gradio as gr
3
  import numpy as np
4
  import cv2
5
- from transformers import (
6
- VideoMAEImageProcessor,
7
- VideoMAEForPreTraining,
8
- VideoMAEForVideoClassification,
9
- )
10
 
11
- # Initialize model and processor for pretraining (reconstruction) and classification
12
- model_name_pretrain = "MCG-NJU/videomae-base"
13
- model_name_classify = "MCG-NTU/videomae-base"
14
 
15
- processor = VideoMAEImageProcessor.from_pretrained(model_name_pretrain)
16
- model_pretrain = VideoMAEForPreTraining.from_pretrained(model_name_pretrain)
17
- model_classify = VideoMAEForVideoClassification.from_pretrained(model_name_classify)
 
18
 
19
- # Some example labels for NTU dataset (replace with full list as needed)
20
  labels = [
21
  "drink water", "eat meal/snack", "brush teeth", "clapping", "writing",
22
  "reading", "wear jacket", "take off jacket", "put on a shoe", "take off a shoe"
@@ -37,20 +33,20 @@ def preprocess_video(video_path):
37
  return frames[:16]
38
 
39
  def predict_video(video):
 
 
40
  frames = preprocess_video(video.name)
41
  pixel_values = processor(frames, return_tensors="pt").pixel_values
42
 
43
- # For pretraining: random mask
44
  num_patches_per_frame = (model_pretrain.config.image_size // model_pretrain.config.patch_size) ** 2
45
  seq_length = (16 // model_pretrain.config.tubelet_size) * num_patches_per_frame
46
  bool_masked_pos = torch.randint(0, 2, (1, seq_length)).bool()
47
 
48
  with torch.no_grad():
49
  outputs = model_pretrain(pixel_values, bool_masked_pos=bool_masked_pos)
50
-
51
  loss = outputs.loss.item()
52
 
53
- # For classification: get logits and predict top 3 classes
54
  with torch.no_grad():
55
  outputs_class = model_classify(pixel_values)
56
  logits = outputs_class.logits
@@ -58,64 +54,41 @@ def predict_video(video):
58
  top5_prob, top5_catid = torch.topk(probs, 3)
59
  top_actions = {labels[catid]: float(prob) for prob, catid in zip(top5_prob, top5_catid)}
60
 
61
- return {
62
- "Reconstruction Loss": f"{loss:.4f}",
63
- "Top 3 Action Predictions": top_actions
64
- }
65
 
66
  def preprocess_image(image):
67
- # Convert PIL image to numpy RGB array and resize
68
- image = np.array(image.convert("RGB").resize((224,224)))
69
- # Add batch and channel dimension
70
  pixel_values = processor(image, return_tensors="pt").pixel_values
71
  return pixel_values
72
 
73
  def predict_image(image):
 
 
74
  pixel_values = preprocess_image(image)
75
-
76
- # For pretraining (masked autoencoding), mask patches randomly
77
  num_patches = (model_pretrain.config.image_size // model_pretrain.config.patch_size) ** 2
78
  bool_masked_pos = torch.randint(0, 2, (1, num_patches)).bool()
79
 
80
  with torch.no_grad():
81
  outputs = model_pretrain(pixel_values, bool_masked_pos=bool_masked_pos)
82
  loss = outputs.loss.item()
83
-
84
  return f"Image Reconstruction Loss: {loss:.4f}"
85
 
86
- # Gradio interface with Tabs for Image and Video
87
  with gr.Blocks() as demo:
88
  gr.Markdown("# VideoMAE Demo: Image and Video Input")
89
 
90
  with gr.Tab("Video Input"):
91
  video_input = gr.Video(label="Upload Video (short clip)")
92
- video_output_loss = gr.Textbox(label="Reconstruction Loss")
93
- video_output_preds = gr.Label(num_top_classes=3, label="Top 3 Action Predictions")
94
  video_btn = gr.Button("Predict Video")
95
-
96
- def video_predict_fn(video):
97
- if video is None:
98
- return "", {}
99
- results = predict_video(video)
100
- return results["Reconstruction Loss"], results["Top 3 Action Predictions"]
101
-
102
- video_btn.click(
103
- fn=video_predict_fn,
104
- inputs=video_input,
105
- outputs=[video_output_loss, video_output_preds],
106
- )
107
 
108
  with gr.Tab("Image Input"):
109
  image_input = gr.Image(label="Upload Image")
110
- image_output = gr.Textbox(label="Reconstruction Loss")
111
  image_btn = gr.Button("Predict Image")
112
-
113
- image_btn.click(
114
- fn=predict_image,
115
- inputs=image_input,
116
- outputs=image_output,
117
- )
118
 
119
  if __name__ == "__main__":
120
  demo.launch()
121
-
 
2
  import gradio as gr
3
  import numpy as np
4
  import cv2
5
+ from transformers import VideoMAEImageProcessor, VideoMAEForPreTraining, VideoMAEForVideoClassification
 
 
 
 
6
 
7
+ # Use the publicly available MCG-NJU model for both pretraining and classification
8
+ model_name = "MCG-NJU/videomae-base"
 
9
 
10
+ # Load processor and models
11
+ processor = VideoMAEImageProcessor.from_pretrained(model_name)
12
+ model_pretrain = VideoMAEForPreTraining.from_pretrained(model_name)
13
+ model_classify = VideoMAEForVideoClassification.from_pretrained(model_name)
14
 
15
+ # Example labels for classification (replace with full NTU action list if needed)
16
  labels = [
17
  "drink water", "eat meal/snack", "brush teeth", "clapping", "writing",
18
  "reading", "wear jacket", "take off jacket", "put on a shoe", "take off a shoe"
 
33
  return frames[:16]
34
 
35
  def predict_video(video):
36
+ if video is None:
37
+ return "", {}
38
  frames = preprocess_video(video.name)
39
  pixel_values = processor(frames, return_tensors="pt").pixel_values
40
 
41
+ # Masked positions for pretraining
42
  num_patches_per_frame = (model_pretrain.config.image_size // model_pretrain.config.patch_size) ** 2
43
  seq_length = (16 // model_pretrain.config.tubelet_size) * num_patches_per_frame
44
  bool_masked_pos = torch.randint(0, 2, (1, seq_length)).bool()
45
 
46
  with torch.no_grad():
47
  outputs = model_pretrain(pixel_values, bool_masked_pos=bool_masked_pos)
 
48
  loss = outputs.loss.item()
49
 
 
50
  with torch.no_grad():
51
  outputs_class = model_classify(pixel_values)
52
  logits = outputs_class.logits
 
54
  top5_prob, top5_catid = torch.topk(probs, 3)
55
  top_actions = {labels[catid]: float(prob) for prob, catid in zip(top5_prob, top5_catid)}
56
 
57
+ return f"Reconstruction Loss: {loss:.4f}", top_actions
 
 
 
58
 
59
  def preprocess_image(image):
60
+ # Resize and convert to RGB numpy array
61
+ image = np.array(image.convert("RGB").resize((224, 224)))
 
62
  pixel_values = processor(image, return_tensors="pt").pixel_values
63
  return pixel_values
64
 
65
  def predict_image(image):
66
+ if image is None:
67
+ return "No image provided."
68
  pixel_values = preprocess_image(image)
 
 
69
  num_patches = (model_pretrain.config.image_size // model_pretrain.config.patch_size) ** 2
70
  bool_masked_pos = torch.randint(0, 2, (1, num_patches)).bool()
71
 
72
  with torch.no_grad():
73
  outputs = model_pretrain(pixel_values, bool_masked_pos=bool_masked_pos)
74
  loss = outputs.loss.item()
 
75
  return f"Image Reconstruction Loss: {loss:.4f}"
76
 
 
77
  with gr.Blocks() as demo:
78
  gr.Markdown("# VideoMAE Demo: Image and Video Input")
79
 
80
  with gr.Tab("Video Input"):
81
  video_input = gr.Video(label="Upload Video (short clip)")
82
+ video_loss = gr.Textbox(label="Reconstruction Loss")
83
+ video_preds = gr.Label(num_top_classes=3, label="Top 3 Action Predictions")
84
  video_btn = gr.Button("Predict Video")
85
+ video_btn.click(predict_video, inputs=video_input, outputs=[video_loss, video_preds])
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  with gr.Tab("Image Input"):
88
  image_input = gr.Image(label="Upload Image")
89
+ image_loss = gr.Textbox(label="Reconstruction Loss")
90
  image_btn = gr.Button("Predict Image")
91
+ image_btn.click(predict_image, inputs=image_input, outputs=image_loss)
 
 
 
 
 
92
 
93
  if __name__ == "__main__":
94
  demo.launch()