File size: 3,013 Bytes
7956c7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# -*- coding: utf-8 -*-
"""app.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1mk2JVb5P1d-OXS_A2BJndkO3Bf4Z7b4w
"""

import os
import torch
from diffusers import StableDiffusionXLPipeline
import gradio as gr
import gc

# Set environment variable early
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Device configuration
device = "cuda" if torch.cuda.is_available() else "cpu"

# Supported model styles
model_dict = {
    "anime": "cagliostrolab/animagine-xl-3.1",
    "photorealistic": "stabilityai/stable-diffusion-xl-base-1.0",
    "artstyle": "stablediffusionapi/dreamshaper-xl-1-0",
    "realistic": "SG161222/RealVisXL_V4.0",
    "anime2": "Linaqruf/animagine-xl-beta",
    "cinematic": "Lykon/dreamshaper-xl-turbo",
    "pixelart": "TheLastBen/pixel-art-xl"
}

# Default negative prompt
default_negative_prompt = (
    "blurry, lowres, bad anatomy, deformed, disfigured, extra limbs, fused fingers, "
    "watermark, text, signature, cropped, low quality, poorly drawn, jpeg artifacts"
)

# Global pipeline variable
pipe = None

def load_pipeline(style):
    global pipe
    if pipe:
        del pipe
        gc.collect()
        torch.cuda.empty_cache()

    model_id = model_dict[style]
    pipe = StableDiffusionXLPipeline.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
        use_safetensors=True
    ).to(device)

    pipe.enable_attention_slicing()
    pipe.enable_vae_slicing()
    return pipe

def generate(prompt, style, negative_prompt, num_images):
    global pipe
    gc.collect()
    torch.cuda.empty_cache()

    if style not in model_dict:
        return [f"❌ Invalid style: {style}"]

    load_pipeline(style)

    if not negative_prompt:
        negative_prompt = default_negative_prompt

    outputs = []
    with torch.inference_mode():
        for _ in range(num_images):
            image = pipe(
                prompt=prompt,
                negative_prompt=negative_prompt,
                num_inference_steps=25,
                guidance_scale=7,
            ).images[0]
            outputs.append(image)

    return outputs

# Build Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("## 🎨 Stable Diffusion XL - Image Generator")

    with gr.Row():
        prompt = gr.Textbox(label="πŸ“ Prompt", placeholder="Describe your image...")
        negative_prompt = gr.Textbox(label="🚫 Negative Prompt (optional)", placeholder="bad quality, blurry, ...")

    with gr.Row():
        style = gr.Dropdown(choices=list(model_dict.keys()), label="🎨 Style", value="photorealistic")
        num_images = gr.Slider(minimum=1, maximum=4, value=1, step=1, label="πŸ“Έ Number of Images")

    generate_btn = gr.Button("πŸš€ Generate")
    gallery = gr.Gallery(label="πŸ–ΌοΈ Output", columns=2, height="auto")

    generate_btn.click(
        fn=generate,
        inputs=[prompt, style, negative_prompt, num_images],
        outputs=gallery
    )

demo.launch()