import gradio as gr import torch from transformers import pipeline import gc import json # Define available models/tasks MODEL_CONFIGS = [ { "name": "Text Generation (GPT-2)", "task": "text-generation", "model": "gpt2", "input_type": "text", "output_type": "text" }, { "name": "Image Classification (ViT)", "task": "image-classification", "model": "google/vit-base-patch16-224", "input_type": "image", "output_type": "label" }, # Add more models/tasks as needed ] # Shared state for demo shared_state = gr.State({"active_model": None, "last_result": None}) # Model cache for lazy loading model_cache = {} def load_model(task, model_name): # Use device_map="auto" or device=0 for GPU if available return pipeline(task, model=model_name, device=-1) def unload_model(model_key): if model_key in model_cache: del model_cache[model_key] gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() with gr.Blocks() as demo: gr.Markdown("# Multi-Model, Multi-Task Gradio Demo\n_Switch between models and tasks in one Space!_") tab_names = [m["name"] for m in MODEL_CONFIGS] with gr.Tabs() as tabs: tab_blocks = [] for i, config in enumerate(MODEL_CONFIGS): with gr.Tab(config["name"]): status = gr.Markdown(f"**Model:** {config['model']}
**Task:** {config['task']}") load_btn = gr.Button("Load Model") unload_btn = gr.Button("Unload Model") if config["input_type"] == "text": input_comp = gr.Textbox(label="Input Text") elif config["input_type"] == "image": input_comp = gr.Image(label="Input Image") else: input_comp = gr.Textbox(label="Input") run_btn = gr.Button("Run Model") output_comp = gr.Textbox(label="Output", lines=4) model_key = f"{config['task']}|{config['model']}" def do_load(state): if model_key not in model_cache: model_cache[model_key] = load_model(config["task"], config["model"]) state = dict(state) state["active_model"] = model_key return f"Loaded: {model_key}", state def do_unload(state): unload_model(model_key) state = dict(state) state["active_model"] = None return f"Unloaded: {model_key}", state def do_run(inp, state): if model_key not in model_cache: return "Model not loaded!", state pipe = model_cache[model_key] result = pipe(inp) state = dict(state) state["last_result"] = result return str(result), state load_btn.click(do_load, shared_state, [status, shared_state]) unload_btn.click(do_unload, shared_state, [status, shared_state]) run_btn.click(do_run, [input_comp, shared_state], [output_comp, shared_state]) # Shared state display def pretty_json(state): return json.dumps(state, indent=2, ensure_ascii=False) shared_state_box = gr.Textbox(label="Shared State", lines=8, interactive=False) shared_state.change(pretty_json, shared_state, shared_state_box) demo.launch()