fylexx commited on
Commit
eee112c
·
1 Parent(s): 4b2c0cf

Gradio app using Hub model

Browse files
Files changed (2) hide show
  1. app.py +54 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import timm
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+ from huggingface_hub import hf_hub_download
7
+ from safetensors.torch import load_file
8
+
9
+ # Pascal VOC classes
10
+ class_names = [
11
+ "aeroplane", "bicycle", "bird", "boat", "bottle",
12
+ "bus", "car", "cat", "chair", "cow",
13
+ "diningtable", "dog", "horse", "motorbike", "person",
14
+ "pottedplant", "sheep", "sofa", "train", "tvmonitor"
15
+ ]
16
+
17
+ # 🧠 Load model from HF Hub
18
+ REPO_ID = "fylex/swin-s3-base-pascal_test" # 🔁 Update this
19
+ MODEL_FILENAME = "model.safetensors"
20
+
21
+ model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
22
+
23
+ # Build and load model
24
+ model = timm.create_model("swin_s3_base_224", pretrained=False, num_classes=len(class_names))
25
+ state_dict = load_file(model_path)
26
+ model.load_state_dict(state_dict)
27
+ model.eval()
28
+
29
+ # Preprocessing
30
+ transform = transforms.Compose([
31
+ transforms.Resize((224, 224)),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize([0.5]*3, [0.5]*3),
34
+ ])
35
+
36
+ # Prediction function
37
+ def predict(image):
38
+ img = transform(image).unsqueeze(0)
39
+ with torch.no_grad():
40
+ logits = model(img)
41
+ probs = torch.nn.functional.softmax(logits, dim=1)[0]
42
+ return {class_names[i]: float(probs[i]) for i in range(len(class_names))}
43
+
44
+ # Gradio interface
45
+ demo = gr.Interface(
46
+ fn=predict,
47
+ inputs=gr.Image(type="pil"),
48
+ outputs=gr.Label(num_top_classes=5),
49
+ title="Swin S3 Base - Pascal VOC Classifier",
50
+ description="A Swin Transformer model fine-tuned on Pascal VOC for multi-class image classification.",
51
+ )
52
+
53
+ if __name__ == "__main__":
54
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ timm
3
+ gradio
4
+ safetensors
5
+ Pillow
6
+ torchvision
7
+ huggingface_hub