LorD276's picture
Upload 2 files
7956c7a verified
# -*- coding: utf-8 -*-
"""app.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1mk2JVb5P1d-OXS_A2BJndkO3Bf4Z7b4w
"""
import os
import torch
from diffusers import StableDiffusionXLPipeline
import gradio as gr
import gc
# Set environment variable early
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# Device configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
# Supported model styles
model_dict = {
"anime": "cagliostrolab/animagine-xl-3.1",
"photorealistic": "stabilityai/stable-diffusion-xl-base-1.0",
"artstyle": "stablediffusionapi/dreamshaper-xl-1-0",
"realistic": "SG161222/RealVisXL_V4.0",
"anime2": "Linaqruf/animagine-xl-beta",
"cinematic": "Lykon/dreamshaper-xl-turbo",
"pixelart": "TheLastBen/pixel-art-xl"
}
# Default negative prompt
default_negative_prompt = (
"blurry, lowres, bad anatomy, deformed, disfigured, extra limbs, fused fingers, "
"watermark, text, signature, cropped, low quality, poorly drawn, jpeg artifacts"
)
# Global pipeline variable
pipe = None
def load_pipeline(style):
global pipe
if pipe:
del pipe
gc.collect()
torch.cuda.empty_cache()
model_id = model_dict[style]
pipe = StableDiffusionXLPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16,
use_safetensors=True
).to(device)
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
return pipe
def generate(prompt, style, negative_prompt, num_images):
global pipe
gc.collect()
torch.cuda.empty_cache()
if style not in model_dict:
return [f"❌ Invalid style: {style}"]
load_pipeline(style)
if not negative_prompt:
negative_prompt = default_negative_prompt
outputs = []
with torch.inference_mode():
for _ in range(num_images):
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=25,
guidance_scale=7,
).images[0]
outputs.append(image)
return outputs
# Build Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## 🎨 Stable Diffusion XL - Image Generator")
with gr.Row():
prompt = gr.Textbox(label="πŸ“ Prompt", placeholder="Describe your image...")
negative_prompt = gr.Textbox(label="🚫 Negative Prompt (optional)", placeholder="bad quality, blurry, ...")
with gr.Row():
style = gr.Dropdown(choices=list(model_dict.keys()), label="🎨 Style", value="photorealistic")
num_images = gr.Slider(minimum=1, maximum=4, value=1, step=1, label="πŸ“Έ Number of Images")
generate_btn = gr.Button("πŸš€ Generate")
gallery = gr.Gallery(label="πŸ–ΌοΈ Output", columns=2, height="auto")
generate_btn.click(
fn=generate,
inputs=[prompt, style, negative_prompt, num_images],
outputs=gallery
)
demo.launch()