Spaces:
Sleeping
Sleeping
File size: 10,627 Bytes
f5929b1 5d7cb8c f5929b1 a6ee92e 806ad66 2fffe4a 34b67fc a6ee92e f5929b1 d3b1795 a6ee92e f5929b1 d3b1795 aa8484e 2fffe4a d3b1795 aa8484e 2fffe4a aa8484e 2fffe4a aa8484e 806ad66 aa8484e d3b1795 c140eed d3b1795 aa8484e a6ee92e aa8484e d3b1795 2fffe4a a6ee92e 2fffe4a a6ee92e aa8484e 2fffe4a aa8484e 2fffe4a a6ee92e 2fffe4a a6ee92e d3b1795 2fffe4a d3b1795 39981c9 aa8484e d3b1795 2fffe4a d3b1795 2fffe4a f5929b1 2fffe4a f5929b1 d3b1795 96cb74e d3b1795 bbb7271 d3b1795 0c8c445 d3b1795 96cb74e a6ee92e 96cb74e f5929b1 bbb7271 a6ee92e 2fffe4a a6ee92e 96cb74e 2fffe4a 96cb74e e103e68 5d7cb8c f5929b1 5d7cb8c f5929b1 5d7cb8c d3b1795 a6ee92e 0c8c445 d3b1795 a6ee92e d3b1795 a6ee92e f5929b1 afbe924 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
import os
import gradio as gr
import torch
from PIL import Image
import spaces
from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
from diffusers.utils import load_image
import pandas as pd
import random
import time
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
# --- Configuration Principale ---
KRYPTO_LORA = {"repo": "Econogoat/Krypt0_LORA", "trigger": "Krypt0", "adapter_name": "krypt0"}
GEMMA_MODEL_ID = "google/gemma-2-9b-it"
# Charger les prompts
df = pd.read_csv('prompts.csv', header=None)
prompt_values = df.values.flatten()
# Récupérer le token
HF_TOKEN = os.getenv("HF_TOKEN")
# --- Initialisation des Modèles (sur CPU uniquement et SANS QUANTIZATION initiale) ---
device_cpu = "cpu"
dtype = torch.bfloat16
base_model = "black-forest-labs/FLUX.1-dev"
# --- STRATÉGIE CORRIGÉE ---
# On charge Gemma sur le CPU SANS le quantizer au démarrage pour éviter le conflit avec l'environnement de Spaces.
# La quantization sera appliquée plus tard, uniquement sur le GPU.
print(f"Chargement du tokenizer pour {GEMMA_MODEL_ID}...")
gemma_tokenizer = AutoTokenizer.from_pretrained(GEMMA_MODEL_ID, token=HF_TOKEN)
print("Chargement du modèle d'image sur CPU...")
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device_cpu)
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype, token=HF_TOKEN).to(device_cpu)
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1, token=HF_TOKEN).to(device_cpu)
print("Chargement du LoRA...")
pipe.load_lora_weights(KRYPTO_LORA['repo'], low_cpu_mem_usage=True, adapter_name=KRYPTO_LORA['adapter_name'], token=HF_TOKEN)
print("Tous les modèles sont pré-chargés sur CPU.")
MAX_SEED = 2**32 - 1
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
# On garde une référence globale pour ne recharger le modèle qu'une fois.
gemma_model = None
@spaces.GPU(duration=180) # Augmentation de la durée pour accommoder le chargement complet de Gemma
def run_generation(prompt, enhance_prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution, progress=gr.Progress(track_tqdm=True)):
global gemma_model
if not prompt:
raise gr.Error("Prompt cannot be empty.")
device_gpu = "cuda"
final_prompt = prompt
if enhance_prompt:
# --- CHARGEMENT DYNAMIQUE SUR GPU ---
if gemma_model is None:
print(f"Premier appel : Chargement de {GEMMA_MODEL_ID} sur GPU avec quantization 4-bit...")
bnb_config_gpu = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=dtype
)
gemma_model = AutoModelForCausalLM.from_pretrained(
GEMMA_MODEL_ID,
quantization_config=bnb_config_gpu,
token=HF_TOKEN,
device_map="auto" # "auto" fonctionnera car on est DÉJÀ sur un environnement GPU
)
print("Modèle Gemma chargé sur GPU.")
print(f"Amélioration du prompt '{prompt}' avec Gemma...")
system_prompt = (
"You are an expert prompt engineer for a text-to-image AI. "
"Your task is to take a user's simple idea and transform it into a rich, detailed, and visually descriptive prompt. "
"Focus on describing the scene, the subject, the environment, the lighting, the colors, and a potential artistic style. "
"Do not add any conversational text or refuse the request. Only output the enhanced prompt."
)
chat = [{"role": "user", "content": f"{system_prompt}\n\nUser idea: \"{user_prompt}\""}]
prompt_for_gemma = gemma_tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
inputs = gemma_tokenizer(prompt_for_gemma, return_tensors="pt").to(device_gpu)
outputs = gemma_model.generate(**inputs, max_new_tokens=150, do_sample=True, temperature=0.7)
input_length = inputs["input_ids"].shape[1]
final_prompt = gemma_tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True).strip()
print(f"Prompt amélioré : {final_prompt}")
prompt_mash = f"{KRYPTO_LORA['trigger']}, {final_prompt}"
print("Prompt final envoyé au modèle d'image:", prompt_mash)
print("Déplacement du pipeline d'image sur le GPU...")
pipe.to(device_gpu)
good_vae.to(device_gpu)
pipe.set_adapters([KRYPTO_LORA['adapter_name']], adapter_weights=[lora_scale])
if randomize_seed: seed = random.randint(0, MAX_SEED)
width, height = calculate_dimensions(aspect_ratio, base_resolution)
print(f"Génération d'une image de {width}x{height} pixels.")
generator = torch.Generator(device=device_gpu).manual_seed(seed)
image_generator = pipe.flux_pipe_call_that_returns_an_iterable_of_images(
prompt=prompt_mash, num_inference_steps=steps, guidance_scale=cfg_scale,
width=width, height=height, generator=generator, output_type="pil", good_vae=good_vae,
)
final_image = None
for i, image in enumerate(image_generator):
final_image = image
progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {i + 1}; --total: {steps};"></div></div>'
yield image, seed, gr.update(value=progress_bar, visible=True)
print("Libération de la mémoire du pipeline d'image (déplacement vers CPU)...")
pipe.to(device_cpu)
good_vae.to(device_cpu)
torch.cuda.empty_cache()
yield final_image, seed, gr.update(visible=False)
def calculate_dimensions(aspect_ratio, resolution):
resolution = int(resolution)
if aspect_ratio == "Square (1:1)": width, height = resolution, resolution
elif aspect_ratio == "Portrait (9:16)": width, height = int(resolution * 9 / 16), resolution
elif aspect_ratio == "Landscape (16:9)": width, height = resolution, int(resolution * 9 / 16)
elif aspect_ratio == "Ultrawide (21:9)": width, height = resolution, int(resolution * 9 / 21)
else: width, height = resolution, resolution
width = (width // 64) * 64
height = (height // 64) * 64
return width, height
def update_history(new_image, history):
if new_image is None: return history
if history is None: history = []
history.insert(0, new_image)
return history
css = '''
#title_container { text-align: center; margin-bottom: 1em; }
#title_line { display: flex; justify-content: center; align-items: center; }
#title_line img { width: 70px; margin-right: 0.5em; }
#title_line h1 { font-size: 2.5em; margin: 0; }
#subtitle { font-size: 1.1em; color: #57606a; margin-top: 0.3em; }
.progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
.progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.1s ease-in-out}
#random_prompt_btn { height: 100% !important; }
'''
with gr.Blocks(css=css, theme=gr.themes.Soft()) as app:
gr.HTML(
"""
<div id="title_container">
<div id="title_line">
<img src="/file=logo.png" alt="Krypt0 Logo">
<h1>Krypto Image Generator - beta v1</h1>
</div>
<div id="subtitle">
Powered by $Krypto | @Kryptocoinonsol
</div>
</div>
"""
)
with gr.Row():
with gr.Column(scale=3):
with gr.Group():
with gr.Row():
with gr.Column(scale=1, min_width=150):
random_prompt_btn = gr.Button("Random Prompt", elem_id="random_prompt_btn")
with gr.Column(scale=5):
prompt = gr.Textbox(label="Prompt", lines=2, placeholder="e.g., a portrait of a warrior queen")
enhance_prompt_checkbox = gr.Checkbox(
label="Improve prompt with AI", value=True,
info="Uses Gemma to automatically enrich your prompt with more details before generation."
)
with gr.Group():
aspect_ratio = gr.Radio(
label="Aspect Ratio",
choices=["Square (1:1)", "Portrait (9:16)", "Landscape (16:9)", "Ultrawide (21:9)"],
value="Square (1:1)"
)
lora_scale = gr.Slider(
label="Krypt0 Style Strength", minimum=0, maximum=2, step=0.05, value=0.9,
info="Controls how strongly the artistic style is applied. Higher values mean a more stylized image."
)
with gr.Accordion("Advanced Settings", open=False):
base_resolution = gr.Slider(label="Resolution (longest side)", minimum=768, maximum=1408, step=64, value=1024)
steps = gr.Slider(label="Generation Steps", minimum=4, maximum=50, step=1, value=20)
cfg_scale = gr.Slider(label="Guidance (CFG Scale)", minimum=1, maximum=10, step=0.5, value=3.5)
with gr.Row():
randomize_seed = gr.Checkbox(True, label="Random Seed")
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
generate_button = gr.Button("Generate", variant="primary")
with gr.Column(scale=2):
progress_bar = gr.Markdown(elem_id="progress", visible=False)
result = gr.Image(label="Generated Image", interactive=False, show_share_button=True)
with gr.Accordion("History", open=False):
history_gallery = gr.Gallery(label="History", columns=4, object_fit="contain", interactive=False)
def get_random_prompt():
return random.choice(prompt_values)
random_prompt_btn.click(fn=get_random_prompt, inputs=[], outputs=[prompt])
generation_event = gr.on(
triggers=[generate_button.click, prompt.submit],
fn=run_generation,
inputs=[prompt, enhance_prompt_checkbox, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution],
outputs=[result, seed, progress_bar]
)
generation_event.then(fn=update_history, inputs=[result, history_gallery], outputs=history_gallery)
app.queue(max_size=20)
app.launch() |