|
import gradio as gr |
|
import torch |
|
import timm |
|
from PIL import Image |
|
from torchvision import transforms |
|
from huggingface_hub import hf_hub_download |
|
from safetensors.torch import load_file |
|
|
|
|
|
class_names = [ |
|
"aeroplane", "bicycle", "bird", "boat", "bottle", |
|
"bus", "car", "cat", "chair", "cow", |
|
"diningtable", "dog", "horse", "motorbike", "person", |
|
"pottedplant", "sheep", "sofa", "train", "tvmonitor" |
|
] |
|
|
|
|
|
REPO_ID = "fylex/swin-s3-base-pascal_test" |
|
MODEL_FILENAME = "model.safetensors" |
|
|
|
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME) |
|
|
|
|
|
model = timm.create_model("swin_s3_base_224", pretrained=False, num_classes=len(class_names)) |
|
state_dict = load_file(model_path) |
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5]*3, [0.5]*3), |
|
]) |
|
|
|
|
|
def predict(image): |
|
img = transform(image).unsqueeze(0) |
|
with torch.no_grad(): |
|
logits = model(img) |
|
probs = torch.nn.functional.softmax(logits, dim=1)[0] |
|
return {class_names[i]: float(probs[i]) for i in range(len(class_names))} |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Label(num_top_classes=5), |
|
title="Swin S3 Base - Pascal VOC Classifier", |
|
description="A Swin Transformer model fine-tuned on Pascal VOC for multi-class image classification.", |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |