dynamic-hfspaces / pipeline_tabs /diffusion_tab.py
LPX55
Add Gradio interface for multi-model diffusion and text generation tasks, including model loading/unloading functionality and shared state management. Introduce new tabs for text and diffusion models, enhancing user interaction and modularity.
a5723a0
import gradio as gr
import torch
from diffusers import DiffusionPipeline
import gc
def diffusion_tab(model_cache, unload_all_models):
def load_diffusion_model():
unload_all_models()
model_id = "LPX55/FLUX.1-merged_lightning_v2"
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
pipe = pipe.to("cpu")
pipe.enable_attention_slicing()
model_cache["diffusion"] = pipe
return "Diffusion model loaded!"
def unload_diffusion_model():
if "diffusion" in model_cache:
del model_cache["diffusion"]
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return "Diffusion model unloaded!"
def run_diffusion(prompt, width, height, steps):
if "diffusion" not in model_cache:
return None, "Diffusion model not loaded!"
pipe = model_cache["diffusion"]
image = pipe(
prompt=prompt,
width=width,
height=height,
num_inference_steps=steps,
).images[0]
return image, "Success!"
with gr.Tab("Diffusion"):
status = gr.Markdown("Model not loaded.")
load_btn = gr.Button("Load Diffusion Model")
unload_btn = gr.Button("Unload Model")
prompt = gr.Textbox(label="Prompt", value="A cat holding a sign that says hello world")
width = gr.Slider(256, 1536, value=768, step=64, label="Width")
height = gr.Slider(256, 1536, value=1152, step=64, label="Height")
steps = gr.Slider(1, 50, value=8, step=1, label="Inference Steps")
run_btn = gr.Button("Generate Image")
output_img = gr.Image(label="Output Image")
output_msg = gr.Textbox(label="Status", interactive=False)
load_btn.click(load_diffusion_model, None, status)
unload_btn.click(unload_diffusion_model, None, status)
run_btn.click(run_diffusion, [prompt, width, height, steps], [output_img, output_msg])