Econogoat commited on
Commit
d3b1795
·
verified ·
1 Parent(s): afbe924

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -107
app.py CHANGED
@@ -4,167 +4,208 @@ import torch
4
  from PIL import Image
5
  import spaces
6
  from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
7
- from transformers import AutoModelForImageTextToText, AutoProcessor
8
  from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
9
  from diffusers.utils import load_image
10
  import pandas as pd
11
  import random
12
  import time
13
 
14
- # --- Configuration (statique) ---
15
- KRYPTO_LORA = {"repo": "Econogoat/Krypt0_LORA", "trigger": "Krypt0", "adapter_name": "krypt0"}
16
- LLM_MODEL_ID = "google/gemma-3n-E4B-it"
17
- SYSTEM_PROMPT = """You are a creative assistant that enhances user prompts for an AI image generation model.
18
- Your task is to take a user's simple idea and expand it into a rich, detailed, and visually descriptive prompt.
19
- Focus on cinematic lighting, intricate details, atmosphere, and a strong artistic style.
20
- Do NOT add the trigger word 'Krypt0', it will be added automatically later.
21
- Reply ONLY with the enhanced prompt, without any introduction or explanation."""
22
 
 
23
  df = pd.read_csv('prompts.csv', header=None)
24
  prompt_values = df.values.flatten()
 
 
25
  HF_TOKEN = os.getenv("HF_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  MAX_SEED = 2**32 - 1
27
 
 
 
 
 
28
  def calculate_dimensions(aspect_ratio, resolution):
 
29
  resolution = int(resolution)
30
- if aspect_ratio == "Square (1:1)": width, height = resolution, resolution
31
- elif aspect_ratio == "Portrait (9:16)": width, height = int(resolution * 9 / 16), resolution
32
- elif aspect_ratio == "Landscape (16:9)": width, height = resolution, int(resolution * 9 / 16)
33
- elif aspect_ratio == "Ultrawide (21:9)": width, height = resolution, int(resolution * 9 / 21)
34
- else: width, height = resolution, resolution
 
 
 
 
 
 
35
  width = (width // 64) * 64
36
  height = (height // 64) * 64
37
  return width, height
38
 
39
- def update_history(new_image, history):
40
- if new_image is None: return history
41
- if history is None: history = []
42
- history.insert(0, new_image)
43
- return history
44
-
45
- @spaces.GPU(duration=180)
46
- def run_generation(prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution,
47
- # On reçoit l'état actuel des modèles en entrée
48
- state_pipe, state_llm_model, state_llm_processor, state_good_vae,
49
- progress=gr.Progress(track_tqdm=True)):
50
-
51
- # --- CHARGEMENT À LA VOLÉE AU PREMIER CLIC, EN UTILISANT gr.State ---
52
- # La condition est maintenant basée sur l'état passé en argument, pas sur une variable globale fragile.
53
- if state_pipe is None:
54
- gr.Info("First run: Loading all models... This will take a moment.")
55
- print("First run: Loading all models inside GPU context...")
56
-
57
- device = "cuda"
58
- dtype = torch.bfloat16
59
-
60
- print("Loading LLM...")
61
- state_llm_processor = AutoProcessor.from_pretrained(LLM_MODEL_ID, token=HF_TOKEN)
62
- state_llm_model = AutoModelForImageTextToText.from_pretrained(LLM_MODEL_ID, torch_dtype=dtype, token=HF_TOKEN)
63
-
64
- print("Loading diffusion models...")
65
- taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
66
- state_good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", token=HF_TOKEN, torch_dtype=dtype)
67
- state_pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", vae=taef1, token=HF_TOKEN, torch_dtype=dtype)
68
-
69
- print("Loading LoRA...")
70
- state_pipe.load_lora_weights(KRYPTO_LORA['repo'], low_cpu_mem_usage=False, adapter_name=KRYPTO_LORA['adapter_name'])
71
-
72
- state_pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(state_pipe)
73
- print("All models loaded and stored in session state.")
74
-
75
- # --- DÉBUT DU PROCESSUS NORMAL ---
76
- if not prompt: raise gr.Error("Prompt cannot be empty.")
77
-
78
- device = "cuda"
79
- dtype = torch.bfloat16
80
-
81
- # --- 1. Amélioration du prompt avec le LLM ---
82
- gr.Info("Enhancing prompt with LLM...")
83
- state_llm_model.to(device)
84
- messages = [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": prompt}]
85
- inputs = state_llm_processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_tensors="pt").to(device)
86
-
87
- with torch.inference_mode():
88
- outputs = state_llm_model.generate(**inputs, max_new_tokens=150)
89
-
90
- enhanced_prompt = state_llm_processor.batch_decode(outputs, skip_special_tokens=True)[0].split("assistant\n")[-1].strip()
91
- state_llm_model.to("cpu"); torch.cuda.empty_cache()
92
-
93
- # --- 2. Génération d'image ---
94
- gr.Info("Prompt enhanced. Starting image generation...")
95
- prompt_mash = f"{KRYPTO_LORA['trigger']}, {enhanced_prompt}"
96
- state_pipe.set_adapters([KRYPTO_LORA['adapter_name']], adapter_weights=[lora_scale])
97
- if randomize_seed: seed = random.randint(0, MAX_SEED)
98
- width, height = calculate_dimensions(aspect_ratio, base_resolution)
99
-
100
- state_pipe.to(device); state_good_vae.to(device)
101
  generator = torch.Generator(device=device).manual_seed(seed)
102
 
103
- image_generator = state_pipe.flux_pipe_call_that_returns_an_iterable_of_images(
104
- prompt=prompt_mash, num_inference_steps=steps, guidance_scale=cfg_scale,
105
- width=width, height=height, generator=generator, joint_attention_kwargs={"scale": 1.0},
106
- output_type="pil", good_vae=state_good_vae
 
 
 
 
 
 
107
  )
 
108
  final_image = None
109
  for i, image in enumerate(image_generator):
110
  final_image = image
111
  progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {i + 1}; --total: {steps};"></div></div>'
112
- # On retourne les états inchangés pendant la prévisualisation
113
- yield image, seed, gr.update(value=progress_bar, visible=True), state_pipe, state_llm_model, state_llm_processor, state_good_vae
 
 
 
 
 
 
 
 
 
114
 
115
- state_pipe.to("cpu"); state_good_vae.to("cpu")
 
 
 
 
 
 
 
 
 
 
116
 
117
- # On retourne l'image finale ET l'état mis à jour des modèles pour qu'ils soient conservés pour le prochain appel
118
- yield final_image, seed, gr.update(visible=False), state_pipe, state_llm_model, state_llm_processor, state_good_vae
 
 
 
 
 
 
 
119
 
120
  run_generation.zerogpu = True
121
 
122
- # --- UI ---
123
- css = ''' #title{text-align: center} #title h1{font-size: 3em; display:inline-flex; align-items:center} #title img{width: 80px; margin-right: 0.25em} .progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px} .progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.1s ease-in-out} #random_prompt_btn{max-width: 2.5em; min-width: 2.5em !important; height: 100% !important;} '''
124
- with gr.Blocks(css=css, theme=gr.themes.Soft()) as app:
125
- # Déclaration des états qui contiendront nos modèles
126
- state_pipe = gr.State(None)
127
- state_llm_model = gr.State(None)
128
- state_llm_processor = gr.State(None)
129
- state_good_vae = gr.State(None)
 
130
 
131
- gr.HTML("""<div id='title'><h1><img src="https://huggingface.co/Econogoat/KRYPTO_LORA/resolve/main/krypt0.png" alt="LoRA"> Krypt0 Image Generator</h1><br><span>Generate images with the Krypt0 artistic style</span></div>""")
 
 
 
 
 
 
132
  with gr.Row():
 
133
  with gr.Column(scale=3):
 
134
  with gr.Group():
135
  with gr.Row():
136
  random_prompt_btn = gr.Button("🎲", elem_id="random_prompt_btn")
137
  prompt = gr.Textbox(label="Prompt", lines=2, placeholder="e.g., a portrait of a warrior queen", scale=8)
 
138
  lora_scale = gr.Slider(label="Krypt0 Style Strength", minimum=0, maximum=2, step=0.05, value=0.9)
139
- aspect_ratio = gr.Radio(label="Aspect Ratio", choices=["Square (1:1)", "Portrait (9:16)", "Landscape (16:9)", "Ultrawide (21:9)"], value="Square (1:1)")
 
 
 
 
 
 
 
 
140
  with gr.Accordion("Advanced Settings", open=True):
141
  base_resolution = gr.Slider(label="Resolution (longest side)", minimum=768, maximum=1408, step=64, value=1024)
142
  steps = gr.Slider(label="Generation Steps", minimum=4, maximum=50, step=1, value=20)
143
  cfg_scale = gr.Slider(label="Guidance (CFG Scale)", minimum=1, maximum=10, step=0.5, value=3.5)
 
144
  with gr.Row():
145
  randomize_seed = gr.Checkbox(True, label="Random Seed")
146
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
 
147
  generate_button = gr.Button("Generate", variant="primary")
 
 
148
  with gr.Column(scale=2):
149
  progress_bar = gr.Markdown(elem_id="progress", visible=False)
150
  result = gr.Image(label="Generated Image", interactive=False, show_share_button=True)
151
  with gr.Accordion("History", open=False):
152
  history_gallery = gr.Gallery(label="History", columns=4, object_fit="contain", interactive=False)
153
 
154
- def get_random_prompt(): return random.choice(prompt_values)
155
- random_prompt_btn.click(fn=get_random_prompt, inputs=[], outputs=[prompt])
156
-
157
- # On ajoute les états aux inputs et outputs de l'événement de génération
158
- # C'est la boucle qui assure la persistence des modèles
159
- model_states = [state_pipe, state_llm_model, state_llm_processor, state_good_vae]
160
-
 
 
 
161
  generation_event = gr.on(
162
- triggers=[generate_button.click, prompt.submit],
163
- fn=run_generation,
164
- inputs=[prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution] + model_states,
165
- outputs=[result, seed, progress_bar] + model_states
 
 
 
 
 
 
166
  )
167
- generation_event.then(fn=update_history, inputs=[result, history_gallery], outputs=history_gallery)
168
 
169
  app.queue(max_size=20)
170
  app.launch()
 
4
  from PIL import Image
5
  import spaces
6
  from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
 
7
  from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
8
  from diffusers.utils import load_image
9
  import pandas as pd
10
  import random
11
  import time
12
 
13
+ # --- Main Configuration ---
14
+ KRYPTO_LORA = {
15
+ "repo": "Econogoat/Krypt0_LORA",
16
+ "trigger": "Krypt0",
17
+ "adapter_name": "krypt0"
18
+ }
 
 
19
 
20
+ # Load prompts for the randomize button
21
  df = pd.read_csv('prompts.csv', header=None)
22
  prompt_values = df.values.flatten()
23
+
24
+ # Get access token from Space secrets
25
  HF_TOKEN = os.getenv("HF_TOKEN")
26
+ if not HF_TOKEN:
27
+ print("WARNING: HF_TOKEN secret is not set. Gated model downloads may fail.")
28
+
29
+ # --- Model Initialization ---
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ print(f"Using device: {device}")
32
+ dtype = torch.bfloat16
33
+ base_model = "black-forest-labs/FLUX.1-dev"
34
+
35
+ # Load model components
36
+ print("Loading model components...")
37
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
38
+ good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype, token=HF_TOKEN).to(device)
39
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1, token=HF_TOKEN).to(device)
40
+ print("Models loaded.")
41
+
42
+ # Load the LoRA adapter once on startup
43
+ print(f"Loading on-board LoRA: {KRYPTO_LORA['repo']}")
44
+ pipe.load_lora_weights(
45
+ KRYPTO_LORA['repo'],
46
+ low_cpu_mem_usage=True,
47
+ adapter_name=KRYPTO_LORA['adapter_name']
48
+ )
49
+ print("LoRA loaded successfully.")
50
+
51
  MAX_SEED = 2**32 - 1
52
 
53
+ # Monkey-patch the pipeline for live preview
54
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
55
+
56
+
57
  def calculate_dimensions(aspect_ratio, resolution):
58
+ """Calculates width and height based on aspect ratio and base resolution."""
59
  resolution = int(resolution)
60
+ if aspect_ratio == "Square (1:1)":
61
+ width, height = resolution, resolution
62
+ elif aspect_ratio == "Portrait (9:16)":
63
+ width, height = int(resolution * 9 / 16), resolution
64
+ elif aspect_ratio == "Landscape (16:9)":
65
+ width, height = resolution, int(resolution * 9 / 16)
66
+ elif aspect_ratio == "Ultrawide (21:9)":
67
+ width, height = resolution, int(resolution * 9 / 21)
68
+ else: # Fallback
69
+ width, height = resolution, resolution
70
+ # Ensure dimensions are multiples of 64 for optimal performance
71
  width = (width // 64) * 64
72
  height = (height // 64) * 64
73
  return width, height
74
 
75
+ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress):
76
+ """Generator function for text-to-image with live preview."""
77
+ # The parent @spaces.GPU function has already allocated a GPU
78
+ pipe.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  generator = torch.Generator(device=device).manual_seed(seed)
80
 
81
+ image_generator = pipe.flux_pipe_call_that_returns_an_iterable_of_images(
82
+ prompt=prompt_mash,
83
+ num_inference_steps=steps,
84
+ guidance_scale=cfg_scale,
85
+ width=width,
86
+ height=height,
87
+ generator=generator,
88
+ joint_attention_kwargs={"scale": 1.0},
89
+ output_type="pil",
90
+ good_vae=good_vae,
91
  )
92
+ # Yield previews and the final image
93
  final_image = None
94
  for i, image in enumerate(image_generator):
95
  final_image = image
96
  progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {i + 1}; --total: {steps};"></div></div>'
97
+ yield image, gr.update(value=progress_bar, visible=True)
98
+ yield final_image, gr.update(visible=False)
99
+
100
+ def update_history(new_image, history):
101
+ """Adds the new image to the history gallery."""
102
+ if new_image is None: # Don't add empty images on error
103
+ return history
104
+ if history is None:
105
+ history = []
106
+ history.insert(0, new_image)
107
+ return history
108
 
109
+ @spaces.GPU(duration=75)
110
+ def run_generation(prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution, progress=gr.Progress(track_tqdm=True)):
111
+ if not prompt:
112
+ raise gr.Error("Prompt cannot be empty.")
113
+
114
+ prompt_mash = f"{KRYPTO_LORA['trigger']}, {prompt}"
115
+ print("Final prompt:", prompt_mash)
116
+
117
+ # Activate the LoRA adapter with the slider's weight
118
+ pipe.set_adapters([KRYPTO_LORA['adapter_name']], adapter_weights=[lora_scale])
119
+ print(f"Adapter '{KRYPTO_LORA['adapter_name']}' activated with weight {lora_scale}.")
120
 
121
+ if randomize_seed:
122
+ seed = random.randint(0, MAX_SEED)
123
+
124
+ width, height = calculate_dimensions(aspect_ratio, base_resolution)
125
+ print(f"Generating a {width}x{height} image.")
126
+
127
+ # The function now only handles text-to-image
128
+ for image, progress_update in generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress):
129
+ yield image, seed, progress_update
130
 
131
  run_generation.zerogpu = True
132
 
133
+ # --- User Interface (Gradio) ---
134
+ css = '''
135
+ #title{text-align: center}
136
+ #title h1{font-size: 3em; display:inline-flex; align-items:center}
137
+ #title img{width: 80px; margin-right: 0.25em}
138
+ .progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
139
+ .progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.1s ease-in-out}
140
+ #random_prompt_btn{max-width: 2.5em; min-width: 2.5em !important; height: 100% !important;}
141
+ '''
142
 
143
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as app:
144
+ # --- Header ---
145
+ with gr.Row():
146
+ gr.HTML(
147
+ """<div id='title'><h1><img src="https://huggingface.co/Econogoat/KRYPTO_LORA/resolve/main/krypt0.png" alt="LoRA"> Krypt0 Image Generator</h1><br><span>Generate images with the Krypt0 artistic style</span></div>"""
148
+ )
149
+
150
  with gr.Row():
151
+ # --- LEFT COLUMN: CONTROLS ---
152
  with gr.Column(scale=3):
153
+ # Prompt and Style Controls
154
  with gr.Group():
155
  with gr.Row():
156
  random_prompt_btn = gr.Button("🎲", elem_id="random_prompt_btn")
157
  prompt = gr.Textbox(label="Prompt", lines=2, placeholder="e.g., a portrait of a warrior queen", scale=8)
158
+
159
  lora_scale = gr.Slider(label="Krypt0 Style Strength", minimum=0, maximum=2, step=0.05, value=0.9)
160
+
161
+ # Image Shape Controls
162
+ aspect_ratio = gr.Radio(
163
+ label="Aspect Ratio",
164
+ choices=["Square (1:1)", "Portrait (9:16)", "Landscape (16:9)", "Ultrawide (21:9)"],
165
+ value="Square (1:1)"
166
+ )
167
+
168
+ # Advanced Settings
169
  with gr.Accordion("Advanced Settings", open=True):
170
  base_resolution = gr.Slider(label="Resolution (longest side)", minimum=768, maximum=1408, step=64, value=1024)
171
  steps = gr.Slider(label="Generation Steps", minimum=4, maximum=50, step=1, value=20)
172
  cfg_scale = gr.Slider(label="Guidance (CFG Scale)", minimum=1, maximum=10, step=0.5, value=3.5)
173
+
174
  with gr.Row():
175
  randomize_seed = gr.Checkbox(True, label="Random Seed")
176
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
177
+
178
  generate_button = gr.Button("Generate", variant="primary")
179
+
180
+ # --- RIGHT COLUMN: RESULTS ---
181
  with gr.Column(scale=2):
182
  progress_bar = gr.Markdown(elem_id="progress", visible=False)
183
  result = gr.Image(label="Generated Image", interactive=False, show_share_button=True)
184
  with gr.Accordion("History", open=False):
185
  history_gallery = gr.Gallery(label="History", columns=4, object_fit="contain", interactive=False)
186
 
187
+ # --- Event Logic ---
188
+ def get_random_prompt():
189
+ return random.choice(prompt_values)
190
+
191
+ random_prompt_btn.click(
192
+ fn=get_random_prompt,
193
+ inputs=[],
194
+ outputs=[prompt]
195
+ )
196
+
197
  generation_event = gr.on(
198
+ triggers=[generate_button.click, prompt.submit],
199
+ fn=run_generation,
200
+ inputs=[prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution],
201
+ outputs=[result, seed, progress_bar]
202
+ )
203
+
204
+ generation_event.then(
205
+ fn=update_history,
206
+ inputs=[result, history_gallery],
207
+ outputs=history_gallery,
208
  )
 
209
 
210
  app.queue(max_size=20)
211
  app.launch()