import gradio as gr import torch import timm from PIL import Image from torchvision import transforms from huggingface_hub import hf_hub_download from safetensors.torch import load_file # Pascal VOC classes class_names = [ "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor" ] # 🧠 Load model from HF Hub REPO_ID = "fylex/swin-s3-base-pascal_test" # 🔁 Update this MODEL_FILENAME = "model.safetensors" model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME) # Build and load model model = timm.create_model("swin_s3_base_224", pretrained=False, num_classes=len(class_names)) state_dict = load_file(model_path) model.load_state_dict(state_dict) model.eval() # Preprocessing transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.5]*3, [0.5]*3), ]) # Prediction function def predict(image): img = transform(image).unsqueeze(0) with torch.no_grad(): logits = model(img) probs = torch.nn.functional.softmax(logits, dim=1)[0] return {class_names[i]: float(probs[i]) for i in range(len(class_names))} # Gradio interface demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=5), title="Swin S3 Base - Pascal VOC Classifier", description="A Swin Transformer model fine-tuned on Pascal VOC for multi-class image classification.", ) if __name__ == "__main__": demo.launch(share=True)