YUGOROU
Fix label outputs with Japanese names
503c8b2
import os
import torch
from transformers import AutoImageProcessor, ViTForImageClassification, pipeline
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import gradio as gr
MODEL_REPO = "YUGOROU/Soba-or-Udon-vit"
MODEL_FILE = "model.safetensors"
BASE_MODEL = "google/vit-base-patch16-224"
def load_model():
processor = AutoImageProcessor.from_pretrained(BASE_MODEL)
model = ViTForImageClassification.from_pretrained(
BASE_MODEL,
num_labels=3,
ignore_mismatched_sizes=True
)
weights_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
state_dict = load_file(weights_path)
model_state = model.state_dict()
filtered_state = {
k: v for k, v in state_dict.items()
if k in model_state and v.shape == model_state[k].shape
}
model.load_state_dict(filtered_state, strict=False)
model.eval()
device = 0 if torch.cuda.is_available() else -1
model.config.id2label = {
0: "うどん",
1: "θ•ŽιΊ¦",
2: "どけらでもγͺい",
}
model.config.label2id = {v: k for k, v in model.config.id2label.items()}
pipe = pipeline(
"image-classification",
model=model,
image_processor=processor,
device=device,
)
return pipe
pipe = load_model()
def predict(image):
result = pipe(image)[0]
return {result["label"]: result["score"]}
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=2),
title="Soba or Udon Classifier",
description="Upload an image of noodles to find out whether it is Soba or Udon.",
)
if __name__ == "__main__":
demo.launch()