Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from transformers import AutoImageProcessor, ViTForImageClassification, pipeline | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
import gradio as gr | |
MODEL_REPO = "YUGOROU/Soba-or-Udon-vit" | |
MODEL_FILE = "model.safetensors" | |
BASE_MODEL = "google/vit-base-patch16-224" | |
def load_model(): | |
processor = AutoImageProcessor.from_pretrained(BASE_MODEL) | |
model = ViTForImageClassification.from_pretrained( | |
BASE_MODEL, | |
num_labels=3, | |
ignore_mismatched_sizes=True | |
) | |
weights_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE) | |
state_dict = load_file(weights_path) | |
model_state = model.state_dict() | |
filtered_state = { | |
k: v for k, v in state_dict.items() | |
if k in model_state and v.shape == model_state[k].shape | |
} | |
model.load_state_dict(filtered_state, strict=False) | |
model.eval() | |
device = 0 if torch.cuda.is_available() else -1 | |
model.config.id2label = { | |
0: "γγ©γ", | |
1: "θιΊ¦", | |
2: "γ©γ‘γγ§γγͺγ", | |
} | |
model.config.label2id = {v: k for k, v in model.config.id2label.items()} | |
pipe = pipeline( | |
"image-classification", | |
model=model, | |
image_processor=processor, | |
device=device, | |
) | |
return pipe | |
pipe = load_model() | |
def predict(image): | |
result = pipe(image)[0] | |
return {result["label"]: result["score"]} | |
demo = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Label(num_top_classes=2), | |
title="Soba or Udon Classifier", | |
description="Upload an image of noodles to find out whether it is Soba or Udon.", | |
) | |
if __name__ == "__main__": | |
demo.launch() | |