Kontext seconds pr iteration is 17x slower than the normal dev model in diffusers with input image

#25
by tao07 - opened

Trying out the new model on my Mac M1 in a small python script using the diffusers library. It looks like an easy drop in replacement for the normal/controlnet dev model. I’m using the q2_k model by calcius with the hyper8 lora. Each iteration takes around 1700 seconds (over 30 minute), compared to the q4 normal dev model with hyper8 lora that takes just around the 100 seconds per iteration. If I don’t pass any image to the pipeline, then it’s the same speed as the normal model. Here are the code for the kontext

from diffusers import FluxTransformer2DModel, GGUFQuantizationConfig, FluxKontextPipeline
from transformers import T5EncoderModel, CLIPTextModel
from PIL import Image, PngImagePlugin
import torch
from diffusers.utils import load_image
import random
import json
import gc
import os
import time
img_parameters = {
"steps": 8,
"guidance_scale": 2.5,
"seed": None,
"clip": None,
"t5": "Add a hat to the cat",
}
seed = img_parameters.get('seed') or random.randint(1, 2147483647)
img_parameters['seed'] = seed
input_image = load_image("/Users/madstsk/Downloads/cat.png")
use_prev_text_encoder = True
save_text_encoder_for_reuse = False
torch.mps.set_per_process_memory_fraction(0.0)
def flush():
gc.collect()
torch.mps.empty_cache()
gc.collect()
torch.mps.empty_cache()
main_folder = "/Users/madstsk/Python_Img/FluxDev_Kontext-Model"
if use_prev_text_encoder:
pooled_prompt_embeds = torch.load('pooled_prompt_embeds_kontext.pt')
prompt_embeds = torch.load('prompt_embeds_kontext.pt')
text_ids = torch.load('text_ids_kontext.pt')
else:
print('Loading text encoders')
start_time = time.time()
text_encoder = CLIPTextModel.from_pretrained(
'/Users/madstsk/Python_Img/Clip_l',
torch_dtype=torch.bfloat16,
)
text_encoder_2 = T5EncoderModel.from_pretrained(
'/Users/madstsk/Python_Img/T5-XXL',
torch_dtype=torch.bfloat16,
)
pipeline = FluxKontextPipeline.from_pretrained(
main_folder,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
transformer=None,
vae=None,
torch_dtype=torch.bfloat16,
).to("mps")
pipeline.enable_attention_slicing()
print('Encoding prompt')
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
prompt=img_parameters.get('clip') or img_parameters.get('t5'), prompt_2=img_parameters.get('t5'), max_sequence_length=256
)
# Saving the tensors to disk
if save_text_encoder_for_reuse:
torch.save(pooled_prompt_embeds, 'pooled_prompt_embeds_kontext.pt')
torch.save(prompt_embeds, 'prompt_embeds_kontext.pt')
torch.save(text_ids, 'text_ids_kontext.pt')
del pipeline
del text_encoder
del text_encoder_2
flush()
print(f"Text encoding took {time.time() - start_time:.4f} seconds.")
print('Load model')
ckpt_path = "/Volumes/T7/ML/ComfyUI/models/unet/flux1-kontext-dev-f32-q2_k.gguf"
transformer = FluxTransformer2DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipeline = FluxKontextPipeline.from_pretrained(
main_folder,
text_encoder=None,
text_encoder_2=None,
tokenizer=None,
tokenizer_2=None,
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("mps")
pipeline.enable_attention_slicing()
pipeline.load_lora_weights('/Volumes/T7/ML/ComfyUI/models/loras/Hyper-FLUX.1-dev-8steps-lora.safetensors', adapter_name="hyper8")
pipeline.set_adapters(["hyper8"], adapter_weights=[0.125])
#pipeline.load_lora_weights('/Volumes/T7/ML/ComfyUI/models/loras/amateurphoto-v6-forcu.safetensors')
#pipeline.load_lora_weights('/Volumes/T7/ML/ComfyUI/models/loras/boreal-v2.safetensors')
#Rendering
print("Running denoising.")
images = pipeline(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=img_parameters.get('steps'),
guidance_scale=img_parameters.get('guidance_scale'),
generator=torch.Generator("mps").manual_seed(seed),
image=input_image #Adding this will increase the iteration speed by 17x
).images[0]
outputpath = './output'
prefix = 'flux_'
existing_files = [f for f in os.listdir(outputpath) if f.startswith(prefix) and f.endswith('.png')]
existing_numbers = []
for f in existing_files:
number_part = f[len(prefix):-len('.png')]
if number_part.isdigit():
existing_numbers.append(int(number_part))
next_number = max(existing_numbers, default=0) + 1
filename = f"{prefix}{str(next_number).zfill(4)}{'.png'}"
filepath = os.path.join(outputpath, filename)
meta = PngImagePlugin.PngInfo()
meta.add_text("Source", json.dumps(img_parameters))
images.save(filepath, pnginfo=meta)

The code is basically a drop in replacement, where FluxPipeline has been replaced by FluxKontextPipeline

I know my Mac only has 8GB of total RAM and it's slowed down a lot due to swapping, but that ain't a big issue with the normal model, even at higher quality quantized models. With Macmon, I have observed that when image is added, physical memory usage constantly going up and then drops. I know adding an image as an input will require more RAM, but q2+image uses around 9-10gb for the python program, where normal model at q4 uses the same.

Is there something that's much different with the new Kontext model or is it just because it's recently released and the diffusers library isn't fully completed support for it?

Sign up or log in to comment