Vision-GPT: Multimodal Image Captioning

A lightweight multimodal model combining GPT-2 and Vision Transformer for image captioning, built for Smart India Hackathon 2025.

🎯 Architecture

Sparse Cross-Attention Fusion (inspired by Llama 3.2)

  • GPT-2 (124M params) - Language model backbone ❄️ frozen
  • ViT-B/16 (87M params) - Visual encoder ❄️ frozen
  • Cross-Attention + Perceiver Resampler (11M params) - Vision-language fusion πŸ”₯ trainable
    • Cross-attention inserted at layers 3, 6, 9
    • Perceiver Resampler for efficient visual token compression

Total: 222M params | Trainable: 11M params (5%)

πŸ“Š Training Details

  • Dataset: Flickr8k
  • Epochs: 2
  • Final Loss: 2.632
  • Strategy: Freeze pretrained models, train only cross-attention layers
  • Hardware: Single GPU

πŸ“¦ Model Versions

FP32 (Full Precision)

  • Size: 0.89 GB
  • Precision: 32-bit floating point
  • Use case: Maximum accuracy, research
  • Path: model_fp32/model_checkpoint.pth

FP16 (Half Precision)

  • Size: 0.52 GB
  • Precision: 16-bit floating point
  • Use case: Faster inference, 2x memory reduction
  • Path: model_fp16/model_checkpoint.pth
  • Space saved: 0.37 GB (41.9% reduction)

πŸš€ Quick Start

Installation

pip install torch torchvision transformers pillow huggingface-hub

Load Model (FP32)

import torch
from huggingface_hub import hf_hub_download

# Download checkpoint
checkpoint_path = hf_hub_download(
    repo_id="gurumurthy3/vision-gpt-flickr8k_v2",
    filename="model_fp32/model_checkpoint.pth"
)

# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model_state_dict = checkpoint['model_state_dict']

# Load your model architecture and weights
# model.load_state_dict(model_state_dict)
# model.eval()

Load Model (FP16 - Faster Inference)

import torch
from huggingface_hub import hf_hub_download

# Download FP16 checkpoint
checkpoint_path = hf_hub_download(
    repo_id="gurumurthy3/vision-gpt-flickr8k_v2",
    filename="model_fp16/model_checkpoint.pth"
)

# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location="cpu")

# Load model
# model.load_state_dict(checkpoint['model_state_dict'])
# model.half()  # Ensure model is in FP16
# model.eval()

# For GPU inference with FP16
# model = model.to('cuda')
# images = images.to('cuda').half()

Generate Caption

from PIL import Image
from torchvision import transforms

# Image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

# Load and preprocess image
image = Image.open("your_image.jpg").convert('RGB')
image_tensor = transform(image).unsqueeze(0)

# Generate caption
with torch.no_grad():
    caption = model.generate(image_tensor, max_length=50)
    print(f"Caption: {caption}")

🎨 Demo

Try it live: Multimodal GPT-2 Demo

πŸ—οΈ Model Architecture

Input Image (224Γ—224)
        ↓
   ViT-B/16 Encoder ❄️
    (87M params)
        ↓
  Perceiver Resampler πŸ”₯
   (compress to 64 tokens)
        ↓
  Cross-Attention Layers πŸ”₯
   (at layers 3, 6, 9)
        ↓
      GPT-2 ❄️
    (124M params)
        ↓
  Generated Caption

⚠️ Limitations

  • Trained only on Flickr8k (limited domain coverage)
  • English captions only
  • Best for images similar to Flickr8k dataset (people, activities, scenes)
  • May generate generic captions for out-of-domain images

πŸ“ Citation

@misc{vision-gpt-flickr8k-2025,
  author = {gurumurthy3},
  title = {Vision-GPT: Multimodal Image Captioning with Sparse Cross-Attention},
  year = {2025},
  publisher = {Hugging Face},
  journal = {Hugging Face Model Hub},
  howpublished = {\url{https://huggingface.co/gurumurthy3/vision-gpt-flickr8k_v2}}
}

πŸ™ Acknowledgments

  • Llama 3.2 Vision for sparse cross-attention inspiration
  • OpenAI for GPT-2
  • Google Research for Vision Transformer (ViT)
  • Flickr8k Dataset by Hodosh et al.
  • Smart India Hackathon 2025

πŸ“„ License

MIT License


Built with ❀️ for Smart India Hackathon 2025

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train gurumurthy3/vision-gpt-flickr8k_v2