Vision-GPT: Multimodal Image Captioning
A lightweight multimodal model combining GPT-2 and Vision Transformer for image captioning, built for Smart India Hackathon 2025.
π― Architecture
Sparse Cross-Attention Fusion (inspired by Llama 3.2)
- GPT-2 (124M params) - Language model backbone βοΈ frozen
- ViT-B/16 (87M params) - Visual encoder βοΈ frozen
- Cross-Attention + Perceiver Resampler (11M params) - Vision-language fusion π₯ trainable
- Cross-attention inserted at layers 3, 6, 9
- Perceiver Resampler for efficient visual token compression
Total: 222M params | Trainable: 11M params (5%)
π Training Details
- Dataset: Flickr8k
- Epochs: 2
- Final Loss: 2.632
- Strategy: Freeze pretrained models, train only cross-attention layers
- Hardware: Single GPU
π¦ Model Versions
FP32 (Full Precision)
- Size: 0.89 GB
- Precision: 32-bit floating point
- Use case: Maximum accuracy, research
- Path:
model_fp32/model_checkpoint.pth
FP16 (Half Precision)
- Size: 0.52 GB
- Precision: 16-bit floating point
- Use case: Faster inference, 2x memory reduction
- Path:
model_fp16/model_checkpoint.pth - Space saved: 0.37 GB (41.9% reduction)
π Quick Start
Installation
pip install torch torchvision transformers pillow huggingface-hub
Load Model (FP32)
import torch
from huggingface_hub import hf_hub_download
# Download checkpoint
checkpoint_path = hf_hub_download(
repo_id="gurumurthy3/vision-gpt-flickr8k_v2",
filename="model_fp32/model_checkpoint.pth"
)
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model_state_dict = checkpoint['model_state_dict']
# Load your model architecture and weights
# model.load_state_dict(model_state_dict)
# model.eval()
Load Model (FP16 - Faster Inference)
import torch
from huggingface_hub import hf_hub_download
# Download FP16 checkpoint
checkpoint_path = hf_hub_download(
repo_id="gurumurthy3/vision-gpt-flickr8k_v2",
filename="model_fp16/model_checkpoint.pth"
)
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Load model
# model.load_state_dict(checkpoint['model_state_dict'])
# model.half() # Ensure model is in FP16
# model.eval()
# For GPU inference with FP16
# model = model.to('cuda')
# images = images.to('cuda').half()
Generate Caption
from PIL import Image
from torchvision import transforms
# Image preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
# Load and preprocess image
image = Image.open("your_image.jpg").convert('RGB')
image_tensor = transform(image).unsqueeze(0)
# Generate caption
with torch.no_grad():
caption = model.generate(image_tensor, max_length=50)
print(f"Caption: {caption}")
π¨ Demo
Try it live: Multimodal GPT-2 Demo
ποΈ Model Architecture
Input Image (224Γ224)
β
ViT-B/16 Encoder βοΈ
(87M params)
β
Perceiver Resampler π₯
(compress to 64 tokens)
β
Cross-Attention Layers π₯
(at layers 3, 6, 9)
β
GPT-2 βοΈ
(124M params)
β
Generated Caption
β οΈ Limitations
- Trained only on Flickr8k (limited domain coverage)
- English captions only
- Best for images similar to Flickr8k dataset (people, activities, scenes)
- May generate generic captions for out-of-domain images
π Citation
@misc{vision-gpt-flickr8k-2025,
author = {gurumurthy3},
title = {Vision-GPT: Multimodal Image Captioning with Sparse Cross-Attention},
year = {2025},
publisher = {Hugging Face},
journal = {Hugging Face Model Hub},
howpublished = {\url{https://huggingface.co/gurumurthy3/vision-gpt-flickr8k_v2}}
}
π Acknowledgments
- Llama 3.2 Vision for sparse cross-attention inspiration
- OpenAI for GPT-2
- Google Research for Vision Transformer (ViT)
- Flickr8k Dataset by Hodosh et al.
- Smart India Hackathon 2025
π License
MIT License
Built with β€οΈ for Smart India Hackathon 2025