KryptoCreator / app.py
Econogoat's picture
Update app.py
aa8484e verified
raw
history blame
10.6 kB
import os
import gradio as gr
import torch
from PIL import Image
import spaces
from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
from diffusers.utils import load_image
import pandas as pd
import random
import time
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
# --- Configuration Principale ---
KRYPTO_LORA = {"repo": "Econogoat/Krypt0_LORA", "trigger": "Krypt0", "adapter_name": "krypt0"}
GEMMA_MODEL_ID = "google/gemma-2-9b-it"
# Charger les prompts
df = pd.read_csv('prompts.csv', header=None)
prompt_values = df.values.flatten()
# Récupérer le token
HF_TOKEN = os.getenv("HF_TOKEN")
# --- Initialisation des Modèles (sur CPU uniquement et SANS QUANTIZATION initiale) ---
device_cpu = "cpu"
dtype = torch.bfloat16
base_model = "black-forest-labs/FLUX.1-dev"
# --- STRATÉGIE CORRIGÉE ---
# On charge Gemma sur le CPU SANS le quantizer au démarrage pour éviter le conflit avec l'environnement de Spaces.
# La quantization sera appliquée plus tard, uniquement sur le GPU.
print(f"Chargement du tokenizer pour {GEMMA_MODEL_ID}...")
gemma_tokenizer = AutoTokenizer.from_pretrained(GEMMA_MODEL_ID, token=HF_TOKEN)
print("Chargement du modèle d'image sur CPU...")
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device_cpu)
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype, token=HF_TOKEN).to(device_cpu)
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1, token=HF_TOKEN).to(device_cpu)
print("Chargement du LoRA...")
pipe.load_lora_weights(KRYPTO_LORA['repo'], low_cpu_mem_usage=True, adapter_name=KRYPTO_LORA['adapter_name'], token=HF_TOKEN)
print("Tous les modèles sont pré-chargés sur CPU.")
MAX_SEED = 2**32 - 1
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
# On garde une référence globale pour ne recharger le modèle qu'une fois.
gemma_model = None
@spaces.GPU(duration=180) # Augmentation de la durée pour accommoder le chargement complet de Gemma
def run_generation(prompt, enhance_prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution, progress=gr.Progress(track_tqdm=True)):
global gemma_model
if not prompt:
raise gr.Error("Prompt cannot be empty.")
device_gpu = "cuda"
final_prompt = prompt
if enhance_prompt:
# --- CHARGEMENT DYNAMIQUE SUR GPU ---
if gemma_model is None:
print(f"Premier appel : Chargement de {GEMMA_MODEL_ID} sur GPU avec quantization 4-bit...")
bnb_config_gpu = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=dtype
)
gemma_model = AutoModelForCausalLM.from_pretrained(
GEMMA_MODEL_ID,
quantization_config=bnb_config_gpu,
token=HF_TOKEN,
device_map="auto" # "auto" fonctionnera car on est DÉJÀ sur un environnement GPU
)
print("Modèle Gemma chargé sur GPU.")
print(f"Amélioration du prompt '{prompt}' avec Gemma...")
system_prompt = (
"You are an expert prompt engineer for a text-to-image AI. "
"Your task is to take a user's simple idea and transform it into a rich, detailed, and visually descriptive prompt. "
"Focus on describing the scene, the subject, the environment, the lighting, the colors, and a potential artistic style. "
"Do not add any conversational text or refuse the request. Only output the enhanced prompt."
)
chat = [{"role": "user", "content": f"{system_prompt}\n\nUser idea: \"{user_prompt}\""}]
prompt_for_gemma = gemma_tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
inputs = gemma_tokenizer(prompt_for_gemma, return_tensors="pt").to(device_gpu)
outputs = gemma_model.generate(**inputs, max_new_tokens=150, do_sample=True, temperature=0.7)
input_length = inputs["input_ids"].shape[1]
final_prompt = gemma_tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True).strip()
print(f"Prompt amélioré : {final_prompt}")
prompt_mash = f"{KRYPTO_LORA['trigger']}, {final_prompt}"
print("Prompt final envoyé au modèle d'image:", prompt_mash)
print("Déplacement du pipeline d'image sur le GPU...")
pipe.to(device_gpu)
good_vae.to(device_gpu)
pipe.set_adapters([KRYPTO_LORA['adapter_name']], adapter_weights=[lora_scale])
if randomize_seed: seed = random.randint(0, MAX_SEED)
width, height = calculate_dimensions(aspect_ratio, base_resolution)
print(f"Génération d'une image de {width}x{height} pixels.")
generator = torch.Generator(device=device_gpu).manual_seed(seed)
image_generator = pipe.flux_pipe_call_that_returns_an_iterable_of_images(
prompt=prompt_mash, num_inference_steps=steps, guidance_scale=cfg_scale,
width=width, height=height, generator=generator, output_type="pil", good_vae=good_vae,
)
final_image = None
for i, image in enumerate(image_generator):
final_image = image
progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {i + 1}; --total: {steps};"></div></div>'
yield image, seed, gr.update(value=progress_bar, visible=True)
print("Libération de la mémoire du pipeline d'image (déplacement vers CPU)...")
pipe.to(device_cpu)
good_vae.to(device_cpu)
torch.cuda.empty_cache()
yield final_image, seed, gr.update(visible=False)
def calculate_dimensions(aspect_ratio, resolution):
resolution = int(resolution)
if aspect_ratio == "Square (1:1)": width, height = resolution, resolution
elif aspect_ratio == "Portrait (9:16)": width, height = int(resolution * 9 / 16), resolution
elif aspect_ratio == "Landscape (16:9)": width, height = resolution, int(resolution * 9 / 16)
elif aspect_ratio == "Ultrawide (21:9)": width, height = resolution, int(resolution * 9 / 21)
else: width, height = resolution, resolution
width = (width // 64) * 64
height = (height // 64) * 64
return width, height
def update_history(new_image, history):
if new_image is None: return history
if history is None: history = []
history.insert(0, new_image)
return history
css = '''
#title_container { text-align: center; margin-bottom: 1em; }
#title_line { display: flex; justify-content: center; align-items: center; }
#title_line img { width: 70px; margin-right: 0.5em; }
#title_line h1 { font-size: 2.5em; margin: 0; }
#subtitle { font-size: 1.1em; color: #57606a; margin-top: 0.3em; }
.progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
.progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.1s ease-in-out}
#random_prompt_btn { height: 100% !important; }
'''
with gr.Blocks(css=css, theme=gr.themes.Soft()) as app:
gr.HTML(
"""
<div id="title_container">
<div id="title_line">
<img src="/file=logo.png" alt="Krypt0 Logo">
<h1>Krypto Image Generator - beta v1</h1>
</div>
<div id="subtitle">
Powered by $Krypto | @Kryptocoinonsol
</div>
</div>
"""
)
with gr.Row():
with gr.Column(scale=3):
with gr.Group():
with gr.Row():
with gr.Column(scale=1, min_width=150):
random_prompt_btn = gr.Button("Random Prompt", elem_id="random_prompt_btn")
with gr.Column(scale=5):
prompt = gr.Textbox(label="Prompt", lines=2, placeholder="e.g., a portrait of a warrior queen")
enhance_prompt_checkbox = gr.Checkbox(
label="Improve prompt with AI", value=True,
info="Uses Gemma to automatically enrich your prompt with more details before generation."
)
with gr.Group():
aspect_ratio = gr.Radio(
label="Aspect Ratio",
choices=["Square (1:1)", "Portrait (9:16)", "Landscape (16:9)", "Ultrawide (21:9)"],
value="Square (1:1)"
)
lora_scale = gr.Slider(
label="Krypt0 Style Strength", minimum=0, maximum=2, step=0.05, value=0.9,
info="Controls how strongly the artistic style is applied. Higher values mean a more stylized image."
)
with gr.Accordion("Advanced Settings", open=False):
base_resolution = gr.Slider(label="Resolution (longest side)", minimum=768, maximum=1408, step=64, value=1024)
steps = gr.Slider(label="Generation Steps", minimum=4, maximum=50, step=1, value=20)
cfg_scale = gr.Slider(label="Guidance (CFG Scale)", minimum=1, maximum=10, step=0.5, value=3.5)
with gr.Row():
randomize_seed = gr.Checkbox(True, label="Random Seed")
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
generate_button = gr.Button("Generate", variant="primary")
with gr.Column(scale=2):
progress_bar = gr.Markdown(elem_id="progress", visible=False)
result = gr.Image(label="Generated Image", interactive=False, show_share_button=True)
with gr.Accordion("History", open=False):
history_gallery = gr.Gallery(label="History", columns=4, object_fit="contain", interactive=False)
def get_random_prompt():
return random.choice(prompt_values)
random_prompt_btn.click(fn=get_random_prompt, inputs=[], outputs=[prompt])
generation_event = gr.on(
triggers=[generate_button.click, prompt.submit],
fn=run_generation,
inputs=[prompt, enhance_prompt_checkbox, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution],
outputs=[result, seed, progress_bar]
)
generation_event.then(fn=update_history, inputs=[result, history_gallery], outputs=history_gallery)
app.queue(max_size=20)
app.launch()