LPX55
Add Gradio interface for multi-model diffusion and text generation tasks, including model loading/unloading functionality and shared state management. Introduce new tabs for text and diffusion models, enhancing user interaction and modularity.
a5723a0
import gradio as gr
import torch
from transformers import pipeline
import gc
import json
# Define available models/tasks
MODEL_CONFIGS = [
{
"name": "Text Generation (GPT-2)",
"task": "text-generation",
"model": "gpt2",
"input_type": "text",
"output_type": "text"
},
{
"name": "Image Classification (ViT)",
"task": "image-classification",
"model": "google/vit-base-patch16-224",
"input_type": "image",
"output_type": "label"
},
# Add more models/tasks as needed
]
# Shared state for demo
shared_state = gr.State({"active_model": None, "last_result": None})
# Model cache for lazy loading
model_cache = {}
def load_model(task, model_name):
# Use device_map="auto" or device=0 for GPU if available
return pipeline(task, model=model_name, device=-1)
def unload_model(model_key):
if model_key in model_cache:
del model_cache[model_key]
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
with gr.Blocks() as demo:
gr.Markdown("# Multi-Model, Multi-Task Gradio Demo\n_Switch between models and tasks in one Space!_")
tab_names = [m["name"] for m in MODEL_CONFIGS]
with gr.Tabs() as tabs:
tab_blocks = []
for i, config in enumerate(MODEL_CONFIGS):
with gr.Tab(config["name"]):
status = gr.Markdown(f"**Model:** {config['model']}<br>**Task:** {config['task']}")
load_btn = gr.Button("Load Model")
unload_btn = gr.Button("Unload Model")
if config["input_type"] == "text":
input_comp = gr.Textbox(label="Input Text")
elif config["input_type"] == "image":
input_comp = gr.Image(label="Input Image")
else:
input_comp = gr.Textbox(label="Input")
run_btn = gr.Button("Run Model")
output_comp = gr.Textbox(label="Output", lines=4)
model_key = f"{config['task']}|{config['model']}"
def do_load(state):
if model_key not in model_cache:
model_cache[model_key] = load_model(config["task"], config["model"])
state = dict(state)
state["active_model"] = model_key
return f"Loaded: {model_key}", state
def do_unload(state):
unload_model(model_key)
state = dict(state)
state["active_model"] = None
return f"Unloaded: {model_key}", state
def do_run(inp, state):
if model_key not in model_cache:
return "Model not loaded!", state
pipe = model_cache[model_key]
result = pipe(inp)
state = dict(state)
state["last_result"] = result
return str(result), state
load_btn.click(do_load, shared_state, [status, shared_state])
unload_btn.click(do_unload, shared_state, [status, shared_state])
run_btn.click(do_run, [input_comp, shared_state], [output_comp, shared_state])
# Shared state display
def pretty_json(state):
return json.dumps(state, indent=2, ensure_ascii=False)
shared_state_box = gr.Textbox(label="Shared State", lines=8, interactive=False)
shared_state.change(pretty_json, shared_state, shared_state_box)
demo.launch()