import os import gradio as gr from transformers import AutoModelForCausalLM, AutoProcessor import torch from PIL import Image # Read token from secret hf_token = os.environ.get("HF_TOKEN") # Load model and processor with auth token model_id = "google/medgemma-4b-it" processor = AutoProcessor.from_pretrained(model_id, token=hf_token) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16, device_map="auto", token=hf_token ) # Prediction function def diagnose(image, query): inputs = processor(text=query, images=image, return_tensors="pt").to(model.device) outputs = model.generate(**inputs, max_new_tokens=256) response = processor.batch_decode(outputs, skip_special_tokens=True)[0] return response # Gradio interface demo = gr.Interface( fn=diagnose, inputs=[ gr.Image(type="pil", label="Upload Skin Image"), gr.Textbox(lines=2, label="Your Query (e.g., What is the disease?)") ], outputs=gr.Textbox(label="Diagnosis & Advice"), title="🩺 DermaScan with MedGemma", description="Upload a skin image and ask a medical question. Powered by Google MedGemma 4B." ) if __name__ == "__main__": demo.launch()