Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,21 +2,17 @@ import torch
|
|
| 2 |
import gradio as gr
|
| 3 |
import numpy as np
|
| 4 |
import cv2
|
| 5 |
-
from transformers import
|
| 6 |
-
VideoMAEImageProcessor,
|
| 7 |
-
VideoMAEForPreTraining,
|
| 8 |
-
VideoMAEForVideoClassification,
|
| 9 |
-
)
|
| 10 |
|
| 11 |
-
#
|
| 12 |
-
|
| 13 |
-
model_name_classify = "MCG-NTU/videomae-base"
|
| 14 |
|
| 15 |
-
processor
|
| 16 |
-
|
| 17 |
-
|
|
|
|
| 18 |
|
| 19 |
-
#
|
| 20 |
labels = [
|
| 21 |
"drink water", "eat meal/snack", "brush teeth", "clapping", "writing",
|
| 22 |
"reading", "wear jacket", "take off jacket", "put on a shoe", "take off a shoe"
|
|
@@ -37,20 +33,20 @@ def preprocess_video(video_path):
|
|
| 37 |
return frames[:16]
|
| 38 |
|
| 39 |
def predict_video(video):
|
|
|
|
|
|
|
| 40 |
frames = preprocess_video(video.name)
|
| 41 |
pixel_values = processor(frames, return_tensors="pt").pixel_values
|
| 42 |
|
| 43 |
-
#
|
| 44 |
num_patches_per_frame = (model_pretrain.config.image_size // model_pretrain.config.patch_size) ** 2
|
| 45 |
seq_length = (16 // model_pretrain.config.tubelet_size) * num_patches_per_frame
|
| 46 |
bool_masked_pos = torch.randint(0, 2, (1, seq_length)).bool()
|
| 47 |
|
| 48 |
with torch.no_grad():
|
| 49 |
outputs = model_pretrain(pixel_values, bool_masked_pos=bool_masked_pos)
|
| 50 |
-
|
| 51 |
loss = outputs.loss.item()
|
| 52 |
|
| 53 |
-
# For classification: get logits and predict top 3 classes
|
| 54 |
with torch.no_grad():
|
| 55 |
outputs_class = model_classify(pixel_values)
|
| 56 |
logits = outputs_class.logits
|
|
@@ -58,64 +54,41 @@ def predict_video(video):
|
|
| 58 |
top5_prob, top5_catid = torch.topk(probs, 3)
|
| 59 |
top_actions = {labels[catid]: float(prob) for prob, catid in zip(top5_prob, top5_catid)}
|
| 60 |
|
| 61 |
-
return {
|
| 62 |
-
"Reconstruction Loss": f"{loss:.4f}",
|
| 63 |
-
"Top 3 Action Predictions": top_actions
|
| 64 |
-
}
|
| 65 |
|
| 66 |
def preprocess_image(image):
|
| 67 |
-
#
|
| 68 |
-
image = np.array(image.convert("RGB").resize((224,224)))
|
| 69 |
-
# Add batch and channel dimension
|
| 70 |
pixel_values = processor(image, return_tensors="pt").pixel_values
|
| 71 |
return pixel_values
|
| 72 |
|
| 73 |
def predict_image(image):
|
|
|
|
|
|
|
| 74 |
pixel_values = preprocess_image(image)
|
| 75 |
-
|
| 76 |
-
# For pretraining (masked autoencoding), mask patches randomly
|
| 77 |
num_patches = (model_pretrain.config.image_size // model_pretrain.config.patch_size) ** 2
|
| 78 |
bool_masked_pos = torch.randint(0, 2, (1, num_patches)).bool()
|
| 79 |
|
| 80 |
with torch.no_grad():
|
| 81 |
outputs = model_pretrain(pixel_values, bool_masked_pos=bool_masked_pos)
|
| 82 |
loss = outputs.loss.item()
|
| 83 |
-
|
| 84 |
return f"Image Reconstruction Loss: {loss:.4f}"
|
| 85 |
|
| 86 |
-
# Gradio interface with Tabs for Image and Video
|
| 87 |
with gr.Blocks() as demo:
|
| 88 |
gr.Markdown("# VideoMAE Demo: Image and Video Input")
|
| 89 |
|
| 90 |
with gr.Tab("Video Input"):
|
| 91 |
video_input = gr.Video(label="Upload Video (short clip)")
|
| 92 |
-
|
| 93 |
-
|
| 94 |
video_btn = gr.Button("Predict Video")
|
| 95 |
-
|
| 96 |
-
def video_predict_fn(video):
|
| 97 |
-
if video is None:
|
| 98 |
-
return "", {}
|
| 99 |
-
results = predict_video(video)
|
| 100 |
-
return results["Reconstruction Loss"], results["Top 3 Action Predictions"]
|
| 101 |
-
|
| 102 |
-
video_btn.click(
|
| 103 |
-
fn=video_predict_fn,
|
| 104 |
-
inputs=video_input,
|
| 105 |
-
outputs=[video_output_loss, video_output_preds],
|
| 106 |
-
)
|
| 107 |
|
| 108 |
with gr.Tab("Image Input"):
|
| 109 |
image_input = gr.Image(label="Upload Image")
|
| 110 |
-
|
| 111 |
image_btn = gr.Button("Predict Image")
|
| 112 |
-
|
| 113 |
-
image_btn.click(
|
| 114 |
-
fn=predict_image,
|
| 115 |
-
inputs=image_input,
|
| 116 |
-
outputs=image_output,
|
| 117 |
-
)
|
| 118 |
|
| 119 |
if __name__ == "__main__":
|
| 120 |
demo.launch()
|
| 121 |
-
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
import numpy as np
|
| 4 |
import cv2
|
| 5 |
+
from transformers import VideoMAEImageProcessor, VideoMAEForPreTraining, VideoMAEForVideoClassification
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
# Use the publicly available MCG-NJU model for both pretraining and classification
|
| 8 |
+
model_name = "MCG-NJU/videomae-base"
|
|
|
|
| 9 |
|
| 10 |
+
# Load processor and models
|
| 11 |
+
processor = VideoMAEImageProcessor.from_pretrained(model_name)
|
| 12 |
+
model_pretrain = VideoMAEForPreTraining.from_pretrained(model_name)
|
| 13 |
+
model_classify = VideoMAEForVideoClassification.from_pretrained(model_name)
|
| 14 |
|
| 15 |
+
# Example labels for classification (replace with full NTU action list if needed)
|
| 16 |
labels = [
|
| 17 |
"drink water", "eat meal/snack", "brush teeth", "clapping", "writing",
|
| 18 |
"reading", "wear jacket", "take off jacket", "put on a shoe", "take off a shoe"
|
|
|
|
| 33 |
return frames[:16]
|
| 34 |
|
| 35 |
def predict_video(video):
|
| 36 |
+
if video is None:
|
| 37 |
+
return "", {}
|
| 38 |
frames = preprocess_video(video.name)
|
| 39 |
pixel_values = processor(frames, return_tensors="pt").pixel_values
|
| 40 |
|
| 41 |
+
# Masked positions for pretraining
|
| 42 |
num_patches_per_frame = (model_pretrain.config.image_size // model_pretrain.config.patch_size) ** 2
|
| 43 |
seq_length = (16 // model_pretrain.config.tubelet_size) * num_patches_per_frame
|
| 44 |
bool_masked_pos = torch.randint(0, 2, (1, seq_length)).bool()
|
| 45 |
|
| 46 |
with torch.no_grad():
|
| 47 |
outputs = model_pretrain(pixel_values, bool_masked_pos=bool_masked_pos)
|
|
|
|
| 48 |
loss = outputs.loss.item()
|
| 49 |
|
|
|
|
| 50 |
with torch.no_grad():
|
| 51 |
outputs_class = model_classify(pixel_values)
|
| 52 |
logits = outputs_class.logits
|
|
|
|
| 54 |
top5_prob, top5_catid = torch.topk(probs, 3)
|
| 55 |
top_actions = {labels[catid]: float(prob) for prob, catid in zip(top5_prob, top5_catid)}
|
| 56 |
|
| 57 |
+
return f"Reconstruction Loss: {loss:.4f}", top_actions
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
def preprocess_image(image):
|
| 60 |
+
# Resize and convert to RGB numpy array
|
| 61 |
+
image = np.array(image.convert("RGB").resize((224, 224)))
|
|
|
|
| 62 |
pixel_values = processor(image, return_tensors="pt").pixel_values
|
| 63 |
return pixel_values
|
| 64 |
|
| 65 |
def predict_image(image):
|
| 66 |
+
if image is None:
|
| 67 |
+
return "No image provided."
|
| 68 |
pixel_values = preprocess_image(image)
|
|
|
|
|
|
|
| 69 |
num_patches = (model_pretrain.config.image_size // model_pretrain.config.patch_size) ** 2
|
| 70 |
bool_masked_pos = torch.randint(0, 2, (1, num_patches)).bool()
|
| 71 |
|
| 72 |
with torch.no_grad():
|
| 73 |
outputs = model_pretrain(pixel_values, bool_masked_pos=bool_masked_pos)
|
| 74 |
loss = outputs.loss.item()
|
|
|
|
| 75 |
return f"Image Reconstruction Loss: {loss:.4f}"
|
| 76 |
|
|
|
|
| 77 |
with gr.Blocks() as demo:
|
| 78 |
gr.Markdown("# VideoMAE Demo: Image and Video Input")
|
| 79 |
|
| 80 |
with gr.Tab("Video Input"):
|
| 81 |
video_input = gr.Video(label="Upload Video (short clip)")
|
| 82 |
+
video_loss = gr.Textbox(label="Reconstruction Loss")
|
| 83 |
+
video_preds = gr.Label(num_top_classes=3, label="Top 3 Action Predictions")
|
| 84 |
video_btn = gr.Button("Predict Video")
|
| 85 |
+
video_btn.click(predict_video, inputs=video_input, outputs=[video_loss, video_preds])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
with gr.Tab("Image Input"):
|
| 88 |
image_input = gr.Image(label="Upload Image")
|
| 89 |
+
image_loss = gr.Textbox(label="Reconstruction Loss")
|
| 90 |
image_btn = gr.Button("Predict Image")
|
| 91 |
+
image_btn.click(predict_image, inputs=image_input, outputs=image_loss)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
if __name__ == "__main__":
|
| 94 |
demo.launch()
|
|
|