SemaSci commited on
Commit
ea95c39
·
verified ·
1 Parent(s): f895199

Update app.py

Browse files

Добавлена возможность использовать LoRA и модель по умолчанию заменена на CompVis/stable-diffusion-v1-5

Files changed (1) hide show
  1. app.py +125 -16
app.py CHANGED
@@ -6,10 +6,63 @@ import random
6
  from diffusers import DiffusionPipeline
7
  import torch
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  #model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
- model_repo_id = "CompVis/stable-diffusion-v1-4"
12
- model_dropdown = ['stabilityai/sdxl-turbo', 'CompVis/stable-diffusion-v1-4' ]
 
 
 
13
 
14
  if torch.cuda.is_available():
15
  torch_dtype = torch.float16
@@ -30,38 +83,75 @@ def infer(
30
  randomize_seed,
31
  width,
32
  height,
33
- model_repo_id=model_repo_id,
34
  seed=42,
35
  guidance_scale=7,
36
  num_inference_steps=20,
 
 
37
  progress=gr.Progress(track_tqdm=True),
38
- ):
 
39
  if randomize_seed:
40
  seed = random.randint(0, MAX_SEED)
41
 
42
  generator = torch.Generator().manual_seed(seed)
43
 
44
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
45
- pipe = pipe.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- image = pipe(
48
- prompt=prompt,
49
- negative_prompt=negative_prompt,
50
- guidance_scale=guidance_scale,
51
- num_inference_steps=num_inference_steps,
52
- width=width,
53
- height=height,
54
- generator=generator,
55
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  return image, seed
58
 
59
 
60
  examples = [
61
- "A young lady in a Russian embroidered kaftan is sitting on a beautiful carved veranda, holding a cup to her mouth and drinking tea from the cup. With her other hand, the girl holds a saucer. The cup and saucer are painted with gzhel. Next to the girl on the table stands a samovar, and steam can be seen above it.",
62
  "Puss in Boots wearing a sombrero crosses the Grand Canyon on a tightrope with a guitar.",
63
  "A cat is playing a song called ""About the Cat"" on an accordion by the sea at sunset. The sun is quickly setting behind the horizon, and the light is fading.",
64
  "A cat walks through the grass on the streets of an abandoned city. The camera view is always focused on the cat's face.",
 
65
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
66
  "An astronaut riding a green horse",
67
  "A delicious ceviche cheesecake slice",
@@ -158,6 +248,23 @@ with gr.Blocks(css=css) as demo:
158
  step=1,
159
  value=20, # Replace with defaults that work for your model
160
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
  gr.Examples(examples=examples, inputs=[prompt])
163
  gr.on(
@@ -173,6 +280,8 @@ with gr.Blocks(css=css) as demo:
173
  seed,
174
  guidance_scale,
175
  num_inference_steps,
 
 
176
  ],
177
  outputs=[result, seed],
178
  )
 
6
  from diffusers import DiffusionPipeline
7
  import torch
8
 
9
+ from peft import PeftModel, LoraConfig
10
+ import os
11
+
12
+ def get_lora_sd_pipeline(
13
+ ckpt_dir='./lora_logos',
14
+ base_model_name_or_path=None,
15
+ dtype=torch.float16,
16
+ adapter_name="default"
17
+ ):
18
+
19
+ unet_sub_dir = os.path.join(ckpt_dir, "unet")
20
+ text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
21
+
22
+ if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
23
+ config = LoraConfig.from_pretrained(text_encoder_sub_dir)
24
+ base_model_name_or_path = config.base_model_name_or_path
25
+
26
+ if base_model_name_or_path is None:
27
+ raise ValueError("Please specify the base model name or path")
28
+
29
+ pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype)
30
+ before_params = pipe.unet.parameters()
31
+ pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
32
+ pipe.unet.set_adapter(adapter_name)
33
+ after_params = pipe.unet.parameters()
34
+ print("Parameters changed:", any(torch.any(b != a) for b, a in zip(before_params, after_params)))
35
+
36
+ if os.path.exists(text_encoder_sub_dir):
37
+ pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name)
38
+
39
+ if dtype in (torch.float16, torch.bfloat16):
40
+ pipe.unet.half()
41
+ pipe.text_encoder.half()
42
+
43
+ return pipe
44
+
45
+ def process_prompt(prompt, tokenizer, text_encoder, max_length=77):
46
+ tokens = tokenizer(prompt, truncation=False, return_tensors="pt")["input_ids"]
47
+ chunks = [tokens[:, i:i + max_length] for i in range(0, tokens.shape[1], max_length)]
48
+
49
+ with torch.no_grad():
50
+ embeds = [text_encoder(chunk.to(text_encoder.device))[0] for chunk in chunks]
51
+
52
+ return torch.cat(embeds, dim=1)
53
+
54
+ def align_embeddings(prompt_embeds, negative_prompt_embeds):
55
+ max_length = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
56
+ return torch.nn.functional.pad(prompt_embeds, (0, 0, 0, max_length - prompt_embeds.shape[1])), \
57
+ torch.nn.functional.pad(negative_prompt_embeds, (0, 0, 0, max_length - negative_prompt_embeds.shape[1]))
58
+
59
  device = "cuda" if torch.cuda.is_available() else "cpu"
60
  #model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
61
+ model_id_default = "CompVis/stable-diffusion-v1-5"
62
+ model_dropdown = ['stabilityai/sdxl-turbo', 'CompVis/stable-diffusion-v1-4', 'CompVis/stable-diffusion-v1-5' ]
63
+
64
+ model_lora_default = "lora_lady_and_cats_logos"
65
+ model_lora_dropdown = ['lora_lady_and_cats_logos', 'lora_pussinboots_logos' ]
66
 
67
  if torch.cuda.is_available():
68
  torch_dtype = torch.float16
 
83
  randomize_seed,
84
  width,
85
  height,
86
+ model_repo_id=model_id_default,
87
  seed=42,
88
  guidance_scale=7,
89
  num_inference_steps=20,
90
+ model_lora_id=model_lora_default,
91
+ lora_scale=0.5,
92
  progress=gr.Progress(track_tqdm=True),
93
+ ):
94
+
95
  if randomize_seed:
96
  seed = random.randint(0, MAX_SEED)
97
 
98
  generator = torch.Generator().manual_seed(seed)
99
 
100
+ # убираем обновление pipe всегда
101
+ #pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
102
+ #pipe = pipe.to(device)
103
+
104
+ # добавляем обновление pipe по условию
105
+ if model_repo_id != model_id_default:
106
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype).to(device)
107
+ prompt_embeds = process_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
108
+ negative_prompt_embeds = process_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
109
+ prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
110
+ else:
111
+ # добавляем lora
112
+ pipe = get_lora_sd_pipeline(ckpt_dir='./lora_lady_and_cats_logos', base_model_name_or_path=model_id_default, dtype=torch_dtype).to(device)
113
+ #pipe = get_lora_sd_pipeline(ckpt_dir='./'+model_lora_id, base_model_name_or_path=model_id_default, dtype=torch_dtype).to(device)
114
+ prompt_embeds = process_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
115
+ negative_prompt_embeds = process_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
116
+ prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
117
+ print(f"LoRA adapter loaded: {pipe.unet.active_adapters}")
118
+ print(f"LoRA scale applied: {lora_scale}")
119
+ pipe.fuse_lora(lora_scale=lora_scale)
120
 
121
+
122
+ # заменяем просто вызов pipe с промптом
123
+ #image = pipe(
124
+ # prompt=prompt,
125
+ # negative_prompt=negative_prompt,
126
+ # guidance_scale=guidance_scale,
127
+ # num_inference_steps=num_inference_steps,
128
+ # width=width,
129
+ # height=height,
130
+ # generator=generator,
131
+ #).images[0]
132
+
133
+
134
+ # на вызов pipe с эмбеддингами
135
+ params = {
136
+ 'prompt_embeds': prompt_embeds,
137
+ 'negative_prompt_embeds': negative_prompt_embeds,
138
+ 'guidance_scale': guidance_scale,
139
+ 'num_inference_steps': num_inference_steps,
140
+ 'width': width,
141
+ 'height': height,
142
+ 'generator': generator,
143
+ }
144
+
145
+ return pipe(**params).images[0]
146
 
147
  return image, seed
148
 
149
 
150
  examples = [
 
151
  "Puss in Boots wearing a sombrero crosses the Grand Canyon on a tightrope with a guitar.",
152
  "A cat is playing a song called ""About the Cat"" on an accordion by the sea at sunset. The sun is quickly setting behind the horizon, and the light is fading.",
153
  "A cat walks through the grass on the streets of an abandoned city. The camera view is always focused on the cat's face.",
154
+ "A young lady in a Russian embroidered kaftan is sitting on a beautiful carved veranda, holding a cup to her mouth and drinking tea from the cup. With her other hand, the girl holds a saucer. The cup and saucer are painted with gzhel. Next to the girl on the table stands a samovar, and steam can be seen above it.",
155
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
156
  "An astronaut riding a green horse",
157
  "A delicious ceviche cheesecake slice",
 
248
  step=1,
249
  value=20, # Replace with defaults that work for your model
250
  )
251
+
252
+ model_lora_id = gr.Dropdown(
253
+ label="Lora Id",
254
+ choices=model_dropdown,
255
+ info="Choose LoRA model",
256
+ visible=True,
257
+ allow_custom_value=True,
258
+ value=model_lora_id,
259
+ )
260
+
261
+ lora_scale = gr.Slider(
262
+ label="LoRA scale",
263
+ minimum=0.0,
264
+ maximum=1.0,
265
+ step=0.1,
266
+ value=0.5,
267
+ )
268
 
269
  gr.Examples(examples=examples, inputs=[prompt])
270
  gr.on(
 
280
  seed,
281
  guidance_scale,
282
  num_inference_steps,
283
+ model_lora_id,
284
+ lora_scale,
285
  ],
286
  outputs=[result, seed],
287
  )