import spaces import gradio as gr import torch import os import traceback from diffusers import ZImagePipeline from huggingface_hub import list_repo_files from PIL import Image # ============================================================ # CONFIG # ============================================================ MODEL_ID = "Tongyi-MAI/Z-Image-Turbo" DEFAULT_LORA_REPO = "rahul7star/ZImageLora" DTYPE = torch.bfloat16 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ============================================================ # GLOBAL STATE # ============================================================ pipe = None CURRENT_LORA_REPO = None CURRENT_LORA_FILE = None # ============================================================ # LOGGING # ============================================================ def log(msg): print(msg) return msg # ============================================================ # PIPELINE BUILD (ONCE) # ============================================================ try: pipe = ZImagePipeline.from_pretrained( MODEL_ID, torch_dtype=DTYPE, ) pipe.to(DEVICE) log("βœ… Pipeline built successfully") except Exception as e: log("❌ Pipeline build failed") log(traceback.format_exc()) pipe = None # ============================================================ # HELPERS # ============================================================ def list_loras_from_repo(repo_id: str): try: files = list_repo_files(repo_id) return [f for f in files if f.endswith(".safetensors")] except Exception as e: log(f"❌ Failed to list LoRAs: {e}") return [] # ============================================================ # IMAGE GENERATION (SAFE LORA LOGIC) # ============================================================ @spaces.GPU() def generate_image(prompt, height, width, steps, seed, guidance_scale): LOGS = [] print(prompt) if pipe is None: return None, [], "❌ Pipeline not initialized" generator = torch.Generator().manual_seed(int(seed)) placeholder = Image.new("RGB", (width, height), (255, 255, 255)) previews = [] # ---- Always start clean ---- try: pipe.unload_lora_weights() except Exception: pass # ---- Load LoRA for this run only ---- if CURRENT_LORA_FILE: try: pipe.load_lora_weights( CURRENT_LORA_REPO, weight_name=CURRENT_LORA_FILE ) LOGS.append(f"🧩 LoRA loaded: {CURRENT_LORA_FILE}") except Exception as e: LOGS.append(f"❌ LoRA load failed: {e}") # ---- Preview steps (lightweight) ---- try: num_previews = min(5, steps) for i in range(num_previews): out = pipe( prompt=prompt, height=height // 4, width=width // 4, num_inference_steps=i + 1, guidance_scale=guidance_scale, generator=generator, ) img = out.images[0].resize((width, height)) previews.append(img) yield None, previews, "\n".join(LOGS) except Exception as e: LOGS.append(f"⚠️ Preview failed: {e}") # ---- Final image ---- try: out = pipe( prompt=prompt, height=height, width=width, num_inference_steps=steps, guidance_scale=guidance_scale, generator=generator, ) final_img = out.images[0] previews.append(final_img) LOGS.append("βœ… Image generated") yield final_img, previews, "\n".join(LOGS) except Exception as e: LOGS.append(f"❌ Generation failed: {e}") yield placeholder, previews, "\n".join(LOGS) finally: # ---- CRITICAL: unload after run ---- try: pipe.unload_lora_weights() LOGS.append("🧹 LoRA unloaded") except Exception: pass # ============================================================ # GRADIO UI # ============================================================ css = """ .gradio-container { max-width: 100% !important; padding: 16px 32px !important; } .section { margin-bottom: 12px; } .generate-btn { background: linear-gradient(90deg, #4b6cb7, #182848) !important; color: white !important; font-weight: 600; height: 46px; border-radius: 10px; } .secondary-btn { height: 42px; border-radius: 10px; } textarea, input { border-radius: 10px !important; } """ with gr.Blocks( title="Z-Image-Turbo (Runtime LoRA)", css=css, ) as demo: gr.Markdown( """ # 🎨 Z-Image-Turbo LORA **Runtime LoRA Β· Safe Mode Β· Full-Width UI** """ ) # ====================================================== # MAIN LAYOUT # ====================================================== with gr.Row(): # ================= LEFT PANEL ================= with gr.Column(scale=5): # -------- Prompt -------- prompt = gr.Textbox( label="Prompt", value="boat in ocean", lines=4, placeholder="Describe the image you want to generate…", ) # -------- LoRA Controls (NEXT TO PROMPT) -------- gr.Markdown("### 🧩 LoRA Controls") lora_repo = gr.Textbox( label="LoRA Repository", value=DEFAULT_LORA_REPO, lines=2, placeholder="username/repo (e.g. rahul7star/ZImageLora)", ) lora_dropdown = gr.Dropdown( label="LoRA File", choices=[], interactive=True, ) with gr.Row(): refresh_btn = gr.Button("πŸ”„ Refresh LoRA List", elem_classes="secondary-btn") clear_lora_btn = gr.Button("❌ Clear LoRA", elem_classes="secondary-btn") # -------- Generation Controls -------- gr.Markdown("### βš™οΈ Generation Settings") with gr.Row(): width = gr.Slider(256, 2048, value=1024, step=8, label="Width") height = gr.Slider(256, 2048, value=1024, step=8, label="Height") with gr.Row(): steps = gr.Slider(1, 50, value=20, step=1, label="Steps") guidance = gr.Slider(0, 10, value=0.0, step=0.5, label="Guidance") seed = gr.Number(value=42, label="Seed", precision=0) run_btn = gr.Button("πŸš€ Generate Image", elem_classes="generate-btn") logs_box = gr.Textbox( label="Logs", lines=10, interactive=False, ) # ================= RIGHT PANEL ================= with gr.Column(scale=7): final_image = gr.Image( label="Final Image", height=520, ) gallery = gr.Gallery( label="Generation Steps", columns=4, height=260, ) # ====================================================== # CALLBACKS # ====================================================== def refresh_loras(repo): files = list_loras_from_repo(repo) return gr.update( choices=files, value=files[0] if files else None, ) refresh_btn.click( refresh_loras, inputs=[lora_repo], outputs=[lora_dropdown], ) def select_lora(lora_file, repo): global CURRENT_LORA_FILE, CURRENT_LORA_REPO CURRENT_LORA_FILE = lora_file CURRENT_LORA_REPO = repo return f"🧩 Selected LoRA: {lora_file}" lora_dropdown.change( select_lora, inputs=[lora_dropdown, lora_repo], outputs=[logs_box], ) def clear_lora(): global CURRENT_LORA_FILE, CURRENT_LORA_REPO CURRENT_LORA_FILE = None CURRENT_LORA_REPO = None try: pipe.unload_lora_weights() except Exception: pass return ( gr.update(value=None), "🧹 LoRA cleared β€” base model will be used." ) clear_lora_btn.click( clear_lora, outputs=[lora_dropdown, logs_box], ) run_btn.click( generate_image, inputs=[prompt, height, width, steps, seed, guidance], outputs=[final_image, gallery, logs_box], ) demo.launch()