Commit
·
20c6eca
1
Parent(s):
a7c9e18
Refactor progress tracking in generate_image function
Browse files- tabs/images/events.py +4 -2
- tabs/images/handlers.py +6 -3
tabs/images/events.py
CHANGED
|
@@ -413,9 +413,11 @@ def generate_image(
|
|
| 413 |
resize_mode,
|
| 414 |
scheduler, image_height, image_width, image_num_images_per_prompt, # type: ignore
|
| 415 |
image_num_inference_steps, image_clip_skip, image_guidance_scale, image_seed, # type: ignore
|
| 416 |
-
refiner, vae
|
|
|
|
| 417 |
):
|
| 418 |
try:
|
|
|
|
| 419 |
base_args = {
|
| 420 |
"model": model,
|
| 421 |
"prompt": prompt,
|
|
@@ -507,7 +509,7 @@ def generate_image(
|
|
| 507 |
base_args = BaseReq(**base_args.__dict__)
|
| 508 |
|
| 509 |
return gr.update(
|
| 510 |
-
value=gen_img(base_args),
|
| 511 |
interactive=True
|
| 512 |
)
|
| 513 |
except Exception as e:
|
|
|
|
| 413 |
resize_mode,
|
| 414 |
scheduler, image_height, image_width, image_num_images_per_prompt, # type: ignore
|
| 415 |
image_num_inference_steps, image_clip_skip, image_guidance_scale, image_seed, # type: ignore
|
| 416 |
+
refiner, vae,
|
| 417 |
+
progress=gr.Progress(track_tqdm=True)
|
| 418 |
):
|
| 419 |
try:
|
| 420 |
+
progress(0, "Configuring arguments...")
|
| 421 |
base_args = {
|
| 422 |
"model": model,
|
| 423 |
"prompt": prompt,
|
|
|
|
| 509 |
base_args = BaseReq(**base_args.__dict__)
|
| 510 |
|
| 511 |
return gr.update(
|
| 512 |
+
value=gen_img(base_args, progress),
|
| 513 |
interactive=True
|
| 514 |
)
|
| 515 |
except Exception as e:
|
tabs/images/handlers.py
CHANGED
|
@@ -205,10 +205,12 @@ def cleanup(pipeline, loras = None, embeddings = None):
|
|
| 205 |
|
| 206 |
|
| 207 |
# Gen Function
|
| 208 |
-
def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
|
|
|
|
| 209 |
pipeline_args = get_pipe(request)
|
| 210 |
pipeline = pipeline_args["pipeline"]
|
| 211 |
try:
|
|
|
|
| 212 |
positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_prompt_attention(pipeline, request.prompt, request.negative_prompt)
|
| 213 |
|
| 214 |
# Common Args
|
|
@@ -243,15 +245,16 @@ def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
|
|
| 243 |
args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode)[0]
|
| 244 |
|
| 245 |
# Generate
|
|
|
|
| 246 |
images = pipeline(**args).images
|
| 247 |
|
| 248 |
# Refiner
|
| 249 |
if request.refiner:
|
| 250 |
images = refiner(image=images, prompt=request.prompt, num_inference_steps=40, denoising_start=0.7).images
|
| 251 |
|
|
|
|
|
|
|
| 252 |
return images
|
| 253 |
except Exception as e:
|
| 254 |
cleanup(pipeline, request.loras, request.embeddings)
|
| 255 |
raise gr.Error(f"Error: {e}")
|
| 256 |
-
finally:
|
| 257 |
-
cleanup(pipeline, request.loras, request.embeddings)
|
|
|
|
| 205 |
|
| 206 |
|
| 207 |
# Gen Function
|
| 208 |
+
def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq, progress=gr.Progress(track_tqdm=True)):
|
| 209 |
+
progress(0.1, "Loading Pipeline")
|
| 210 |
pipeline_args = get_pipe(request)
|
| 211 |
pipeline = pipeline_args["pipeline"]
|
| 212 |
try:
|
| 213 |
+
progress(0.5, "Configuring Pipeline")
|
| 214 |
positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_prompt_attention(pipeline, request.prompt, request.negative_prompt)
|
| 215 |
|
| 216 |
# Common Args
|
|
|
|
| 245 |
args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode)[0]
|
| 246 |
|
| 247 |
# Generate
|
| 248 |
+
progress(0.9, "Generating Images")
|
| 249 |
images = pipeline(**args).images
|
| 250 |
|
| 251 |
# Refiner
|
| 252 |
if request.refiner:
|
| 253 |
images = refiner(image=images, prompt=request.prompt, num_inference_steps=40, denoising_start=0.7).images
|
| 254 |
|
| 255 |
+
progress(1.0, "Cleaning Up")
|
| 256 |
+
cleanup(pipeline, request.loras, request.embeddings)
|
| 257 |
return images
|
| 258 |
except Exception as e:
|
| 259 |
cleanup(pipeline, request.loras, request.embeddings)
|
| 260 |
raise gr.Error(f"Error: {e}")
|
|
|
|
|
|