Econogoat commited on
Commit
792f241
·
verified ·
1 Parent(s): 23fa75e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -77
app.py CHANGED
@@ -10,61 +10,72 @@ import pandas as pd
10
  import random
11
  import time
12
 
13
- # --- Configuration Principale ---
14
- KRYPTO_LORA = {"repo": "Econogoat/Krypt0_LORA", "trigger": "Krypt0", "adapter_name": "krypt0"}
15
- BASE_IMAGE_MODEL = "black-forest-labs/FLUX.1-dev"
 
 
 
 
16
 
17
- # Charger les prompts
18
  df = pd.read_csv('prompts.csv', header=None)
19
  prompt_values = df.values.flatten()
20
 
21
- # --- Variables globales pour conserver les modèles en mémoire ---
22
- # On utilise une stratégie de chargement à la demande ("lazy loading") pour être compatible avec ZeroGPU
23
- pipe = None
24
- good_vae = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  MAX_SEED = 2**32 - 1
27
 
28
- @spaces.GPU(duration=180) # Durée pour accommoder le premier chargement du modèle d'image
29
- def run_generation(prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution, progress=gr.Progress(track_tqdm=True)):
30
- global pipe, good_vae
31
- if not prompt:
32
- raise gr.Error("Prompt cannot be empty.")
33
 
34
- device_gpu = "cuda"
35
- device_cpu = "cpu"
36
- dtype = torch.bfloat16
37
-
38
- # --- Chargement à la demande ("Lazy Loading") du modèle d'image sur le GPU ---
39
- # Cette section ne s'exécute qu'une seule fois, lors du tout premier appel.
40
- if pipe is None:
41
- print("Premier appel : Chargement du pipeline d'image sur GPU...")
42
- HF_TOKEN = os.getenv("HF_TOKEN")
43
- taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device_gpu)
44
- good_vae = AutoencoderKL.from_pretrained(BASE_IMAGE_MODEL, subfolder="vae", torch_dtype=dtype, token=HF_TOKEN).to(device_gpu)
45
- pipe = DiffusionPipeline.from_pretrained(BASE_IMAGE_MODEL, torch_dtype=dtype, vae=taef1, token=HF_TOKEN).to(device_gpu)
46
- print("Chargement du LoRA sur le pipeline GPU...")
47
- pipe.load_lora_weights(KRYPTO_LORA['repo'], low_cpu_mem_usage=False, adapter_name=KRYPTO_LORA['adapter_name'], token=HF_TOKEN)
48
- print("Pipeline d'image prêt.")
49
- # Ajout de la méthode de prévisualisation
50
- pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
51
-
52
- # Création du prompt final
53
- prompt_mash = f"{KRYPTO_LORA['trigger']}, {prompt}"
54
- print("Prompt final envoyé au modèle:", prompt_mash)
55
-
56
- # Activation du LoRA
57
- pipe.set_adapters([KRYPTO_LORA['adapter_name']], adapter_weights=[lora_scale])
58
-
59
- if randomize_seed:
60
- seed = random.randint(0, MAX_SEED)
61
 
62
- width, height = calculate_dimensions(aspect_ratio, base_resolution)
63
- print(f"Génération d'une image de {width}x{height} pixels.")
64
-
65
- generator = torch.Generator(device=device_gpu).manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- # Appel du générateur d'image
68
  image_generator = pipe.flux_pipe_call_that_returns_an_iterable_of_images(
69
  prompt=prompt_mash,
70
  num_inference_steps=steps,
@@ -72,43 +83,48 @@ def run_generation(prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, a
72
  width=width,
73
  height=height,
74
  generator=generator,
 
75
  output_type="pil",
76
  good_vae=good_vae,
77
  )
78
-
79
  final_image = None
80
  for i, image in enumerate(image_generator):
81
  final_image = image
82
  progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {i + 1}; --total: {steps};"></div></div>'
83
- yield image, seed, gr.update(value=progress_bar, visible=True)
84
-
85
- # Une fois la génération terminée, on peut libérer de la VRAM en déplaçant le modèle sur le CPU
86
- # C'est optionnel mais une bonne pratique dans les environnements managés
87
- print("Génération terminée. Déplacement du pipeline vers le CPU pour libérer la VRAM.")
88
- pipe.to(device_cpu)
89
- good_vae.to(device_cpu)
90
- torch.cuda.empty_cache()
91
- pipe = None # Force le rechargement au prochain appel
92
-
93
- yield final_image, seed, gr.update(visible=False)
94
-
95
- def calculate_dimensions(aspect_ratio, resolution):
96
- resolution = int(resolution)
97
- if aspect_ratio == "Square (1:1)": width, height = resolution, resolution
98
- elif aspect_ratio == "Portrait (9:16)": width, height = int(resolution * 9 / 16), resolution
99
- elif aspect_ratio == "Landscape (16:9)": width, height = resolution, int(resolution * 9 / 16)
100
- elif aspect_ratio == "Ultrawide (21:9)": width, height = resolution, int(resolution * 9 / 21)
101
- else: width, height = resolution, resolution
102
- width = (width // 64) * 64
103
- height = (height // 64) * 64
104
- return width, height
105
 
106
  def update_history(new_image, history):
107
- if new_image is None: return history
108
- if history is None: history = []
 
 
109
  history.insert(0, new_image)
110
  return history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
 
 
 
 
 
 
 
 
 
112
  css = '''
113
  #title_container { text-align: center; margin-bottom: 1em; }
114
  #title_line { display: flex; justify-content: center; align-items: center; }
@@ -121,6 +137,7 @@ css = '''
121
  '''
122
 
123
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as app:
 
124
  gr.HTML(
125
  """
126
  <div id="title_container">
@@ -134,14 +151,19 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as app:
134
  </div>
135
  """
136
  )
 
137
  with gr.Row():
 
138
  with gr.Column(scale=3):
 
139
  with gr.Group():
140
  with gr.Row():
141
  with gr.Column(scale=1, min_width=150):
142
  random_prompt_btn = gr.Button("Random Prompt", elem_id="random_prompt_btn")
143
  with gr.Column(scale=5):
144
  prompt = gr.Textbox(label="Prompt", lines=2, placeholder="e.g., a portrait of a warrior queen")
 
 
145
  with gr.Group():
146
  aspect_ratio = gr.Radio(
147
  label="Aspect Ratio",
@@ -149,37 +171,56 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as app:
149
  value="Square (1:1)"
150
  )
151
  lora_scale = gr.Slider(
152
- label="Krypt0 Style Strength", minimum=0, maximum=2, step=0.05, value=0.9,
 
 
 
 
153
  info="Controls how strongly the artistic style is applied. Higher values mean a more stylized image."
154
  )
 
 
 
155
  with gr.Accordion("Advanced Settings", open=False):
156
  base_resolution = gr.Slider(label="Resolution (longest side)", minimum=768, maximum=1408, step=64, value=1024)
157
  steps = gr.Slider(label="Generation Steps", minimum=4, maximum=50, step=1, value=20)
158
  cfg_scale = gr.Slider(label="Guidance (CFG Scale)", minimum=1, maximum=10, step=0.5, value=3.5)
 
159
  with gr.Row():
160
  randomize_seed = gr.Checkbox(True, label="Random Seed")
161
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
 
162
  generate_button = gr.Button("Generate", variant="primary")
 
 
163
  with gr.Column(scale=2):
164
  progress_bar = gr.Markdown(elem_id="progress", visible=False)
165
  result = gr.Image(label="Generated Image", interactive=False, show_share_button=True)
166
  with gr.Accordion("History", open=False):
167
  history_gallery = gr.Gallery(label="History", columns=4, object_fit="contain", interactive=False)
168
 
 
169
  def get_random_prompt():
170
  return random.choice(prompt_values)
171
-
172
- random_prompt_btn.click(fn=get_random_prompt, inputs=[], outputs=[prompt])
173
-
174
- # Les entrées de la fonction de génération sont simplifiées
 
 
 
175
  generation_event = gr.on(
176
  triggers=[generate_button.click, prompt.submit],
177
  fn=run_generation,
178
  inputs=[prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution],
179
  outputs=[result, seed, progress_bar]
180
  )
181
-
182
- generation_event.then(fn=update_history, inputs=[result, history_gallery], outputs=history_gallery)
 
 
 
 
183
 
184
  app.queue(max_size=20)
185
  app.launch()
 
10
  import random
11
  import time
12
 
13
+ # --- Main Configuration ---
14
+ KRYPTO_LORA = {
15
+ # CORRECTION : Le nom du dépôt était mal orthographié (O majuscule au lieu d'un zéro).
16
+ "repo": "Econogoat/Krypt0_LORA",
17
+ "trigger": "Krypt0",
18
+ "adapter_name": "krypt0"
19
+ }
20
 
21
+ # Load prompts for the randomize button
22
  df = pd.read_csv('prompts.csv', header=None)
23
  prompt_values = df.values.flatten()
24
 
25
+ # Get access token from Space secrets
26
+ HF_TOKEN = os.getenv("HF_TOKEN")
27
+ if not HF_TOKEN:
28
+ print("WARNING: HF_TOKEN secret is not set. Gated model downloads may fail.")
29
+
30
+ # --- Model Initialization ---
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ print(f"Using device: {device}")
33
+ dtype = torch.bfloat16
34
+ base_model = "black-forest-labs/FLUX.1-dev"
35
+
36
+ # Load model components
37
+ print("Loading model components...")
38
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
39
+ good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype, token=HF_TOKEN).to(device)
40
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1, token=HF_TOKEN).to(device)
41
+ print("Models loaded.")
42
+
43
+ # Load the LoRA adapter once on startup
44
+ print(f"Loading on-board LoRA: {KRYPTO_LORA['repo']}")
45
+ pipe.load_lora_weights(
46
+ KRYPTO_LORA['repo'],
47
+ low_cpu_mem_usage=True,
48
+ adapter_name=KRYPTO_LORA['adapter_name'],
49
+ token=HF_TOKEN # Ajout du token ici aussi pour les LoRA privés/protégés
50
+ )
51
+ print("LoRA loaded successfully.")
52
 
53
  MAX_SEED = 2**32 - 1
54
 
55
+ # Monkey-patch the pipeline for live preview
56
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
 
 
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ def calculate_dimensions(aspect_ratio, resolution):
60
+ resolution = int(resolution)
61
+ if aspect_ratio == "Square (1:1)":
62
+ width, height = resolution, resolution
63
+ elif aspect_ratio == "Portrait (9:16)":
64
+ width, height = int(resolution * 9 / 16), resolution
65
+ elif aspect_ratio == "Landscape (16:9)":
66
+ width, height = resolution, int(resolution * 9 / 16)
67
+ elif aspect_ratio == "Ultrawide (21:9)":
68
+ width, height = resolution, int(resolution * 9 / 21)
69
+ else:
70
+ width, height = resolution, resolution
71
+ width = (width // 64) * 64
72
+ height = (height // 64) * 64
73
+ return width, height
74
+
75
+ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress):
76
+ pipe.to(device)
77
+ generator = torch.Generator(device=device).manual_seed(seed)
78
 
 
79
  image_generator = pipe.flux_pipe_call_that_returns_an_iterable_of_images(
80
  prompt=prompt_mash,
81
  num_inference_steps=steps,
 
83
  width=width,
84
  height=height,
85
  generator=generator,
86
+ joint_attention_kwargs={"scale": 1.0},
87
  output_type="pil",
88
  good_vae=good_vae,
89
  )
 
90
  final_image = None
91
  for i, image in enumerate(image_generator):
92
  final_image = image
93
  progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {i + 1}; --total: {steps};"></div></div>'
94
+ yield image, gr.update(value=progress_bar, visible=True)
95
+ yield final_image, gr.update(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  def update_history(new_image, history):
98
+ if new_image is None:
99
+ return history
100
+ if history is None:
101
+ history = []
102
  history.insert(0, new_image)
103
  return history
104
+
105
+ @spaces.GPU(duration=75)
106
+ def run_generation(prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution, progress=gr.Progress(track_tqdm=True)):
107
+ if not prompt:
108
+ raise gr.Error("Prompt cannot be empty.")
109
+
110
+ prompt_mash = f"{KRYPTO_LORA['trigger']}, {prompt}"
111
+ print("Final prompt:", prompt_mash)
112
+
113
+ pipe.set_adapters([KRYPTO_LORA['adapter_name']], adapter_weights=[lora_scale])
114
+ print(f"Adapter '{KRYPTO_LORA['adapter_name']}' activated with weight {lora_scale}.")
115
+
116
+ if randomize_seed:
117
+ seed = random.randint(0, MAX_SEED)
118
 
119
+ width, height = calculate_dimensions(aspect_ratio, base_resolution)
120
+ print(f"Generating a {width}x{height} image.")
121
+
122
+ for image, progress_update in generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress):
123
+ yield image, seed, progress_update
124
+
125
+ run_generation.zerogpu = True
126
+
127
+ # --- User Interface (Gradio) ---
128
  css = '''
129
  #title_container { text-align: center; margin-bottom: 1em; }
130
  #title_line { display: flex; justify-content: center; align-items: center; }
 
137
  '''
138
 
139
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as app:
140
+ # --- Header ---
141
  gr.HTML(
142
  """
143
  <div id="title_container">
 
151
  </div>
152
  """
153
  )
154
+
155
  with gr.Row():
156
+ # --- LEFT COLUMN: CONTROLS ---
157
  with gr.Column(scale=3):
158
+ # Prompt Controls
159
  with gr.Group():
160
  with gr.Row():
161
  with gr.Column(scale=1, min_width=150):
162
  random_prompt_btn = gr.Button("Random Prompt", elem_id="random_prompt_btn")
163
  with gr.Column(scale=5):
164
  prompt = gr.Textbox(label="Prompt", lines=2, placeholder="e.g., a portrait of a warrior queen")
165
+
166
+ # Image Shape and Style Controls
167
  with gr.Group():
168
  aspect_ratio = gr.Radio(
169
  label="Aspect Ratio",
 
171
  value="Square (1:1)"
172
  )
173
  lora_scale = gr.Slider(
174
+ label="Krypt0 Style Strength",
175
+ minimum=0,
176
+ maximum=2,
177
+ step=0.05,
178
+ value=0.9,
179
  info="Controls how strongly the artistic style is applied. Higher values mean a more stylized image."
180
  )
181
+
182
+ # Advanced Settings
183
+ # CORRECTION : L'accordéon doit être fermé par défaut.
184
  with gr.Accordion("Advanced Settings", open=False):
185
  base_resolution = gr.Slider(label="Resolution (longest side)", minimum=768, maximum=1408, step=64, value=1024)
186
  steps = gr.Slider(label="Generation Steps", minimum=4, maximum=50, step=1, value=20)
187
  cfg_scale = gr.Slider(label="Guidance (CFG Scale)", minimum=1, maximum=10, step=0.5, value=3.5)
188
+
189
  with gr.Row():
190
  randomize_seed = gr.Checkbox(True, label="Random Seed")
191
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
192
+
193
  generate_button = gr.Button("Generate", variant="primary")
194
+
195
+ # --- RIGHT COLUMN: RESULTS ---
196
  with gr.Column(scale=2):
197
  progress_bar = gr.Markdown(elem_id="progress", visible=False)
198
  result = gr.Image(label="Generated Image", interactive=False, show_share_button=True)
199
  with gr.Accordion("History", open=False):
200
  history_gallery = gr.Gallery(label="History", columns=4, object_fit="contain", interactive=False)
201
 
202
+ # --- Event Logic ---
203
  def get_random_prompt():
204
  return random.choice(prompt_values)
205
+
206
+ random_prompt_btn.click(
207
+ fn=get_random_prompt,
208
+ inputs=[],
209
+ outputs=[prompt]
210
+ )
211
+
212
  generation_event = gr.on(
213
  triggers=[generate_button.click, prompt.submit],
214
  fn=run_generation,
215
  inputs=[prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution],
216
  outputs=[result, seed, progress_bar]
217
  )
218
+
219
+ generation_event.then(
220
+ fn=update_history,
221
+ inputs=[result, history_gallery],
222
+ outputs=history_gallery,
223
+ )
224
 
225
  app.queue(max_size=20)
226
  app.launch()