LPX55
Add Gradio interface for multi-model diffusion and text generation tasks, including model loading/unloading functionality and shared state management. Introduce new tabs for text and diffusion models, enhancing user interaction and modularity.
a5723a0
raw
history blame
1.08 kB
import gradio as gr
from transformers import pipeline
def text_tab(model_cache, unload_all_models):
def load_text_model():
unload_all_models()
model_cache["text"] = pipeline("text-generation", model="gpt2", device=-1)
return "Text model loaded!"
def unload_text_model():
if "text" in model_cache:
del model_cache["text"]
return "Text model unloaded!"
def run_text(prompt):
if "text" not in model_cache:
return "Text model not loaded!"
return model_cache["text"](prompt)[0]["generated_text"]
with gr.Tab("Text Generation"):
status = gr.Markdown("Model not loaded.")
load_btn = gr.Button("Load Text Model")
unload_btn = gr.Button("Unload Model")
prompt = gr.Textbox(label="Prompt", value="Hello world")
run_btn = gr.Button("Generate")
output = gr.Textbox(label="Output")
load_btn.click(load_text_model, None, status)
unload_btn.click(unload_text_model, None, status)
run_btn.click(run_text, prompt, output)