Spaces:
Running
Running
import os | |
import gradio as gr | |
import torch | |
from PIL import Image | |
import spaces | |
from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL | |
from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images | |
from diffusers.utils import load_image | |
import pandas as pd | |
import random | |
import time | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
# --- Configuration Principale --- | |
KRYPTO_LORA = {"repo": "Econogoat/Krypt0_LORA", "trigger": "Krypt0", "adapter_name": "krypt0"} | |
GEMMA_MODEL_ID = "google/gemma-2-9b-it" | |
# Charger les prompts | |
df = pd.read_csv('prompts.csv', header=None) | |
prompt_values = df.values.flatten() | |
# Récupérer le token | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# --- Initialisation des Modèles (sur CPU uniquement et SANS QUANTIZATION initiale) --- | |
device_cpu = "cpu" | |
dtype = torch.bfloat16 | |
base_model = "black-forest-labs/FLUX.1-dev" | |
# --- STRATÉGIE CORRIGÉE --- | |
# On charge Gemma sur le CPU SANS le quantizer au démarrage pour éviter le conflit avec l'environnement de Spaces. | |
# La quantization sera appliquée plus tard, uniquement sur le GPU. | |
print(f"Chargement du tokenizer pour {GEMMA_MODEL_ID}...") | |
gemma_tokenizer = AutoTokenizer.from_pretrained(GEMMA_MODEL_ID, token=HF_TOKEN) | |
print("Chargement du modèle d'image sur CPU...") | |
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device_cpu) | |
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype, token=HF_TOKEN).to(device_cpu) | |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1, token=HF_TOKEN).to(device_cpu) | |
print("Chargement du LoRA...") | |
pipe.load_lora_weights(KRYPTO_LORA['repo'], low_cpu_mem_usage=True, adapter_name=KRYPTO_LORA['adapter_name'], token=HF_TOKEN) | |
print("Tous les modèles sont pré-chargés sur CPU.") | |
MAX_SEED = 2**32 - 1 | |
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe) | |
# On garde une référence globale pour ne recharger le modèle qu'une fois. | |
gemma_model = None | |
# Augmentation de la durée pour accommoder le chargement complet de Gemma | |
def run_generation(prompt, enhance_prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution, progress=gr.Progress(track_tqdm=True)): | |
global gemma_model | |
if not prompt: | |
raise gr.Error("Prompt cannot be empty.") | |
device_gpu = "cuda" | |
final_prompt = prompt | |
if enhance_prompt: | |
# --- CHARGEMENT DYNAMIQUE SUR GPU --- | |
if gemma_model is None: | |
print(f"Premier appel : Chargement de {GEMMA_MODEL_ID} sur GPU avec quantization 4-bit...") | |
bnb_config_gpu = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_compute_dtype=dtype | |
) | |
gemma_model = AutoModelForCausalLM.from_pretrained( | |
GEMMA_MODEL_ID, | |
quantization_config=bnb_config_gpu, | |
token=HF_TOKEN, | |
device_map="auto" # "auto" fonctionnera car on est DÉJÀ sur un environnement GPU | |
) | |
print("Modèle Gemma chargé sur GPU.") | |
print(f"Amélioration du prompt '{prompt}' avec Gemma...") | |
system_prompt = ( | |
"You are an expert prompt engineer for a text-to-image AI. " | |
"Your task is to take a user's simple idea and transform it into a rich, detailed, and visually descriptive prompt. " | |
"Focus on describing the scene, the subject, the environment, the lighting, the colors, and a potential artistic style. " | |
"Do not add any conversational text or refuse the request. Only output the enhanced prompt." | |
) | |
chat = [{"role": "user", "content": f"{system_prompt}\n\nUser idea: \"{user_prompt}\""}] | |
prompt_for_gemma = gemma_tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) | |
inputs = gemma_tokenizer(prompt_for_gemma, return_tensors="pt").to(device_gpu) | |
outputs = gemma_model.generate(**inputs, max_new_tokens=150, do_sample=True, temperature=0.7) | |
input_length = inputs["input_ids"].shape[1] | |
final_prompt = gemma_tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True).strip() | |
print(f"Prompt amélioré : {final_prompt}") | |
prompt_mash = f"{KRYPTO_LORA['trigger']}, {final_prompt}" | |
print("Prompt final envoyé au modèle d'image:", prompt_mash) | |
print("Déplacement du pipeline d'image sur le GPU...") | |
pipe.to(device_gpu) | |
good_vae.to(device_gpu) | |
pipe.set_adapters([KRYPTO_LORA['adapter_name']], adapter_weights=[lora_scale]) | |
if randomize_seed: seed = random.randint(0, MAX_SEED) | |
width, height = calculate_dimensions(aspect_ratio, base_resolution) | |
print(f"Génération d'une image de {width}x{height} pixels.") | |
generator = torch.Generator(device=device_gpu).manual_seed(seed) | |
image_generator = pipe.flux_pipe_call_that_returns_an_iterable_of_images( | |
prompt=prompt_mash, num_inference_steps=steps, guidance_scale=cfg_scale, | |
width=width, height=height, generator=generator, output_type="pil", good_vae=good_vae, | |
) | |
final_image = None | |
for i, image in enumerate(image_generator): | |
final_image = image | |
progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {i + 1}; --total: {steps};"></div></div>' | |
yield image, seed, gr.update(value=progress_bar, visible=True) | |
print("Libération de la mémoire du pipeline d'image (déplacement vers CPU)...") | |
pipe.to(device_cpu) | |
good_vae.to(device_cpu) | |
torch.cuda.empty_cache() | |
yield final_image, seed, gr.update(visible=False) | |
def calculate_dimensions(aspect_ratio, resolution): | |
resolution = int(resolution) | |
if aspect_ratio == "Square (1:1)": width, height = resolution, resolution | |
elif aspect_ratio == "Portrait (9:16)": width, height = int(resolution * 9 / 16), resolution | |
elif aspect_ratio == "Landscape (16:9)": width, height = resolution, int(resolution * 9 / 16) | |
elif aspect_ratio == "Ultrawide (21:9)": width, height = resolution, int(resolution * 9 / 21) | |
else: width, height = resolution, resolution | |
width = (width // 64) * 64 | |
height = (height // 64) * 64 | |
return width, height | |
def update_history(new_image, history): | |
if new_image is None: return history | |
if history is None: history = [] | |
history.insert(0, new_image) | |
return history | |
css = ''' | |
#title_container { text-align: center; margin-bottom: 1em; } | |
#title_line { display: flex; justify-content: center; align-items: center; } | |
#title_line img { width: 70px; margin-right: 0.5em; } | |
#title_line h1 { font-size: 2.5em; margin: 0; } | |
#subtitle { font-size: 1.1em; color: #57606a; margin-top: 0.3em; } | |
.progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px} | |
.progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.1s ease-in-out} | |
#random_prompt_btn { height: 100% !important; } | |
''' | |
with gr.Blocks(css=css, theme=gr.themes.Soft()) as app: | |
gr.HTML( | |
""" | |
<div id="title_container"> | |
<div id="title_line"> | |
<img src="/file=logo.png" alt="Krypt0 Logo"> | |
<h1>Krypto Image Generator - beta v1</h1> | |
</div> | |
<div id="subtitle"> | |
Powered by $Krypto | @Kryptocoinonsol | |
</div> | |
</div> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
with gr.Group(): | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=150): | |
random_prompt_btn = gr.Button("Random Prompt", elem_id="random_prompt_btn") | |
with gr.Column(scale=5): | |
prompt = gr.Textbox(label="Prompt", lines=2, placeholder="e.g., a portrait of a warrior queen") | |
enhance_prompt_checkbox = gr.Checkbox( | |
label="Improve prompt with AI", value=True, | |
info="Uses Gemma to automatically enrich your prompt with more details before generation." | |
) | |
with gr.Group(): | |
aspect_ratio = gr.Radio( | |
label="Aspect Ratio", | |
choices=["Square (1:1)", "Portrait (9:16)", "Landscape (16:9)", "Ultrawide (21:9)"], | |
value="Square (1:1)" | |
) | |
lora_scale = gr.Slider( | |
label="Krypt0 Style Strength", minimum=0, maximum=2, step=0.05, value=0.9, | |
info="Controls how strongly the artistic style is applied. Higher values mean a more stylized image." | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
base_resolution = gr.Slider(label="Resolution (longest side)", minimum=768, maximum=1408, step=64, value=1024) | |
steps = gr.Slider(label="Generation Steps", minimum=4, maximum=50, step=1, value=20) | |
cfg_scale = gr.Slider(label="Guidance (CFG Scale)", minimum=1, maximum=10, step=0.5, value=3.5) | |
with gr.Row(): | |
randomize_seed = gr.Checkbox(True, label="Random Seed") | |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) | |
generate_button = gr.Button("Generate", variant="primary") | |
with gr.Column(scale=2): | |
progress_bar = gr.Markdown(elem_id="progress", visible=False) | |
result = gr.Image(label="Generated Image", interactive=False, show_share_button=True) | |
with gr.Accordion("History", open=False): | |
history_gallery = gr.Gallery(label="History", columns=4, object_fit="contain", interactive=False) | |
def get_random_prompt(): | |
return random.choice(prompt_values) | |
random_prompt_btn.click(fn=get_random_prompt, inputs=[], outputs=[prompt]) | |
generation_event = gr.on( | |
triggers=[generate_button.click, prompt.submit], | |
fn=run_generation, | |
inputs=[prompt, enhance_prompt_checkbox, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution], | |
outputs=[result, seed, progress_bar] | |
) | |
generation_event.then(fn=update_history, inputs=[result, history_gallery], outputs=history_gallery) | |
app.queue(max_size=20) | |
app.launch() |