medgemma-app / app.py
janhvi145's picture
Update app.py
30ec883 verified
import os
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor
import torch
from PIL import Image
# Read token from secret
hf_token = os.environ.get("HF_TOKEN")
# Load model and processor with auth token
model_id = "google/medgemma-4b-it"
processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
token=hf_token
)
# Prediction function
def diagnose(image, query):
inputs = processor(text=query, images=image, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=256)
response = processor.batch_decode(outputs, skip_special_tokens=True)[0]
return response
# Gradio interface
demo = gr.Interface(
fn=diagnose,
inputs=[
gr.Image(type="pil", label="Upload Skin Image"),
gr.Textbox(lines=2, label="Your Query (e.g., What is the disease?)")
],
outputs=gr.Textbox(label="Diagnosis & Advice"),
title="🩺 DermaScan with MedGemma",
description="Upload a skin image and ask a medical question. Powered by Google MedGemma 4B."
)
if __name__ == "__main__":
demo.launch()