Spaces:
Running
on
Zero
Running
on
Zero
Fix user scale override
Browse files
app.py
CHANGED
|
@@ -55,7 +55,7 @@ def load_lora_weights(repo_id, weights_filename):
|
|
| 55 |
def update_selection(selected_state: gr.SelectData, flux_loras):
|
| 56 |
"""Update UI when a LoRA is selected"""
|
| 57 |
if selected_state.index >= len(flux_loras):
|
| 58 |
-
return "### No LoRA selected", gr.update(), None
|
| 59 |
|
| 60 |
lora_repo = flux_loras[selected_state.index]["repo"]
|
| 61 |
trigger_word = flux_loras[selected_state.index]["trigger_word"]
|
|
@@ -67,7 +67,10 @@ def update_selection(selected_state: gr.SelectData, flux_loras):
|
|
| 67 |
else:
|
| 68 |
new_placeholder = f"opt - describe the person/subject, e.g. 'a man with glasses and a beard'"
|
| 69 |
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
def get_huggingface_lora(link):
|
| 73 |
"""Download LoRA from HuggingFace link"""
|
|
@@ -133,12 +136,12 @@ def classify_gallery(flux_loras):
|
|
| 133 |
sorted_gallery = sorted(flux_loras, key=lambda x: x.get("likes", 0), reverse=True)
|
| 134 |
return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
|
| 135 |
|
| 136 |
-
def infer_with_lora_wrapper(input_image, prompt, selected_index, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.75,portrait_mode=False, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
|
| 137 |
"""Wrapper function to handle state serialization"""
|
| 138 |
-
return infer_with_lora(input_image, prompt, selected_index, custom_lora, seed, randomize_seed, guidance_scale, lora_scale, portrait_mode, flux_loras, progress)
|
| 139 |
|
| 140 |
@spaces.GPU
|
| 141 |
-
def infer_with_lora(input_image, prompt, selected_index, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.0, portrait_mode=False, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
|
| 142 |
"""Generate image with selected LoRA"""
|
| 143 |
global current_lora, pipe
|
| 144 |
|
|
@@ -155,14 +158,9 @@ def infer_with_lora(input_image, prompt, selected_index, custom_lora, seed=42, r
|
|
| 155 |
# Load LoRA if needed
|
| 156 |
if lora_to_use and lora_to_use != current_lora:
|
| 157 |
try:
|
| 158 |
-
# Unload current LoRA
|
| 159 |
if current_lora:
|
| 160 |
pipe.unload_lora_weights()
|
| 161 |
|
| 162 |
-
# Load new LoRA
|
| 163 |
-
if lora_to_use["lora_scale_config"]:
|
| 164 |
-
lora_scale = lora_to_use["lora_scale_config"]
|
| 165 |
-
print("lora scale loaded from config", lora_scale)
|
| 166 |
lora_path = load_lora_weights(lora_to_use["repo"], lora_to_use["weights"])
|
| 167 |
if lora_path:
|
| 168 |
pipe.load_lora_weights(lora_path, adapter_name="selected_lora")
|
|
@@ -173,8 +171,9 @@ def infer_with_lora(input_image, prompt, selected_index, custom_lora, seed=42, r
|
|
| 173 |
except Exception as e:
|
| 174 |
print(f"Error loading LoRA: {e}")
|
| 175 |
# Continue without LoRA
|
| 176 |
-
|
| 177 |
-
|
|
|
|
| 178 |
|
| 179 |
input_image = input_image.convert("RGB")
|
| 180 |
# Add trigger word to prompt
|
|
@@ -204,7 +203,7 @@ def infer_with_lora(input_image, prompt, selected_index, custom_lora, seed=42, r
|
|
| 204 |
height=input_image.size[1],
|
| 205 |
prompt=prompt,
|
| 206 |
guidance_scale=guidance_scale,
|
| 207 |
-
generator=torch.Generator().manual_seed(seed)
|
| 208 |
).images[0]
|
| 209 |
|
| 210 |
return image, seed, gr.update(visible=True), lora_scale
|
|
@@ -264,6 +263,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Lexend
|
|
| 264 |
|
| 265 |
selected_state = gr.State(value=None)
|
| 266 |
custom_loaded_lora = gr.State(value=None)
|
|
|
|
| 267 |
|
| 268 |
with gr.Row(elem_id="main_app"):
|
| 269 |
with gr.Column(scale=4, elem_id="box_column"):
|
|
@@ -348,15 +348,15 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Lexend
|
|
| 348 |
gallery.select(
|
| 349 |
fn=update_selection,
|
| 350 |
inputs=[gr_flux_loras],
|
| 351 |
-
outputs=[prompt_title, prompt, selected_state],
|
| 352 |
show_progress=False
|
| 353 |
)
|
| 354 |
|
| 355 |
gr.on(
|
| 356 |
triggers=[run_button.click, prompt.submit],
|
| 357 |
fn=infer_with_lora_wrapper,
|
| 358 |
-
inputs=[input_image, prompt, selected_state, custom_loaded_lora, seed, randomize_seed, guidance_scale, lora_scale, portrait_mode, gr_flux_loras],
|
| 359 |
-
outputs=[result, seed, reuse_button,
|
| 360 |
)
|
| 361 |
|
| 362 |
reuse_button.click(
|
|
|
|
| 55 |
def update_selection(selected_state: gr.SelectData, flux_loras):
|
| 56 |
"""Update UI when a LoRA is selected"""
|
| 57 |
if selected_state.index >= len(flux_loras):
|
| 58 |
+
return "### No LoRA selected", gr.update(), None, gr.update()
|
| 59 |
|
| 60 |
lora_repo = flux_loras[selected_state.index]["repo"]
|
| 61 |
trigger_word = flux_loras[selected_state.index]["trigger_word"]
|
|
|
|
| 67 |
else:
|
| 68 |
new_placeholder = f"opt - describe the person/subject, e.g. 'a man with glasses and a beard'"
|
| 69 |
|
| 70 |
+
optimal_scale = flux_loras[selected_state.index].get("lora_scale_config", 1.0)
|
| 71 |
+
|
| 72 |
+
return updated_text, gr.update(placeholder=new_placeholder), selected_state.index, gr.update(value=optimal_scale)
|
| 73 |
+
|
| 74 |
|
| 75 |
def get_huggingface_lora(link):
|
| 76 |
"""Download LoRA from HuggingFace link"""
|
|
|
|
| 136 |
sorted_gallery = sorted(flux_loras, key=lambda x: x.get("likes", 0), reverse=True)
|
| 137 |
return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
|
| 138 |
|
| 139 |
+
def infer_with_lora_wrapper(input_image, prompt, selected_index, lora_state, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.75,portrait_mode=False, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
|
| 140 |
"""Wrapper function to handle state serialization"""
|
| 141 |
+
return infer_with_lora(input_image, prompt, selected_index, lora_state, custom_lora, seed, randomize_seed, guidance_scale, lora_scale, portrait_mode, flux_loras, progress)
|
| 142 |
|
| 143 |
@spaces.GPU
|
| 144 |
+
def infer_with_lora(input_image, prompt, selected_index, lora_state, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.0, portrait_mode=False, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
|
| 145 |
"""Generate image with selected LoRA"""
|
| 146 |
global current_lora, pipe
|
| 147 |
|
|
|
|
| 158 |
# Load LoRA if needed
|
| 159 |
if lora_to_use and lora_to_use != current_lora:
|
| 160 |
try:
|
|
|
|
| 161 |
if current_lora:
|
| 162 |
pipe.unload_lora_weights()
|
| 163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
lora_path = load_lora_weights(lora_to_use["repo"], lora_to_use["weights"])
|
| 165 |
if lora_path:
|
| 166 |
pipe.load_lora_weights(lora_path, adapter_name="selected_lora")
|
|
|
|
| 171 |
except Exception as e:
|
| 172 |
print(f"Error loading LoRA: {e}")
|
| 173 |
# Continue without LoRA
|
| 174 |
+
elif lora_scale != lora_state:
|
| 175 |
+
pipe.set_adapters(["selected_lora"], adapter_weights=[lora_scale])
|
| 176 |
+
print(f"using already loaded lora: {lora_to_use}, udpated {lora_scale} based on user preference")
|
| 177 |
|
| 178 |
input_image = input_image.convert("RGB")
|
| 179 |
# Add trigger word to prompt
|
|
|
|
| 203 |
height=input_image.size[1],
|
| 204 |
prompt=prompt,
|
| 205 |
guidance_scale=guidance_scale,
|
| 206 |
+
generator=torch.Generator().manual_seed(seed)
|
| 207 |
).images[0]
|
| 208 |
|
| 209 |
return image, seed, gr.update(visible=True), lora_scale
|
|
|
|
| 263 |
|
| 264 |
selected_state = gr.State(value=None)
|
| 265 |
custom_loaded_lora = gr.State(value=None)
|
| 266 |
+
lora_state = gr.State(value=1.0)
|
| 267 |
|
| 268 |
with gr.Row(elem_id="main_app"):
|
| 269 |
with gr.Column(scale=4, elem_id="box_column"):
|
|
|
|
| 348 |
gallery.select(
|
| 349 |
fn=update_selection,
|
| 350 |
inputs=[gr_flux_loras],
|
| 351 |
+
outputs=[prompt_title, prompt, selected_state, lora_scale],
|
| 352 |
show_progress=False
|
| 353 |
)
|
| 354 |
|
| 355 |
gr.on(
|
| 356 |
triggers=[run_button.click, prompt.submit],
|
| 357 |
fn=infer_with_lora_wrapper,
|
| 358 |
+
inputs=[input_image, prompt, selected_state, lora_state, custom_loaded_lora, seed, randomize_seed, guidance_scale, lora_scale, portrait_mode, gr_flux_loras],
|
| 359 |
+
outputs=[result, seed, reuse_button, lora_state]
|
| 360 |
)
|
| 361 |
|
| 362 |
reuse_button.click(
|