Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import gradio as gr | |
| import torch | |
| import os | |
| import traceback | |
| from diffusers import ZImagePipeline | |
| from huggingface_hub import list_repo_files | |
| from PIL import Image | |
| # ============================================================ | |
| # CONFIG | |
| # ============================================================ | |
| MODEL_ID = "Tongyi-MAI/Z-Image-Turbo" | |
| DEFAULT_LORA_REPO = "rahul7star/ZImageLora" | |
| DTYPE = torch.bfloat16 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ============================================================ | |
| # GLOBAL STATE | |
| # ============================================================ | |
| pipe = None | |
| CURRENT_LORA_REPO = None | |
| CURRENT_LORA_FILE = None | |
| # ============================================================ | |
| # LOGGING | |
| # ============================================================ | |
| def log(msg): | |
| print(msg) | |
| return msg | |
| # ============================================================ | |
| # PIPELINE BUILD (ONCE) | |
| # ============================================================ | |
| try: | |
| pipe = ZImagePipeline.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=DTYPE, | |
| ) | |
| pipe.to(DEVICE) | |
| log("✅ Pipeline built successfully") | |
| except Exception as e: | |
| log("❌ Pipeline build failed") | |
| log(traceback.format_exc()) | |
| pipe = None | |
| # ============================================================ | |
| # HELPERS | |
| # ============================================================ | |
| def list_loras_from_repo(repo_id: str): | |
| try: | |
| files = list_repo_files(repo_id) | |
| return [f for f in files if f.endswith(".safetensors")] | |
| except Exception as e: | |
| log(f"❌ Failed to list LoRAs: {e}") | |
| return [] | |
| # ============================================================ | |
| # IMAGE GENERATION (SAFE LORA LOGIC) | |
| # ============================================================ | |
| def generate_image(prompt, height, width, steps, seed, guidance_scale): | |
| LOGS = [] | |
| print(prompt) | |
| if pipe is None: | |
| return None, [], "❌ Pipeline not initialized" | |
| generator = torch.Generator().manual_seed(int(seed)) | |
| placeholder = Image.new("RGB", (width, height), (255, 255, 255)) | |
| previews = [] | |
| # ---- Always start clean ---- | |
| try: | |
| pipe.unload_lora_weights() | |
| except Exception: | |
| pass | |
| # ---- Load LoRA for this run only ---- | |
| if CURRENT_LORA_FILE: | |
| try: | |
| pipe.load_lora_weights( | |
| CURRENT_LORA_REPO, | |
| weight_name=CURRENT_LORA_FILE | |
| ) | |
| LOGS.append(f"🧩 LoRA loaded: {CURRENT_LORA_FILE}") | |
| except Exception as e: | |
| LOGS.append(f"❌ LoRA load failed: {e}") | |
| # ---- Preview steps (lightweight) ---- | |
| try: | |
| num_previews = min(5, steps) | |
| for i in range(num_previews): | |
| out = pipe( | |
| prompt=prompt, | |
| height=height // 4, | |
| width=width // 4, | |
| num_inference_steps=i + 1, | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| ) | |
| img = out.images[0].resize((width, height)) | |
| previews.append(img) | |
| yield None, previews, "\n".join(LOGS) | |
| except Exception as e: | |
| LOGS.append(f"⚠️ Preview failed: {e}") | |
| # ---- Final image ---- | |
| try: | |
| out = pipe( | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=steps, | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| ) | |
| final_img = out.images[0] | |
| previews.append(final_img) | |
| LOGS.append("✅ Image generated") | |
| yield final_img, previews, "\n".join(LOGS) | |
| except Exception as e: | |
| LOGS.append(f"❌ Generation failed: {e}") | |
| yield placeholder, previews, "\n".join(LOGS) | |
| finally: | |
| # ---- CRITICAL: unload after run ---- | |
| try: | |
| pipe.unload_lora_weights() | |
| LOGS.append("🧹 LoRA unloaded") | |
| except Exception: | |
| pass | |
| # ============================================================ | |
| # GRADIO UI | |
| # ============================================================ | |
| css = """ | |
| .gradio-container { | |
| max-width: 100% !important; | |
| padding: 16px 32px !important; | |
| } | |
| .section { | |
| margin-bottom: 12px; | |
| } | |
| .generate-btn { | |
| background: linear-gradient(90deg, #4b6cb7, #182848) !important; | |
| color: white !important; | |
| font-weight: 600; | |
| height: 46px; | |
| border-radius: 10px; | |
| } | |
| .secondary-btn { | |
| height: 42px; | |
| border-radius: 10px; | |
| } | |
| textarea, input { | |
| border-radius: 10px !important; | |
| } | |
| """ | |
| with gr.Blocks( | |
| title="Z-Image-Turbo (Runtime LoRA)", | |
| css=css, | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🎨 Z-Image-Turbo LORA | |
| **Runtime LoRA · Safe Mode · Full-Width UI** | |
| """ | |
| ) | |
| # ====================================================== | |
| # MAIN LAYOUT | |
| # ====================================================== | |
| with gr.Row(): | |
| # ================= LEFT PANEL ================= | |
| with gr.Column(scale=5): | |
| # -------- Prompt -------- | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| value="boat in ocean", | |
| lines=4, | |
| placeholder="Describe the image you want to generate…", | |
| ) | |
| # -------- LoRA Controls (NEXT TO PROMPT) -------- | |
| gr.Markdown("### 🧩 LoRA Controls") | |
| lora_repo = gr.Textbox( | |
| label="LoRA Repository", | |
| value=DEFAULT_LORA_REPO, | |
| lines=2, | |
| placeholder="username/repo (e.g. rahul7star/ZImageLora)", | |
| ) | |
| lora_dropdown = gr.Dropdown( | |
| label="LoRA File", | |
| choices=[], | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| refresh_btn = gr.Button("🔄 Refresh LoRA List", elem_classes="secondary-btn") | |
| clear_lora_btn = gr.Button("❌ Clear LoRA", elem_classes="secondary-btn") | |
| # -------- Generation Controls -------- | |
| gr.Markdown("### ⚙️ Generation Settings") | |
| with gr.Row(): | |
| width = gr.Slider(256, 2048, value=1024, step=8, label="Width") | |
| height = gr.Slider(256, 2048, value=1024, step=8, label="Height") | |
| with gr.Row(): | |
| steps = gr.Slider(1, 50, value=20, step=1, label="Steps") | |
| guidance = gr.Slider(0, 10, value=0.0, step=0.5, label="Guidance") | |
| seed = gr.Number(value=42, label="Seed", precision=0) | |
| run_btn = gr.Button("🚀 Generate Image", elem_classes="generate-btn") | |
| logs_box = gr.Textbox( | |
| label="Logs", | |
| lines=10, | |
| interactive=False, | |
| ) | |
| # ================= RIGHT PANEL ================= | |
| with gr.Column(scale=7): | |
| final_image = gr.Image( | |
| label="Final Image", | |
| height=520, | |
| ) | |
| gallery = gr.Gallery( | |
| label="Generation Steps", | |
| columns=4, | |
| height=260, | |
| ) | |
| # ====================================================== | |
| # CALLBACKS | |
| # ====================================================== | |
| def refresh_loras(repo): | |
| files = list_loras_from_repo(repo) | |
| return gr.update( | |
| choices=files, | |
| value=files[0] if files else None, | |
| ) | |
| refresh_btn.click( | |
| refresh_loras, | |
| inputs=[lora_repo], | |
| outputs=[lora_dropdown], | |
| ) | |
| def select_lora(lora_file, repo): | |
| global CURRENT_LORA_FILE, CURRENT_LORA_REPO | |
| CURRENT_LORA_FILE = lora_file | |
| CURRENT_LORA_REPO = repo | |
| return f"🧩 Selected LoRA: {lora_file}" | |
| lora_dropdown.change( | |
| select_lora, | |
| inputs=[lora_dropdown, lora_repo], | |
| outputs=[logs_box], | |
| ) | |
| def clear_lora(): | |
| global CURRENT_LORA_FILE, CURRENT_LORA_REPO | |
| CURRENT_LORA_FILE = None | |
| CURRENT_LORA_REPO = None | |
| try: | |
| pipe.unload_lora_weights() | |
| except Exception: | |
| pass | |
| return ( | |
| gr.update(value=None), | |
| "🧹 LoRA cleared — base model will be used." | |
| ) | |
| clear_lora_btn.click( | |
| clear_lora, | |
| outputs=[lora_dropdown, logs_box], | |
| ) | |
| run_btn.click( | |
| generate_image, | |
| inputs=[prompt, height, width, steps, seed, guidance], | |
| outputs=[final_image, gallery, logs_box], | |
| ) | |
| demo.launch() | |