import torch import gradio as gr import numpy as np import cv2 from transformers import VideoMAEImageProcessor, VideoMAEForPreTraining, VideoMAEForVideoClassification # Use the publicly available MCG-NJU model for both pretraining and classification model_name = "MCG-NJU/videomae-base" # Load processor and models processor = VideoMAEImageProcessor.from_pretrained(model_name) model_pretrain = VideoMAEForPreTraining.from_pretrained(model_name) model_classify = VideoMAEForVideoClassification.from_pretrained(model_name) # Example labels for classification (replace with full NTU action list if needed) labels = [ "drink water", "eat meal/snack", "brush teeth", "clapping", "writing", "reading", "wear jacket", "take off jacket", "put on a shoe", "take off a shoe" ] def preprocess_video(video_path): cap = cv2.VideoCapture(video_path) frames = [] while True: ret, frame = cap.read() if not ret: break frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(frame) cap.release() if len(frames) < 16: frames += [frames[-1]] * (16 - len(frames)) return frames[:16] def predict_video(video): if video is None: return "", {} frames = preprocess_video(video.name) pixel_values = processor(frames, return_tensors="pt").pixel_values # Masked positions for pretraining num_patches_per_frame = (model_pretrain.config.image_size // model_pretrain.config.patch_size) ** 2 seq_length = (16 // model_pretrain.config.tubelet_size) * num_patches_per_frame bool_masked_pos = torch.randint(0, 2, (1, seq_length)).bool() with torch.no_grad(): outputs = model_pretrain(pixel_values, bool_masked_pos=bool_masked_pos) loss = outputs.loss.item() with torch.no_grad(): outputs_class = model_classify(pixel_values) logits = outputs_class.logits probs = torch.nn.functional.softmax(logits, dim=1)[0] top5_prob, top5_catid = torch.topk(probs, 3) top_actions = {labels[catid]: float(prob) for prob, catid in zip(top5_prob, top5_catid)} return f"Reconstruction Loss: {loss:.4f}", top_actions def preprocess_image(image): # Resize and convert to RGB numpy array image = np.array(image.convert("RGB").resize((224, 224))) pixel_values = processor(image, return_tensors="pt").pixel_values return pixel_values def predict_image(image): if image is None: return "No image provided." pixel_values = preprocess_image(image) num_patches = (model_pretrain.config.image_size // model_pretrain.config.patch_size) ** 2 bool_masked_pos = torch.randint(0, 2, (1, num_patches)).bool() with torch.no_grad(): outputs = model_pretrain(pixel_values, bool_masked_pos=bool_masked_pos) loss = outputs.loss.item() return f"Image Reconstruction Loss: {loss:.4f}" with gr.Blocks() as demo: gr.Markdown("# VideoMAE Demo: Image and Video Input") with gr.Tab("Video Input"): video_input = gr.Video(label="Upload Video (short clip)") video_loss = gr.Textbox(label="Reconstruction Loss") video_preds = gr.Label(num_top_classes=3, label="Top 3 Action Predictions") video_btn = gr.Button("Predict Video") video_btn.click(predict_video, inputs=video_input, outputs=[video_loss, video_preds]) with gr.Tab("Image Input"): image_input = gr.Image(label="Upload Image") image_loss = gr.Textbox(label="Reconstruction Loss") image_btn = gr.Button("Predict Image") image_btn.click(predict_image, inputs=image_input, outputs=image_loss) if __name__ == "__main__": demo.launch()