Spaces:
Runtime error
Runtime error
LPX55
Add Gradio interface for multi-model diffusion and text generation tasks, including model loading/unloading functionality and shared state management. Introduce new tabs for text and diffusion models, enhancing user interaction and modularity.
a5723a0
import gradio as gr | |
import torch | |
from transformers import pipeline | |
import gc | |
import json | |
# Define available models/tasks | |
MODEL_CONFIGS = [ | |
{ | |
"name": "Text Generation (GPT-2)", | |
"task": "text-generation", | |
"model": "gpt2", | |
"input_type": "text", | |
"output_type": "text" | |
}, | |
{ | |
"name": "Image Classification (ViT)", | |
"task": "image-classification", | |
"model": "google/vit-base-patch16-224", | |
"input_type": "image", | |
"output_type": "label" | |
}, | |
# Add more models/tasks as needed | |
] | |
# Shared state for demo | |
shared_state = gr.State({"active_model": None, "last_result": None}) | |
# Model cache for lazy loading | |
model_cache = {} | |
def load_model(task, model_name): | |
# Use device_map="auto" or device=0 for GPU if available | |
return pipeline(task, model=model_name, device=-1) | |
def unload_model(model_key): | |
if model_key in model_cache: | |
del model_cache[model_key] | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
with gr.Blocks() as demo: | |
gr.Markdown("# Multi-Model, Multi-Task Gradio Demo\n_Switch between models and tasks in one Space!_") | |
tab_names = [m["name"] for m in MODEL_CONFIGS] | |
with gr.Tabs() as tabs: | |
tab_blocks = [] | |
for i, config in enumerate(MODEL_CONFIGS): | |
with gr.Tab(config["name"]): | |
status = gr.Markdown(f"**Model:** {config['model']}<br>**Task:** {config['task']}") | |
load_btn = gr.Button("Load Model") | |
unload_btn = gr.Button("Unload Model") | |
if config["input_type"] == "text": | |
input_comp = gr.Textbox(label="Input Text") | |
elif config["input_type"] == "image": | |
input_comp = gr.Image(label="Input Image") | |
else: | |
input_comp = gr.Textbox(label="Input") | |
run_btn = gr.Button("Run Model") | |
output_comp = gr.Textbox(label="Output", lines=4) | |
model_key = f"{config['task']}|{config['model']}" | |
def do_load(state): | |
if model_key not in model_cache: | |
model_cache[model_key] = load_model(config["task"], config["model"]) | |
state = dict(state) | |
state["active_model"] = model_key | |
return f"Loaded: {model_key}", state | |
def do_unload(state): | |
unload_model(model_key) | |
state = dict(state) | |
state["active_model"] = None | |
return f"Unloaded: {model_key}", state | |
def do_run(inp, state): | |
if model_key not in model_cache: | |
return "Model not loaded!", state | |
pipe = model_cache[model_key] | |
result = pipe(inp) | |
state = dict(state) | |
state["last_result"] = result | |
return str(result), state | |
load_btn.click(do_load, shared_state, [status, shared_state]) | |
unload_btn.click(do_unload, shared_state, [status, shared_state]) | |
run_btn.click(do_run, [input_comp, shared_state], [output_comp, shared_state]) | |
# Shared state display | |
def pretty_json(state): | |
return json.dumps(state, indent=2, ensure_ascii=False) | |
shared_state_box = gr.Textbox(label="Shared State", lines=8, interactive=False) | |
shared_state.change(pretty_json, shared_state, shared_state_box) | |
demo.launch() |