Econogoat commited on
Commit
a6ee92e
·
verified ·
1 Parent(s): c09e4fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -36
app.py CHANGED
@@ -10,51 +10,93 @@ import pandas as pd
10
  import random
11
  import time
12
 
13
- # --- Main Configuration ---
 
 
 
14
  KRYPTO_LORA = {
15
- # CORRECTION : Le nom du dépôt était mal orthographié (O majuscule au lieu d'un zéro).
16
- "repo": "Econogoat/Krypt0_LORA",
17
  "trigger": "Krypt0",
18
  "adapter_name": "krypt0"
19
  }
20
 
21
- # Load prompts for the randomize button
22
  df = pd.read_csv('prompts.csv', header=None)
23
  prompt_values = df.values.flatten()
24
 
25
- # Get access token from Space secrets
26
  HF_TOKEN = os.getenv("HF_TOKEN")
27
- if not HF_TOKEN:
28
- print("WARNING: HF_TOKEN secret is not set. Gated model downloads may fail.")
29
 
30
- # --- Model Initialization ---
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
32
  print(f"Using device: {device}")
33
  dtype = torch.bfloat16
34
  base_model = "black-forest-labs/FLUX.1-dev"
35
 
36
- # Load model components
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  print("Loading model components...")
38
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
39
  good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype, token=HF_TOKEN).to(device)
40
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1, token=HF_TOKEN).to(device)
41
  print("Models loaded.")
42
 
43
- # Load the LoRA adapter once on startup
44
  print(f"Loading on-board LoRA: {KRYPTO_LORA['repo']}")
45
  pipe.load_lora_weights(
46
  KRYPTO_LORA['repo'],
47
  low_cpu_mem_usage=True,
48
  adapter_name=KRYPTO_LORA['adapter_name'],
49
- token=HF_TOKEN # Ajout du token ici aussi pour les LoRA privés/protégés
50
  )
51
  print("LoRA loaded successfully.")
52
 
53
  MAX_SEED = 2**32 - 1
54
-
55
- # Monkey-patch the pipeline for live preview
56
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  def calculate_dimensions(aspect_ratio, resolution):
60
  resolution = int(resolution)
@@ -103,15 +145,21 @@ def update_history(new_image, history):
103
  return history
104
 
105
  @spaces.GPU(duration=75)
106
- def run_generation(prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution, progress=gr.Progress(track_tqdm=True)):
107
  if not prompt:
108
  raise gr.Error("Prompt cannot be empty.")
109
 
110
- prompt_mash = f"{KRYPTO_LORA['trigger']}, {prompt}"
111
- print("Final prompt:", prompt_mash)
 
 
 
 
 
 
 
112
 
113
  pipe.set_adapters([KRYPTO_LORA['adapter_name']], adapter_weights=[lora_scale])
114
- print(f"Adapter '{KRYPTO_LORA['adapter_name']}' activated with weight {lora_scale}.")
115
 
116
  if randomize_seed:
117
  seed = random.randint(0, MAX_SEED)
@@ -124,7 +172,7 @@ def run_generation(prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, a
124
 
125
  run_generation.zerogpu = True
126
 
127
- # --- User Interface (Gradio) ---
128
  css = '''
129
  #title_container { text-align: center; margin-bottom: 1em; }
130
  #title_line { display: flex; justify-content: center; align-items: center; }
@@ -142,7 +190,7 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as app:
142
  """
143
  <div id="title_container">
144
  <div id="title_line">
145
- <img src="/file=LogoKrypto.png" alt="Krypt0 Logo">
146
  <h1>Krypto Image Generator - beta v1</h1>
147
  </div>
148
  <div id="subtitle">
@@ -162,7 +210,14 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as app:
162
  random_prompt_btn = gr.Button("Random Prompt", elem_id="random_prompt_btn")
163
  with gr.Column(scale=5):
164
  prompt = gr.Textbox(label="Prompt", lines=2, placeholder="e.g., a portrait of a warrior queen")
165
-
 
 
 
 
 
 
 
166
  # Image Shape and Style Controls
167
  with gr.Group():
168
  aspect_ratio = gr.Radio(
@@ -172,15 +227,11 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as app:
172
  )
173
  lora_scale = gr.Slider(
174
  label="Krypt0 Style Strength",
175
- minimum=0,
176
- maximum=2,
177
- step=0.05,
178
- value=0.9,
179
  info="Controls how strongly the artistic style is applied. Higher values mean a more stylized image."
180
  )
181
 
182
  # Advanced Settings
183
- # CORRECTION : L'accordéon doit être fermé par défaut.
184
  with gr.Accordion("Advanced Settings", open=False):
185
  base_resolution = gr.Slider(label="Resolution (longest side)", minimum=768, maximum=1408, step=64, value=1024)
186
  steps = gr.Slider(label="Generation Steps", minimum=4, maximum=50, step=1, value=20)
@@ -203,24 +254,17 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as app:
203
  def get_random_prompt():
204
  return random.choice(prompt_values)
205
 
206
- random_prompt_btn.click(
207
- fn=get_random_prompt,
208
- inputs=[],
209
- outputs=[prompt]
210
- )
211
 
 
212
  generation_event = gr.on(
213
  triggers=[generate_button.click, prompt.submit],
214
  fn=run_generation,
215
- inputs=[prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution],
216
  outputs=[result, seed, progress_bar]
217
  )
218
 
219
- generation_event.then(
220
- fn=update_history,
221
- inputs=[result, history_gallery],
222
- outputs=history_gallery,
223
- )
224
 
225
  app.queue(max_size=20)
226
  app.launch()
 
10
  import random
11
  import time
12
 
13
+ # --- NOUVEAU : Imports pour le LLM (Gemma) ---
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
15
+
16
+ # --- Configuration Principale ---
17
  KRYPTO_LORA = {
18
+ "repo": "Econogoat/Krypt0_LORA",
 
19
  "trigger": "Krypt0",
20
  "adapter_name": "krypt0"
21
  }
22
 
23
+ # Charger les prompts
24
  df = pd.read_csv('prompts.csv', header=None)
25
  prompt_values = df.values.flatten()
26
 
27
+ # Récupérer le token
28
  HF_TOKEN = os.getenv("HF_TOKEN")
 
 
29
 
30
+ # --- Initialisation des Modèles ---
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
32
  print(f"Using device: {device}")
33
  dtype = torch.bfloat16
34
  base_model = "black-forest-labs/FLUX.1-dev"
35
 
36
+ # --- NOUVEAU : Chargement du LLM Gemma pour l'amélioration des prompts ---
37
+ gemma_model_id = "google/gemma-2-9b-it"
38
+ print(f"Chargement du LLM pour l'amélioration des prompts : {gemma_model_id}")
39
+
40
+ # Configuration de la quantization pour charger le modèle en 4-bit.
41
+ # Cela réduit considérablement l'utilisation de la mémoire.
42
+ quantization_config = BitsAndBytesConfig(load_in_4bit=True)
43
+
44
+ gemma_tokenizer = AutoTokenizer.from_pretrained(gemma_model_id, token=HF_TOKEN)
45
+ gemma_model = AutoModelForCausalLM.from_pretrained(
46
+ gemma_model_id,
47
+ quantization_config=quantization_config,
48
+ token=HF_TOKEN,
49
+ device_map="auto" # Laisse accelerate gérer le placement sur le GPU/CPU
50
+ )
51
+ print("Modèle Gemma chargé avec succès.")
52
+
53
+
54
+ # --- Chargement des modèles d'image ---
55
  print("Loading model components...")
56
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
57
  good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype, token=HF_TOKEN).to(device)
58
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1, token=HF_TOKEN).to(device)
59
  print("Models loaded.")
60
 
61
+ # Charger le LoRA
62
  print(f"Loading on-board LoRA: {KRYPTO_LORA['repo']}")
63
  pipe.load_lora_weights(
64
  KRYPTO_LORA['repo'],
65
  low_cpu_mem_usage=True,
66
  adapter_name=KRYPTO_LORA['adapter_name'],
67
+ token=HF_TOKEN
68
  )
69
  print("LoRA loaded successfully.")
70
 
71
  MAX_SEED = 2**32 - 1
 
 
72
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
73
 
74
+ # --- NOUVEAU : Fonction d'amélioration du prompt ---
75
+ def improve_prompt_with_gemma(user_prompt):
76
+ system_prompt = (
77
+ "You are an expert prompt engineer for a text-to-image AI. "
78
+ "Your task is to take a user's simple idea and transform it into a rich, detailed, and visually descriptive prompt. "
79
+ "Focus on describing the scene, the subject, the environment, the lighting, the colors, and a potential artistic style. "
80
+ "Do not add any conversational text or refuse the request. Only output the enhanced prompt."
81
+ )
82
+
83
+ # Formatage pour Gemma
84
+ chat = [
85
+ {"role": "user", "content": f"{system_prompt}\n\nUser idea: \"{user_prompt}\""}
86
+ ]
87
+ prompt_for_gemma = gemma_tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
88
+
89
+ inputs = gemma_tokenizer(prompt_for_gemma, return_tensors="pt").to(device)
90
+
91
+ # Génération de la réponse
92
+ outputs = gemma_model.generate(**inputs, max_new_tokens=150, do_sample=True, temperature=0.7)
93
+
94
+ # Décodage et nettoyage de la réponse
95
+ input_length = inputs["input_ids"].shape[1]
96
+ enhanced_prompt = gemma_tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
97
+
98
+ return enhanced_prompt.strip()
99
+
100
 
101
  def calculate_dimensions(aspect_ratio, resolution):
102
  resolution = int(resolution)
 
145
  return history
146
 
147
  @spaces.GPU(duration=75)
148
+ def run_generation(prompt, enhance_prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution, progress=gr.Progress(track_tqdm=True)):
149
  if not prompt:
150
  raise gr.Error("Prompt cannot be empty.")
151
 
152
+ # --- NOUVEAU : Logique d'amélioration du prompt ---
153
+ final_prompt = prompt
154
+ if enhance_prompt:
155
+ print(f"Amélioration du prompt '{prompt}' avec Gemma...")
156
+ final_prompt = improve_prompt_with_gemma(prompt)
157
+ print(f"Prompt amélioré : {final_prompt}")
158
+
159
+ prompt_mash = f"{KRYPTO_LORA['trigger']}, {final_prompt}"
160
+ print("Prompt final envoyé au modèle d'image:", prompt_mash)
161
 
162
  pipe.set_adapters([KRYPTO_LORA['adapter_name']], adapter_weights=[lora_scale])
 
163
 
164
  if randomize_seed:
165
  seed = random.randint(0, MAX_SEED)
 
172
 
173
  run_generation.zerogpu = True
174
 
175
+ # --- Interface Utilisateur (Gradio) ---
176
  css = '''
177
  #title_container { text-align: center; margin-bottom: 1em; }
178
  #title_line { display: flex; justify-content: center; align-items: center; }
 
190
  """
191
  <div id="title_container">
192
  <div id="title_line">
193
+ <img src="/file=logo.png" alt="Krypt0 Logo">
194
  <h1>Krypto Image Generator - beta v1</h1>
195
  </div>
196
  <div id="subtitle">
 
210
  random_prompt_btn = gr.Button("Random Prompt", elem_id="random_prompt_btn")
211
  with gr.Column(scale=5):
212
  prompt = gr.Textbox(label="Prompt", lines=2, placeholder="e.g., a portrait of a warrior queen")
213
+
214
+ # --- NOUVEAU : Case à cocher pour l'amélioration AI ---
215
+ enhance_prompt_checkbox = gr.Checkbox(
216
+ label="Improve prompt with AI",
217
+ value=True,
218
+ info="Uses Gemma to automatically enrich your prompt with more details before generation."
219
+ )
220
+
221
  # Image Shape and Style Controls
222
  with gr.Group():
223
  aspect_ratio = gr.Radio(
 
227
  )
228
  lora_scale = gr.Slider(
229
  label="Krypt0 Style Strength",
230
+ minimum=0, maximum=2, step=0.05, value=0.9,
 
 
 
231
  info="Controls how strongly the artistic style is applied. Higher values mean a more stylized image."
232
  )
233
 
234
  # Advanced Settings
 
235
  with gr.Accordion("Advanced Settings", open=False):
236
  base_resolution = gr.Slider(label="Resolution (longest side)", minimum=768, maximum=1408, step=64, value=1024)
237
  steps = gr.Slider(label="Generation Steps", minimum=4, maximum=50, step=1, value=20)
 
254
  def get_random_prompt():
255
  return random.choice(prompt_values)
256
 
257
+ random_prompt_btn.click(fn=get_random_prompt, inputs=[], outputs=[prompt])
 
 
 
 
258
 
259
+ # MODIFIÉ : Ajout de `enhance_prompt_checkbox` dans les entrées
260
  generation_event = gr.on(
261
  triggers=[generate_button.click, prompt.submit],
262
  fn=run_generation,
263
+ inputs=[prompt, enhance_prompt_checkbox, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution],
264
  outputs=[result, seed, progress_bar]
265
  )
266
 
267
+ generation_event.then(fn=update_history, inputs=[result, history_gallery], outputs=history_gallery)
 
 
 
 
268
 
269
  app.queue(max_size=20)
270
  app.launch()