multimodalart HF Staff commited on
Commit
d434b56
·
verified ·
1 Parent(s): 4e85690

Add optimizations (#1)

Browse files

- Add optimizations (f999ccda0036b0e1c33263a084fda2742d7f0fee)
- Create optimization.py (28a5a5501c48df1538fbedce90011acbf627b6b0)
- Update requirements.txt (38bb795e28951db7465262866a358315b317c2f1)

Files changed (3) hide show
  1. app.py +14 -1
  2. optimization.py +70 -0
  3. requirements.txt +3 -1
app.py CHANGED
@@ -8,9 +8,15 @@ import json
8
 
9
  from PIL import Image
10
  from diffusers import QwenImageEditPipeline, FlowMatchEulerDiscreteScheduler
 
11
  from huggingface_hub import InferenceClient
12
  import math
13
 
 
 
 
 
 
14
  # --- Prompt Enhancement using Hugging Face InferenceClient ---
15
  def polish_prompt_hf(original_prompt, system_prompt):
16
  """
@@ -159,7 +165,7 @@ scheduler_config = {
159
  scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
160
 
161
  # Load the edit pipeline with Lightning scheduler
162
- pipe = QwenImageEditPipeline.from_pretrained(
163
  "Qwen/Qwen-Image-Edit",
164
  scheduler=scheduler,
165
  torch_dtype=dtype
@@ -177,6 +183,13 @@ except Exception as e:
177
  print(f"Warning: Could not load Lightning LoRA weights: {e}")
178
  print("Continuing with base model...")
179
 
 
 
 
 
 
 
 
180
  # --- UI Constants and Helpers ---
181
  MAX_SEED = np.iinfo(np.int32).max
182
 
 
8
 
9
  from PIL import Image
10
  from diffusers import QwenImageEditPipeline, FlowMatchEulerDiscreteScheduler
11
+
12
  from huggingface_hub import InferenceClient
13
  import math
14
 
15
+ from optimization import optimize_pipeline_
16
+ from qwenimage.pipeline_qwen_image_edit import QwenImageEditPipeline as QwenImageEditPipelineCustom
17
+ from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
18
+ from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
19
+
20
  # --- Prompt Enhancement using Hugging Face InferenceClient ---
21
  def polish_prompt_hf(original_prompt, system_prompt):
22
  """
 
165
  scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
166
 
167
  # Load the edit pipeline with Lightning scheduler
168
+ pipe = QwenImageEditPipelineCustom.from_pretrained(
169
  "Qwen/Qwen-Image-Edit",
170
  scheduler=scheduler,
171
  torch_dtype=dtype
 
183
  print(f"Warning: Could not load Lightning LoRA weights: {e}")
184
  print("Continuing with base model...")
185
 
186
+ # Apply the same optimizations from the first version
187
+ pipe.transformer.__class__ = QwenImageTransformer2DModel
188
+ pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
189
+
190
+ # --- Ahead-of-time compilation ---
191
+ optimize_pipeline_(pipe, image=Image.new("RGB", (1024, 1024)), prompt="prompt")
192
+
193
  # --- UI Constants and Helpers ---
194
  MAX_SEED = np.iinfo(np.int32).max
195
 
optimization.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from typing import Any
5
+ from typing import Callable
6
+ from typing import ParamSpec
7
+ from torchao.quantization import quantize_
8
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
9
+ import spaces
10
+ import torch
11
+ from torch.utils._pytree import tree_map
12
+
13
+
14
+ P = ParamSpec('P')
15
+
16
+
17
+ TRANSFORMER_IMAGE_SEQ_LENGTH_DIM = torch.export.Dim('image_seq_length')
18
+ TRANSFORMER_TEXT_SEQ_LENGTH_DIM = torch.export.Dim('text_seq_length')
19
+
20
+ TRANSFORMER_DYNAMIC_SHAPES = {
21
+ 'hidden_states': {
22
+ 1: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
23
+ },
24
+ 'encoder_hidden_states': {
25
+ 1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
26
+ },
27
+ 'encoder_hidden_states_mask': {
28
+ 1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
29
+ },
30
+ 'image_rotary_emb': ({
31
+ 0: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
32
+ }, {
33
+ 0: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
34
+ }),
35
+ }
36
+
37
+
38
+ INDUCTOR_CONFIGS = {
39
+ 'conv_1x1_as_mm': True,
40
+ 'epilogue_fusion': False,
41
+ 'coordinate_descent_tuning': True,
42
+ 'coordinate_descent_check_all_directions': True,
43
+ 'max_autotune': True,
44
+ 'triton.cudagraphs': True,
45
+ }
46
+
47
+
48
+ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
49
+
50
+ @spaces.GPU(duration=1500)
51
+ def compile_transformer():
52
+
53
+ with spaces.aoti_capture(pipeline.transformer) as call:
54
+ pipeline(*args, **kwargs)
55
+
56
+ dynamic_shapes = tree_map(lambda t: None, call.kwargs)
57
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
58
+
59
+ # quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
60
+
61
+ exported = torch.export.export(
62
+ mod=pipeline.transformer,
63
+ args=call.args,
64
+ kwargs=call.kwargs,
65
+ dynamic_shapes=dynamic_shapes,
66
+ )
67
+
68
+ return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
69
+
70
+ spaces.aoti_apply(compile_transformer(), pipeline.transformer)
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
- git+https://github.com/huggingface/diffusers.git@7a2b78bf0f788d311cc96b61e660a8e13e3b1e63
 
 
2
  transformers
3
  accelerate
4
  safetensors
 
1
+ git+https://github.com/huggingface/diffusers.git@qwenimage-lru-cache-bypass
2
+ kernels
3
+ torchao==0.11.0
4
  transformers
5
  accelerate
6
  safetensors