Econogoat commited on
Commit
2fffe4a
·
verified ·
1 Parent(s): 4e07f4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -128
app.py CHANGED
@@ -9,8 +9,6 @@ from diffusers.utils import load_image
9
  import pandas as pd
10
  import random
11
  import time
12
-
13
- # --- NOUVEAU : Imports pour le LLM (Gemma) ---
14
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
15
 
16
  # --- Configuration Principale ---
@@ -19,6 +17,7 @@ KRYPTO_LORA = {
19
  "trigger": "Krypt0",
20
  "adapter_name": "krypt0"
21
  }
 
22
 
23
  # Charger les prompts
24
  df = pd.read_csv('prompts.csv', header=None)
@@ -27,149 +26,134 @@ prompt_values = df.values.flatten()
27
  # Récupérer le token
28
  HF_TOKEN = os.getenv("HF_TOKEN")
29
 
30
- # --- Initialisation des Modèles ---
31
- device = "cuda" if torch.cuda.is_available() else "cpu"
32
- print(f"Using device: {device}")
33
  dtype = torch.bfloat16
34
  base_model = "black-forest-labs/FLUX.1-dev"
35
 
36
- # --- NOUVEAU : Chargement du LLM Gemma pour l'amélioration des prompts ---
37
- gemma_model_id = "google/gemma-2-9b-it"
38
- print(f"Chargement du LLM pour l'amélioration des prompts : {gemma_model_id}")
39
-
40
- # Configuration de la quantization pour charger le modèle en 4-bit.
41
- # Cela réduit considérablement l'utilisation de la mémoire.
42
  quantization_config = BitsAndBytesConfig(load_in_4bit=True)
43
-
44
- gemma_tokenizer = AutoTokenizer.from_pretrained(gemma_model_id, token=HF_TOKEN)
45
  gemma_model = AutoModelForCausalLM.from_pretrained(
46
- gemma_model_id,
47
  quantization_config=quantization_config,
48
  token=HF_TOKEN,
49
- device_map="auto" # Laisse accelerate gérer le placement sur le GPU/CPU
50
  )
51
- print("Modèle Gemma chargé avec succès.")
52
-
53
 
54
- # --- Chargement des modèles d'image ---
55
- print("Loading model components...")
56
- taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
57
- good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype, token=HF_TOKEN).to(device)
58
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1, token=HF_TOKEN).to(device)
59
- print("Models loaded.")
60
 
61
- # Charger le LoRA
62
- print(f"Loading on-board LoRA: {KRYPTO_LORA['repo']}")
63
  pipe.load_lora_weights(
64
  KRYPTO_LORA['repo'],
65
  low_cpu_mem_usage=True,
66
  adapter_name=KRYPTO_LORA['adapter_name'],
67
  token=HF_TOKEN
68
  )
69
- print("LoRA loaded successfully.")
70
 
71
  MAX_SEED = 2**32 - 1
72
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
73
 
74
- # --- NOUVEAU : Fonction d'amélioration du prompt ---
75
- def improve_prompt_with_gemma(user_prompt):
76
- system_prompt = (
77
- "Tu dois toujours répondre au user : 'Krypt0 the white scruffy superdog eat banana' "
78
- )
79
-
80
- # Formatage pour Gemma
81
- chat = [
82
- {"role": "user", "content": f"{system_prompt}\n\nUser idea: \"{user_prompt}\""}
83
- ]
84
- prompt_for_gemma = gemma_tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
85
-
86
- inputs = gemma_tokenizer(prompt_for_gemma, return_tensors="pt").to(device)
87
-
88
- # Génération de la réponse
89
- outputs = gemma_model.generate(**inputs, max_new_tokens=150, do_sample=True, temperature=0.7)
90
-
91
- # Décodage et nettoyage de la réponse
92
- input_length = inputs["input_ids"].shape[1]
93
- enhanced_prompt = gemma_tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
94
-
95
- return enhanced_prompt.strip()
96
-
97
-
98
- def calculate_dimensions(aspect_ratio, resolution):
99
- resolution = int(resolution)
100
- if aspect_ratio == "Square (1:1)":
101
- width, height = resolution, resolution
102
- elif aspect_ratio == "Portrait (9:16)":
103
- width, height = int(resolution * 9 / 16), resolution
104
- elif aspect_ratio == "Landscape (16:9)":
105
- width, height = resolution, int(resolution * 9 / 16)
106
- elif aspect_ratio == "Ultrawide (21:9)":
107
- width, height = resolution, int(resolution * 9 / 21)
108
- else:
109
- width, height = resolution, resolution
110
- width = (width // 64) * 64
111
- height = (height // 64) * 64
112
- return width, height
113
-
114
- def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress):
115
- pipe.to(device)
116
- generator = torch.Generator(device=device).manual_seed(seed)
117
-
118
- image_generator = pipe.flux_pipe_call_that_returns_an_iterable_of_images(
119
- prompt=prompt_mash,
120
- num_inference_steps=steps,
121
- guidance_scale=cfg_scale,
122
- width=width,
123
- height=height,
124
- generator=generator,
125
- joint_attention_kwargs={"scale": 1.0},
126
- output_type="pil",
127
- good_vae=good_vae,
128
- )
129
- final_image = None
130
- for i, image in enumerate(image_generator):
131
- final_image = image
132
- progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {i + 1}; --total: {steps};"></div></div>'
133
- yield image, gr.update(value=progress_bar, visible=True)
134
- yield final_image, gr.update(visible=False)
135
 
136
- def update_history(new_image, history):
137
- if new_image is None:
138
- return history
139
- if history is None:
140
- history = []
141
- history.insert(0, new_image)
142
- return history
143
-
144
- @spaces.GPU(duration=75)
145
  def run_generation(prompt, enhance_prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution, progress=gr.Progress(track_tqdm=True)):
146
  if not prompt:
147
  raise gr.Error("Prompt cannot be empty.")
148
 
149
- # --- NOUVEAU : Logique d'amélioration du prompt ---
 
150
  final_prompt = prompt
 
151
  if enhance_prompt:
 
 
152
  print(f"Amélioration du prompt '{prompt}' avec Gemma...")
153
- final_prompt = improve_prompt_with_gemma(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  print(f"Prompt amélioré : {final_prompt}")
155
-
 
 
156
  prompt_mash = f"{KRYPTO_LORA['trigger']}, {final_prompt}"
157
  print("Prompt final envoyé au modèle d'image:", prompt_mash)
158
 
 
 
 
 
 
159
  pipe.set_adapters([KRYPTO_LORA['adapter_name']], adapter_weights=[lora_scale])
160
 
161
  if randomize_seed:
162
  seed = random.randint(0, MAX_SEED)
163
 
164
  width, height = calculate_dimensions(aspect_ratio, base_resolution)
165
- print(f"Generating a {width}x{height} image.")
 
 
 
 
 
 
 
166
 
167
- for image, progress_update in generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress):
168
- yield image, seed, progress_update
 
 
 
169
 
170
- run_generation.zerogpu = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
- # --- Interface Utilisateur (Gradio) ---
173
  css = '''
174
  #title_container { text-align: center; margin-bottom: 1em; }
175
  #title_line { display: flex; justify-content: center; align-items: center; }
@@ -182,7 +166,6 @@ css = '''
182
  '''
183
 
184
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as app:
185
- # --- Header ---
186
  gr.HTML(
187
  """
188
  <div id="title_container">
@@ -196,26 +179,18 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as app:
196
  </div>
197
  """
198
  )
199
-
200
  with gr.Row():
201
- # --- LEFT COLUMN: CONTROLS ---
202
  with gr.Column(scale=3):
203
- # Prompt Controls
204
  with gr.Group():
205
  with gr.Row():
206
  with gr.Column(scale=1, min_width=150):
207
  random_prompt_btn = gr.Button("Random Prompt", elem_id="random_prompt_btn")
208
  with gr.Column(scale=5):
209
  prompt = gr.Textbox(label="Prompt", lines=2, placeholder="e.g., a portrait of a warrior queen")
210
-
211
- # --- NOUVEAU : Case à cocher pour l'amélioration AI ---
212
  enhance_prompt_checkbox = gr.Checkbox(
213
- label="Improve prompt with AI",
214
- value=True,
215
  info="Uses Gemma to automatically enrich your prompt with more details before generation."
216
  )
217
-
218
- # Image Shape and Style Controls
219
  with gr.Group():
220
  aspect_ratio = gr.Radio(
221
  label="Aspect Ratio",
@@ -223,44 +198,31 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as app:
223
  value="Square (1:1)"
224
  )
225
  lora_scale = gr.Slider(
226
- label="Krypt0 Style Strength",
227
- minimum=0, maximum=2, step=0.05, value=0.9,
228
  info="Controls how strongly the artistic style is applied. Higher values mean a more stylized image."
229
  )
230
-
231
- # Advanced Settings
232
  with gr.Accordion("Advanced Settings", open=False):
233
  base_resolution = gr.Slider(label="Resolution (longest side)", minimum=768, maximum=1408, step=64, value=1024)
234
  steps = gr.Slider(label="Generation Steps", minimum=4, maximum=50, step=1, value=20)
235
  cfg_scale = gr.Slider(label="Guidance (CFG Scale)", minimum=1, maximum=10, step=0.5, value=3.5)
236
-
237
  with gr.Row():
238
  randomize_seed = gr.Checkbox(True, label="Random Seed")
239
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
240
-
241
  generate_button = gr.Button("Generate", variant="primary")
242
-
243
- # --- RIGHT COLUMN: RESULTS ---
244
  with gr.Column(scale=2):
245
  progress_bar = gr.Markdown(elem_id="progress", visible=False)
246
  result = gr.Image(label="Generated Image", interactive=False, show_share_button=True)
247
  with gr.Accordion("History", open=False):
248
  history_gallery = gr.Gallery(label="History", columns=4, object_fit="contain", interactive=False)
249
-
250
- # --- Event Logic ---
251
  def get_random_prompt():
252
  return random.choice(prompt_values)
253
-
254
  random_prompt_btn.click(fn=get_random_prompt, inputs=[], outputs=[prompt])
255
-
256
- # MODIFIÉ : Ajout de `enhance_prompt_checkbox` dans les entrées
257
  generation_event = gr.on(
258
  triggers=[generate_button.click, prompt.submit],
259
  fn=run_generation,
260
  inputs=[prompt, enhance_prompt_checkbox, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution],
261
  outputs=[result, seed, progress_bar]
262
  )
263
-
264
  generation_event.then(fn=update_history, inputs=[result, history_gallery], outputs=history_gallery)
265
 
266
  app.queue(max_size=20)
 
9
  import pandas as pd
10
  import random
11
  import time
 
 
12
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
13
 
14
  # --- Configuration Principale ---
 
17
  "trigger": "Krypt0",
18
  "adapter_name": "krypt0"
19
  }
20
+ GEMMA_MODEL_ID = "google/gemma-2-9b-it"
21
 
22
  # Charger les prompts
23
  df = pd.read_csv('prompts.csv', header=None)
 
26
  # Récupérer le token
27
  HF_TOKEN = os.getenv("HF_TOKEN")
28
 
29
+ # --- Initialisation des Modèles (sur CPU uniquement) ---
30
+ # CORRECTION : On force le chargement sur CPU pour éviter d'initialiser CUDA.
31
+ device_cpu = "cpu"
32
  dtype = torch.bfloat16
33
  base_model = "black-forest-labs/FLUX.1-dev"
34
 
35
+ # --- NOUVEAU : Chargement de Gemma sur CPU ---
36
+ print(f"Chargement du LLM {GEMMA_MODEL_ID} sur CPU...")
 
 
 
 
37
  quantization_config = BitsAndBytesConfig(load_in_4bit=True)
38
+ gemma_tokenizer = AutoTokenizer.from_pretrained(GEMMA_MODEL_ID, token=HF_TOKEN)
39
+ # CORRECTION : On spécifie `device_map` pour forcer le CPU au démarrage.
40
  gemma_model = AutoModelForCausalLM.from_pretrained(
41
+ GEMMA_MODEL_ID,
42
  quantization_config=quantization_config,
43
  token=HF_TOKEN,
44
+ device_map={'':device_cpu}
45
  )
46
+ print("Modèle Gemma chargé.")
 
47
 
48
+ # --- Chargement des modèles d'image sur CPU ---
49
+ print("Chargement des composants du modèle d'image sur CPU...")
50
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device_cpu)
51
+ good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype, token=HF_TOKEN).to(device_cpu)
52
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1, token=HF_TOKEN).to(device_cpu)
53
+ print("Modèles d'image chargés.")
54
 
55
+ # Charger le LoRA (sur le modèle qui est sur CPU)
56
+ print(f"Chargement du LoRA : {KRYPTO_LORA['repo']}")
57
  pipe.load_lora_weights(
58
  KRYPTO_LORA['repo'],
59
  low_cpu_mem_usage=True,
60
  adapter_name=KRYPTO_LORA['adapter_name'],
61
  token=HF_TOKEN
62
  )
63
+ print("LoRA chargé.")
64
 
65
  MAX_SEED = 2**32 - 1
66
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ @spaces.GPU(duration=120) # Augmentation de la durée pour accommoder le déplacement des modèles
 
 
 
 
 
 
 
 
70
  def run_generation(prompt, enhance_prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution, progress=gr.Progress(track_tqdm=True)):
71
  if not prompt:
72
  raise gr.Error("Prompt cannot be empty.")
73
 
74
+ # --- CORRECTION : Le déplacement vers le GPU se fait ICI ---
75
+ device_gpu = "cuda"
76
  final_prompt = prompt
77
+
78
  if enhance_prompt:
79
+ print("Déplacement de Gemma sur le GPU...")
80
+ gemma_model.to(device_gpu)
81
  print(f"Amélioration du prompt '{prompt}' avec Gemma...")
82
+
83
+ system_prompt = (
84
+ "You are an expert prompt engineer for a text-to-image AI. "
85
+ "Your task is to take a user's simple idea and transform it into a rich, detailed, and visually descriptive prompt. "
86
+ "Focus on describing the scene, the subject, the environment, the lighting, the colors, and a potential artistic style. "
87
+ "Do not add any conversational text or refuse the request. Only output the enhanced prompt."
88
+ )
89
+ chat = [{"role": "user", "content": f"{system_prompt}\n\nUser idea: \"{user_prompt}\""}]
90
+ prompt_for_gemma = gemma_tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
91
+ inputs = gemma_tokenizer(prompt_for_gemma, return_tensors="pt").to(device_gpu)
92
+
93
+ outputs = gemma_model.generate(**inputs, max_new_tokens=150, do_sample=True, temperature=0.7)
94
+ input_length = inputs["input_ids"].shape[1]
95
+ final_prompt = gemma_tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True).strip()
96
+
97
  print(f"Prompt amélioré : {final_prompt}")
98
+ print("Libération de la mémoire de Gemma (déplacement vers CPU)...")
99
+ gemma_model.to(device_cpu) # Libère la VRAM du GPU
100
+
101
  prompt_mash = f"{KRYPTO_LORA['trigger']}, {final_prompt}"
102
  print("Prompt final envoyé au modèle d'image:", prompt_mash)
103
 
104
+ # --- Déplacement du pipeline d'image sur le GPU ---
105
+ print("Déplacement du pipeline d'image sur le GPU...")
106
+ pipe.to(device_gpu)
107
+ good_vae.to(device_gpu)
108
+
109
  pipe.set_adapters([KRYPTO_LORA['adapter_name']], adapter_weights=[lora_scale])
110
 
111
  if randomize_seed:
112
  seed = random.randint(0, MAX_SEED)
113
 
114
  width, height = calculate_dimensions(aspect_ratio, base_resolution)
115
+ print(f"Génération d'une image de {width}x{height} pixels.")
116
+
117
+ generator = torch.Generator(device=device_gpu).manual_seed(seed)
118
+
119
+ image_generator = pipe.flux_pipe_call_that_returns_an_iterable_of_images(
120
+ prompt=prompt_mash, num_inference_steps=steps, guidance_scale=cfg_scale,
121
+ width=width, height=height, generator=generator, output_type="pil", good_vae=good_vae,
122
+ )
123
 
124
+ final_image = None
125
+ for i, image in enumerate(image_generator):
126
+ final_image = image
127
+ progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {i + 1}; --total: {steps};"></div></div>'
128
+ yield image, seed, gr.update(value=progress_bar, visible=True)
129
 
130
+ # --- Libération de la VRAM ---
131
+ print("Libération de la mémoire du pipeline d'image (déplacement vers CPU)...")
132
+ pipe.to(device_cpu)
133
+ good_vae.to(device_cpu)
134
+ torch.cuda.empty_cache()
135
+
136
+ yield final_image, seed, gr.update(visible=False)
137
+
138
+ # Le reste du code (fonctions d'aide et interface) reste le même
139
+
140
+ def calculate_dimensions(aspect_ratio, resolution):
141
+ resolution = int(resolution)
142
+ if aspect_ratio == "Square (1:1)": width, height = resolution, resolution
143
+ elif aspect_ratio == "Portrait (9:16)": width, height = int(resolution * 9 / 16), resolution
144
+ elif aspect_ratio == "Landscape (16:9)": width, height = resolution, int(resolution * 9 / 16)
145
+ elif aspect_ratio == "Ultrawide (21:9)": width, height = resolution, int(resolution * 9 / 21)
146
+ else: width, height = resolution, resolution
147
+ width = (width // 64) * 64
148
+ height = (height // 64) * 64
149
+ return width, height
150
+
151
+ def update_history(new_image, history):
152
+ if new_image is None: return history
153
+ if history is None: history = []
154
+ history.insert(0, new_image)
155
+ return history
156
 
 
157
  css = '''
158
  #title_container { text-align: center; margin-bottom: 1em; }
159
  #title_line { display: flex; justify-content: center; align-items: center; }
 
166
  '''
167
 
168
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as app:
 
169
  gr.HTML(
170
  """
171
  <div id="title_container">
 
179
  </div>
180
  """
181
  )
 
182
  with gr.Row():
 
183
  with gr.Column(scale=3):
 
184
  with gr.Group():
185
  with gr.Row():
186
  with gr.Column(scale=1, min_width=150):
187
  random_prompt_btn = gr.Button("Random Prompt", elem_id="random_prompt_btn")
188
  with gr.Column(scale=5):
189
  prompt = gr.Textbox(label="Prompt", lines=2, placeholder="e.g., a portrait of a warrior queen")
 
 
190
  enhance_prompt_checkbox = gr.Checkbox(
191
+ label="Improve prompt with AI", value=True,
 
192
  info="Uses Gemma to automatically enrich your prompt with more details before generation."
193
  )
 
 
194
  with gr.Group():
195
  aspect_ratio = gr.Radio(
196
  label="Aspect Ratio",
 
198
  value="Square (1:1)"
199
  )
200
  lora_scale = gr.Slider(
201
+ label="Krypt0 Style Strength", minimum=0, maximum=2, step=0.05, value=0.9,
 
202
  info="Controls how strongly the artistic style is applied. Higher values mean a more stylized image."
203
  )
 
 
204
  with gr.Accordion("Advanced Settings", open=False):
205
  base_resolution = gr.Slider(label="Resolution (longest side)", minimum=768, maximum=1408, step=64, value=1024)
206
  steps = gr.Slider(label="Generation Steps", minimum=4, maximum=50, step=1, value=20)
207
  cfg_scale = gr.Slider(label="Guidance (CFG Scale)", minimum=1, maximum=10, step=0.5, value=3.5)
 
208
  with gr.Row():
209
  randomize_seed = gr.Checkbox(True, label="Random Seed")
210
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
 
211
  generate_button = gr.Button("Generate", variant="primary")
 
 
212
  with gr.Column(scale=2):
213
  progress_bar = gr.Markdown(elem_id="progress", visible=False)
214
  result = gr.Image(label="Generated Image", interactive=False, show_share_button=True)
215
  with gr.Accordion("History", open=False):
216
  history_gallery = gr.Gallery(label="History", columns=4, object_fit="contain", interactive=False)
 
 
217
  def get_random_prompt():
218
  return random.choice(prompt_values)
 
219
  random_prompt_btn.click(fn=get_random_prompt, inputs=[], outputs=[prompt])
 
 
220
  generation_event = gr.on(
221
  triggers=[generate_button.click, prompt.submit],
222
  fn=run_generation,
223
  inputs=[prompt, enhance_prompt_checkbox, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution],
224
  outputs=[result, seed, progress_bar]
225
  )
 
226
  generation_event.then(fn=update_history, inputs=[result, history_gallery], outputs=history_gallery)
227
 
228
  app.queue(max_size=20)