import os import torch from transformers import AutoImageProcessor, ViTForImageClassification, pipeline from huggingface_hub import hf_hub_download from safetensors.torch import load_file import gradio as gr MODEL_REPO = "YUGOROU/Soba-or-Udon-vit" MODEL_FILE = "model.safetensors" BASE_MODEL = "google/vit-base-patch16-224" def load_model(): processor = AutoImageProcessor.from_pretrained(BASE_MODEL) model = ViTForImageClassification.from_pretrained( BASE_MODEL, num_labels=3, ignore_mismatched_sizes=True ) weights_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE) state_dict = load_file(weights_path) model_state = model.state_dict() filtered_state = { k: v for k, v in state_dict.items() if k in model_state and v.shape == model_state[k].shape } model.load_state_dict(filtered_state, strict=False) model.eval() device = 0 if torch.cuda.is_available() else -1 model.config.id2label = { 0: "うどん", 1: "蕎麦", 2: "どちらでもない", } model.config.label2id = {v: k for k, v in model.config.id2label.items()} pipe = pipeline( "image-classification", model=model, image_processor=processor, device=device, ) return pipe pipe = load_model() def predict(image): result = pipe(image)[0] return {result["label"]: result["score"]} demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=2), title="Soba or Udon Classifier", description="Upload an image of noodles to find out whether it is Soba or Udon.", ) if __name__ == "__main__": demo.launch()