Spaces:
Runtime error
Runtime error
LPX55
Add Gradio interface for multi-model diffusion and text generation tasks, including model loading/unloading functionality and shared state management. Introduce new tabs for text and diffusion models, enhancing user interaction and modularity.
a5723a0
import gradio as gr | |
import torch | |
import gc | |
import json | |
from pipeline_tabs.text_tab import text_tab | |
from pipeline_tabs.diffusion_tab import diffusion_tab | |
model_cache = {} | |
def unload_all_models(): | |
model_cache.clear() | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
with gr.Blocks() as demo: | |
with gr.Tabs(): | |
text_tab(model_cache, unload_all_models) | |
diffusion_tab(model_cache, unload_all_models) | |
# Shared state display | |
def pretty_json(): | |
return json.dumps(list(model_cache.keys()), indent=2, ensure_ascii=False) | |
state_box = gr.Textbox(label="Loaded Models", lines=4, interactive=False, value=pretty_json()) | |
# Update state_box whenever a model is loaded/unloaded | |
demo.load(fn=pretty_json, inputs=None, outputs=state_box) | |
# Optionally, you can add a button to refresh the state display | |
refresh_btn = gr.Button("Refresh Model State") | |
refresh_btn.click(fn=pretty_json, inputs=None, outputs=state_box) | |
demo.launch() |