Econogoat commited on
Commit
e4f3dd9
·
verified ·
1 Parent(s): 67954ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -54
app.py CHANGED
@@ -9,91 +9,71 @@ from diffusers.utils import load_image
9
  import pandas as pd
10
  import random
11
  import time
12
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
13
 
14
  # --- Configuration Principale ---
15
  KRYPTO_LORA = {"repo": "Econogoat/Krypt0_LORA", "trigger": "Krypt0", "adapter_name": "krypt0"}
16
- # --- CORRECTION DÉFINITIVE : Utilisation du bon ID de modèle, comme vous l'avez demandé ---
17
- GEMMA_MODEL_ID = "google/gemma-1.1-2b-it"
18
  BASE_IMAGE_MODEL = "black-forest-labs/FLUX.1-dev"
19
 
20
- # --- Pré-chargement sur CPU des éléments légers UNIQUEMENT ---
21
- print("Pré-chargement des tokenizers sur CPU...")
22
- HF_TOKEN = os.getenv("HF_TOKEN")
23
- gemma_tokenizer = AutoTokenizer.from_pretrained(GEMMA_MODEL_ID, token=HF_TOKEN)
24
- print("Pré-chargement terminé. Les modèles lourds seront chargés à la demande sur le GPU.")
25
 
26
- # --- Variables globales pour conserver les modèles en mémoire sur le GPU ---
 
27
  pipe = None
28
  good_vae = None
29
- gemma_model = None
30
 
31
  MAX_SEED = 2**32 - 1
32
 
33
- @spaces.GPU(duration=180) # Durée augmentée pour le premier chargement
34
- def run_generation(prompt, enhance_prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution, progress=gr.Progress(track_tqdm=True)):
35
- global pipe, good_vae, gemma_model
36
  if not prompt:
37
  raise gr.Error("Prompt cannot be empty.")
38
 
39
  device_gpu = "cuda"
40
  device_cpu = "cpu"
41
  dtype = torch.bfloat16
42
- final_prompt = prompt
43
 
44
- # --- Chargement à la demande ("Lazy Loading") des modèles sur le GPU ---
 
45
  if pipe is None:
46
  print("Premier appel : Chargement du pipeline d'image sur GPU...")
 
47
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device_gpu)
48
  good_vae = AutoencoderKL.from_pretrained(BASE_IMAGE_MODEL, subfolder="vae", torch_dtype=dtype, token=HF_TOKEN).to(device_gpu)
49
  pipe = DiffusionPipeline.from_pretrained(BASE_IMAGE_MODEL, torch_dtype=dtype, vae=taef1, token=HF_TOKEN).to(device_gpu)
50
  print("Chargement du LoRA sur le pipeline GPU...")
51
  pipe.load_lora_weights(KRYPTO_LORA['repo'], low_cpu_mem_usage=False, adapter_name=KRYPTO_LORA['adapter_name'], token=HF_TOKEN)
52
  print("Pipeline d'image prêt.")
 
53
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
54
 
55
- if enhance_prompt:
56
- if gemma_model is None:
57
- print(f"Premier appel d'amélioration : Chargement de {GEMMA_MODEL_ID} sur GPU...")
58
- # La quantization est moins critique pour ce petit modèle, mais reste une bonne pratique.
59
- bnb_config_gpu = BitsAndBytesConfig(
60
- load_in_4bit=True,
61
- bnb_4bit_quant_type="nf4",
62
- bnb_4bit_use_double_quant=True,
63
- bnb_4bit_compute_dtype=dtype
64
- )
65
- gemma_model = AutoModelForCausalLM.from_pretrained(
66
- GEMMA_MODEL_ID,
67
- quantization_config=bnb_config_gpu,
68
- token=HF_TOKEN,
69
- device_map="auto"
70
- )
71
- print("Modèle Gemma prêt.")
72
-
73
- print(f"Amélioration du prompt '{prompt}' avec Gemma...")
74
- system_prompt = "You are an expert prompt engineer for a text-to-image AI. Your task is to take a user's simple idea and transform it into a rich, detailed, and visually descriptive prompt. Focus on describing the scene, the subject, the environment, the lighting, the colors, and a potential artistic style. Do not add any conversational text or refuse the request. Only output the enhanced prompt."
75
- chat = [{"role": "user", "content": f"{system_prompt}\n\nUser idea: \"{user_prompt}\""}]
76
- prompt_for_gemma = gemma_tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
77
- inputs = gemma_tokenizer(prompt_for_gemma, return_tensors="pt").to(device_gpu)
78
-
79
- outputs = gemma_model.generate(**inputs, max_new_tokens=150, do_sample=True, temperature=0.7)
80
- input_length = inputs["input_ids"].shape[1]
81
- final_prompt = gemma_tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True).strip()
82
- print(f"Prompt amélioré : {final_prompt}")
83
-
84
- prompt_mash = f"{KRYPTO_LORA['trigger']}, {final_prompt}"
85
- print("Prompt final envoyé au modèle d'image:", prompt_mash)
86
 
 
87
  pipe.set_adapters([KRYPTO_LORA['adapter_name']], adapter_weights=[lora_scale])
88
 
89
- if randomize_seed: seed = random.randint(0, MAX_SEED)
 
 
90
  width, height = calculate_dimensions(aspect_ratio, base_resolution)
91
  print(f"Génération d'une image de {width}x{height} pixels.")
92
 
93
  generator = torch.Generator(device=device_gpu).manual_seed(seed)
 
 
94
  image_generator = pipe.flux_pipe_call_that_returns_an_iterable_of_images(
95
- prompt=prompt_mash, num_inference_steps=steps, guidance_scale=cfg_scale,
96
- width=width, height=height, generator=generator, output_type="pil", good_vae=good_vae,
 
 
 
 
 
 
97
  )
98
 
99
  final_image = None
@@ -102,6 +82,14 @@ def run_generation(prompt, enhance_prompt, lora_scale, cfg_scale, steps, randomi
102
  progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {i + 1}; --total: {steps};"></div></div>'
103
  yield image, seed, gr.update(value=progress_bar, visible=True)
104
 
 
 
 
 
 
 
 
 
105
  yield final_image, seed, gr.update(visible=False)
106
 
107
  def calculate_dimensions(aspect_ratio, resolution):
@@ -154,10 +142,6 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as app:
154
  random_prompt_btn = gr.Button("Random Prompt", elem_id="random_prompt_btn")
155
  with gr.Column(scale=5):
156
  prompt = gr.Textbox(label="Prompt", lines=2, placeholder="e.g., a portrait of a warrior queen")
157
- enhance_prompt_checkbox = gr.Checkbox(
158
- label="Improve prompt with AI", value=True,
159
- info="Uses Gemma to automatically enrich your prompt with more details before generation."
160
- )
161
  with gr.Group():
162
  aspect_ratio = gr.Radio(
163
  label="Aspect Ratio",
@@ -181,15 +165,20 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as app:
181
  result = gr.Image(label="Generated Image", interactive=False, show_share_button=True)
182
  with gr.Accordion("History", open=False):
183
  history_gallery = gr.Gallery(label="History", columns=4, object_fit="contain", interactive=False)
 
184
  def get_random_prompt():
185
  return random.choice(prompt_values)
 
186
  random_prompt_btn.click(fn=get_random_prompt, inputs=[], outputs=[prompt])
 
 
187
  generation_event = gr.on(
188
  triggers=[generate_button.click, prompt.submit],
189
  fn=run_generation,
190
- inputs=[prompt, enhance_prompt_checkbox, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution],
191
  outputs=[result, seed, progress_bar]
192
  )
 
193
  generation_event.then(fn=update_history, inputs=[result, history_gallery], outputs=history_gallery)
194
 
195
  app.queue(max_size=20)
 
9
  import pandas as pd
10
  import random
11
  import time
 
12
 
13
  # --- Configuration Principale ---
14
  KRYPTO_LORA = {"repo": "Econogoat/Krypt0_LORA", "trigger": "Krypt0", "adapter_name": "krypt0"}
 
 
15
  BASE_IMAGE_MODEL = "black-forest-labs/FLUX.1-dev"
16
 
17
+ # Charger les prompts
18
+ df = pd.read_csv('prompts.csv', header=None)
19
+ prompt_values = df.values.flatten()
 
 
20
 
21
+ # --- Variables globales pour conserver les modèles en mémoire ---
22
+ # On utilise une stratégie de chargement à la demande ("lazy loading") pour être compatible avec ZeroGPU
23
  pipe = None
24
  good_vae = None
 
25
 
26
  MAX_SEED = 2**32 - 1
27
 
28
+ @spaces.GPU(duration=180) # Durée pour accommoder le premier chargement du modèle d'image
29
+ def run_generation(prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution, progress=gr.Progress(track_tqdm=True)):
30
+ global pipe, good_vae
31
  if not prompt:
32
  raise gr.Error("Prompt cannot be empty.")
33
 
34
  device_gpu = "cuda"
35
  device_cpu = "cpu"
36
  dtype = torch.bfloat16
 
37
 
38
+ # --- Chargement à la demande ("Lazy Loading") du modèle d'image sur le GPU ---
39
+ # Cette section ne s'exécute qu'une seule fois, lors du tout premier appel.
40
  if pipe is None:
41
  print("Premier appel : Chargement du pipeline d'image sur GPU...")
42
+ HF_TOKEN = os.getenv("HF_TOKEN")
43
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device_gpu)
44
  good_vae = AutoencoderKL.from_pretrained(BASE_IMAGE_MODEL, subfolder="vae", torch_dtype=dtype, token=HF_TOKEN).to(device_gpu)
45
  pipe = DiffusionPipeline.from_pretrained(BASE_IMAGE_MODEL, torch_dtype=dtype, vae=taef1, token=HF_TOKEN).to(device_gpu)
46
  print("Chargement du LoRA sur le pipeline GPU...")
47
  pipe.load_lora_weights(KRYPTO_LORA['repo'], low_cpu_mem_usage=False, adapter_name=KRYPTO_LORA['adapter_name'], token=HF_TOKEN)
48
  print("Pipeline d'image prêt.")
49
+ # Ajout de la méthode de prévisualisation
50
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
51
 
52
+ # Création du prompt final
53
+ prompt_mash = f"{KRYPTO_LORA['trigger']}, {prompt}"
54
+ print("Prompt final envoyé au modèle:", prompt_mash)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ # Activation du LoRA
57
  pipe.set_adapters([KRYPTO_LORA['adapter_name']], adapter_weights=[lora_scale])
58
 
59
+ if randomize_seed:
60
+ seed = random.randint(0, MAX_SEED)
61
+
62
  width, height = calculate_dimensions(aspect_ratio, base_resolution)
63
  print(f"Génération d'une image de {width}x{height} pixels.")
64
 
65
  generator = torch.Generator(device=device_gpu).manual_seed(seed)
66
+
67
+ # Appel du générateur d'image
68
  image_generator = pipe.flux_pipe_call_that_returns_an_iterable_of_images(
69
+ prompt=prompt_mash,
70
+ num_inference_steps=steps,
71
+ guidance_scale=cfg_scale,
72
+ width=width,
73
+ height=height,
74
+ generator=generator,
75
+ output_type="pil",
76
+ good_vae=good_vae,
77
  )
78
 
79
  final_image = None
 
82
  progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {i + 1}; --total: {steps};"></div></div>'
83
  yield image, seed, gr.update(value=progress_bar, visible=True)
84
 
85
+ # Une fois la génération terminée, on peut libérer de la VRAM en déplaçant le modèle sur le CPU
86
+ # C'est optionnel mais une bonne pratique dans les environnements managés
87
+ print("Génération terminée. Déplacement du pipeline vers le CPU pour libérer la VRAM.")
88
+ pipe.to(device_cpu)
89
+ good_vae.to(device_cpu)
90
+ torch.cuda.empty_cache()
91
+ pipe = None # Force le rechargement au prochain appel
92
+
93
  yield final_image, seed, gr.update(visible=False)
94
 
95
  def calculate_dimensions(aspect_ratio, resolution):
 
142
  random_prompt_btn = gr.Button("Random Prompt", elem_id="random_prompt_btn")
143
  with gr.Column(scale=5):
144
  prompt = gr.Textbox(label="Prompt", lines=2, placeholder="e.g., a portrait of a warrior queen")
 
 
 
 
145
  with gr.Group():
146
  aspect_ratio = gr.Radio(
147
  label="Aspect Ratio",
 
165
  result = gr.Image(label="Generated Image", interactive=False, show_share_button=True)
166
  with gr.Accordion("History", open=False):
167
  history_gallery = gr.Gallery(label="History", columns=4, object_fit="contain", interactive=False)
168
+
169
  def get_random_prompt():
170
  return random.choice(prompt_values)
171
+
172
  random_prompt_btn.click(fn=get_random_prompt, inputs=[], outputs=[prompt])
173
+
174
+ # Les entrées de la fonction de génération sont simplifiées
175
  generation_event = gr.on(
176
  triggers=[generate_button.click, prompt.submit],
177
  fn=run_generation,
178
+ inputs=[prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution],
179
  outputs=[result, seed, progress_bar]
180
  )
181
+
182
  generation_event.then(fn=update_history, inputs=[result, history_gallery], outputs=history_gallery)
183
 
184
  app.queue(max_size=20)