Zimg-debug / app_lora1.py
rahul7star's picture
Update app_lora1.py
ba4f365 verified
import spaces
import gradio as gr
import torch
import os
import traceback
from diffusers import ZImagePipeline
from huggingface_hub import list_repo_files
from PIL import Image
# ============================================================
# CONFIG
# ============================================================
MODEL_ID = "Tongyi-MAI/Z-Image-Turbo"
DEFAULT_LORA_REPO = "rahul7star/ZImageLora"
DTYPE = torch.bfloat16
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ============================================================
# GLOBAL STATE
# ============================================================
pipe = None
CURRENT_LORA_REPO = None
CURRENT_LORA_FILE = None
# ============================================================
# LOGGING
# ============================================================
def log(msg):
print(msg)
return msg
# ============================================================
# PIPELINE BUILD (ONCE)
# ============================================================
try:
pipe = ZImagePipeline.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
)
pipe.to(DEVICE)
log("✅ Pipeline built successfully")
except Exception as e:
log("❌ Pipeline build failed")
log(traceback.format_exc())
pipe = None
# ============================================================
# HELPERS
# ============================================================
def list_loras_from_repo(repo_id: str):
try:
files = list_repo_files(repo_id)
return [f for f in files if f.endswith(".safetensors")]
except Exception as e:
log(f"❌ Failed to list LoRAs: {e}")
return []
# ============================================================
# IMAGE GENERATION (SAFE LORA LOGIC)
# ============================================================
@spaces.GPU()
def generate_image(prompt, height, width, steps, seed, guidance_scale):
LOGS = []
print(prompt)
if pipe is None:
return None, [], "❌ Pipeline not initialized"
generator = torch.Generator().manual_seed(int(seed))
placeholder = Image.new("RGB", (width, height), (255, 255, 255))
previews = []
# ---- Always start clean ----
try:
pipe.unload_lora_weights()
except Exception:
pass
# ---- Load LoRA for this run only ----
if CURRENT_LORA_FILE:
try:
pipe.load_lora_weights(
CURRENT_LORA_REPO,
weight_name=CURRENT_LORA_FILE
)
LOGS.append(f"🧩 LoRA loaded: {CURRENT_LORA_FILE}")
except Exception as e:
LOGS.append(f"❌ LoRA load failed: {e}")
# ---- Preview steps (lightweight) ----
try:
num_previews = min(5, steps)
for i in range(num_previews):
out = pipe(
prompt=prompt,
height=height // 4,
width=width // 4,
num_inference_steps=i + 1,
guidance_scale=guidance_scale,
generator=generator,
)
img = out.images[0].resize((width, height))
previews.append(img)
yield None, previews, "\n".join(LOGS)
except Exception as e:
LOGS.append(f"⚠️ Preview failed: {e}")
# ---- Final image ----
try:
out = pipe(
prompt=prompt,
height=height,
width=width,
num_inference_steps=steps,
guidance_scale=guidance_scale,
generator=generator,
)
final_img = out.images[0]
previews.append(final_img)
LOGS.append("✅ Image generated")
yield final_img, previews, "\n".join(LOGS)
except Exception as e:
LOGS.append(f"❌ Generation failed: {e}")
yield placeholder, previews, "\n".join(LOGS)
finally:
# ---- CRITICAL: unload after run ----
try:
pipe.unload_lora_weights()
LOGS.append("🧹 LoRA unloaded")
except Exception:
pass
# ============================================================
# GRADIO UI
# ============================================================
css = """
.gradio-container {
max-width: 100% !important;
padding: 16px 32px !important;
}
.section {
margin-bottom: 12px;
}
.generate-btn {
background: linear-gradient(90deg, #4b6cb7, #182848) !important;
color: white !important;
font-weight: 600;
height: 46px;
border-radius: 10px;
}
.secondary-btn {
height: 42px;
border-radius: 10px;
}
textarea, input {
border-radius: 10px !important;
}
"""
with gr.Blocks(
title="Z-Image-Turbo (Runtime LoRA)",
css=css,
) as demo:
gr.Markdown(
"""
# 🎨 Z-Image-Turbo LORA
**Runtime LoRA · Safe Mode · Full-Width UI**
"""
)
# ======================================================
# MAIN LAYOUT
# ======================================================
with gr.Row():
# ================= LEFT PANEL =================
with gr.Column(scale=5):
# -------- Prompt --------
prompt = gr.Textbox(
label="Prompt",
value="boat in ocean",
lines=4,
placeholder="Describe the image you want to generate…",
)
# -------- LoRA Controls (NEXT TO PROMPT) --------
gr.Markdown("### 🧩 LoRA Controls")
lora_repo = gr.Textbox(
label="LoRA Repository",
value=DEFAULT_LORA_REPO,
lines=2,
placeholder="username/repo (e.g. rahul7star/ZImageLora)",
)
lora_dropdown = gr.Dropdown(
label="LoRA File",
choices=[],
interactive=True,
)
with gr.Row():
refresh_btn = gr.Button("🔄 Refresh LoRA List", elem_classes="secondary-btn")
clear_lora_btn = gr.Button("❌ Clear LoRA", elem_classes="secondary-btn")
# -------- Generation Controls --------
gr.Markdown("### ⚙️ Generation Settings")
with gr.Row():
width = gr.Slider(256, 2048, value=1024, step=8, label="Width")
height = gr.Slider(256, 2048, value=1024, step=8, label="Height")
with gr.Row():
steps = gr.Slider(1, 50, value=20, step=1, label="Steps")
guidance = gr.Slider(0, 10, value=0.0, step=0.5, label="Guidance")
seed = gr.Number(value=42, label="Seed", precision=0)
run_btn = gr.Button("🚀 Generate Image", elem_classes="generate-btn")
logs_box = gr.Textbox(
label="Logs",
lines=10,
interactive=False,
)
# ================= RIGHT PANEL =================
with gr.Column(scale=7):
final_image = gr.Image(
label="Final Image",
height=520,
)
gallery = gr.Gallery(
label="Generation Steps",
columns=4,
height=260,
)
# ======================================================
# CALLBACKS
# ======================================================
def refresh_loras(repo):
files = list_loras_from_repo(repo)
return gr.update(
choices=files,
value=files[0] if files else None,
)
refresh_btn.click(
refresh_loras,
inputs=[lora_repo],
outputs=[lora_dropdown],
)
def select_lora(lora_file, repo):
global CURRENT_LORA_FILE, CURRENT_LORA_REPO
CURRENT_LORA_FILE = lora_file
CURRENT_LORA_REPO = repo
return f"🧩 Selected LoRA: {lora_file}"
lora_dropdown.change(
select_lora,
inputs=[lora_dropdown, lora_repo],
outputs=[logs_box],
)
def clear_lora():
global CURRENT_LORA_FILE, CURRENT_LORA_REPO
CURRENT_LORA_FILE = None
CURRENT_LORA_REPO = None
try:
pipe.unload_lora_weights()
except Exception:
pass
return (
gr.update(value=None),
"🧹 LoRA cleared — base model will be used."
)
clear_lora_btn.click(
clear_lora,
outputs=[lora_dropdown, logs_box],
)
run_btn.click(
generate_image,
inputs=[prompt, height, width, steps, seed, guidance],
outputs=[final_image, gallery, logs_box],
)
demo.launch()