LPX55
Add Gradio interface for multi-model diffusion and text generation tasks, including model loading/unloading functionality and shared state management. Introduce new tabs for text and diffusion models, enhancing user interaction and modularity.
a5723a0
import gradio as gr
from transformers import pipeline
def text_tab(model_cache, unload_all_models):
def load_text_model():
unload_all_models()
model_cache["text"] = pipeline("text-generation", model="gpt2", device=-1)
return "Text model loaded!"
def unload_text_model():
if "text" in model_cache:
del model_cache["text"]
return "Text model unloaded!"
def run_text(prompt):
if "text" not in model_cache:
return "Text model not loaded!"
return model_cache["text"](prompt)[0]["generated_text"]
with gr.Tab("Text Generation"):
status = gr.Markdown("Model not loaded.")
load_btn = gr.Button("Load Text Model")
unload_btn = gr.Button("Unload Model")
prompt = gr.Textbox(label="Prompt", value="Hello world")
run_btn = gr.Button("Generate")
output = gr.Textbox(label="Output")
load_btn.click(load_text_model, None, status)
unload_btn.click(unload_text_model, None, status)
run_btn.click(run_text, prompt, output)