File size: 10,627 Bytes
f5929b1
 
 
 
 
5d7cb8c
f5929b1
 
 
 
 
a6ee92e
 
 
806ad66
2fffe4a
34b67fc
a6ee92e
f5929b1
 
d3b1795
a6ee92e
f5929b1
d3b1795
aa8484e
2fffe4a
d3b1795
 
 
aa8484e
 
 
 
2fffe4a
aa8484e
2fffe4a
 
 
aa8484e
806ad66
aa8484e
d3b1795
c140eed
d3b1795
 
aa8484e
 
 
 
a6ee92e
aa8484e
d3b1795
 
 
2fffe4a
a6ee92e
2fffe4a
a6ee92e
aa8484e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fffe4a
aa8484e
2fffe4a
 
 
 
 
 
 
 
 
 
 
 
 
a6ee92e
2fffe4a
a6ee92e
 
d3b1795
2fffe4a
 
 
d3b1795
39981c9
aa8484e
d3b1795
2fffe4a
 
 
 
 
 
 
d3b1795
2fffe4a
 
 
 
 
f5929b1
2fffe4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5929b1
d3b1795
96cb74e
 
 
 
 
d3b1795
 
bbb7271
d3b1795
0c8c445
d3b1795
96cb74e
 
 
 
a6ee92e
96cb74e
 
 
 
 
 
 
 
f5929b1
 
 
 
bbb7271
 
 
 
a6ee92e
2fffe4a
a6ee92e
 
96cb74e
 
 
 
 
 
 
2fffe4a
96cb74e
 
e103e68
5d7cb8c
 
 
f5929b1
5d7cb8c
 
 
f5929b1
 
5d7cb8c
 
 
d3b1795
 
a6ee92e
0c8c445
d3b1795
 
a6ee92e
d3b1795
 
a6ee92e
f5929b1
 
afbe924
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import os
import gradio as gr
import torch
from PIL import Image
import spaces
from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
from diffusers.utils import load_image
import pandas as pd
import random
import time
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# --- Configuration Principale ---
KRYPTO_LORA = {"repo": "Econogoat/Krypt0_LORA", "trigger": "Krypt0", "adapter_name": "krypt0"}
GEMMA_MODEL_ID = "google/gemma-2-9b-it"

# Charger les prompts
df = pd.read_csv('prompts.csv', header=None)
prompt_values = df.values.flatten()

# Récupérer le token
HF_TOKEN = os.getenv("HF_TOKEN")

# --- Initialisation des Modèles (sur CPU uniquement et SANS QUANTIZATION initiale) ---
device_cpu = "cpu"
dtype = torch.bfloat16
base_model = "black-forest-labs/FLUX.1-dev"

# --- STRATÉGIE CORRIGÉE ---
# On charge Gemma sur le CPU SANS le quantizer au démarrage pour éviter le conflit avec l'environnement de Spaces.
# La quantization sera appliquée plus tard, uniquement sur le GPU.
print(f"Chargement du tokenizer pour {GEMMA_MODEL_ID}...")
gemma_tokenizer = AutoTokenizer.from_pretrained(GEMMA_MODEL_ID, token=HF_TOKEN)
print("Chargement du modèle d'image sur CPU...")
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device_cpu)
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype, token=HF_TOKEN).to(device_cpu)
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1, token=HF_TOKEN).to(device_cpu)
print("Chargement du LoRA...")
pipe.load_lora_weights(KRYPTO_LORA['repo'], low_cpu_mem_usage=True, adapter_name=KRYPTO_LORA['adapter_name'], token=HF_TOKEN)
print("Tous les modèles sont pré-chargés sur CPU.")

MAX_SEED = 2**32 - 1
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)

# On garde une référence globale pour ne recharger le modèle qu'une fois.
gemma_model = None

@spaces.GPU(duration=180) # Augmentation de la durée pour accommoder le chargement complet de Gemma
def run_generation(prompt, enhance_prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution, progress=gr.Progress(track_tqdm=True)):
    global gemma_model
    if not prompt:
        raise gr.Error("Prompt cannot be empty.")

    device_gpu = "cuda"
    final_prompt = prompt
    
    if enhance_prompt:
        # --- CHARGEMENT DYNAMIQUE SUR GPU ---
        if gemma_model is None:
            print(f"Premier appel : Chargement de {GEMMA_MODEL_ID} sur GPU avec quantization 4-bit...")
            bnb_config_gpu = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_use_double_quant=True,
                bnb_4bit_compute_dtype=dtype
            )
            gemma_model = AutoModelForCausalLM.from_pretrained(
                GEMMA_MODEL_ID,
                quantization_config=bnb_config_gpu,
                token=HF_TOKEN,
                device_map="auto" # "auto" fonctionnera car on est DÉJÀ sur un environnement GPU
            )
            print("Modèle Gemma chargé sur GPU.")
        
        print(f"Amélioration du prompt '{prompt}' avec Gemma...")
        system_prompt = (
            "You are an expert prompt engineer for a text-to-image AI. "
            "Your task is to take a user's simple idea and transform it into a rich, detailed, and visually descriptive prompt. "
            "Focus on describing the scene, the subject, the environment, the lighting, the colors, and a potential artistic style. "
            "Do not add any conversational text or refuse the request. Only output the enhanced prompt."
        )
        chat = [{"role": "user", "content": f"{system_prompt}\n\nUser idea: \"{user_prompt}\""}]
        prompt_for_gemma = gemma_tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
        inputs = gemma_tokenizer(prompt_for_gemma, return_tensors="pt").to(device_gpu)
        
        outputs = gemma_model.generate(**inputs, max_new_tokens=150, do_sample=True, temperature=0.7)
        input_length = inputs["input_ids"].shape[1]
        final_prompt = gemma_tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True).strip()
        print(f"Prompt amélioré : {final_prompt}")

    prompt_mash = f"{KRYPTO_LORA['trigger']}, {final_prompt}"
    print("Prompt final envoyé au modèle d'image:", prompt_mash)
    
    print("Déplacement du pipeline d'image sur le GPU...")
    pipe.to(device_gpu)
    good_vae.to(device_gpu)
    pipe.set_adapters([KRYPTO_LORA['adapter_name']], adapter_weights=[lora_scale])
    
    if randomize_seed: seed = random.randint(0, MAX_SEED)
    width, height = calculate_dimensions(aspect_ratio, base_resolution)
    print(f"Génération d'une image de {width}x{height} pixels.")
    
    generator = torch.Generator(device=device_gpu).manual_seed(seed)
    image_generator = pipe.flux_pipe_call_that_returns_an_iterable_of_images(
        prompt=prompt_mash, num_inference_steps=steps, guidance_scale=cfg_scale,
        width=width, height=height, generator=generator, output_type="pil", good_vae=good_vae,
    )

    final_image = None
    for i, image in enumerate(image_generator):
        final_image = image
        progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {i + 1}; --total: {steps};"></div></div>'
        yield image, seed, gr.update(value=progress_bar, visible=True)

    print("Libération de la mémoire du pipeline d'image (déplacement vers CPU)...")
    pipe.to(device_cpu)
    good_vae.to(device_cpu)
    torch.cuda.empty_cache()
    
    yield final_image, seed, gr.update(visible=False)

def calculate_dimensions(aspect_ratio, resolution):
    resolution = int(resolution)
    if aspect_ratio == "Square (1:1)": width, height = resolution, resolution
    elif aspect_ratio == "Portrait (9:16)": width, height = int(resolution * 9 / 16), resolution
    elif aspect_ratio == "Landscape (16:9)": width, height = resolution, int(resolution * 9 / 16)
    elif aspect_ratio == "Ultrawide (21:9)": width, height = resolution, int(resolution * 9 / 21)
    else: width, height = resolution, resolution
    width = (width // 64) * 64
    height = (height // 64) * 64
    return width, height

def update_history(new_image, history):
    if new_image is None: return history
    if history is None: history = []
    history.insert(0, new_image)
    return history

css = '''
#title_container { text-align: center; margin-bottom: 1em; }
#title_line { display: flex; justify-content: center; align-items: center; }
#title_line img { width: 70px; margin-right: 0.5em; }
#title_line h1 { font-size: 2.5em; margin: 0; }
#subtitle { font-size: 1.1em; color: #57606a; margin-top: 0.3em; }
.progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
.progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.1s ease-in-out}
#random_prompt_btn { height: 100% !important; }
'''

with gr.Blocks(css=css, theme=gr.themes.Soft()) as app:
    gr.HTML(
        """
        <div id="title_container">
            <div id="title_line">
                <img src="/file=logo.png" alt="Krypt0 Logo">
                <h1>Krypto Image Generator - beta v1</h1>
            </div>
            <div id="subtitle">
                Powered by $Krypto | @Kryptocoinonsol
            </div>
        </div>
        """
    )
    with gr.Row():
        with gr.Column(scale=3):
            with gr.Group():
                with gr.Row():
                    with gr.Column(scale=1, min_width=150):
                        random_prompt_btn = gr.Button("Random Prompt", elem_id="random_prompt_btn")
                    with gr.Column(scale=5):
                        prompt = gr.Textbox(label="Prompt", lines=2, placeholder="e.g., a portrait of a warrior queen")
                enhance_prompt_checkbox = gr.Checkbox(
                    label="Improve prompt with AI", value=True, 
                    info="Uses Gemma to automatically enrich your prompt with more details before generation."
                )
            with gr.Group():
                aspect_ratio = gr.Radio(
                    label="Aspect Ratio",
                    choices=["Square (1:1)", "Portrait (9:16)", "Landscape (16:9)", "Ultrawide (21:9)"],
                    value="Square (1:1)"
                )
                lora_scale = gr.Slider(
                    label="Krypt0 Style Strength", minimum=0, maximum=2, step=0.05, value=0.9,
                    info="Controls how strongly the artistic style is applied. Higher values mean a more stylized image."
                )
            with gr.Accordion("Advanced Settings", open=False):
                base_resolution = gr.Slider(label="Resolution (longest side)", minimum=768, maximum=1408, step=64, value=1024)
                steps = gr.Slider(label="Generation Steps", minimum=4, maximum=50, step=1, value=20)
                cfg_scale = gr.Slider(label="Guidance (CFG Scale)", minimum=1, maximum=10, step=0.5, value=3.5)
                with gr.Row():
                    randomize_seed = gr.Checkbox(True, label="Random Seed")
                    seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
            generate_button = gr.Button("Generate", variant="primary")
        with gr.Column(scale=2):
            progress_bar = gr.Markdown(elem_id="progress", visible=False)
            result = gr.Image(label="Generated Image", interactive=False, show_share_button=True)
            with gr.Accordion("History", open=False):
                history_gallery = gr.Gallery(label="History", columns=4, object_fit="contain", interactive=False)
    def get_random_prompt():
        return random.choice(prompt_values)
    random_prompt_btn.click(fn=get_random_prompt, inputs=[], outputs=[prompt])
    generation_event = gr.on(
        triggers=[generate_button.click, prompt.submit],
        fn=run_generation,
        inputs=[prompt, enhance_prompt_checkbox, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution],
        outputs=[result, seed, progress_bar]
    )
    generation_event.then(fn=update_history, inputs=[result, history_gallery], outputs=history_gallery)

app.queue(max_size=20)
app.launch()