recoilme commited on
Commit
fcb2271
·
1 Parent(s): 74bfd8a
config.json DELETED
@@ -1,38 +0,0 @@
1
- {
2
- "_class_name": "AutoencoderKL",
3
- "_diffusers_version": "0.34.0",
4
- "_name_or_path": "sdxl_vae",
5
- "act_fn": "silu",
6
- "block_out_channels": [
7
- 128,
8
- 256,
9
- 512,
10
- 512
11
- ],
12
- "down_block_types": [
13
- "DownEncoderBlock2D",
14
- "DownEncoderBlock2D",
15
- "DownEncoderBlock2D",
16
- "DownEncoderBlock2D"
17
- ],
18
- "force_upcast": false,
19
- "in_channels": 3,
20
- "latent_channels": 4,
21
- "latents_mean": null,
22
- "latents_std": null,
23
- "layers_per_block": 2,
24
- "mid_block_add_attention": true,
25
- "norm_num_groups": 32,
26
- "out_channels": 3,
27
- "sample_size": 512,
28
- "scaling_factor": 0.13025,
29
- "shift_factor": null,
30
- "up_block_types": [
31
- "UpDecoderBlock2D",
32
- "UpDecoderBlock2D",
33
- "UpDecoderBlock2D",
34
- "UpDecoderBlock2D"
35
- ],
36
- "use_post_quant_conv": true,
37
- "use_quant_conv": true
38
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusion_pytorch_model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:03f2412467f6bedce9efeddba5860b5ec0d3267931d14c500d4bd7a878e14cbd
3
- size 334643268
 
 
 
 
eval_alchemist.py CHANGED
@@ -1,330 +1,520 @@
1
  import os
 
 
 
 
 
 
 
 
2
  import torch
3
  import torch.nn.functional as F
4
- import lpips
5
- from PIL import Image, UnidentifiedImageError
6
- from tqdm import tqdm
7
  from torch.utils.data import Dataset, DataLoader
8
- from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop,ToPILImage
9
- from diffusers import AutoencoderKL, AsymmetricAutoencoderKL, AutoencoderKLWan,AutoencoderKLLTXVideo
10
- import random
11
 
12
- # --------------------------- Параметры ---------------------------
 
 
 
 
 
 
 
 
 
 
13
  DEVICE = "cuda"
14
  DTYPE = torch.float16
15
- IMAGE_FOLDER = "/workspace/alchemist" #wget https://huggingface.co/datasets/AiArtLab/alchemist/resolve/main/alchemist.zip
16
  MIN_SIZE = 1280
17
  CROP_SIZE = 512
18
  BATCH_SIZE = 10
19
  MAX_IMAGES = 0
20
  NUM_WORKERS = 4
21
- NUM_SAMPLES_TO_SAVE = 2 # Сколько примеров сохранить (0 - не сохранять)
22
- SAMPLES_FOLDER = "vaetest"
23
 
24
- # Список VAE для тестирования
25
  VAE_LIST = [
26
- # ("stable-diffusion-v1-5/stable-diffusion-v1-5", AutoencoderKL, "stable-diffusion-v1-5/stable-diffusion-v1-5", "vae"),
27
- # ("cross-attention/asymmetric-autoencoder-kl-x-1-5", AsymmetricAutoencoderKL, "cross-attention/asymmetric-autoencoder-kl-x-1-5", None),
28
- # ("madebyollin/sdxl-vae-fp16", AutoencoderKL, "madebyollin/sdxl-vae-fp16-fix", None),
29
- # ("KBlueLeaf/EQ-SDXL-VAE", AutoencoderKL, "KBlueLeaf/EQ-SDXL-VAE", None),
30
- # ("AiArtLab/sdxl_vae", AutoencoderKL, "AiArtLab/sdxl_vae", None),
31
- # ("AiArtLab/sdxlvae_nightly", AutoencoderKL, "AiArtLab/sdxl_vae", "vae_nightly"),
32
- # ("Lightricks/LTX-Video", AutoencoderKLLTXVideo, "Lightricks/LTX-Video", "vae"),
33
- # ("Wan2.2-TI2V-5B-Diffusers", AutoencoderKLWan, "Wan-AI/Wan2.2-TI2V-5B-Diffusers", "vae"),
34
- # ("Wan2.2-T2V-A14B-Diffusers", AutoencoderKLWan, "Wan-AI/Wan2.2-T2V-A14B-Diffusers", "vae"),
35
- # ("AiArtLab/sdxs", AutoencoderKL, "AiArtLab/sdxs", "vae"),
36
- ("FLUX.1-schnell-vae", AutoencoderKL, "black-forest-labs/FLUX.1-schnell", "vae"),
37
- ("simple_vae", AutoencoderKL, "AiArtLab/simplevae", "vae"),
38
- ("simple_vae2", AutoencoderKL, "AiArtLab/simplevae", None),
39
- ("simple_vae_nightly", AutoencoderKL, "/workspace/sdxl_vae/simple_vae_nightly", None),
40
-
 
41
  ]
42
 
43
- # --------------------------- Sobel Edge Detection ---------------------------
44
- # Определяем фильтры Собеля глобально
45
- _sobel_kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
46
- _sobel_ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
47
 
48
- def sobel_edges(x: torch.Tensor) -> torch.Tensor:
49
- """
50
- Вычисляет карту границ с помощью оператора Собеля
51
- x: [B,C,H,W] в диапазоне [-1,1]
52
- Возвращает: [B,C,H,W] - магнитуда градиента
53
- """
54
- C = x.shape[1]
55
- kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1)
56
- ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1)
57
- gx = F.conv2d(x, kx, padding=1, groups=C)
58
- gy = F.conv2d(x, ky, padding=1, groups=C)
59
- return torch.sqrt(gx * gx + gy * gy + 1e-12)
60
-
61
- def compute_edge_loss(real: torch.Tensor, fake: torch.Tensor) -> float:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  """
63
- Вычисляет Edge Loss между реальным и сгенерированным изображением
64
- real, fake: [B,C,H,W] в диапазоне [0,1]
65
- Возвращает: скалярное значение loss
66
  """
67
- # Конвертируем в [-1,1] для sobel_edges
68
- real_norm = real * 2 - 1
69
- fake_norm = fake * 2 - 1
70
-
71
- # Получаем карты границ
72
- edges_real = sobel_edges(real_norm)
73
- edges_fake = sobel_edges(fake_norm)
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- # L1 loss между картами границ
76
- return F.l1_loss(edges_fake, edges_real).item()
77
 
78
- # --------------------------- Dataset ---------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  class ImageFolderDataset(Dataset):
80
- def __init__(self, root_dir, extensions=('.png',), min_size=1024, crop_size=512, limit=None):
81
- self.root_dir = root_dir
82
- self.min_size = min_size
83
- self.crop_size = crop_size
84
- self.paths = []
85
-
86
- print("Сканирование папки...")
87
  for root, _, files in os.walk(root_dir):
88
  for fname in files:
89
  if fname.lower().endswith(extensions):
90
- self.paths.append(os.path.join(root, fname))
91
-
92
  if limit:
93
- self.paths = self.paths[:limit]
94
-
95
- print("Проверка изображений...")
96
  valid = []
97
- for p in tqdm(self.paths, desc="Проверка"):
98
  try:
99
  with Image.open(p) as im:
100
  im.verify()
101
  valid.append(p)
102
- except:
103
- continue
 
 
 
104
  self.paths = valid
105
-
106
- if len(self.paths) == 0:
107
- raise RuntimeError(f"Не найдено валидных изображений в {root_dir}")
108
-
109
- random.shuffle(self.paths)
110
  print(f"Найдено {len(self.paths)} изображений")
111
-
112
  self.transform = Compose([
113
- Resize(min_size, interpolation=Image.LANCZOS),
114
  CenterCrop(crop_size),
115
- ToTensor(),
116
  ])
117
-
118
  def __len__(self):
119
  return len(self.paths)
120
-
121
  def __getitem__(self, idx):
122
- path = self.paths[idx]
123
- with Image.open(path) as img:
124
  img = img.convert("RGB")
125
  return self.transform(img)
126
 
127
- # --------------------------- Функции ---------------------------
128
- def process(x):
129
- return x * 2 - 1
130
 
131
- def deprocess(x):
132
- return x * 0.5 + 0.5
 
 
133
 
134
- def _sanitize_name(name: str) -> str:
135
- return name.replace('/', '_').replace('-', '_')
136
 
137
- # --------------------------- Анализ VAE ---------------------------
138
- @torch.no_grad()
139
- def tensor_stats(name, x: torch.Tensor):
140
- finite = torch.isfinite(x)
141
- fin_ratio = finite.float().mean().item()
142
- x_f = x[finite]
143
- minv = x_f.min().item() if x_f.numel() else float('nan')
144
- maxv = x_f.max().item() if x_f.numel() else float('nan')
145
- mean = x_f.mean().item() if x_f.numel() else float('nan')
146
- std = x_f.std().item() if x_f.numel() else float('nan')
147
- big = (x_f.abs() > 20).float().mean().item() if x_f.numel() else float('nan')
148
- print(f"[{name}] shape={tuple(x.shape)} dtype={x.dtype} "
149
- f"finite={fin_ratio:.6f} min={minv:.4g} max={maxv:.4g} mean={mean:.4g} std={std:.4g} |x|>20={big:.6f}")
150
 
151
- @torch.no_grad()
152
- def analyze_vae_latents(vae, name, images):
153
- """
154
- images: [B,3,H,W] в [-1,1]
155
- """
156
- try:
157
- enc = vae.encode(images)
158
- if hasattr(enc, "latent_dist"):
159
- mu, logvar = enc.latent_dist.mean, enc.latent_dist.logvar
160
- z = enc.latent_dist.sample()
161
- else:
162
- mu, logvar = enc[0], enc[1]
163
- z = mu
164
- tensor_stats(f"{name}.mu", mu)
165
- tensor_stats(f"{name}.logvar", logvar)
166
- tensor_stats(f"{name}.z_raw", z)
167
 
168
- sf = getattr(vae.config, "scaling_factor", 1.0)
169
- z_scaled = z * sf
170
- tensor_stats(f"{name}.z_scaled(x{sf})", z_scaled)
171
- except Exception as e:
172
- print(f"⚠️ Ошибка анализа VAE {name}: {e}")
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
- # --------------------------- Основной код ---------------------------
176
- if __name__ == "__main__":
177
- if NUM_SAMPLES_TO_SAVE > 0:
178
- os.makedirs(SAMPLES_FOLDER, exist_ok=True)
179
-
180
- dataset = ImageFolderDataset(
181
- IMAGE_FOLDER,
182
- extensions=('.png',),
183
- min_size=MIN_SIZE,
184
- crop_size=CROP_SIZE,
185
- limit=MAX_IMAGES
186
- )
187
-
188
- dataloader = DataLoader(
189
- dataset,
190
- batch_size=BATCH_SIZE,
191
- shuffle=False,
192
- num_workers=NUM_WORKERS,
193
- pin_memory=True,
194
- drop_last=False
195
- )
196
-
197
- lpips_net = lpips.LPIPS(net="vgg").eval().to(DEVICE).requires_grad_(False)
198
-
199
- print("\nЗагрузка VAE моделей...")
200
- vaes = []
201
- names = []
202
-
203
- for name, vae_class, model_path, subfolder in VAE_LIST:
204
- try:
205
- print(f" Загружаю {name}...")
206
- # Исправлена загрузка для variant
207
- if "sdxs" in model_path:
208
- vae = vae_class.from_pretrained(model_path, subfolder=subfolder, variant="fp16")
209
- else:
210
- vae = vae_class.from_pretrained(model_path, subfolder=subfolder)
211
- vae = vae.to(DEVICE, DTYPE).eval()
212
- vaes.append(vae)
213
- names.append(name)
214
- except Exception as e:
215
- print(f" ❌ Ошибка загрузки {name}: {e}")
216
-
217
- print("\nОценка метрик...")
218
- results = {name: {"mse": 0.0, "psnr": 0.0, "lpips": 0.0, "edge": 0.0, "count": 0} for name in names}
219
-
220
- to_pil = ToPILImage()
221
-
222
- # >>>>>>>> ОСНОВНЫЕ ИЗМЕНЕНИЯ ЗДЕСЬ (KISS) <<<<<<<<
223
- with torch.no_grad():
224
- images_saved = 0 # считаем именно КОЛ-ВО ИЗОБРАЖЕНИЙ, а не сохранённых файлов
225
- for batch in tqdm(dataloader, desc="Обработка батчей"):
226
- batch = batch.to(DEVICE) # [B,3,H,W] в [0,1]
227
- test_inp = process(batch).to(DTYPE) # [-1,1] для энкодера
228
- # >>> Анализируем латенты каждой VAE на первой итерации
229
- if images_saved == 0: # только для первого батча, чтобы не засорять лог
230
- for vae, name in zip(vaes, names):
231
- analyze_vae_latents(vae, name, test_inp)
232
-
233
- # 1) считаем реконструкции для всех VAE на весь батч
234
- recon_list = []
235
- for vae, name in zip(vaes, names):
236
- test_inp_vae = test_inp # локальная копия
237
- #if name == "Wan2.2-T2V-A14B-Diffusers" and test_inp_vae.ndim == 4:
238
- if (isinstance(vae, AutoencoderKLWan) or isinstance(vae, AutoencoderKLLTXVideo)) and test_inp_vae.ndim == 4:
239
- test_inp_vae = test_inp_vae.unsqueeze(2) # только для Wan
240
- latent = vae.encode(test_inp_vae).latent_dist.mode()
241
- dec = vae.decode(latent).sample.float()
242
- if dec.ndim == 5:
243
- dec = dec.squeeze(2)
244
- recon = deprocess(dec).clamp(0.0, 1.0)
245
- recon_list.append(recon)
246
-
247
- # 2) обновляем метрики (по каждой VAE)
248
- for recon, name in zip(recon_list, names):
249
- for i in range(batch.shape[0]):
250
- img_orig = batch[i:i+1]
251
- img_recon = recon[i:i+1]
252
- mse = F.mse_loss(img_orig, img_recon).item()
253
- psnr = 10 * torch.log10(1 / torch.tensor(mse)).item()
254
- lpips_val = lpips_net(img_orig, img_recon, normalize=True).mean().item()
255
- edge_loss = compute_edge_loss(img_orig, img_recon)
256
- results[name]["mse"] += mse
257
- results[name]["psnr"] += psnr
258
- results[name]["lpips"] += lpips_val
259
- results[name]["edge"] += edge_loss
260
- results[name]["count"] += 1
261
-
262
- # 3) сохраняем ровно NUM_SAMPLES_TO_SAVE изображений (orig + все VAE + общий коллаж)
263
- if NUM_SAMPLES_TO_SAVE > 0:
264
- for i in range(batch.shape[0]):
265
- if images_saved >= NUM_SAMPLES_TO_SAVE:
266
- break
267
- idx_str = f"{images_saved + 1:03d}"
268
-
269
- # original
270
- orig_pil = to_pil(batch[i].detach().float().cpu())
271
- orig_pil.save(os.path.join(SAMPLES_FOLDER, f"{idx_str}_orig.png"))
272
-
273
- # per-VAE decodes
274
- tiles = [orig_pil]
275
- for recon, name in zip(recon_list, names):
276
- recon_pil = to_pil(recon[i].detach().cpu())
277
- recon_pil.save(os.path.join(
278
- SAMPLES_FOLDER, f"{idx_str}_decoded_{_sanitize_name(name)}.png"
279
- ))
280
- tiles.append(recon_pil)
281
-
282
- # общий коллаж: [orig | vae1 | vae2 | ...]
283
- collage_w = CROP_SIZE * len(tiles)
284
- collage_h = CROP_SIZE
285
- collage = Image.new("RGB", (collage_w, collage_h))
286
- x = 0
287
- for tile in tiles:
288
- collage.paste(tile, (x, 0))
289
- x += CROP_SIZE
290
- collage.save(os.path.join(SAMPLES_FOLDER, f"{idx_str}_all.png"))
291
-
292
- images_saved += 1
293
-
294
-
295
- # Усреднение результатов
296
- for name in names:
297
- count = results[name]["count"]
298
- results[name]["mse"] /= count
299
- results[name]["psnr"] /= count
300
- results[name]["lpips"] /= count
301
- results[name]["edge"] /= count
302
-
303
- # Вывод абсолютных значений
304
- print("\n=== Абсолютные значения ===")
305
- for name in names:
306
- print(f"{name:30s}: MSE: {results[name]['mse']:.3e}, PSNR: {results[name]['psnr']:.4f}, "
307
- f"LPIPS: {results[name]['lpips']:.4f}, Edge: {results[name]['edge']:.4f}")
308
-
309
- # Вывод таблицы с процентами
310
- print("\n=== Сравнение с первой моделью (%) ===")
311
- print(f"| {'Модель':30s} | {'MSE':>10s} | {'PSNR':>10s} | {'LPIPS':>10s} | {'Edge':>10s} |")
312
- print(f"|{'-'*32}|{'-'*12}|{'-'*12}|{'-'*12}|{'-'*12}|")
313
-
314
- baseline = names[0]
315
- for name in names:
316
- # Для MSE, LPIPS и Edge: меньше = лучше, поэтому инвертируем
317
- mse_pct = (results[baseline]["mse"] / results[name]["mse"]) * 100
318
- # Для PSNR: больше = лучше
319
- psnr_pct = (results[name]["psnr"] / results[baseline]["psnr"]) * 100
320
- # Для LPIPS и Edge: меньше = лучше
321
- lpips_pct = (results[baseline]["lpips"] / results[name]["lpips"]) * 100
322
- edge_pct = (results[baseline]["edge"] / results[name]["edge"]) * 100
323
-
324
  if name == baseline:
325
- print(f"| {name:30s} | {'100%':>10s} | {'100%':>10s} | {'100%':>10s} | {'100%':>10s} |")
326
  else:
327
- print(f"| {name:30s} | {f'{mse_pct:.1f}%':>10s} | {f'{psnr_pct:.1f}%':>10s} | "
328
- f"{f'{lpips_pct:.1f}%':>10s} | {f'{edge_pct:.1f}%':>10s} |")
329
-
330
- print("\n✅ Готово!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import json
3
+ import random
4
+ from typing import Dict, List, Tuple, Optional, Any
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+ from tqdm import tqdm
9
+
10
  import torch
11
  import torch.nn.functional as F
 
 
 
12
  from torch.utils.data import Dataset, DataLoader
13
+ from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop
14
+ from torchvision.utils import save_image
15
+ import lpips
16
 
17
+ from diffusers import (
18
+ AutoencoderKL,
19
+ AutoencoderKLWan,
20
+ AutoencoderKLLTXVideo,
21
+ AutoencoderKLQwenImage
22
+ )
23
+
24
+ from scipy.stats import skew, kurtosis
25
+
26
+
27
+ # ========================== Конфиг ==========================
28
  DEVICE = "cuda"
29
  DTYPE = torch.float16
30
+ IMAGE_FOLDER = "/home/recoilme/dataset/alchemist"
31
  MIN_SIZE = 1280
32
  CROP_SIZE = 512
33
  BATCH_SIZE = 10
34
  MAX_IMAGES = 0
35
  NUM_WORKERS = 4
36
+ SAMPLES_DIR = "vaetest"
 
37
 
 
38
  VAE_LIST = [
39
+ # ("SD15 VAE", AutoencoderKL, "stable-diffusion-v1-5/stable-diffusion-v1-5", "vae"),
40
+ ("SDXL VAE fp16 fix", AutoencoderKL, "madebyollin/sdxl-vae-fp16-fix", None),
41
+ #("Wan2.2-TI2V-5B", AutoencoderKLWan, "Wan-AI/Wan2.2-TI2V-5B-Diffusers", "vae"),
42
+ #("Wan2.2-T2V-A14B", AutoencoderKLWan, "Wan-AI/Wan2.2-T2V-A14B-Diffusers", "vae"),
43
+ #("SimpleVAE1", AutoencoderKL, "/home/recoilme/simplevae/simplevae", "simple_vae_nightly"),
44
+ #("SimpleVAE2", AutoencoderKL, "/home/recoilme/simplevae/simplevae", "simple_vae_nightly2"),
45
+ #("FLUX.1-schnell VAE", AutoencoderKL, "black-forest-labs/FLUX.1-schnell", "vae"),
46
+ # ("LTX-Video VAE", AutoencoderKLLTXVideo, "Lightricks/LTX-Video", "vae"),
47
+ #("QwenImage", AutoencoderKLQwenImage, "Qwen/Qwen-Image", "vae"),
48
+ #("wan16x_vae_nightly", AutoencoderKLWan, "AiArtLab/simplevae","wan16x_vae_nightly"),
49
+ #("wan16x_vae_nightly2", AutoencoderKLWan, "AiArtLab/simplevae","wan16x_vae_nightly2"),
50
+ #("SimpleVAE ", AutoencoderKL, "AiArtLab/simplevae", None),
51
+ #("AuraDiffusion/16ch-vae", AutoencoderKL, "AuraDiffusion/16ch-vae", None),
52
+ #("SimpleVAE nightly", AutoencoderKL, "AiArtLab/simplevae", "simple_vae_nightly"),
53
+ #("SimpleVAE nightly2", AutoencoderKL, "AiArtLab/simplevae", "simple_vae_nightly2"),
54
+ ("aiartlab/SDXLVAE", AutoencoderKL, "/home/recoilme/vae", "sdxlvae"),
55
  ]
56
 
 
 
 
 
57
 
58
+ # ========================== Утилиты ==========================
59
+ def to_neg1_1(x: torch.Tensor) -> torch.Tensor:
60
+ return x * 2 - 1
61
+
62
+
63
+ def to_0_1(x: torch.Tensor) -> torch.Tensor:
64
+ return (x + 1) * 0.5
65
+
66
+
67
+ def safe_psnr(mse: float) -> float:
68
+ if mse <= 1e-12:
69
+ return float("inf")
70
+ return 10.0 * float(np.log10(1.0 / mse))
71
+
72
+
73
+ def is_video_like_vae(vae) -> bool:
74
+ # Wan и LTX-Video ждут [B, C, T, H, W]
75
+ return isinstance(vae, (AutoencoderKLWan, AutoencoderKLLTXVideo,AutoencoderKLQwenImage))
76
+
77
+
78
+ def add_time_dim_if_needed(x: torch.Tensor, vae) -> torch.Tensor:
79
+ if is_video_like_vae(vae) and x.ndim == 4:
80
+ return x.unsqueeze(2) # -> [B, C, 1, H, W]
81
+ return x
82
+
83
+
84
+ def strip_time_dim_if_possible(x: torch.Tensor, vae) -> torch.Tensor:
85
+ if is_video_like_vae(vae) and x.ndim == 5 and x.shape[2] == 1:
86
+ return x.squeeze(2) # -> [B, C, H, W]
87
+ return x
88
+
89
+
90
+ @torch.no_grad()
91
+ def sobel_edge_l1(real_0_1: torch.Tensor, fake_0_1: torch.Tensor) -> float:
92
+ real = to_neg1_1(real_0_1)
93
+ fake = to_neg1_1(fake_0_1)
94
+ kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32, device=real.device).view(1, 1, 3, 3)
95
+ ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32, device=real.device).view(1, 1, 3, 3)
96
+ C = real.shape[1]
97
+ kx = kx.to(real.dtype).repeat(C, 1, 1, 1)
98
+ ky = ky.to(real.dtype).repeat(C, 1, 1, 1)
99
+
100
+ def grad_mag(x):
101
+ gx = F.conv2d(x, kx, padding=1, groups=C)
102
+ gy = F.conv2d(x, ky, padding=1, groups=C)
103
+ return torch.sqrt(gx * gx + gy * gy + 1e-12)
104
+
105
+ return F.l1_loss(grad_mag(fake), grad_mag(real)).item()
106
+
107
+
108
+ def flatten_channels(x: torch.Tensor) -> torch.Tensor:
109
+ # -> [C, N*H*W] или [C, N*T*H*W]
110
+ if x.ndim == 4:
111
+ return x.permute(1, 0, 2, 3).reshape(x.shape[1], -1)
112
+ elif x.ndim == 5:
113
+ return x.permute(1, 0, 2, 3, 4).reshape(x.shape[1], -1)
114
+ else:
115
+ raise ValueError(f"Unexpected tensor ndim={x.ndim}")
116
+
117
+
118
+ def _to_numpy_1d(x: Any) -> Optional[np.ndarray]:
119
+ if x is None:
120
+ return None
121
+ if isinstance(x, (int, float)):
122
+ return None
123
+ if isinstance(x, torch.Tensor):
124
+ x = x.detach().cpu().float().numpy()
125
+ elif isinstance(x, (list, tuple)):
126
+ x = np.array(x, dtype=np.float32)
127
+ elif isinstance(x, np.ndarray):
128
+ x = x.astype(np.float32, copy=False)
129
+ else:
130
+ return None
131
+ x = x.reshape(-1)
132
+ return x
133
+
134
+
135
+ def _to_float(x: Any) -> Optional[float]:
136
+ if x is None:
137
+ return None
138
+ if isinstance(x, (int, float)):
139
+ return float(x)
140
+ if isinstance(x, np.ndarray) and x.size == 1:
141
+ return float(x.item())
142
+ if isinstance(x, torch.Tensor) and x.numel() == 1:
143
+ return float(x.item())
144
+ return None
145
+
146
+
147
+ def get_norm_tensors_and_summary(vae, latent_like: torch.Tensor):
148
  """
149
+ Нормализация латентов: глобальная и поканальная.
150
+ Применение: сначала глобальная (scalar), затем поканальная (vector).
151
+ Если в конфиге есть несколько ключей — аккумулируем.
152
  """
153
+ cfg = getattr(vae, "config", vae)
154
+
155
+ scale_keys = [
156
+ "latents_std"
157
+ ]
158
+ shift_keys = [
159
+ "latents_mean"
160
+ ]
161
+
162
+ C = latent_like.shape[1]
163
+ nd = latent_like.ndim # 4 или 5
164
+ dev = latent_like.device
165
+ dt = latent_like.dtype
166
+
167
+ scale_global = getattr(vae.config, "scaling_factor", 1.0)
168
+ shift_global = getattr(vae.config, "shift_factor", 0.0)
169
+ if scale_global is None:
170
+ scale_global = 1.0
171
+ if shift_global is None:
172
+ shift_global = 0.0
173
 
174
+ scale_channel = np.ones(C, dtype=np.float32)
175
+ shift_channel = np.zeros(C, dtype=np.float32)
176
 
177
+ for k in scale_keys:
178
+ v = getattr(cfg, k, None)
179
+ if v is None:
180
+ continue
181
+ vec = _to_numpy_1d(v)
182
+ if vec is not None and vec.size == C:
183
+ scale_channel *= vec
184
+ else:
185
+ s = _to_float(v)
186
+ if s is not None:
187
+ scale_global *= s
188
+
189
+ for k in shift_keys:
190
+ v = getattr(cfg, k, None)
191
+ if v is None:
192
+ continue
193
+ vec = _to_numpy_1d(v)
194
+ if vec is not None and vec.size == C:
195
+ shift_channel += vec
196
+ else:
197
+ s = _to_float(v)
198
+ if s is not None:
199
+ shift_global += s
200
+
201
+ g_shape = [1] * nd
202
+ c_shape = [1] * nd
203
+ c_shape[1] = C
204
+
205
+ t_scale_g = torch.tensor(scale_global, dtype=dt, device=dev).view(*g_shape)
206
+ t_shift_g = torch.tensor(shift_global, dtype=dt, device=dev).view(*g_shape)
207
+ t_scale_c = torch.from_numpy(scale_channel).to(device=dev, dtype=dt).view(*c_shape)
208
+ t_shift_c = torch.from_numpy(shift_channel).to(device=dev, dtype=dt).view(*c_shape)
209
+
210
+ summary = {
211
+ "scale_global": float(scale_global),
212
+ "shift_global": float(shift_global),
213
+ "scale_channel_min": float(scale_channel.min()),
214
+ "scale_channel_mean": float(scale_channel.mean()),
215
+ "scale_channel_max": float(scale_channel.max()),
216
+ "shift_channel_min": float(shift_channel.min()),
217
+ "shift_channel_mean": float(shift_channel.mean()),
218
+ "shift_channel_max": float(shift_channel.max()),
219
+ }
220
+ return t_shift_g, t_scale_g, t_shift_c, t_scale_c, summary
221
+
222
+
223
+ @torch.no_grad()
224
+ def kl_divergence_per_image(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
225
+ kl_map = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) # [B, ...]
226
+ return kl_map.float().view(kl_map.shape[0], -1).mean(dim=1) # [B]
227
+
228
+
229
+ def sanitize_filename(name: str) -> str:
230
+ name = name.replace("/", "_").replace("\\", "_").replace(" ", "_")
231
+ return "".join(ch if (ch.isalnum() or ch in "._-") else "_" for ch in name)
232
+
233
+
234
+ # ========================== Датасет ==========================
235
  class ImageFolderDataset(Dataset):
236
+ def __init__(self, root_dir: str, extensions=(".png", ".jpg", ".jpeg", ".webp"), min_size=1024, crop_size=512, limit=None):
237
+ paths = []
 
 
 
 
 
238
  for root, _, files in os.walk(root_dir):
239
  for fname in files:
240
  if fname.lower().endswith(extensions):
241
+ paths.append(os.path.join(root, fname))
 
242
  if limit:
243
+ paths = paths[:limit]
244
+
 
245
  valid = []
246
+ for p in tqdm(paths, desc="Проверяем файлы"):
247
  try:
248
  with Image.open(p) as im:
249
  im.verify()
250
  valid.append(p)
251
+ except Exception:
252
+ pass
253
+ if not valid:
254
+ raise RuntimeError(f"Нет валидных изображений в {root_dir}")
255
+ random.shuffle(valid)
256
  self.paths = valid
 
 
 
 
 
257
  print(f"Найдено {len(self.paths)} изображений")
258
+
259
  self.transform = Compose([
260
+ Resize(min_size),
261
  CenterCrop(crop_size),
262
+ ToTensor(), # 0..1, float32
263
  ])
264
+
265
  def __len__(self):
266
  return len(self.paths)
267
+
268
  def __getitem__(self, idx):
269
+ with Image.open(self.paths[idx]) as img:
 
270
  img = img.convert("RGB")
271
  return self.transform(img)
272
 
 
 
 
273
 
274
+ # ========================== Основное ==========================
275
+ def main():
276
+ torch.set_grad_enabled(False)
277
+ os.makedirs(SAMPLES_DIR, exist_ok=True)
278
 
279
+ dataset = ImageFolderDataset(IMAGE_FOLDER, min_size=MIN_SIZE, crop_size=CROP_SIZE, limit=MAX_IMAGES)
280
+ loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
281
 
282
+ lpips_net = lpips.LPIPS(net="vgg").to(DEVICE).eval()
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
+ # Загрузка VAE
285
+ vaes: List[Tuple[str, object]] = []
286
+ print("\nЗагрузка VAE...")
287
+ for human_name, vae_class, model_path, subfolder in VAE_LIST:
288
+ try:
289
+ vae = vae_class.from_pretrained(model_path, subfolder=subfolder, torch_dtype=DTYPE)
290
+ vae = vae.to(DEVICE).eval()
291
+ vaes.append((human_name, vae))
292
+ print(f" ✅ {human_name}")
293
+ except Exception as e:
294
+ print(f" ❌ {human_name}: {e}")
 
 
 
 
 
295
 
296
+ if not vaes:
297
+ print("Нет успешно загруженных VAE. Выходим.")
298
+ return
 
 
299
 
300
+ # Агрегаторы
301
+ per_model_metrics: Dict[str, Dict[str, float]] = {
302
+ name: {"mse": 0.0, "psnr": 0.0, "lpips": 0.0, "edge": 0.0, "kl": 0.0, "count": 0.0}
303
+ for name, _ in vaes
304
+ }
305
+
306
+ buffers_zmodel: Dict[str, List[torch.Tensor]] = {name: [] for name, _ in vaes}
307
+ norm_summaries: Dict[str, Dict[str, float]] = {}
308
+
309
+ # Флаг для сохранения первой картинки
310
+ saved_first_for: Dict[str, bool] = {name: False for name, _ in vaes}
311
+
312
+ for batch_0_1 in tqdm(loader, desc="Батчи"):
313
+ batch_0_1 = batch_0_1.to(DEVICE, torch.float32)
314
+ batch_neg1_1 = to_neg1_1(batch_0_1).to(DTYPE)
315
+
316
+ for model_name, vae in vaes:
317
+ x_in = add_time_dim_if_needed(batch_neg1_1, vae)
318
+
319
+ posterior = vae.encode(x_in).latent_dist
320
+ mu, logvar = posterior.mean, posterior.logvar
321
+
322
+ # Реконструкция (детерминированно)
323
+ z_raw_mode = posterior.mode()
324
+ x_dec = vae.decode(z_raw_mode).sample # [-1, 1]
325
+ x_dec = strip_time_dim_if_possible(x_dec, vae)
326
+ x_rec_0_1 = to_0_1(x_dec.float()).clamp(0, 1)
327
+
328
+ # Латенты для UNet: global -> channelwise
329
+ z_raw_sample = posterior.sample()
330
+ t_shift_g, t_scale_g, t_shift_c, t_scale_c, summary = get_norm_tensors_and_summary(vae, z_raw_sample)
331
+
332
+ if model_name not in norm_summaries:
333
+ norm_summaries[model_name] = summary
334
+
335
+ z_tmp = (z_raw_sample - t_shift_g) * t_scale_g
336
+ z_model = (z_tmp - t_shift_c) * t_scale_c
337
+ z_model = strip_time_dim_if_possible(z_model, vae)
338
+
339
+ buffers_zmodel[model_name].append(z_model.detach().to("cpu", torch.float32))
340
+
341
+ # Сохранить первую картинку (оригинал и реконструкцию) для каждого VAE
342
+ if not saved_first_for[model_name]:
343
+ safe = sanitize_filename(model_name)
344
+ orig_path = os.path.join(SAMPLES_DIR, f"{safe}_original.png")
345
+ dec_path = os.path.join(SAMPLES_DIR, f"{safe}_decoded.png")
346
+ save_image(batch_0_1[0:1].cpu(), orig_path)
347
+ save_image(x_rec_0_1[0:1].cpu(), dec_path)
348
+ saved_first_for[model_name] = True
349
+
350
+ # Метрики по картинкам
351
+ B = batch_0_1.shape[0]
352
+ for i in range(B):
353
+ gt = batch_0_1[i:i+1]
354
+ rec = x_rec_0_1[i:i+1]
355
+
356
+ mse = F.mse_loss(gt, rec).item()
357
+ psnr = safe_psnr(mse)
358
+ lp = float(lpips_net(gt, rec, normalize=True).mean().item())
359
+ edge = sobel_edge_l1(gt, rec)
360
+
361
+ per_model_metrics[model_name]["mse"] += mse
362
+ per_model_metrics[model_name]["psnr"] += psnr
363
+ per_model_metrics[model_name]["lpips"] += lp
364
+ per_model_metrics[model_name]["edge"] += edge
365
+
366
+ # KL per-image
367
+ kl_pi = kl_divergence_per_image(mu, logvar) # [B]
368
+ per_model_metrics[model_name]["kl"] += float(kl_pi.sum().item())
369
+ per_model_metrics[model_name]["count"] += B
370
+
371
+ # Усреднение метрик
372
+ for name in per_model_metrics:
373
+ c = max(1.0, per_model_metrics[name]["count"])
374
+ for k in ["mse", "psnr", "lpips", "edge", "kl"]:
375
+ per_model_metrics[name][k] /= c
376
+
377
+ # Подсчёт статистик латентов и нормальности
378
+ per_model_latent_stats = {}
379
+ for name, _ in vaes:
380
+ if not buffers_zmodel[name]:
381
+ continue
382
+ Z = torch.cat(buffers_zmodel[name], dim=0) # [N, C, H, W]
383
+
384
+ # Глобальные
385
+ z_min = float(Z.min().item())
386
+ z_mean = float(Z.mean().item())
387
+ z_max = float(Z.max().item())
388
+ z_std = float(Z.std(unbiased=True).item())
389
+
390
+ # Пер-канально: skew/kurtosis
391
+ Z_ch = flatten_channels(Z).numpy() # [C, *]
392
+ C = Z_ch.shape[0]
393
+ sk = np.zeros(C, dtype=np.float64)
394
+ ku = np.zeros(C, dtype=np.float64)
395
+ for c in range(C):
396
+ v = Z_ch[c]
397
+ sk[c] = float(skew(v, bias=False))
398
+ ku[c] = float(kurtosis(v, fisher=True, bias=False))
399
+
400
+ skew_min, skew_mean, skew_max = float(sk.min()), float(sk.mean()), float(sk.max())
401
+ kurt_min, kurt_mean, kurt_max = float(ku.min()), float(ku.mean()), float(ku.max())
402
+ mean_abs_skew = float(np.mean(np.abs(sk)))
403
+ mean_abs_kurt = float(np.mean(np.abs(ku)))
404
+
405
+ per_model_latent_stats[name] = {
406
+ "Z_min": z_min, "Z_mean": z_mean, "Z_max": z_max, "Z_std": z_std,
407
+ "skew_min": skew_min, "skew_mean": skew_mean, "skew_max": skew_max,
408
+ "kurt_min": kurt_min, "kurt_mean": kurt_mean, "kurt_max": kurt_max,
409
+ "mean_abs_skew": mean_abs_skew, "mean_abs_kurt": mean_abs_kurt,
410
+ }
411
+
412
+ # Печать параметров нормализации (shift/scale)
413
+ print("\n=== Параметры нормализации латентов (как применялись) ===")
414
+ for name, _ in vaes:
415
+ if name not in norm_summaries:
416
+ continue
417
+ s = norm_summaries[name]
418
+ print(
419
+ f"{name:26s} | "
420
+ f"shift_g={s['shift_global']:.6g} scale_g={s['scale_global']:.6g} | "
421
+ f"shift_c[min/mean/max]=[{s['shift_channel_min']:.6g}, {s['shift_channel_mean']:.6g}, {s['shift_channel_max']:.6g}] | "
422
+ f"scale_c[min/mean/max]=[{s['scale_channel_min']:.6g}, {s['scale_channel_mean']:.6g}, {s['scale_channel_max']:.6g}]"
423
+ )
424
+
425
+ # Абсолютные метрики
426
+ print("\n=== Абсолютные метрики реконструкции и латентов ===")
427
+ for name, _ in vaes:
428
+ if name not in per_model_latent_stats:
429
+ continue
430
+ m = per_model_metrics[name]
431
+ s = per_model_latent_stats[name]
432
+ print(
433
+ f"{name:26s} | "
434
+ f"MSE={m['mse']:.3e} PSNR={m['psnr']:.2f} LPIPS={m['lpips']:.3f} Edge={m['edge']:.3f} KL={m['kl']:.3f} | "
435
+ f"Z[min/mean/max/std]=[{s['Z_min']:.3f}, {s['Z_mean']:.3f}, {s['Z_max']:.3f}, {s['Z_std']:.3f}] | "
436
+ f"Skew[min/mean/max]=[{s['skew_min']:.3f}, {s['skew_mean']:.3f}, {s['skew_max']:.3f}] | "
437
+ f"Kurt[min/mean/max]=[{s['kurt_min']:.3f}, {s['kurt_mean']:.3f}, {s['kurt_max']:.3f}]"
438
+ )
439
+
440
+ # Сравнение с первой моделью
441
+ baseline = vaes[0][0]
442
+ print("\n=== Сравнение с первой моделью (проценты) ===")
443
+ print(f"| {'Модель':26s} | {'MSE':>9s} | {'PSNR':>9s} | {'LPIPS':>9s} | {'Edge':>9s} | {'Skew|0':>9s} | {'Kurt|0':>9s} |")
444
+ print(f"|{'-'*28}|{'-'*11}|{'-'*11}|{'-'*11}|{'-'*11}|{'-'*11}|{'-'*11}|")
445
+
446
+ b_m = per_model_metrics[baseline]
447
+ b_s = per_model_latent_stats[baseline]
448
+
449
+ for name, _ in vaes:
450
+ m = per_model_metrics[name]
451
+ s = per_model_latent_stats[name]
452
+
453
+ mse_pct = (b_m["mse"] / max(1e-12, m["mse"])) * 100.0 # меньше лучше
454
+ psnr_pct = (m["psnr"] / max(1e-12, b_m["psnr"])) * 100.0 # больше лучше
455
+ lpips_pct= (b_m["lpips"] / max(1e-12, m["lpips"])) * 100.0 # меньше лучше
456
+ edge_pct = (b_m["edge"] / max(1e-12, m["edge"])) * 100.0 # меньше лучше
457
+
458
+ skew0_pct = (b_s["mean_abs_skew"] / max(1e-12, s["mean_abs_skew"])) * 100.0
459
+ kurt0_pct = (b_s["mean_abs_kurt"] / max(1e-12, s["mean_abs_kurt"])) * 100.0
460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  if name == baseline:
462
+ print(f"| {name:26s} | {'100%':>9s} | {'100%':>9s} | {'100%':>9s} | {'100%':>9s} | {'100%':>9s} | {'100%':>9s} |")
463
  else:
464
+ print(f"| {name:26s} | {mse_pct:8.1f}% | {psnr_pct:8.1f}% | {lpips_pct:8.1f}% | {edge_pct:8.1f}% | {skew0_pct:8.1f}% | {kurt0_pct:8.1f}% |")
465
+
466
+ # ========================== Коррекции для последнего VAE + сохранение в JSON ==========================
467
+ last_name = vaes[-1][0]
468
+ if buffers_zmodel[last_name]:
469
+ Z = torch.cat(buffers_zmodel[last_name], dim=0) # [N, C, H, W]
470
+
471
+ # Глобальная коррекция (по всем каналам/пикселям)
472
+ z_mean = float(Z.mean().item())
473
+ z_std = float(Z.std(unbiased=True).item())
474
+ correction_global = {
475
+ "shift": -z_mean,
476
+ "scale": (1.0 / z_std) if z_std > 1e-12 else 1.0
477
+ }
478
+
479
+ # Поканальная коррекция
480
+ Z_ch = flatten_channels(Z) # [C, M]
481
+ ch_means_t = Z_ch.mean(dim=1) # [C]
482
+ ch_stds_t = Z_ch.std(dim=1, unbiased=True) + 1e-12 # [C]
483
+ ch_means = [float(x) for x in ch_means_t.tolist()]
484
+ ch_stds = [float(x) for x in ch_stds_t.tolist()]
485
+
486
+ correction_per_channel = [
487
+ {"shift": float(-m), "scale": float(1.0 / s)}
488
+ for m, s in zip(ch_means, ch_stds)
489
+ ]
490
+
491
+ print(f"\n=== Доп. коррекция для {last_name} (поверх VAE-нормализации) ===")
492
+ print(f"global_correction = {correction_global}")
493
+ print(f"channelwise_means = {ch_means}")
494
+ print(f"channelwise_stds = {ch_stds}")
495
+ print(f"channelwise_correction = {correction_per_channel}")
496
+
497
+ # Сохранение в JSON
498
+ json_path = os.path.join(SAMPLES_DIR, f"{sanitize_filename(last_name)}_correction.json")
499
+ to_save = {
500
+ "model_name": last_name,
501
+ "vae_normalization_summary": norm_summaries.get(last_name, {}),
502
+ "global_correction": correction_global,
503
+ "per_channel_means": ch_means,
504
+ "per_channel_stds": ch_stds,
505
+ "per_channel_correction": correction_per_channel,
506
+ "apply_order": {
507
+ "forward": "z_model -> (z - global_shift)*global_scale -> (per-channel: (z - mean_c)/std_c)",
508
+ "inverse": "z_corr -> (per-channel: z*std_c + mean_c) -> (z/global_scale + global_shift)"
509
+ },
510
+ "note": "Эти коэффициенты рассчитаны по z_model (после встроенных VAE shift/scale), чтобы привести распределение к N(0,1)."
511
+ }
512
+ with open(json_path, "w", encoding="utf-8") as f:
513
+ json.dump(to_save, f, ensure_ascii=False, indent=2)
514
+ print("Corrections JSON saved to:", os.path.abspath(json_path))
515
+
516
+ print("\n✅ Готово. Сэмплы сохранены в:", os.path.abspath(SAMPLES_DIR))
517
+
518
+
519
+ if __name__ == "__main__":
520
+ main()
eval_alchemist2.py DELETED
@@ -1,516 +0,0 @@
1
- import os
2
- import json
3
- import random
4
- from typing import Dict, List, Tuple, Optional, Any
5
-
6
- import numpy as np
7
- from PIL import Image
8
- from tqdm import tqdm
9
-
10
- import torch
11
- import torch.nn.functional as F
12
- from torch.utils.data import Dataset, DataLoader
13
- from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop
14
- from torchvision.utils import save_image
15
- import lpips
16
-
17
- from diffusers import (
18
- AutoencoderKL,
19
- AutoencoderKLWan,
20
- AutoencoderKLLTXVideo,
21
- AutoencoderKLQwenImage
22
- )
23
-
24
- from scipy.stats import skew, kurtosis
25
-
26
-
27
- # ========================== Конфиг ==========================
28
- DEVICE = "cuda"
29
- DTYPE = torch.float16
30
- IMAGE_FOLDER = "/workspace/alchemist"
31
- MIN_SIZE = 1280
32
- CROP_SIZE = 512
33
- BATCH_SIZE = 10
34
- MAX_IMAGES = 0
35
- NUM_WORKERS = 4
36
- SAMPLES_DIR = "vaetest"
37
-
38
- VAE_LIST = [
39
- # ("SD15 VAE", AutoencoderKL, "stable-diffusion-v1-5/stable-diffusion-v1-5", "vae"),
40
- # ("SDXL VAE fp16 fix", AutoencoderKL, "madebyollin/sdxl-vae-fp16-fix", None),
41
- #("Wan2.2-TI2V-5B", AutoencoderKLWan, "Wan-AI/Wan2.2-TI2V-5B-Diffusers", "vae"),
42
- #("Wan2.2-T2V-A14B", AutoencoderKLWan, "Wan-AI/Wan2.2-T2V-A14B-Diffusers", "vae"),
43
- #("SimpleVAE1", AutoencoderKL, "/home/recoilme/simplevae/simplevae", "simple_vae_nightly"),
44
- #("SimpleVAE2", AutoencoderKL, "/home/recoilme/simplevae/simplevae", "simple_vae_nightly2"),
45
- #("SimpleVAE ", AutoencoderKL, "AiArtLab/simplevae", None),
46
- #("SimpleVAE nightly", AutoencoderKL, "AiArtLab/simplevae", "simple_vae_nightly"),
47
- ("FLUX.1-schnell VAE", AutoencoderKL, "black-forest-labs/FLUX.1-schnell", "vae"),
48
- ("SimpleVAE nightly", AutoencoderKL, "AiArtLab/simplevae", "simple_vae_nightly"),
49
- # ("LTX-Video VAE", AutoencoderKLLTXVideo, "Lightricks/LTX-Video", "vae"),
50
- #("QwenImage", AutoencoderKLQwenImage, "Qwen/Qwen-Image", "vae"),
51
- ]
52
-
53
-
54
- # ========================== Утилиты ==========================
55
- def to_neg1_1(x: torch.Tensor) -> torch.Tensor:
56
- return x * 2 - 1
57
-
58
-
59
- def to_0_1(x: torch.Tensor) -> torch.Tensor:
60
- return (x + 1) * 0.5
61
-
62
-
63
- def safe_psnr(mse: float) -> float:
64
- if mse <= 1e-12:
65
- return float("inf")
66
- return 10.0 * float(np.log10(1.0 / mse))
67
-
68
-
69
- def is_video_like_vae(vae) -> bool:
70
- # Wan и LTX-Video ждут [B, C, T, H, W]
71
- return isinstance(vae, (AutoencoderKLWan, AutoencoderKLLTXVideo,AutoencoderKLQwenImage))
72
-
73
-
74
- def add_time_dim_if_needed(x: torch.Tensor, vae) -> torch.Tensor:
75
- if is_video_like_vae(vae) and x.ndim == 4:
76
- return x.unsqueeze(2) # -> [B, C, 1, H, W]
77
- return x
78
-
79
-
80
- def strip_time_dim_if_possible(x: torch.Tensor, vae) -> torch.Tensor:
81
- if is_video_like_vae(vae) and x.ndim == 5 and x.shape[2] == 1:
82
- return x.squeeze(2) # -> [B, C, H, W]
83
- return x
84
-
85
-
86
- @torch.no_grad()
87
- def sobel_edge_l1(real_0_1: torch.Tensor, fake_0_1: torch.Tensor) -> float:
88
- real = to_neg1_1(real_0_1)
89
- fake = to_neg1_1(fake_0_1)
90
- kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32, device=real.device).view(1, 1, 3, 3)
91
- ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32, device=real.device).view(1, 1, 3, 3)
92
- C = real.shape[1]
93
- kx = kx.to(real.dtype).repeat(C, 1, 1, 1)
94
- ky = ky.to(real.dtype).repeat(C, 1, 1, 1)
95
-
96
- def grad_mag(x):
97
- gx = F.conv2d(x, kx, padding=1, groups=C)
98
- gy = F.conv2d(x, ky, padding=1, groups=C)
99
- return torch.sqrt(gx * gx + gy * gy + 1e-12)
100
-
101
- return F.l1_loss(grad_mag(fake), grad_mag(real)).item()
102
-
103
-
104
- def flatten_channels(x: torch.Tensor) -> torch.Tensor:
105
- # -> [C, N*H*W] или [C, N*T*H*W]
106
- if x.ndim == 4:
107
- return x.permute(1, 0, 2, 3).reshape(x.shape[1], -1)
108
- elif x.ndim == 5:
109
- return x.permute(1, 0, 2, 3, 4).reshape(x.shape[1], -1)
110
- else:
111
- raise ValueError(f"Unexpected tensor ndim={x.ndim}")
112
-
113
-
114
- def _to_numpy_1d(x: Any) -> Optional[np.ndarray]:
115
- if x is None:
116
- return None
117
- if isinstance(x, (int, float)):
118
- return None
119
- if isinstance(x, torch.Tensor):
120
- x = x.detach().cpu().float().numpy()
121
- elif isinstance(x, (list, tuple)):
122
- x = np.array(x, dtype=np.float32)
123
- elif isinstance(x, np.ndarray):
124
- x = x.astype(np.float32, copy=False)
125
- else:
126
- return None
127
- x = x.reshape(-1)
128
- return x
129
-
130
-
131
- def _to_float(x: Any) -> Optional[float]:
132
- if x is None:
133
- return None
134
- if isinstance(x, (int, float)):
135
- return float(x)
136
- if isinstance(x, np.ndarray) and x.size == 1:
137
- return float(x.item())
138
- if isinstance(x, torch.Tensor) and x.numel() == 1:
139
- return float(x.item())
140
- return None
141
-
142
-
143
- def get_norm_tensors_and_summary(vae, latent_like: torch.Tensor):
144
- """
145
- Нормализация латентов: глобальная и поканальная.
146
- Применение: сначала глобальная (scalar), затем поканальная (vector).
147
- Если в конфиге есть несколько ключей — аккумулируем.
148
- """
149
- cfg = getattr(vae, "config", vae)
150
-
151
- scale_keys = [
152
- "latents_std"
153
- ]
154
- shift_keys = [
155
- "latents_mean"
156
- ]
157
-
158
- C = latent_like.shape[1]
159
- nd = latent_like.ndim # 4 или 5
160
- dev = latent_like.device
161
- dt = latent_like.dtype
162
-
163
- scale_global = getattr(vae.config, "scaling_factor", 1.0)
164
- shift_global = getattr(vae.config, "shift_factor", 0.0)
165
- if scale_global is None:
166
- scale_global = 1.0
167
- if shift_global is None:
168
- shift_global = 0.0
169
-
170
- scale_channel = np.ones(C, dtype=np.float32)
171
- shift_channel = np.zeros(C, dtype=np.float32)
172
-
173
- for k in scale_keys:
174
- v = getattr(cfg, k, None)
175
- if v is None:
176
- continue
177
- vec = _to_numpy_1d(v)
178
- if vec is not None and vec.size == C:
179
- scale_channel *= vec
180
- else:
181
- s = _to_float(v)
182
- if s is not None:
183
- scale_global *= s
184
-
185
- for k in shift_keys:
186
- v = getattr(cfg, k, None)
187
- if v is None:
188
- continue
189
- vec = _to_numpy_1d(v)
190
- if vec is not None and vec.size == C:
191
- shift_channel += vec
192
- else:
193
- s = _to_float(v)
194
- if s is not None:
195
- shift_global += s
196
-
197
- g_shape = [1] * nd
198
- c_shape = [1] * nd
199
- c_shape[1] = C
200
-
201
- t_scale_g = torch.tensor(scale_global, dtype=dt, device=dev).view(*g_shape)
202
- t_shift_g = torch.tensor(shift_global, dtype=dt, device=dev).view(*g_shape)
203
- t_scale_c = torch.from_numpy(scale_channel).to(device=dev, dtype=dt).view(*c_shape)
204
- t_shift_c = torch.from_numpy(shift_channel).to(device=dev, dtype=dt).view(*c_shape)
205
-
206
- summary = {
207
- "scale_global": float(scale_global),
208
- "shift_global": float(shift_global),
209
- "scale_channel_min": float(scale_channel.min()),
210
- "scale_channel_mean": float(scale_channel.mean()),
211
- "scale_channel_max": float(scale_channel.max()),
212
- "shift_channel_min": float(shift_channel.min()),
213
- "shift_channel_mean": float(shift_channel.mean()),
214
- "shift_channel_max": float(shift_channel.max()),
215
- }
216
- return t_shift_g, t_scale_g, t_shift_c, t_scale_c, summary
217
-
218
-
219
- @torch.no_grad()
220
- def kl_divergence_per_image(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
221
- kl_map = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) # [B, ...]
222
- return kl_map.float().view(kl_map.shape[0], -1).mean(dim=1) # [B]
223
-
224
-
225
- def sanitize_filename(name: str) -> str:
226
- name = name.replace("/", "_").replace("\\", "_").replace(" ", "_")
227
- return "".join(ch if (ch.isalnum() or ch in "._-") else "_" for ch in name)
228
-
229
-
230
- # ========================== Датасет ==========================
231
- class ImageFolderDataset(Dataset):
232
- def __init__(self, root_dir: str, extensions=(".png", ".jpg", ".jpeg", ".webp"), min_size=1024, crop_size=512, limit=None):
233
- paths = []
234
- for root, _, files in os.walk(root_dir):
235
- for fname in files:
236
- if fname.lower().endswith(extensions):
237
- paths.append(os.path.join(root, fname))
238
- if limit:
239
- paths = paths[:limit]
240
-
241
- valid = []
242
- for p in tqdm(paths, desc="Проверяем файлы"):
243
- try:
244
- with Image.open(p) as im:
245
- im.verify()
246
- valid.append(p)
247
- except Exception:
248
- pass
249
- if not valid:
250
- raise RuntimeError(f"Нет валидных изображений в {root_dir}")
251
- random.shuffle(valid)
252
- self.paths = valid
253
- print(f"Найдено {len(self.paths)} изображений")
254
-
255
- self.transform = Compose([
256
- Resize(min_size),
257
- CenterCrop(crop_size),
258
- ToTensor(), # 0..1, float32
259
- ])
260
-
261
- def __len__(self):
262
- return len(self.paths)
263
-
264
- def __getitem__(self, idx):
265
- with Image.open(self.paths[idx]) as img:
266
- img = img.convert("RGB")
267
- return self.transform(img)
268
-
269
-
270
- # ========================== Основное ==========================
271
- def main():
272
- torch.set_grad_enabled(False)
273
- os.makedirs(SAMPLES_DIR, exist_ok=True)
274
-
275
- dataset = ImageFolderDataset(IMAGE_FOLDER, min_size=MIN_SIZE, crop_size=CROP_SIZE, limit=MAX_IMAGES)
276
- loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
277
-
278
- lpips_net = lpips.LPIPS(net="vgg").to(DEVICE).eval()
279
-
280
- # Загрузка VAE
281
- vaes: List[Tuple[str, object]] = []
282
- print("\nЗагрузка VAE...")
283
- for human_name, vae_class, model_path, subfolder in VAE_LIST:
284
- try:
285
- vae = vae_class.from_pretrained(model_path, subfolder=subfolder, torch_dtype=DTYPE)
286
- vae = vae.to(DEVICE).eval()
287
- vaes.append((human_name, vae))
288
- print(f" ✅ {human_name}")
289
- except Exception as e:
290
- print(f" ❌ {human_name}: {e}")
291
-
292
- if not vaes:
293
- print("Нет успешно загруженных VAE. Выходим.")
294
- return
295
-
296
- # Агрегаторы
297
- per_model_metrics: Dict[str, Dict[str, float]] = {
298
- name: {"mse": 0.0, "psnr": 0.0, "lpips": 0.0, "edge": 0.0, "kl": 0.0, "count": 0.0}
299
- for name, _ in vaes
300
- }
301
-
302
- buffers_zmodel: Dict[str, List[torch.Tensor]] = {name: [] for name, _ in vaes}
303
- norm_summaries: Dict[str, Dict[str, float]] = {}
304
-
305
- # Флаг для сохранения первой картинки
306
- saved_first_for: Dict[str, bool] = {name: False for name, _ in vaes}
307
-
308
- for batch_0_1 in tqdm(loader, desc="Батчи"):
309
- batch_0_1 = batch_0_1.to(DEVICE, torch.float32)
310
- batch_neg1_1 = to_neg1_1(batch_0_1).to(DTYPE)
311
-
312
- for model_name, vae in vaes:
313
- x_in = add_time_dim_if_needed(batch_neg1_1, vae)
314
-
315
- posterior = vae.encode(x_in).latent_dist
316
- mu, logvar = posterior.mean, posterior.logvar
317
-
318
- # Реконструкция (детерминированно)
319
- z_raw_mode = posterior.mode()
320
- x_dec = vae.decode(z_raw_mode).sample # [-1, 1]
321
- x_dec = strip_time_dim_if_possible(x_dec, vae)
322
- x_rec_0_1 = to_0_1(x_dec.float()).clamp(0, 1)
323
-
324
- # Латенты для UNet: global -> channelwise
325
- z_raw_sample = posterior.sample()
326
- t_shift_g, t_scale_g, t_shift_c, t_scale_c, summary = get_norm_tensors_and_summary(vae, z_raw_sample)
327
-
328
- if model_name not in norm_summaries:
329
- norm_summaries[model_name] = summary
330
-
331
- z_tmp = (z_raw_sample - t_shift_g) * t_scale_g
332
- z_model = (z_tmp - t_shift_c) * t_scale_c
333
- z_model = strip_time_dim_if_possible(z_model, vae)
334
-
335
- buffers_zmodel[model_name].append(z_model.detach().to("cpu", torch.float32))
336
-
337
- # Сохранить первую картинку (оригинал и реконструкцию) для каждого VAE
338
- if not saved_first_for[model_name]:
339
- safe = sanitize_filename(model_name)
340
- orig_path = os.path.join(SAMPLES_DIR, f"{safe}_original.png")
341
- dec_path = os.path.join(SAMPLES_DIR, f"{safe}_decoded.png")
342
- save_image(batch_0_1[0:1].cpu(), orig_path)
343
- save_image(x_rec_0_1[0:1].cpu(), dec_path)
344
- saved_first_for[model_name] = True
345
-
346
- # Метрики по картинкам
347
- B = batch_0_1.shape[0]
348
- for i in range(B):
349
- gt = batch_0_1[i:i+1]
350
- rec = x_rec_0_1[i:i+1]
351
-
352
- mse = F.mse_loss(gt, rec).item()
353
- psnr = safe_psnr(mse)
354
- lp = float(lpips_net(gt, rec, normalize=True).mean().item())
355
- edge = sobel_edge_l1(gt, rec)
356
-
357
- per_model_metrics[model_name]["mse"] += mse
358
- per_model_metrics[model_name]["psnr"] += psnr
359
- per_model_metrics[model_name]["lpips"] += lp
360
- per_model_metrics[model_name]["edge"] += edge
361
-
362
- # KL per-image
363
- kl_pi = kl_divergence_per_image(mu, logvar) # [B]
364
- per_model_metrics[model_name]["kl"] += float(kl_pi.sum().item())
365
- per_model_metrics[model_name]["count"] += B
366
-
367
- # Усреднение метрик
368
- for name in per_model_metrics:
369
- c = max(1.0, per_model_metrics[name]["count"])
370
- for k in ["mse", "psnr", "lpips", "edge", "kl"]:
371
- per_model_metrics[name][k] /= c
372
-
373
- # Подсчёт статистик латентов и нормальности
374
- per_model_latent_stats = {}
375
- for name, _ in vaes:
376
- if not buffers_zmodel[name]:
377
- continue
378
- Z = torch.cat(buffers_zmodel[name], dim=0) # [N, C, H, W]
379
-
380
- # Глобальные
381
- z_min = float(Z.min().item())
382
- z_mean = float(Z.mean().item())
383
- z_max = float(Z.max().item())
384
- z_std = float(Z.std(unbiased=True).item())
385
-
386
- # Пер-канально: skew/kurtosis
387
- Z_ch = flatten_channels(Z).numpy() # [C, *]
388
- C = Z_ch.shape[0]
389
- sk = np.zeros(C, dtype=np.float64)
390
- ku = np.zeros(C, dtype=np.float64)
391
- for c in range(C):
392
- v = Z_ch[c]
393
- sk[c] = float(skew(v, bias=False))
394
- ku[c] = float(kurtosis(v, fisher=True, bias=False))
395
-
396
- skew_min, skew_mean, skew_max = float(sk.min()), float(sk.mean()), float(sk.max())
397
- kurt_min, kurt_mean, kurt_max = float(ku.min()), float(ku.mean()), float(ku.max())
398
- mean_abs_skew = float(np.mean(np.abs(sk)))
399
- mean_abs_kurt = float(np.mean(np.abs(ku)))
400
-
401
- per_model_latent_stats[name] = {
402
- "Z_min": z_min, "Z_mean": z_mean, "Z_max": z_max, "Z_std": z_std,
403
- "skew_min": skew_min, "skew_mean": skew_mean, "skew_max": skew_max,
404
- "kurt_min": kurt_min, "kurt_mean": kurt_mean, "kurt_max": kurt_max,
405
- "mean_abs_skew": mean_abs_skew, "mean_abs_kurt": mean_abs_kurt,
406
- }
407
-
408
- # Печать параметров нормализации (shift/scale)
409
- print("\n=== Параметры нормализации латентов (как применялись) ===")
410
- for name, _ in vaes:
411
- if name not in norm_summaries:
412
- continue
413
- s = norm_summaries[name]
414
- print(
415
- f"{name:26s} | "
416
- f"shift_g={s['shift_global']:.6g} scale_g={s['scale_global']:.6g} | "
417
- f"shift_c[min/mean/max]=[{s['shift_channel_min']:.6g}, {s['shift_channel_mean']:.6g}, {s['shift_channel_max']:.6g}] | "
418
- f"scale_c[min/mean/max]=[{s['scale_channel_min']:.6g}, {s['scale_channel_mean']:.6g}, {s['scale_channel_max']:.6g}]"
419
- )
420
-
421
- # Абсолютные метрики
422
- print("\n=== Абсолютные метрики реконструкции и латентов ===")
423
- for name, _ in vaes:
424
- if name not in per_model_latent_stats:
425
- continue
426
- m = per_model_metrics[name]
427
- s = per_model_latent_stats[name]
428
- print(
429
- f"{name:26s} | "
430
- f"MSE={m['mse']:.3e} PSNR={m['psnr']:.2f} LPIPS={m['lpips']:.3f} Edge={m['edge']:.3f} KL={m['kl']:.3f} | "
431
- f"Z[min/mean/max/std]=[{s['Z_min']:.3f}, {s['Z_mean']:.3f}, {s['Z_max']:.3f}, {s['Z_std']:.3f}] | "
432
- f"Skew[min/mean/max]=[{s['skew_min']:.3f}, {s['skew_mean']:.3f}, {s['skew_max']:.3f}] | "
433
- f"Kurt[min/mean/max]=[{s['kurt_min']:.3f}, {s['kurt_mean']:.3f}, {s['kurt_max']:.3f}]"
434
- )
435
-
436
- # Сравнение с первой моделью
437
- baseline = vaes[0][0]
438
- print("\n=== Сравнение с первой моделью (проценты) ===")
439
- print(f"| {'Модель':26s} | {'MSE':>9s} | {'PSNR':>9s} | {'LPIPS':>9s} | {'Edge':>9s} | {'Skew|0':>9s} | {'Kurt|0':>9s} |")
440
- print(f"|{'-'*28}|{'-'*11}|{'-'*11}|{'-'*11}|{'-'*11}|{'-'*11}|{'-'*11}|")
441
-
442
- b_m = per_model_metrics[baseline]
443
- b_s = per_model_latent_stats[baseline]
444
-
445
- for name, _ in vaes:
446
- m = per_model_metrics[name]
447
- s = per_model_latent_stats[name]
448
-
449
- mse_pct = (b_m["mse"] / max(1e-12, m["mse"])) * 100.0 # меньше лучше
450
- psnr_pct = (m["psnr"] / max(1e-12, b_m["psnr"])) * 100.0 # больше лучше
451
- lpips_pct= (b_m["lpips"] / max(1e-12, m["lpips"])) * 100.0 # меньше лучше
452
- edge_pct = (b_m["edge"] / max(1e-12, m["edge"])) * 100.0 # меньше лучше
453
-
454
- skew0_pct = (b_s["mean_abs_skew"] / max(1e-12, s["mean_abs_skew"])) * 100.0
455
- kurt0_pct = (b_s["mean_abs_kurt"] / max(1e-12, s["mean_abs_kurt"])) * 100.0
456
-
457
- if name == baseline:
458
- print(f"| {name:26s} | {'100%':>9s} | {'100%':>9s} | {'100%':>9s} | {'100%':>9s} | {'100%':>9s} | {'100%':>9s} |")
459
- else:
460
- print(f"| {name:26s} | {mse_pct:8.1f}% | {psnr_pct:8.1f}% | {lpips_pct:8.1f}% | {edge_pct:8.1f}% | {skew0_pct:8.1f}% | {kurt0_pct:8.1f}% |")
461
-
462
- # ========================== Коррекции для последнего VAE + сохранение в JSON ==========================
463
- last_name = vaes[-1][0]
464
- if buffers_zmodel[last_name]:
465
- Z = torch.cat(buffers_zmodel[last_name], dim=0) # [N, C, H, W]
466
-
467
- # Глобальная коррекция (по всем каналам/пикселям)
468
- z_mean = float(Z.mean().item())
469
- z_std = float(Z.std(unbiased=True).item())
470
- correction_global = {
471
- "shift": -z_mean,
472
- "scale": (1.0 / z_std) if z_std > 1e-12 else 1.0
473
- }
474
-
475
- # Поканальная коррекция
476
- Z_ch = flatten_channels(Z) # [C, M]
477
- ch_means_t = Z_ch.mean(dim=1) # [C]
478
- ch_stds_t = Z_ch.std(dim=1, unbiased=True) + 1e-12 # [C]
479
- ch_means = [float(x) for x in ch_means_t.tolist()]
480
- ch_stds = [float(x) for x in ch_stds_t.tolist()]
481
-
482
- correction_per_channel = [
483
- {"shift": float(-m), "scale": float(1.0 / s)}
484
- for m, s in zip(ch_means, ch_stds)
485
- ]
486
-
487
- print(f"\n=== Доп. коррекция для {last_name} (поверх VAE-нормализации) ===")
488
- print(f"global_correction = {correction_global}")
489
- print(f"channelwise_means = {ch_means}")
490
- print(f"channelwise_stds = {ch_stds}")
491
- print(f"channelwise_correction = {correction_per_channel}")
492
-
493
- # Сохранение в JSON
494
- json_path = os.path.join(SAMPLES_DIR, f"{sanitize_filename(last_name)}_correction.json")
495
- to_save = {
496
- "model_name": last_name,
497
- "vae_normalization_summary": norm_summaries.get(last_name, {}),
498
- "global_correction": correction_global,
499
- "per_channel_means": ch_means,
500
- "per_channel_stds": ch_stds,
501
- "per_channel_correction": correction_per_channel,
502
- "apply_order": {
503
- "forward": "z_model -> (z - global_shift)*global_scale -> (per-channel: (z - mean_c)/std_c)",
504
- "inverse": "z_corr -> (per-channel: z*std_c + mean_c) -> (z/global_scale + global_shift)"
505
- },
506
- "note": "Эти коэффициенты рассчитаны по z_model (после встроенных VAE shift/scale), чтобы привести распределение к N(0,1)."
507
- }
508
- with open(json_path, "w", encoding="utf-8") as f:
509
- json.dump(to_save, f, ensure_ascii=False, indent=2)
510
- print("Corrections JSON saved to:", os.path.abspath(json_path))
511
-
512
- print("\n✅ Готово. Сэмплы сохранены в:", os.path.abspath(SAMPLES_DIR))
513
-
514
-
515
- if __name__ == "__main__":
516
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sdxl_vae_a1111.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ebe642d26e14851e98eb3d06575009e0d1a669704a1c9c8dcf06573d82233a21
3
  size 334640988
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f8e696f579f70d185f4b944a0d821ab5578a0915ac079fe44c148ce5102cc5b
3
  size 334640988
simple_vae/config.json DELETED
@@ -1,38 +0,0 @@
1
- {
2
- "_class_name": "AutoencoderKL",
3
- "_diffusers_version": "0.35.0.dev0",
4
- "_name_or_path": "simple_vae",
5
- "act_fn": "silu",
6
- "block_out_channels": [
7
- 128,
8
- 256,
9
- 512,
10
- 512
11
- ],
12
- "down_block_types": [
13
- "DownEncoderBlock2D",
14
- "DownEncoderBlock2D",
15
- "DownEncoderBlock2D",
16
- "DownEncoderBlock2D"
17
- ],
18
- "force_upcast": false,
19
- "in_channels": 3,
20
- "latent_channels": 16,
21
- "latents_mean": null,
22
- "latents_std": null,
23
- "layers_per_block": 2,
24
- "mid_block_add_attention": true,
25
- "norm_num_groups": 32,
26
- "out_channels": 3,
27
- "sample_size": 1024,
28
- "scaling_factor": 1.0,
29
- "shift_factor": 0,
30
- "up_block_types": [
31
- "UpDecoderBlock2D",
32
- "UpDecoderBlock2D",
33
- "UpDecoderBlock2D",
34
- "UpDecoderBlock2D"
35
- ],
36
- "use_post_quant_conv": true,
37
- "use_quant_conv": true
38
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
simple_vae/diffusion_pytorch_model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8ba1d500c4bd376a7c8662a35fa1857c7e577da0635414b524180852143ef2f6
3
- size 335311892
 
 
 
 
simple_vae_nightly/config.json DELETED
@@ -1,38 +0,0 @@
1
- {
2
- "_class_name": "AutoencoderKL",
3
- "_diffusers_version": "0.35.0.dev0",
4
- "_name_or_path": "simple_vae",
5
- "act_fn": "silu",
6
- "block_out_channels": [
7
- 128,
8
- 256,
9
- 512,
10
- 512
11
- ],
12
- "down_block_types": [
13
- "DownEncoderBlock2D",
14
- "DownEncoderBlock2D",
15
- "DownEncoderBlock2D",
16
- "DownEncoderBlock2D"
17
- ],
18
- "force_upcast": false,
19
- "in_channels": 3,
20
- "latent_channels": 16,
21
- "latents_mean": null,
22
- "latents_std": null,
23
- "layers_per_block": 2,
24
- "mid_block_add_attention": true,
25
- "norm_num_groups": 32,
26
- "out_channels": 3,
27
- "sample_size": 1024,
28
- "scaling_factor": 1.0,
29
- "shift_factor": 0,
30
- "up_block_types": [
31
- "UpDecoderBlock2D",
32
- "UpDecoderBlock2D",
33
- "UpDecoderBlock2D",
34
- "UpDecoderBlock2D"
35
- ],
36
- "use_post_quant_conv": true,
37
- "use_quant_conv": true
38
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
simple_vae_nightly/diffusion_pytorch_model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b39620d0953839362425f03674e6c3e37f03d20be3fbd7f281baea4dfc336a40
3
- size 335311892
 
 
 
 
train_sdxl_vae_wan.py → src/train_sdxl_vae.py RENAMED
File without changes
train_sdxl_vae.py DELETED
@@ -1,547 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- import os
3
- import math
4
- import re
5
- import torch
6
- import numpy as np
7
- import random
8
- import gc
9
- from datetime import datetime
10
- from pathlib import Path
11
-
12
- import torchvision.transforms as transforms
13
- import torch.nn.functional as F
14
- from torch.utils.data import DataLoader, Dataset
15
- from torch.optim.lr_scheduler import LambdaLR
16
- from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
17
- from accelerate import Accelerator
18
- from PIL import Image, UnidentifiedImageError
19
- from tqdm import tqdm
20
- import bitsandbytes as bnb
21
- import wandb
22
- import lpips # pip install lpips
23
- from collections import deque
24
-
25
- # --------------------------- Параметры ---------------------------
26
- ds_path = "/workspace/png"
27
- project = "simple_vae"
28
- batch_size = 3
29
- base_learning_rate = 5e-5
30
- min_learning_rate = 9e-7
31
- num_epochs = 16
32
- sample_interval_share = 10
33
- use_wandb = True
34
- save_model = True
35
- use_decay = True
36
- asymmetric = False
37
- optimizer_type = "adam8bit"
38
- dtype = torch.float32
39
- # model_resolution — то, что подавается в VAE (низкое разрешение)
40
- model_resolution = 512 # бывший `resolution`
41
- # high_resolution — настоящий «высокий» кроп, на котором считаем метрики и сохраняем сэмплы
42
- high_resolution = 512
43
- limit = 0
44
- save_barrier = 1.03
45
- warmup_percent = 0.01
46
- percentile_clipping = 95
47
- beta2 = 0.97
48
- eps = 1e-6
49
- clip_grad_norm = 1.0
50
- mixed_precision = "no" # или "fp16"/"bf16" при поддержке
51
- gradient_accumulation_steps = 5
52
- generated_folder = "samples"
53
- save_as = "simple_vae_nightly"
54
- num_workers = 0
55
- device = None # accelerator задаст устройство
56
-
57
- # --- Пропорции лоссов и окно медианного нормирования (КОЭФ., не значения) ---
58
- # Итоговые доли в total loss (сумма = 1.0)
59
- loss_ratios = {
60
- "lpips": 0.85,
61
- "edge": 0.05,
62
- "mse": 0.05,
63
- "mae": 0.05,
64
- }
65
- median_coeff_steps = 256 # за сколько шагов считать медианные коэффициенты
66
-
67
- # --------------------------- параметры препроцессинга ---------------------------
68
- resize_long_side = 1280 # если None или 0 — ресайза не будет; рекомендовано 1280
69
-
70
- Path(generated_folder).mkdir(parents=True, exist_ok=True)
71
-
72
- accelerator = Accelerator(
73
- mixed_precision=mixed_precision,
74
- gradient_accumulation_steps=gradient_accumulation_steps
75
- )
76
- device = accelerator.device
77
-
78
- # reproducibility
79
- seed = int(datetime.now().strftime("%Y%m%d"))
80
- torch.manual_seed(seed)
81
- np.random.seed(seed)
82
- random.seed(seed)
83
-
84
- torch.backends.cudnn.benchmark = False
85
-
86
- # --------------------------- WandB ---------------------------
87
- if use_wandb and accelerator.is_main_process:
88
- wandb.init(project=project, config={
89
- "batch_size": batch_size,
90
- "base_learning_rate": base_learning_rate,
91
- "num_epochs": num_epochs,
92
- "optimizer_type": optimizer_type,
93
- "model_resolution": model_resolution,
94
- "high_resolution": high_resolution,
95
- "gradient_accumulation_steps": gradient_accumulation_steps,
96
- })
97
-
98
- # --------------------------- VAE ---------------------------
99
- if model_resolution==high_resolution and not asymmetric:
100
- vae = AutoencoderKL.from_pretrained(project).to(dtype)
101
- else:
102
- vae = AsymmetricAutoencoderKL.from_pretrained(project).to(dtype)
103
-
104
- # torch.compile (если доступно) — просто и без лишней логики
105
- if hasattr(torch, "compile"):
106
- try:
107
- vae = torch.compile(vae)
108
- except Exception as e:
109
- print(f"[WARN] torch.compile failed: {e}")
110
-
111
- # >>> Заморозка всех параметров, затем выборочная разморозка
112
- for p in vae.parameters():
113
- p.requires_grad = False
114
-
115
- decoder = getattr(vae, "decoder", None)
116
- if decoder is None:
117
- raise RuntimeError("vae.decoder not found — не могу применить стратегию разморозки. Проверь структуру модели.")
118
-
119
- unfrozen_param_names = []
120
-
121
- if not hasattr(decoder, "up_blocks"):
122
- raise RuntimeError("decoder.up_blocks не найдены — ожидается список блоков декодера.")
123
-
124
- # >>> Размораживаем все up_blocks и mid_block (как было в твоём варианте start_idx=0)
125
- n_up = len(decoder.up_blocks)
126
- start_idx = 0
127
- for idx in range(start_idx, n_up):
128
- block = decoder.up_blocks[idx]
129
- for name, p in block.named_parameters():
130
- p.requires_grad = True
131
- unfrozen_param_names.append(f"decoder.up_blocks.{idx}.{name}")
132
-
133
- if hasattr(decoder, "mid_block"):
134
- for name, p in decoder.mid_block.named_parameters():
135
- p.requires_grad = True
136
- unfrozen_param_names.append(f"decoder.mid_block.{name}")
137
- else:
138
- print("[WARN] decoder.mid_block не найден — mid_block не разморожен.")
139
-
140
- print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:")
141
- for nm in unfrozen_param_names[:200]:
142
- print(" ", nm)
143
-
144
- # сохраняем trainable_module (get_param_groups будет учитывать p.requires_grad)
145
- trainable_module = vae.decoder
146
-
147
- # --------------------------- Custom PNG Dataset (only .png, skip corrupted) -----------
148
- class PngFolderDataset(Dataset):
149
- def __init__(self, root_dir, min_exts=('.png',), resolution=1024, limit=0):
150
- self.root_dir = root_dir
151
- self.resolution = resolution
152
- self.paths = []
153
- # collect png files recursively
154
- for root, _, files in os.walk(root_dir):
155
- for fname in files:
156
- if fname.lower().endswith(tuple(ext.lower() for ext in min_exts)):
157
- self.paths.append(os.path.join(root, fname))
158
- # optional limit
159
- if limit:
160
- self.paths = self.paths[:limit]
161
- # verify images and keep only valid ones
162
- valid = []
163
- for p in self.paths:
164
- try:
165
- with Image.open(p) as im:
166
- im.verify() # fast check for truncated/corrupted images
167
- valid.append(p)
168
- except (OSError, UnidentifiedImageError):
169
- # skip corrupted image
170
- continue
171
- self.paths = valid
172
- if len(self.paths) == 0:
173
- raise RuntimeError(f"No valid PNG images found under {root_dir}")
174
- # final shuffle for randomness
175
- random.shuffle(self.paths)
176
-
177
- def __len__(self):
178
- return len(self.paths)
179
-
180
- def __getitem__(self, idx):
181
- p = self.paths[idx % len(self.paths)]
182
- # open and convert to RGB; ensure file is closed promptly
183
- with Image.open(p) as img:
184
- img = img.convert("RGB")
185
- # пережимаем длинную сторону до resize_long_side (Lanczos)
186
- if not resize_long_side or resize_long_side <= 0:
187
- return img
188
- w, h = img.size
189
- long = max(w, h)
190
- if long <= resize_long_side:
191
- return img
192
- scale = resize_long_side / float(long)
193
- new_w = int(round(w * scale))
194
- new_h = int(round(h * scale))
195
- return img.resize((new_w, new_h), Image.LANCZOS)
196
-
197
- # --------------------------- Датасет и трансформы ---------------------------
198
-
199
- def random_crop(img, sz):
200
- w, h = img.size
201
- if w < sz or h < sz:
202
- img = img.resize((max(sz, w), max(sz, h)), Image.LANCZOS)
203
- x = random.randint(0, max(1, img.width - sz))
204
- y = random.randint(0, max(1, img.height - sz))
205
- return img.crop((x, y, x + sz, y + sz))
206
-
207
- tfm = transforms.Compose([
208
- transforms.ToTensor(),
209
- transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
210
- ])
211
-
212
- # build dataset using high_resolution crops
213
- dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit)
214
- if len(dataset) < batch_size:
215
- raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}")
216
-
217
- # collate_fn кропит до high_resolution
218
-
219
- def collate_fn(batch):
220
- imgs = []
221
- for img in batch: # img is PIL.Image
222
- img = random_crop(img, high_resolution) # кропим high-res
223
- imgs.append(tfm(img))
224
- return torch.stack(imgs)
225
-
226
- dataloader = DataLoader(
227
- dataset,
228
- batch_size=batch_size,
229
- shuffle=True,
230
- collate_fn=collate_fn,
231
- num_workers=num_workers,
232
- pin_memory=True,
233
- drop_last=True
234
- )
235
-
236
- # --------------------------- Оптимизатор ---------------------------
237
-
238
- def get_param_groups(module, weight_decay=0.001):
239
- no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "ln_1.weight", "ln_f.weight"]
240
- decay_params = []
241
- no_decay_params = []
242
- for n, p in module.named_parameters():
243
- if not p.requires_grad:
244
- continue
245
- if any(nd in n for nd in no_decay):
246
- no_decay_params.append(p)
247
- else:
248
- decay_params.append(p)
249
- return [
250
- {"params": decay_params, "weight_decay": weight_decay},
251
- {"params": no_decay_params, "weight_decay": 0.0},
252
- ]
253
-
254
- def create_optimizer(name, param_groups):
255
- if name == "adam8bit":
256
- return bnb.optim.AdamW8bit(
257
- param_groups, lr=base_learning_rate, betas=(0.9, beta2), eps=eps
258
- )
259
- raise ValueError(name)
260
-
261
- param_groups = get_param_groups(trainable_module, weight_decay=0.001)
262
- optimizer = create_optimizer(optimizer_type, param_groups)
263
-
264
- # --------------------------- Подготовка Accelerate (вместе) ---------------------------
265
-
266
- batches_per_epoch = len(dataloader) # число микро-батчей (dataloader steps)
267
- steps_per_epoch = int(math.ceil(batches_per_epoch / float(gradient_accumulation_steps))) # чис��о optimizer.step() за эпоху
268
- total_steps = steps_per_epoch * num_epochs
269
-
270
-
271
- def lr_lambda(step):
272
- if not use_decay:
273
- return 1.0
274
- x = float(step) / float(max(1, total_steps))
275
- warmup = float(warmup_percent)
276
- min_ratio = float(min_learning_rate) / float(base_learning_rate)
277
- if x < warmup:
278
- return min_ratio + (1.0 - min_ratio) * (x / warmup)
279
- decay_ratio = (x - warmup) / (1.0 - warmup)
280
- return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * decay_ratio))
281
-
282
- scheduler = LambdaLR(optimizer, lr_lambda)
283
-
284
- # Подготовка
285
- dataloader, vae, optimizer, scheduler = accelerator.prepare(dataloader, vae, optimizer, scheduler)
286
-
287
- trainable_params = [p for p in vae.decoder.parameters() if p.requires_grad]
288
-
289
- # --------------------------- LPIPS и вспомогательные функции ---------------------------
290
- _lpips_net = None
291
-
292
- def _get_lpips():
293
- global _lpips_net
294
- if _lpips_net is None:
295
- _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
296
- return _lpips_net
297
-
298
- # Собель для edge loss
299
- _sobel_kx = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], dtype=torch.float32)
300
- _sobel_ky = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], dtype=torch.float32)
301
-
302
- def sobel_edges(x: torch.Tensor) -> torch.Tensor:
303
- # x: [B,C,H,W] в [-1,1]
304
- C = x.shape[1]
305
- kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1)
306
- ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1)
307
- gx = F.conv2d(x, kx, padding=1, groups=C)
308
- gy = F.conv2d(x, ky, padding=1, groups=C)
309
- return torch.sqrt(gx * gx + gy * gy + 1e-12)
310
-
311
- # Нормализация лоссов по медианам: считаем КОЭФФИЦИЕНТЫ
312
- class MedianLossNormalizer:
313
- def __init__(self, desired_ratios: dict, window_steps: int):
314
- # нормируем доли на случай, если сумма != 1
315
- s = sum(desired_ratios.values())
316
- self.ratios = {k: (v / s) for k, v in desired_ratios.items()}
317
- self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
318
- self.window = window_steps
319
-
320
- def update_and_total(self, abs_losses: dict):
321
- # Заполняем буферы фактическими АБСОЛЮТНЫМИ значениями лоссов
322
- for k, v in abs_losses.items():
323
- if k in self.buffers:
324
- self.buffers[k].append(float(v.detach().cpu()))
325
- # Медианы (устойчивые к выбросам)
326
- meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
327
- # Вычисляем КОЭФФИЦИЕНТЫ как ratio_k / median_k — т.е. именно коэффициенты, а не значения
328
- coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
329
- # Важно: при таких коэффициентах сумма (coeff_k * median_k) = сумма(ratio_k) = 1, т.е. масштаб стабилен
330
- total = sum(coeffs[k] * abs_losses[k] for k in coeffs)
331
- return total, coeffs, meds
332
-
333
- normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
334
-
335
- # --------------------------- Сэмплы ---------------------------
336
- @torch.no_grad()
337
- def get_fixed_samples(n=3):
338
- idx = random.sample(range(len(dataset)), min(n, len(dataset)))
339
- pil_imgs = [dataset[i] for i in idx] # dataset returns PIL.Image
340
- tensors = []
341
- for img in pil_imgs:
342
- img = random_crop(img, high_resolution) # high-res fixed samples
343
- tensors.append(tfm(img))
344
- return torch.stack(tensors).to(accelerator.device, dtype)
345
-
346
- fixed_samples = get_fixed_samples()
347
-
348
- @torch.no_grad()
349
- def _to_pil_uint8(img_tensor: torch.Tensor) -> Image.Image:
350
- # img_tensor: [C,H,W] in [-1,1]
351
- arr = ((img_tensor.float().clamp(-1, 1) + 1.0) * 127.5).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0)
352
- return Image.fromarray(arr)
353
-
354
- @torch.no_grad()
355
- def generate_and_save_samples(step=None):
356
- try:
357
- temp_vae = accelerator.unwrap_model(vae).eval()
358
- lpips_net = _get_lpips()
359
- with torch.no_grad():
360
- # Готовим low-res вход для кодера ВСЕГДА под model_resolution
361
- orig_high = fixed_samples # [B,C,H,W] в [-1,1]
362
- orig_low = F.interpolate(orig_high, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
363
- # dtype как у модели
364
- model_dtype = next(temp_vae.parameters()).dtype
365
- orig_low = orig_low.to(dtype=model_dtype)
366
- # encode/decode
367
- latents = temp_vae.encode(orig_low).latent_dist.mean
368
- rec = temp_vae.decode(latents).sample
369
-
370
- # Приводим spatial размер рекона к high-res (downsample для асимметричных VAE)
371
- if rec.shape[-2:] != orig_high.shape[-2:]:
372
- rec = F.interpolate(rec, size=orig_high.shape[-2:], mode="bilinear", align_corners=False)
373
-
374
- # Сохраняем ПЕРВЫЙ семпл: real и decoded без номера шага в имени
375
- first_real = _to_pil_uint8(orig_high[0])
376
- first_dec = _to_pil_uint8(rec[0])
377
- first_real.save(f"{generated_folder}/sample_real.jpg", quality=95)
378
- first_dec.save(f"{generated_folder}/sample_decoded.jpg", quality=95)
379
-
380
- # Дополнительно сохраняем текущие реконструкции без номера шага (чтобы не плодить файлы — будут перезаписываться)
381
- for i in range(rec.shape[0]):
382
- _to_pil_uint8(rec[i]).save(f"{generated_folder}/sample_{i}.jpg", quality=95)
383
-
384
- # LPIPS на полном изображении (high-res) — для лога
385
- lpips_scores = []
386
- for i in range(rec.shape[0]):
387
- orig_full = orig_high[i:i+1].to(torch.float32)
388
- rec_full = rec[i:i+1].to(torch.float32)
389
- if rec_full.shape[-2:] != orig_full.shape[-2:]:
390
- rec_full = F.interpolate(rec_full, size=orig_full.shape[-2:], mode="bilinear", align_corners=False)
391
- lpips_val = lpips_net(orig_full, rec_full).item()
392
- lpips_scores.append(lpips_val)
393
- avg_lpips = float(np.mean(lpips_scores))
394
-
395
- if use_wandb and accelerator.is_main_process:
396
- wandb.log({
397
- "lpips_mean": avg_lpips,
398
- }, step=step)
399
- finally:
400
- gc.collect()
401
- torch.cuda.empty_cache()
402
-
403
- if accelerator.is_main_process and save_model:
404
- print("Генерация сэмплов до старта обучения...")
405
- generate_and_save_samples(0)
406
-
407
- accelerator.wait_for_everyone()
408
-
409
- # --------------------------- Тренировка ---------------------------
410
-
411
- progress = tqdm(total=total_steps, disable=not accelerator.is_local_main_process)
412
- global_step = 0
413
- min_loss = float("inf")
414
- sample_interval = max(1, total_steps // max(1, sample_interval_share * num_epochs))
415
-
416
- for epoch in range(num_epochs):
417
- vae.train()
418
- batch_losses = []
419
- batch_grads = []
420
- # Доп. трекинг по отдельным лоссам
421
- track_losses = {k: [] for k in loss_ratios.keys()}
422
- for imgs in dataloader:
423
- with accelerator.accumulate(vae):
424
- # imgs: high-res tensor from dataloader ([-1,1]), move to device
425
- imgs = imgs.to(accelerator.device)
426
-
427
- # ВСЕГДА даунсемплим вход под model_resolution для кодера
428
- # Тупая железяка норовит все по своему сделать
429
- if high_resolution != model_resolution:
430
- imgs_low = F.interpolate(imgs, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
431
- else:
432
- imgs_low = imgs
433
-
434
- # ensure dtype matches model params to avoid float/half mismatch
435
- model_dtype = next(vae.parameters()).dtype
436
- if imgs_low.dtype != model_dtype:
437
- imgs_low_model = imgs_low.to(dtype=model_dtype)
438
- else:
439
- imgs_low_model = imgs_low
440
-
441
- # Encode/decode
442
- latents = vae.encode(imgs_low_model).latent_dist.mean
443
- rec = vae.decode(latents).sample # rec может быть увеличенным (асимметричный VAE)
444
-
445
- # Приводим размер к high-res
446
- if rec.shape[-2:] != imgs.shape[-2:]:
447
- rec = F.interpolate(rec, size=imgs.shape[-2:], mode="bilinear", align_corners=False)
448
-
449
- # Лоссы считаем на high-res
450
- rec_f32 = rec.to(torch.float32)
451
- imgs_f32 = imgs.to(torch.float32)
452
-
453
- # Отдельные лоссы
454
- abs_losses = {
455
- "mae": F.l1_loss(rec_f32, imgs_f32),
456
- "mse": F.mse_loss(rec_f32, imgs_f32),
457
- "lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
458
- "edge": F.l1_loss(sobel_edges(rec_f32), sobel_edges(imgs_f32)),
459
- }
460
-
461
- # Total с медианными КОЭФФИЦИЕНТАМИ
462
- # Не надо так орать когда у тебя получилось понять мою идею
463
- total_loss, coeffs, meds = normalizer.update_and_total(abs_losses)
464
-
465
- if torch.isnan(total_loss) or torch.isinf(total_loss):
466
- print("NaN/Inf loss – stopping")
467
- raise RuntimeError("NaN/Inf loss")
468
-
469
- accelerator.backward(total_loss)
470
-
471
- grad_norm = torch.tensor(0.0, device=accelerator.device)
472
- if accelerator.sync_gradients:
473
- grad_norm = accelerator.clip_grad_norm_(trainable_params, clip_grad_norm)
474
- optimizer.step()
475
- scheduler.step()
476
- optimizer.zero_grad(set_to_none=True)
477
-
478
- global_step += 1
479
- progress.update(1)
480
-
481
- # --- Логирование ---
482
- if accelerator.is_main_process:
483
- try:
484
- current_lr = optimizer.param_groups[0]["lr"]
485
- except Exception:
486
- current_lr = scheduler.get_last_lr()[0]
487
-
488
- batch_losses.append(total_loss.detach().item())
489
- batch_grads.append(float(grad_norm if isinstance(grad_norm, (float, int)) else grad_norm.cpu().item()))
490
- for k, v in abs_losses.items():
491
- track_losses[k].append(float(v.detach().item()))
492
-
493
- if use_wandb and accelerator.sync_gradients:
494
- log_dict = {
495
- "total_loss": float(total_loss.detach().item()),
496
- "learning_rate": current_lr,
497
- "epoch": epoch,
498
- "grad_norm": batch_grads[-1],
499
- }
500
- # добавляем отдельные лоссы
501
- for k, v in abs_losses.items():
502
- log_dict[f"loss_{k}"] = float(v.detach().item())
503
- # логи коэффициентов и медиан
504
- for k in coeffs:
505
- log_dict[f"coeff_{k}"] = float(coeffs[k])
506
- log_dict[f"median_{k}"] = float(meds[k])
507
- wandb.log(log_dict, step=global_step)
508
-
509
- # периодические сэмплы и чекпоинты
510
- if global_step > 0 and global_step % sample_interval == 0:
511
- if accelerator.is_main_process:
512
- generate_and_save_samples(global_step)
513
- accelerator.wait_for_everyone()
514
-
515
- # Средние по последним итерациям
516
- n_micro = sample_interval * gradient_accumulation_steps
517
- if len(batch_losses) >= n_micro:
518
- avg_loss = float(np.mean(batch_losses[-n_micro:]))
519
- else:
520
- avg_loss = float(np.mean(batch_losses)) if batch_losses else float("nan")
521
-
522
- avg_grad = float(np.mean(batch_grads[-n_micro:])) if len(batch_grads) >= 1 else float(np.mean(batch_grads)) if batch_grads else 0.0
523
-
524
- if accelerator.is_main_process:
525
- print(f"Epoch {epoch} step {global_step} loss: {avg_loss:.6f}, grad_norm: {avg_grad:.6f}, lr: {current_lr:.9f}")
526
- if save_model and avg_loss < min_loss * save_barrier:
527
- min_loss = avg_loss
528
- accelerator.unwrap_model(vae).save_pretrained(save_as)
529
- if use_wandb:
530
- wandb.log({"interm_loss": avg_loss, "interm_grad": avg_grad}, step=global_step)
531
-
532
- if accelerator.is_main_process:
533
- epoch_avg = float(np.mean(batch_losses)) if batch_losses else float("nan")
534
- print(f"Epoch {epoch} done, avg loss {epoch_avg:.6f}")
535
- if use_wandb:
536
- wandb.log({"epoch_loss": epoch_avg, "epoch": epoch + 1}, step=global_step)
537
-
538
- # --------------------------- Финальное сохранение ---------------------------
539
- if accelerator.is_main_process:
540
- print("Training finished – saving final model")
541
- if save_model:
542
- accelerator.unwrap_model(vae).save_pretrained(save_as)
543
-
544
- accelerator.free_memory()
545
- if torch.distributed.is_initialized():
546
- torch.distributed.destroy_process_group()
547
- print("Готово!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_sdxl_vae_full.py DELETED
@@ -1,594 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- import os
3
- import math
4
- import re
5
- import torch
6
- import numpy as np
7
- import random
8
- import gc
9
- from datetime import datetime
10
- from pathlib import Path
11
-
12
- import torchvision.transforms as transforms
13
- import torch.nn.functional as F
14
- from torch.utils.data import DataLoader, Dataset
15
- from torch.optim.lr_scheduler import LambdaLR
16
- from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
17
- from accelerate import Accelerator
18
- from PIL import Image, UnidentifiedImageError
19
- from tqdm import tqdm
20
- import bitsandbytes as bnb
21
- import wandb
22
- import lpips # pip install lpips
23
- from collections import deque
24
-
25
- # --------------------------- Параметры ---------------------------
26
- ds_path = "/workspace/png"
27
- project = "simple_vae"
28
- batch_size = 3
29
- base_learning_rate = 2e-6
30
- min_learning_rate = 8e-7
31
- num_epochs = 8
32
- sample_interval_share = 10
33
- use_wandb = True
34
- save_model = True
35
- use_decay = True
36
- asymmetric = False
37
- optimizer_type = "adam8bit"
38
- dtype = torch.float32
39
- # model_resolution — то, что подавается в VAE (низкое разрешение)
40
- model_resolution = 512 # бывший `resolution`
41
- # high_resolution — настоящий «высокий» кроп, на котором считаем метрики и сохраняем сэмплы
42
- high_resolution = 512
43
- limit = 0
44
- save_barrier = 1.03
45
- warmup_percent = 0.01
46
- percentile_clipping = 95
47
- beta2 = 0.97
48
- eps = 1e-6
49
- clip_grad_norm = 1.0
50
- mixed_precision = "no" # или "fp16"/"bf16" при поддержке
51
- gradient_accumulation_steps = 5
52
- generated_folder = "samples"
53
- save_as = "simple_vae_nightly"
54
- num_workers = 0
55
- device = None # accelerator задаст устройство
56
-
57
- # --------------------------- Тренировочные режимы ---------------------------
58
- # CHANGED: добавлен параметр для полного обучения VAE (а не только декодера).
59
- # Если False — поведение прежнее: учим только decoder.* (up_blocks + mid_block).
60
- # Если True — размораживаем ВСЮ модель и добавляем KL-loss для энкодера.
61
- full_training = False
62
-
63
- # CHANGED: добавлен вес (через долю в нормализаторе) для KL, используется только при full_training=True.
64
- kl_ratio = 0.00 # простая доля для KL в общей смеси (KISS). Игнорируется, если full_training=False.
65
-
66
- # --- Пропорции лоссов и окно медианного нормирования (КОЭФ., не значения) ---
67
- # Итоговые доли в total loss (сумма = 1.0 после нормализации).
68
- loss_ratios = {
69
- "lpips": 0.60,
70
- "edge": 0.10,
71
- "mse": 0.15,
72
- "mae": 0.15,
73
- # CHANGED: заранее добавлен ключ "kl" (по умолчанию 0.0). Если включаем full_training — активируем ниже.
74
- "kl": 0.00,
75
- }
76
- median_coeff_steps = 256 # за сколько шагов считать медианные коэффициенты
77
-
78
- # --------------------------- параметры препроцессинга ---------------------------
79
- resize_long_side = 1280 # если None или 0 — ресайза не будет; рекомендовано 1280
80
-
81
- Path(generated_folder).mkdir(parents=True, exist_ok=True)
82
-
83
- accelerator = Accelerator(
84
- mixed_precision=mixed_precision,
85
- gradient_accumulation_steps=gradient_accumulation_steps
86
- )
87
- device = accelerator.device
88
-
89
- # reproducibility
90
- seed = int(datetime.now().strftime("%Y%m%d"))
91
- torch.manual_seed(seed)
92
- np.random.seed(seed)
93
- random.seed(seed)
94
-
95
- torch.backends.cudnn.benchmark = False
96
-
97
- # --------------------------- WandB ---------------------------
98
- if use_wandb and accelerator.is_main_process:
99
- wandb.init(project=project, config={
100
- "batch_size": batch_size,
101
- "base_learning_rate": base_learning_rate,
102
- "num_epochs": num_epochs,
103
- "optimizer_type": optimizer_type,
104
- "model_resolution": model_resolution,
105
- "high_resolution": high_resolution,
106
- "gradient_accumulation_steps": gradient_accumulation_steps,
107
- "full_training": full_training, # CHANGED: логируем режим
108
- "kl_ratio": kl_ratio, # CHANGED: логируем долю KL
109
- })
110
-
111
- # --------------------------- VAE ---------------------------
112
- if model_resolution==high_resolution and not asymmetric:
113
- vae = AutoencoderKL.from_pretrained(project).to(dtype)
114
- else:
115
- vae = AsymmetricAutoencoderKL.from_pretrained(project).to(dtype)
116
-
117
- # torch.compile (если доступно) — просто и без лишней логики
118
- if hasattr(torch, "compile"):
119
- try:
120
- vae = torch.compile(vae)
121
- except Exception as e:
122
- print(f"[WARN] torch.compile failed: {e}")
123
-
124
- # >>> Стратегия заморозки / разморозки
125
- for p in vae.parameters():
126
- p.requires_grad = False
127
-
128
- decoder = getattr(vae, "decoder", None)
129
- if decoder is None:
130
- raise RuntimeError("vae.decoder not found — не могу применить стратегию разморозки. Проверь структуру модели.")
131
-
132
- unfrozen_param_names = []
133
-
134
- if not full_training:
135
- # === Прежнее поведение: обучаем только decoder.up_blocks и decoder.mid_block ===
136
- if not hasattr(decoder, "up_blocks"):
137
- raise RuntimeError("decoder.up_blocks не найдены — ожидается список блоков декодера.")
138
-
139
- n_up = len(decoder.up_blocks)
140
- start_idx = 0
141
- for idx in range(start_idx, n_up):
142
- block = decoder.up_blocks[idx]
143
- for name, p in block.named_parameters():
144
- p.requires_grad = True
145
- unfrozen_param_names.append(f"decoder.up_blocks.{idx}.{name}")
146
-
147
- if hasattr(decoder, "mid_block"):
148
- for name, p in decoder.mid_block.named_parameters():
149
- p.requires_grad = True
150
- unfrozen_param_names.append(f"decoder.mid_block.{name}")
151
- else:
152
- print("[WARN] decoder.mid_block не найден — mid_block не разморожен.")
153
-
154
- # Обучаем только декодер
155
- trainable_module = vae.decoder
156
- else:
157
- # === CHANGED: Полное обучение — размораживаем ВСЕ слои VAE (и энкодер, и декодер, и пост-проекцию) ===
158
- for name, p in vae.named_parameters():
159
- p.requires_grad = True
160
- unfrozen_param_names.append(name)
161
- trainable_module = vae # CHANGED: учим всю модель
162
-
163
- # CHANGED: активируем KL-долю в нормализаторе
164
- loss_ratios["kl"] = float(kl_ratio)
165
-
166
- print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:")
167
- for nm in unfrozen_param_names[:200]:
168
- print(" ", nm)
169
-
170
- # --------------------------- Custom PNG Dataset (only .png, skip corrupted) -----------
171
- class PngFolderDataset(Dataset):
172
- def __init__(self, root_dir, min_exts=('.png',), resolution=1024, limit=0):
173
- self.root_dir = root_dir
174
- self.resolution = resolution
175
- self.paths = []
176
- # collect png files recursively
177
- for root, _, files in os.walk(root_dir):
178
- for fname in files:
179
- if fname.lower().endswith(tuple(ext.lower() for ext in min_exts)):
180
- self.paths.append(os.path.join(root, fname))
181
- # optional limit
182
- if limit:
183
- self.paths = self.paths[:limit]
184
- # verify images and keep only valid ones
185
- valid = []
186
- for p in self.paths:
187
- try:
188
- with Image.open(p) as im:
189
- im.verify() # fast check for truncated/corrupted images
190
- valid.append(p)
191
- except (OSError, UnidentifiedImageError):
192
- # skip corrupted image
193
- continue
194
- self.paths = valid
195
- if len(self.paths) == 0:
196
- raise RuntimeError(f"No valid PNG images found under {root_dir}")
197
- # final shuffle for randomness
198
- random.shuffle(self.paths)
199
-
200
- def __len__(self):
201
- return len(self.paths)
202
-
203
- def __getitem__(self, idx):
204
- p = self.paths[idx % len(self.paths)]
205
- # open and convert to RGB; ensure file is closed promptly
206
- with Image.open(p) as img:
207
- img = img.convert("RGB")
208
- # пережимаем длинную сторону до resize_long_side (Lanczos)
209
- if not resize_long_side or resize_long_side <= 0:
210
- return img
211
- w, h = img.size
212
- long = max(w, h)
213
- if long <= resize_long_side:
214
- return img
215
- scale = resize_long_side / float(long)
216
- new_w = int(round(w * scale))
217
- new_h = int(round(h * scale))
218
- return img.resize((new_w, new_h), Image.LANCZOS)
219
-
220
- # --------------------------- Датасет и трансформы ---------------------------
221
-
222
- def random_crop(img, sz):
223
- w, h = img.size
224
- if w < sz or h < sz:
225
- img = img.resize((max(sz, w), max(sz, h)), Image.LANCZOS)
226
- x = random.randint(0, max(1, img.width - sz))
227
- y = random.randint(0, max(1, img.height - sz))
228
- return img.crop((x, y, x + sz, y + sz))
229
-
230
- tfm = transforms.Compose([
231
- transforms.ToTensor(),
232
- transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
233
- ])
234
-
235
- # build dataset using high_resolution crops
236
- dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit)
237
- if len(dataset) < batch_size:
238
- raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}")
239
-
240
- # collate_fn кропит до high_resolution
241
- def collate_fn(batch):
242
- imgs = []
243
- for img in batch: # img is PIL.Image
244
- img = random_crop(img, high_resolution) # кропим high-res
245
- imgs.append(tfm(img))
246
- return torch.stack(imgs)
247
-
248
- dataloader = DataLoader(
249
- dataset,
250
- batch_size=batch_size,
251
- shuffle=True,
252
- collate_fn=collate_fn,
253
- num_workers=num_workers,
254
- pin_memory=True,
255
- drop_last=True
256
- )
257
-
258
- # --------------------------- Оптимизатор ---------------------------
259
-
260
- def get_param_groups(module, weight_decay=0.001):
261
- no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "ln_1.weight", "ln_f.weight"]
262
- decay_params = []
263
- no_decay_params = []
264
- for n, p in module.named_parameters():
265
- if not p.requires_grad:
266
- continue
267
- if any(nd in n for nd in no_decay):
268
- no_decay_params.append(p)
269
- else:
270
- decay_params.append(p)
271
- return [
272
- {"params": decay_params, "weight_decay": weight_decay},
273
- {"params": no_decay_params, "weight_decay": 0.0},
274
- ]
275
-
276
- def create_optimizer(name, param_groups):
277
- if name == "adam8bit":
278
- return bnb.optim.AdamW8bit(
279
- param_groups, lr=base_learning_rate, betas=(0.9, beta2), eps=eps
280
- )
281
- raise ValueError(name)
282
-
283
- param_groups = get_param_groups(trainable_module, weight_decay=0.001)
284
- optimizer = create_optimizer(optimizer_type, param_groups)
285
-
286
- # --------------------------- График LR ---------------------------
287
-
288
- batches_per_epoch = len(dataloader) # число микро-батчей (dataloader steps)
289
- steps_per_epoch = int(math.ceil(batches_per_epoch / float(gradient_accumulation_steps))) # число optimizer.step() за эпоху
290
- total_steps = steps_per_epoch * num_epochs
291
-
292
- def lr_lambda(step):
293
- if not use_decay:
294
- return 1.0
295
- x = float(step) / float(max(1, total_steps))
296
- warmup = float(warmup_percent)
297
- min_ratio = float(min_learning_rate) / float(base_learning_rate)
298
- if x < warmup:
299
- return min_ratio + (1.0 - min_ratio) * (x / warmup)
300
- decay_ratio = (x - warmup) / (1.0 - warmup)
301
- return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * decay_ratio))
302
-
303
- scheduler = LambdaLR(optimizer, lr_lambda)
304
-
305
- # Подготовка
306
- dataloader, vae, optimizer, scheduler = accelerator.prepare(dataloader, vae, optimizer, scheduler)
307
-
308
- # CHANGED: формируем список trainable_params исходя из выбранного trainable_module
309
- trainable_params = [p for p in (trainable_module.parameters() if hasattr(trainable_module, "parameters") else []) if p.requires_grad]
310
-
311
- # --------------------------- LPIPS и вспомогательные функции ---------------------------
312
- _lpips_net = None
313
-
314
- def _get_lpips():
315
- global _lpips_net
316
- if _lpips_net is None:
317
- _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
318
- return _lpips_net
319
-
320
- # Собель для edge loss
321
- _sobel_kx = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], dtype=torch.float32)
322
- _sobel_ky = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], dtype=torch.float32)
323
-
324
- def sobel_edges(x: torch.Tensor) -> torch.Tensor:
325
- # x: [B,C,H,W] в [-1,1]
326
- C = x.shape[1]
327
- kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1)
328
- ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1)
329
- gx = F.conv2d(x, kx, padding=1, groups=C)
330
- gy = F.conv2d(x, ky, padding=1, groups=C)
331
- return torch.sqrt(gx * gx + gy * gy + 1e-12)
332
-
333
- # Нормализация лоссов по медианам: считаем КОЭФФИЦИЕНТЫ
334
- class MedianLossNormalizer:
335
- def __init__(self, desired_ratios: dict, window_steps: int):
336
- # нормируем доли на случай, если сумма != 1
337
- s = sum(desired_ratios.values())
338
- self.ratios = {k: (v / s) if s > 0 else 0.0 for k, v in desired_ratios.items()}
339
- self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
340
- self.window = window_steps
341
-
342
- def update_and_total(self, abs_losses: dict):
343
- # Заполняем буферы фактическими АБСОЛЮТНЫМИ значениями лоссов
344
- for k, v in abs_losses.items():
345
- if k in self.buffers:
346
- self.buffers[k].append(float(v.detach().abs().cpu()))
347
- # Медианы (устойчивые к выбросам)
348
- meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
349
- # Вычисляем КОЭФФИЦИЕНТЫ как ratio_k / median_k — т.е. именно коэффициенты, а не значения
350
- coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
351
- # Итоговый total — сумма по ключам, присутствующим в abs_losses
352
- total = sum(coeffs[k] * abs_losses[k] for k in abs_losses if k in coeffs)
353
- return total, coeffs, meds
354
-
355
- # CHANGED: создаём нормализатор ПОСЛЕ возможной активации kl_ratio выше
356
- normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
357
-
358
- # --------------------------- Сэмплы ---------------------------
359
- @torch.no_grad()
360
- def get_fixed_samples(n=3):
361
- idx = random.sample(range(len(dataset)), min(n, len(dataset)))
362
- pil_imgs = [dataset[i] for i in idx] # dataset returns PIL.Image
363
- tensors = []
364
- for img in pil_imgs:
365
- img = random_crop(img, high_resolution) # high-res fixed samples
366
- tensors.append(tfm(img))
367
- return torch.stack(tensors).to(accelerator.device, dtype)
368
-
369
- fixed_samples = get_fixed_samples()
370
-
371
- @torch.no_grad()
372
- def _to_pil_uint8(img_tensor: torch.Tensor) -> Image.Image:
373
- # img_tensor: [C,H,W] in [-1,1]
374
- arr = ((img_tensor.float().clamp(-1, 1) + 1.0) * 127.5).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0)
375
- return Image.fromarray(arr)
376
-
377
- @torch.no_grad()
378
- def generate_and_save_samples(step=None):
379
- try:
380
- temp_vae = accelerator.unwrap_model(vae).eval()
381
- lpips_net = _get_lpips()
382
- with torch.no_grad():
383
- # Готовим low-res вход для кодера ВСЕГДА под model_resolution
384
- orig_high = fixed_samples # [B,C,H,W] в [-1,1]
385
- orig_low = F.interpolate(orig_high, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
386
- # dtype как у модели
387
- model_dtype = next(temp_vae.parameters()).dtype
388
- orig_low = orig_low.to(dtype=model_dtype)
389
- # encode/decode
390
- # CHANGED: при валидации/сэмплах всегда используем mean (стабильно и детерминированно)
391
- enc = temp_vae.encode(orig_low)
392
- latents_mean = enc.latent_dist.mean
393
- rec = temp_vae.decode(latents_mean).sample
394
-
395
- # Приводим spatial размер рекона к high-res (downsample для асимметричных VAE)
396
- if rec.shape[-2:] != orig_high.shape[-2:]:
397
- rec = F.interpolate(rec, size=orig_high.shape[-2:], mode="bilinear", align_corners=False)
398
-
399
- # Сохраняем ПЕРВЫЙ семпл: real и decoded без номера шага в имени
400
- first_real = _to_pil_uint8(orig_high[0])
401
- first_dec = _to_pil_uint8(rec[0])
402
- first_real.save(f"{generated_folder}/sample_real.jpg", quality=95)
403
- first_dec.save(f"{generated_folder}/sample_decoded.jpg", quality=95)
404
-
405
- # Дополнительно сохраняем текущие реконструкции без номера шага (чтобы не плодить файлы — будут перезаписываться)
406
- for i in range(rec.shape[0]):
407
- _to_pil_uint8(rec[i]).save(f"{generated_folder}/sample_{i}.jpg", quality=95)
408
-
409
- # LPIPS на полном изображении (high-res) — для лога
410
- lpips_scores = []
411
- for i in range(rec.shape[0]):
412
- orig_full = orig_high[i:i+1].to(torch.float32)
413
- rec_full = rec[i:i+1].to(torch.float32)
414
- if rec_full.shape[-2:] != orig_full.shape[-2:]:
415
- rec_full = F.interpolate(rec_full, size=orig_full.shape[-2:], mode="bilinear", align_corners=False)
416
- lpips_val = lpips_net(orig_full, rec_full).item()
417
- lpips_scores.append(lpips_val)
418
- avg_lpips = float(np.mean(lpips_scores))
419
-
420
- if use_wandb and accelerator.is_main_process:
421
- wandb.log({
422
- "lpips_mean": avg_lpips,
423
- }, step=step)
424
- finally:
425
- gc.collect()
426
- torch.cuda.empty_cache()
427
-
428
- if accelerator.is_main_process and save_model:
429
- print("Генерация сэмплов до старта обучения...")
430
- generate_and_save_samples(0)
431
-
432
- accelerator.wait_for_everyone()
433
-
434
- # --------------------------- Тренировка ---------------------------
435
-
436
- progress = tqdm(total=total_steps, disable=not accelerator.is_local_main_process)
437
- global_step = 0
438
- min_loss = float("inf")
439
- sample_interval = max(1, total_steps // max(1, sample_interval_share * num_epochs))
440
-
441
- for epoch in range(num_epochs):
442
- vae.train()
443
- batch_losses = []
444
- batch_grads = []
445
- # Доп. трекинг по отдельным лоссам
446
- track_losses = {k: [] for k in loss_ratios.keys()}
447
- for imgs in dataloader:
448
- with accelerator.accumulate(vae):
449
- # imgs: high-res tensor from dataloader ([-1,1]), move to device
450
- imgs = imgs.to(accelerator.device)
451
-
452
- # ВСЕГДА даунсемплим вход под model_resolution для кодера
453
- if high_resolution != model_resolution:
454
- imgs_low = F.interpolate(imgs, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
455
- else:
456
- imgs_low = imgs
457
-
458
- # ensure dtype matches model params to avoid float/half mismatch
459
- model_dtype = next(vae.parameters()).dtype
460
- imgs_low_model = imgs_low.to(dtype=model_dtype) if imgs_low.dtype != model_dtype else imgs_low
461
-
462
- # Encode/decode
463
- enc = vae.encode(imgs_low_model)
464
-
465
- # CHANGED: если тренируем всю модель — используем reparameterization sample()
466
- # это важно для стохастичности и согласованности с KL.
467
- latents = enc.latent_dist.sample() if full_training else enc.latent_dist.mean
468
-
469
- rec = vae.decode(latents).sample # rec может быть увеличенным (асимметричный VAE)
470
-
471
- # Приводим размер к high-res
472
- if rec.shape[-2:] != imgs.shape[-2:]:
473
- rec = F.interpolate(rec, size=imgs.shape[-2:], mode="bilinear", align_corners=False)
474
-
475
- # Лоссы считаем на high-res
476
- rec_f32 = rec.to(torch.float32)
477
- imgs_f32 = imgs.to(torch.float32)
478
-
479
- # Отдельные лоссы (абсолютные значения)
480
- abs_losses = {
481
- "mae": F.l1_loss(rec_f32, imgs_f32),
482
- "mse": F.mse_loss(rec_f32, imgs_f32),
483
- "lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
484
- "edge": F.l1_loss(sobel_edges(rec_f32), sobel_edges(imgs_f32)),
485
- }
486
-
487
- # CHANGED: KL-loss добавляется ТОЛЬКО при полном обучении.
488
- # KL(q(z|x) || N(0,1)) = -0.5 * sum(1 + logσ^2 - μ^2 - σ^2).
489
- if full_training:
490
- mean = enc.latent_dist.mean
491
- logvar = enc.latent_dist.logvar
492
- # стабильное усреднение по батчу и пространству
493
- # СТАРОЕ (неправильное):
494
- #kl = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
495
- # НОВОЕ (правильное):
496
- kl_per_sample = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp(), dim=[1, 2, 3])
497
- kl = torch.mean(kl_per_sample)
498
- abs_losses["kl"] = kl
499
- else:
500
- # ключ присутствует в ratios, но при partial-training его доля = 0 и он не влияет
501
- abs_losses["kl"] = torch.tensor(0.0, device=accelerator.device, dtype=torch.float32)
502
-
503
- # Total с медианными КОЭФФИЦИЕНТАМИ
504
- total_loss, coeffs, meds = normalizer.update_and_total(abs_losses)
505
-
506
- if torch.isnan(total_loss) or torch.isinf(total_loss):
507
- print("NaN/Inf loss – stopping")
508
- raise RuntimeError("NaN/Inf loss")
509
-
510
- accelerator.backward(total_loss)
511
-
512
- grad_norm = torch.tensor(0.0, device=accelerator.device)
513
- if accelerator.sync_gradients:
514
- grad_norm = accelerator.clip_grad_norm_(trainable_params, clip_grad_norm)
515
- optimizer.step()
516
- scheduler.step()
517
- optimizer.zero_grad(set_to_none=True)
518
-
519
- global_step += 1
520
- progress.update(1)
521
-
522
- # --- Логирование ---
523
- if accelerator.is_main_process:
524
- try:
525
- current_lr = optimizer.param_groups[0]["lr"]
526
- except Exception:
527
- current_lr = scheduler.get_last_lr()[0]
528
-
529
- batch_losses.append(total_loss.detach().item())
530
- # CHANGED: корректно извлекаем scalar из разн. типов
531
- if isinstance(grad_norm, torch.Tensor):
532
- batch_grads.append(float(grad_norm.detach().cpu().item()))
533
- else:
534
- batch_grads.append(float(grad_norm))
535
-
536
- for k, v in abs_losses.items():
537
- track_losses[k].append(float(v.detach().item()))
538
-
539
- if use_wandb and accelerator.sync_gradients:
540
- log_dict = {
541
- "total_loss": float(total_loss.detach().item()),
542
- "learning_rate": current_lr,
543
- "epoch": epoch,
544
- "grad_norm": batch_grads[-1],
545
- "mode/full_training": int(full_training), # CHANGED: для наглядности в логах
546
- }
547
- # добавляем отдельные лоссы
548
- for k, v in abs_losses.items():
549
- log_dict[f"loss_{k}"] = float(v.detach().item())
550
- # логи коэффициентов и медиан
551
- for k in coeffs:
552
- log_dict[f"coeff_{k}"] = float(coeffs[k])
553
- log_dict[f"median_{k}"] = float(meds[k])
554
- wandb.log(log_dict, step=global_step)
555
-
556
- # периодические сэмплы и чекпоинты
557
- if global_step > 0 and global_step % sample_interval == 0:
558
- if accelerator.is_main_process:
559
- generate_and_save_samples(global_step)
560
- accelerator.wait_for_everyone()
561
-
562
- # Средние по последним итерациям
563
- n_micro = sample_interval * gradient_accumulation_steps
564
- if len(batch_losses) >= n_micro:
565
- avg_loss = float(np.mean(batch_losses[-n_micro:]))
566
- else:
567
- avg_loss = float(np.mean(batch_losses)) if batch_losses else float("nan")
568
-
569
- avg_grad = float(np.mean(batch_grads[-n_micro:])) if len(batch_grads) >= 1 else float(np.mean(batch_grads)) if batch_grads else 0.0
570
-
571
- if accelerator.is_main_process:
572
- print(f"Epoch {epoch} step {global_step} loss: {avg_loss:.6f}, grad_norm: {avg_grad:.6f}, lr: {current_lr:.9f}")
573
- if save_model and avg_loss < min_loss * save_barrier:
574
- min_loss = avg_loss
575
- accelerator.unwrap_model(vae).save_pretrained(save_as)
576
- if use_wandb:
577
- wandb.log({"interm_loss": avg_loss, "interm_grad": avg_grad}, step=global_step)
578
-
579
- if accelerator.is_main_process:
580
- epoch_avg = float(np.mean(batch_losses)) if batch_losses else float("nan")
581
- print(f"Epoch {epoch} done, avg loss {epoch_avg:.6f}")
582
- if use_wandb:
583
- wandb.log({"epoch_loss": epoch_avg, "epoch": epoch + 1}, step=global_step)
584
-
585
- # --------------------------- Финальное сохранение ---------------------------
586
- if accelerator.is_main_process:
587
- print("Training finished – saving final model")
588
- if save_model:
589
- accelerator.unwrap_model(vae).save_pretrained(save_as)
590
-
591
- accelerator.free_memory()
592
- if torch.distributed.is_initialized():
593
- torch.distributed.destroy_process_group()
594
- print("Готово!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_sdxl_vae_my.py DELETED
@@ -1,507 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- import os
3
- import math
4
- import re
5
- import torch
6
- import numpy as np
7
- import random
8
- import gc
9
- from datetime import datetime
10
- from pathlib import Path
11
-
12
- import torchvision.transforms as transforms
13
- import torch.nn.functional as F
14
- from torch.utils.data import DataLoader, Dataset
15
- from torch.optim.lr_scheduler import LambdaLR
16
- from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
17
- from accelerate import Accelerator
18
- from PIL import Image, UnidentifiedImageError
19
- from tqdm import tqdm
20
- import bitsandbytes as bnb
21
- import wandb
22
- import lpips # pip install lpips
23
-
24
- # --------------------------- Параметры ---------------------------
25
- ds_path = "/workspace/png"
26
- project = "asymmetric_vae"
27
- batch_size = 2
28
- base_learning_rate = 1e-6
29
- min_learning_rate = 8e-7
30
- num_epochs = 8
31
- sample_interval_share = 10
32
- use_wandb = True
33
- save_model = True
34
- use_decay = True
35
- asymmetric = True
36
- optimizer_type = "adam8bit"
37
- dtype = torch.float32
38
- # model_resolution — то, что подавается в VAE (низкое разрешение)
39
- model_resolution = 512 # бывший `resolution`
40
- # high_resolution — настоящий «высокий» кроп, на котором считаем метрики и сохраняем сэмплы
41
- high_resolution = 1024
42
- limit = 0
43
- save_barrier = 1.03
44
- warmup_percent = 0.01
45
- percentile_clipping = 95
46
- beta2 = 0.97
47
- eps = 1e-6
48
- clip_grad_norm = 1.0
49
- mixed_precision = "no" # или "fp16"/"bf16" при поддержке
50
- gradient_accumulation_steps = 8
51
- generated_folder = "samples"
52
- save_as = "asymmetric_vae_new"
53
- perceptual_loss_weight = 0.03 # начальное значение веса (будет перезаписываться каждый шаг)
54
- num_workers = 0
55
- device = None # accelerator задаст устройство
56
-
57
- # --- Параметры динамической нормализации LPIPS
58
- lpips_ratio = 0.9 #percent lpips in loss
59
-
60
- min_perceptual_weight = 0.1 # минимальный предел веса
61
- max_perceptual_weight = 99 # максимальный предел веса (защита от взрывов)
62
-
63
- # --------------------------- параметры препроцессинга ---------------------------
64
- resize_long_side = 1280 # если None или 0 — ресайза не будет; рекомендовано 1024
65
-
66
- Path(generated_folder).mkdir(parents=True, exist_ok=True)
67
-
68
- accelerator = Accelerator(
69
- mixed_precision=mixed_precision,
70
- gradient_accumulation_steps=gradient_accumulation_steps
71
- )
72
- device = accelerator.device
73
-
74
- # reproducibility
75
- seed = int(datetime.now().strftime("%Y%m%d"))
76
- torch.manual_seed(seed)
77
- np.random.seed(seed)
78
- random.seed(seed)
79
-
80
- torch.backends.cudnn.benchmark = True
81
-
82
- # --------------------------- WandB ---------------------------
83
- if use_wandb and accelerator.is_main_process:
84
- wandb.init(project=project, config={
85
- "batch_size": batch_size,
86
- "base_learning_rate": base_learning_rate,
87
- "num_epochs": num_epochs,
88
- "optimizer_type": optimizer_type,
89
- "model_resolution": model_resolution,
90
- "high_resolution": high_resolution,
91
- "gradient_accumulation_steps": gradient_accumulation_steps,
92
- })
93
-
94
- # --------------------------- VAE ---------------------------
95
- if model_resolution==high_resolution and not asymmetric:
96
- vae = AutoencoderKL.from_pretrained(project).to(dtype)
97
- else:
98
- vae = AsymmetricAutoencoderKL.from_pretrained(project).to(dtype)
99
-
100
- # >>> CHANGED: заморозка всех параметров, затем разморозка mid_block + up_blocks[-2:]
101
- for p in vae.parameters():
102
- p.requires_grad = False
103
-
104
- decoder = getattr(vae, "decoder", None)
105
- if decoder is None:
106
- raise RuntimeError("vae.decoder not found — не могу применить стратегию разморозки. Проверь структуру модели.")
107
-
108
- unfrozen_param_names = []
109
-
110
- if not hasattr(decoder, "up_blocks"):
111
- raise RuntimeError("decoder.up_blocks не найдены — ожидается список блоков декодера.")
112
-
113
- # >>> CHANGED: размораживаем последние 2 up_blocks (как просил) и mid_block
114
- n_up = len(decoder.up_blocks)
115
- start_idx = 0 #max(0, n_up - 2) # all
116
- for idx in range(start_idx, n_up):
117
- block = decoder.up_blocks[idx]
118
- for name, p in block.named_parameters():
119
- p.requires_grad = True
120
- unfrozen_param_names.append(f"decoder.up_blocks.{idx}.{name}")
121
-
122
- if hasattr(decoder, "mid_block"):
123
- for name, p in decoder.mid_block.named_parameters():
124
- p.requires_grad = True
125
- unfrozen_param_names.append(f"decoder.mid_block.{name}")
126
- else:
127
- print("[WARN] decoder.mid_block не найден — mid_block не разморожен.")
128
-
129
- print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:")
130
- for nm in unfrozen_param_names[:200]:
131
- print(" ", nm)
132
-
133
- # сохраняем trainable_module (get_param_groups будет учитывать p.requires_grad)
134
- trainable_module = vae.decoder
135
-
136
- # --------------------------- Custom PNG Dataset (only .png, skip corrupted) -----------
137
- class PngFolderDataset(Dataset):
138
- def __init__(self, root_dir, min_exts=('.png',), resolution=1024, limit=0):
139
- # >>> CHANGED: default resolution argument is high-resolution (1024)
140
- self.root_dir = root_dir
141
- self.resolution = resolution
142
- self.paths = []
143
- # collect png files recursively
144
- for root, _, files in os.walk(root_dir):
145
- for fname in files:
146
- if fname.lower().endswith(tuple(ext.lower() for ext in min_exts)):
147
- self.paths.append(os.path.join(root, fname))
148
- # optional limit
149
- if limit:
150
- self.paths = self.paths[:limit]
151
- # verify images and keep only valid ones
152
- valid = []
153
- for p in self.paths:
154
- try:
155
- with Image.open(p) as im:
156
- im.verify() # fast check for truncated/corrupted images
157
- valid.append(p)
158
- except (OSError, UnidentifiedImageError):
159
- # skip corrupted image
160
- continue
161
- self.paths = valid
162
- if len(self.paths) == 0:
163
- raise RuntimeError(f"No valid PNG images found under {root_dir}")
164
- # final shuffle for randomness
165
- random.shuffle(self.paths)
166
-
167
- def __len__(self):
168
- return len(self.paths)
169
-
170
- def __getitem__(self, idx):
171
- p = self.paths[idx % len(self.paths)]
172
- # open and convert to RGB; ensure file is closed promptly
173
- with Image.open(p) as img:
174
- img = img.convert("RGB")
175
- # return PIL image (collate will transform)
176
- if not resize_long_side or resize_long_side <= 0:
177
- return img
178
- w, h = img.size
179
- long = max(w, h)
180
- if long <= resize_long_side:
181
- return img
182
- scale = resize_long_side / float(long)
183
- new_w = int(round(w * scale))
184
- new_h = int(round(h * scale))
185
- return img.resize((new_w, new_h), Image.LANCZOS)
186
-
187
- # --------------------------- Датасет и трансформы ---------------------------
188
-
189
- def random_crop(img, sz):
190
- w, h = img.size
191
- if w < sz or h < sz:
192
- img = img.resize((max(sz, w), max(sz, h)), Image.LANCZOS)
193
- x = random.randint(0, max(1, img.width - sz))
194
- y = random.randint(0, max(1, img.height - sz))
195
- return img.crop((x, y, x + sz, y + sz))
196
-
197
- tfm = transforms.Compose([
198
- transforms.ToTensor(),
199
- transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
200
- ])
201
-
202
- # build dataset using high_resolution crops
203
- dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit) # >>> CHANGED
204
- if len(dataset) < batch_size:
205
- raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}")
206
-
207
- # collate_fn кропит до high_resolution
208
- def collate_fn(batch):
209
- imgs = []
210
- for img in batch: # img is PIL.Image
211
- img = random_crop(img, high_resolution) # >>> CHANGED: crop high-res
212
- imgs.append(tfm(img))
213
- return torch.stack(imgs)
214
-
215
- dataloader = DataLoader(
216
- dataset,
217
- batch_size=batch_size,
218
- shuffle=True,
219
- collate_fn=collate_fn,
220
- num_workers=num_workers,
221
- pin_memory=True,
222
- drop_last=True
223
- )
224
-
225
- # --------------------------- Оптимизатор ---------------------------
226
- def get_param_groups(module, weight_decay=0.001):
227
- no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "ln_1.weight", "ln_f.weight"]
228
- decay_params = []
229
- no_decay_params = []
230
- for n, p in module.named_parameters():
231
- if not p.requires_grad:
232
- continue
233
- if any(nd in n for nd in no_decay):
234
- no_decay_params.append(p)
235
- else:
236
- decay_params.append(p)
237
- return [
238
- {"params": decay_params, "weight_decay": weight_decay},
239
- {"params": no_decay_params, "weight_decay": 0.0},
240
- ]
241
-
242
- def create_optimizer(name, param_groups):
243
- if name == "adam8bit":
244
- return bnb.optim.AdamW8bit(
245
- param_groups, lr=base_learning_rate, betas=(0.9, beta2), eps=eps
246
- )
247
- raise ValueError(name)
248
-
249
- param_groups = get_param_groups(trainable_module, weight_decay=0.001)
250
- optimizer = create_optimizer(optimizer_type, param_groups)
251
-
252
- # --------------------------- Подготовка Accelerate (вместе) ---------------------------
253
- batches_per_epoch = len(dataloader) # число микро-батчей (dataloader steps)
254
- steps_per_epoch = int(math.ceil(batches_per_epoch / float(gradient_accumulation_steps))) # число optimizer.step() за эпоху
255
- total_steps = steps_per_epoch * num_epochs
256
-
257
- def lr_lambda(step):
258
- if not use_decay:
259
- return 1.0
260
- x = float(step) / float(max(1, total_steps))
261
- warmup = float(warmup_percent)
262
- min_ratio = float(min_learning_rate) / float(base_learning_rate)
263
- if x < warmup:
264
- return min_ratio + (1.0 - min_ratio) * (x / warmup)
265
- decay_ratio = (x - warmup) / (1.0 - warmup)
266
- return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * decay_ratio))
267
-
268
- scheduler = LambdaLR(optimizer, lr_lambda)
269
-
270
- # Подготовка
271
- dataloader, vae, optimizer, scheduler = accelerator.prepare(dataloader, vae, optimizer, scheduler)
272
-
273
- trainable_params = [p for p in vae.decoder.parameters() if p.requires_grad]
274
-
275
- # --------------------------- Сэмплы и LPIPS helper ---------------------------
276
- @torch.no_grad()
277
- def get_fixed_samples(n=3):
278
- idx = random.sample(range(len(dataset)), min(n, len(dataset)))
279
- pil_imgs = [dataset[i] for i in idx] # dataset returns PIL.Image
280
- tensors = []
281
- for img in pil_imgs:
282
- img = random_crop(img, high_resolution) # >>> CHANGED: high-res fixed samples
283
- tensors.append(tfm(img))
284
- return torch.stack(tensors).to(accelerator.device, dtype)
285
-
286
- fixed_samples = get_fixed_samples()
287
-
288
- _lpips_net = None
289
- def _get_lpips():
290
- global _lpips_net
291
- if _lpips_net is None:
292
- # lpips uses its internal vgg, but we use it as-is.
293
- _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
294
- return _lpips_net
295
-
296
- @torch.no_grad()
297
- def generate_and_save_samples(step=None):
298
- try:
299
- temp_vae = accelerator.unwrap_model(vae).eval()
300
- lpips_net = _get_lpips()
301
- with torch.no_grad():
302
- # >>> CHANGED: use high-res fixed_samples, downsample to model_res for encoding
303
- orig_high = fixed_samples # already on device
304
- # make low-res input for model
305
- if model_resolution==high_resolution:
306
- orig_low = F.interpolate(orig_high, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
307
- else:
308
- orig_low =orig_high
309
-
310
- # ensure dtype matches model params to avoid dtype mismatch
311
- model_dtype = next(temp_vae.parameters()).dtype
312
- orig_low = orig_low.to(dtype=model_dtype)
313
-
314
- latent_dist = temp_vae.encode(orig_low).latent_dist
315
- latents = latent_dist.mean
316
- rec = temp_vae.decode(latents).sample # expected to be upscaled to high_res
317
-
318
- # make sure rec is float32 in range [0,1] for saving
319
- # if rec spatial size differs from orig_high, resize rec to orig_high
320
- if rec.shape[-2:] != orig_high.shape[-2:]:
321
- rec = F.interpolate(rec, size=orig_high.shape[-2:], mode="bilinear", align_corners=False)
322
-
323
- rec_img = ((rec.float() / 2.0 + 0.5).clamp(0, 1) * 255).cpu().numpy()
324
- for i in range(rec_img.shape[0]):
325
- arr = rec_img[i].transpose(1, 2, 0).astype(np.uint8)
326
- Image.fromarray(arr).save(f"{generated_folder}/sample_{step if step is not None else 'init'}_{i}.jpg", quality=95)
327
-
328
- # LPIPS на полном изображении (high-res)
329
- lpips_scores = []
330
- for i in range(rec.shape[0]):
331
- orig_full = orig_high[i:i+1] # [B, C, H, W], in [-1,1]
332
- rec_full = rec[i:i+1]
333
- # ensure same spatial size/dtype
334
- if rec_full.shape[-2:] != orig_full.shape[-2:]:
335
- rec_full = F.interpolate(rec_full, size=orig_full.shape[-2:], mode="bilinear", align_corners=False)
336
- rec_full = rec_full.to(torch.float32)
337
- orig_full = orig_full.to(torch.float32)
338
- lpips_val = lpips_net(orig_full, rec_full).item()
339
- lpips_scores.append(lpips_val)
340
- avg_lpips = float(np.mean(lpips_scores))
341
- if use_wandb and accelerator.is_main_process:
342
- wandb.log({
343
- "generated_images": [wandb.Image(Image.fromarray(rec_img[i].transpose(1,2,0).astype(np.uint8))) for i in range(rec_img.shape[0])],
344
- "lpips_mean": avg_lpips
345
- }, step=step)
346
- finally:
347
- gc.collect()
348
- torch.cuda.empty_cache()
349
-
350
- if accelerator.is_main_process and save_model:
351
- print("Генерация сэмплов до старта обучения...")
352
- generate_and_save_samples(0)
353
-
354
- accelerator.wait_for_everyone()
355
-
356
- # --------------------------- Тренировка ---------------------------
357
-
358
- progress = tqdm(total=total_steps, disable=not accelerator.is_local_main_process)
359
- global_step = 0
360
- min_loss = float("inf")
361
- sample_interval = max(1, total_steps // max(1, sample_interval_share * num_epochs))
362
-
363
- for epoch in range(num_epochs):
364
- vae.train()
365
- batch_losses = []
366
- batch_losses_mae = []
367
- batch_losses_lpips = []
368
- batch_losses_perc = []
369
- batch_grads = []
370
- for imgs in dataloader:
371
- with accelerator.accumulate(vae):
372
- # imgs: high-res tensor from dataloader ([-1,1]), move to device
373
- imgs = imgs.to(accelerator.device)
374
-
375
- # >>> CHANGED: create low-res input for model by downsampling high-res crop
376
- if model_resolution==high_resolution:
377
- imgs_low = F.interpolate(imgs, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
378
- else:
379
- imgs_low = imgs
380
-
381
- # ensure dtype matches model params to avoid float/half mismatch
382
- model_dtype = next(vae.parameters()).dtype
383
- if imgs_low.dtype != model_dtype:
384
- imgs_low_model = imgs_low.to(dtype=model_dtype)
385
- else:
386
- imgs_low_model = imgs_low
387
-
388
- # Encode/decode on low-res input
389
- latent_dist = vae.encode(imgs_low_model).latent_dist
390
- latents = latent_dist.mean
391
- rec = vae.decode(latents).sample # rec is expected to be high-res (upscaled)
392
-
393
- # If rec isn't the same spatial size as original high-res input, resize to high-res
394
- if rec.shape[-2:] != imgs.shape[-2:]:
395
- rec = F.interpolate(rec, size=imgs.shape[-2:], mode="bilinear", align_corners=False)
396
-
397
- # Now compute losses **on high-res** (rec vs imgs)
398
- rec_f32 = rec.to(torch.float32)
399
- imgs_f32 = imgs.to(torch.float32)
400
-
401
- # MAE
402
- mae_loss = F.l1_loss(rec_f32, imgs_f32)
403
-
404
- # LPIPS (ensure float32)
405
- lpips_loss = _get_lpips()(rec_f32, imgs_f32).mean()
406
-
407
- # dynamic perceptual weighting (same as before)
408
- if float(mae_loss.detach().cpu().item()) > 1e-12:
409
- desired_multiplier = lpips_ratio / max(1.0 - lpips_ratio, 1e-12)
410
- new_weight = (mae_loss.item() / float(lpips_loss.detach().cpu().item())) * desired_multiplier
411
- else:
412
- new_weight = perceptual_loss_weight
413
-
414
- perceptual_loss_weight = float(np.clip(new_weight, min_perceptual_weight, max_perceptual_weight))
415
- batch_losses_perc.append(perceptual_loss_weight)
416
- if len(batch_losses_perc) >= sample_interval:
417
- avg_perc = float(np.mean(batch_losses_perc[-sample_interval:]))
418
- else:
419
- avg_perc = float(np.mean(batch_losses_perc[-sample_interval:]))
420
-
421
- total_loss = mae_loss + avg_perc * lpips_loss
422
-
423
- if torch.isnan(total_loss) or torch.isinf(total_loss):
424
- print("NaN/Inf loss – stopping")
425
- raise RuntimeError("NaN/Inf loss")
426
-
427
- accelerator.backward(total_loss)
428
-
429
- grad_norm = torch.tensor(0.0, device=accelerator.device)
430
- if accelerator.sync_gradients:
431
- grad_norm = accelerator.clip_grad_norm_(trainable_params, clip_grad_norm)
432
- optimizer.step()
433
- scheduler.step()
434
- optimizer.zero_grad(set_to_none=True)
435
-
436
- global_step += 1
437
- progress.update(1)
438
-
439
- # --- Логирование ---
440
- if accelerator.is_main_process:
441
- try:
442
- current_lr = optimizer.param_groups[0]["lr"]
443
- except Exception:
444
- current_lr = scheduler.get_last_lr()[0]
445
-
446
- batch_losses.append(total_loss.detach().item())
447
- batch_losses_mae.append(mae_loss.detach().item())
448
- batch_losses_lpips.append(lpips_loss.detach().item())
449
- batch_grads.append(float(grad_norm if isinstance(grad_norm, (float, int)) else grad_norm.cpu().item()))
450
-
451
- if use_wandb and accelerator.sync_gradients:
452
- wandb.log({
453
- "mae_loss": mae_loss.detach().item(),
454
- "lpips_loss": lpips_loss.detach().item(),
455
- "perceptual_loss_weight": avg_perc,
456
- "total_loss": total_loss.detach().item(),
457
- "learning_rate": current_lr,
458
- "epoch": epoch,
459
- "grad_norm": batch_grads[-1],
460
- }, step=global_step)
461
-
462
- # периодические сэмплы и чекпоинты
463
- if global_step > 0 and global_step % sample_interval == 0:
464
- # делаем генерацию и лог только в main process (генерация использует fixed_samples high-res)
465
- if accelerator.is_main_process:
466
- generate_and_save_samples(global_step)
467
-
468
- accelerator.wait_for_everyone()
469
-
470
- # сколько микро-батчей нужно взять для усреднения
471
- n_micro = sample_interval * gradient_accumulation_steps
472
- # защищаем от выхода за пределы
473
- if len(batch_losses) >= n_micro:
474
- avg_loss = float(np.mean(batch_losses[-n_micro:]))
475
- avg_loss_mae = float(np.mean(batch_losses_mae[-n_micro:]))
476
- avg_loss_lpips = float(np.mean(batch_losses_lpips[-n_micro:]))
477
- else:
478
- avg_loss = float(np.mean(batch_losses)) if batch_losses else float("nan")
479
- avg_loss_mae = float(np.mean(batch_losses_mae)) if batch_losses_mae else float("nan")
480
- avg_loss_lpips = float(np.mean(batch_losses_lpips)) if batch_losses_lpips else float("nan")
481
-
482
- avg_grad = float(np.mean(batch_grads[-n_micro:])) if len(batch_grads) >= 1 else float(np.mean(batch_grads)) if batch_grads else 0.0
483
-
484
- if accelerator.is_main_process:
485
- print(f"Epoch {epoch} step {global_step} loss: {avg_loss:.6f}, grad_norm: {avg_grad:.6f}, lr: {current_lr:.9f}")
486
- if save_model and avg_loss < min_loss * save_barrier:
487
- min_loss = avg_loss
488
- accelerator.unwrap_model(vae).save_pretrained(save_as)
489
- if use_wandb:
490
- wandb.log({"interm_loss": avg_loss,"interm_loss_mae": avg_loss_mae,"interm_loss_lpips": avg_loss_lpips, "interm_grad": avg_grad}, step=global_step)
491
-
492
- if accelerator.is_main_process:
493
- epoch_avg = float(np.mean(batch_losses)) if batch_losses else float("nan")
494
- print(f"Epoch {epoch} done, avg loss {epoch_avg:.6f}")
495
- if use_wandb:
496
- wandb.log({"epoch_loss": epoch_avg, "epoch": epoch + 1}, step=global_step)
497
-
498
- # --------------------------- Финальное сохранение ---------------------------
499
- if accelerator.is_main_process:
500
- print("Training finished – saving final model")
501
- if save_model:
502
- accelerator.unwrap_model(vae).save_pretrained(save_as)
503
-
504
- accelerator.free_memory()
505
- if torch.distributed.is_initialized():
506
- torch.distributed.destroy_process_group()
507
- print("Готово!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_sdxl_vae_qwen.py DELETED
@@ -1,526 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- import os
3
- import math
4
- import re
5
- import torch
6
- import numpy as np
7
- import random
8
- import gc
9
- from datetime import datetime
10
- from pathlib import Path
11
-
12
- import torchvision.transforms as transforms
13
- import torch.nn.functional as F
14
- from torch.utils.data import DataLoader, Dataset
15
- from torch.optim.lr_scheduler import LambdaLR
16
- from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
17
- # QWEN: импорт класса
18
- from diffusers import AutoencoderKLQwenImage
19
-
20
- from accelerate import Accelerator
21
- from PIL import Image, UnidentifiedImageError
22
- from tqdm import tqdm
23
- import bitsandbytes as bnb
24
- import wandb
25
- import lpips # pip install lpips
26
- from collections import deque
27
-
28
- # --------------------------- Параметры ---------------------------
29
- ds_path = "/workspace/png"
30
- project = "qwen_vae"
31
- batch_size = 3
32
- base_learning_rate = 5e-5
33
- min_learning_rate = 9e-7
34
- num_epochs = 16
35
- sample_interval_share = 10
36
- use_wandb = True
37
- save_model = True
38
- use_decay = True
39
- optimizer_type = "adam8bit"
40
- dtype = torch.float32
41
-
42
- model_resolution = 512
43
- high_resolution = 512
44
- limit = 0
45
- save_barrier = 1.03
46
- warmup_percent = 0.01
47
- percentile_clipping = 95
48
- beta2 = 0.97
49
- eps = 1e-6
50
- clip_grad_norm = 1.0
51
- mixed_precision = "no"
52
- gradient_accumulation_steps = 5
53
- generated_folder = "samples"
54
- save_as = "wen_vae_nightly"
55
- num_workers = 0
56
- device = None
57
-
58
- # --- Режимы обучения ---
59
- # QWEN: учим только декодер
60
- train_decoder_only = True
61
- full_training = False # если True — учим весь VAE и добавляем KL (ниже)
62
- kl_ratio = 0.05
63
-
64
- # Доли лоссов
65
- loss_ratios = {
66
- "lpips": 0.80,
67
- "edge": 0.05,
68
- "mse": 0.10,
69
- "mae": 0.05,
70
- "kl": 0.00, # активируем при full_training=True
71
- }
72
- median_coeff_steps = 256
73
-
74
- resize_long_side = 1280 # ресайз длинной стороны исходных картинок
75
-
76
- # QWEN: конфиг загрузки модели
77
- vae_kind = "qwen" # "qwen" или "kl" (обычный)
78
- vae_model_id = "Qwen/Qwen-Image"
79
- vae_subfolder = "vae"
80
-
81
- Path(generated_folder).mkdir(parents=True, exist_ok=True)
82
-
83
- accelerator = Accelerator(
84
- mixed_precision=mixed_precision,
85
- gradient_accumulation_steps=gradient_accumulation_steps
86
- )
87
- device = accelerator.device
88
-
89
- # reproducibility
90
- seed = int(datetime.now().strftime("%Y%m%d"))
91
- torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
92
- torch.backends.cudnn.benchmark = False
93
-
94
- # --------------------------- WandB ---------------------------
95
- if use_wandb and accelerator.is_main_process:
96
- wandb.init(project=project, config={
97
- "batch_size": batch_size,
98
- "base_learning_rate": base_learning_rate,
99
- "num_epochs": num_epochs,
100
- "optimizer_type": optimizer_type,
101
- "model_resolution": model_resolution,
102
- "high_resolution": high_resolution,
103
- "gradient_accumulation_steps": gradient_accumulation_steps,
104
- "train_decoder_only": train_decoder_only,
105
- "full_training": full_training,
106
- "kl_ratio": kl_ratio,
107
- "vae_kind": vae_kind,
108
- "vae_model_id": vae_model_id,
109
- })
110
-
111
- # --------------------------- VAE ---------------------------
112
- def is_qwen_vae(vae) -> bool:
113
- return isinstance(vae, AutoencoderKLQwenImage) or ("Qwen" in vae.__class__.__name__)
114
-
115
- # загрузка
116
- if vae_kind == "qwen":
117
- vae = AutoencoderKLQwenImage.from_pretrained(vae_model_id, subfolder=vae_subfolder)
118
- else:
119
- # старое поведение (пример)
120
- if model_resolution==high_resolution:
121
- vae = AutoencoderKL.from_pretrained(project)
122
- else:
123
- vae = AsymmetricAutoencoderKL.from_pretrained(project)
124
-
125
- vae = vae.to(dtype)
126
-
127
- # torch.compile (опционально)
128
- if hasattr(torch, "compile"):
129
- try:
130
- vae = torch.compile(vae)
131
- except Exception as e:
132
- print(f"[WARN] torch.compile failed: {e}")
133
-
134
- # --------------------------- Freeze/Unfreeze ---------------------------
135
- for p in vae.parameters():
136
- p.requires_grad = False
137
-
138
- unfrozen_param_names = []
139
-
140
- if full_training and not train_decoder_only:
141
- # учим всю модель
142
- for name, p in vae.named_parameters():
143
- p.requires_grad = True
144
- unfrozen_param_names.append(name)
145
- loss_ratios["kl"] = float(kl_ratio)
146
- trainable_module = vae
147
- else:
148
- # QWEN: учим только декодер (и post_quant_conv — часть декодерного тракта)
149
- # универсально: всё, что начинается с "decoder." или "post_quant_conv"
150
- for name, p in vae.named_parameters():
151
- if name.startswith("decoder.") or name.startswith("post_quant_conv"):
152
- p.requires_grad = True
153
- unfrozen_param_names.append(name)
154
- trainable_module = vae.decoder if hasattr(vae, "decoder") else vae
155
-
156
- print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:")
157
- for nm in unfrozen_param_names[:200]:
158
- print(" ", nm)
159
-
160
- # --------------------------- Датасет ---------------------------
161
- class PngFolderDataset(Dataset):
162
- def __init__(self, root_dir, min_exts=('.png',), resolution=1024, limit=0):
163
- self.root_dir = root_dir
164
- self.resolution = resolution
165
- self.paths = []
166
- for root, _, files in os.walk(root_dir):
167
- for fname in files:
168
- if fname.lower().endswith(tuple(ext.lower() for ext in min_exts)):
169
- self.paths.append(os.path.join(root, fname))
170
- if limit:
171
- self.paths = self.paths[:limit]
172
- valid = []
173
- for p in self.paths:
174
- try:
175
- with Image.open(p) as im:
176
- im.verify()
177
- valid.append(p)
178
- except (OSError, UnidentifiedImageError):
179
- continue
180
- self.paths = valid
181
- if len(self.paths) == 0:
182
- raise RuntimeError(f"No valid PNG images found under {root_dir}")
183
- random.shuffle(self.paths)
184
-
185
- def __len__(self):
186
- return len(self.paths)
187
-
188
- def __getitem__(self, idx):
189
- p = self.paths[idx % len(self.paths)]
190
- with Image.open(p) as img:
191
- img = img.convert("RGB")
192
- if not resize_long_side or resize_long_side <= 0:
193
- return img
194
- w, h = img.size
195
- long = max(w, h)
196
- if long <= resize_long_side:
197
- return img
198
- scale = resize_long_side / float(long)
199
- new_w = int(round(w * scale))
200
- new_h = int(round(h * scale))
201
- return img.resize((new_w, new_h), Image.LANCZOS)
202
-
203
- def random_crop(img, sz):
204
- w, h = img.size
205
- if w < sz or h < sz:
206
- img = img.resize((max(sz, w), max(sz, h)), Image.LANCZOS)
207
- x = random.randint(0, max(1, img.width - sz))
208
- y = random.randint(0, max(1, img.height - sz))
209
- return img.crop((x, y, x + sz, y + sz))
210
-
211
- tfm = transforms.Compose([
212
- transforms.ToTensor(),
213
- transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
214
- ])
215
-
216
- dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit)
217
- if len(dataset) < batch_size:
218
- raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}")
219
-
220
- def collate_fn(batch):
221
- imgs = []
222
- for img in batch:
223
- img = random_crop(img, high_resolution)
224
- imgs.append(tfm(img))
225
- return torch.stack(imgs)
226
-
227
- dataloader = DataLoader(
228
- dataset,
229
- batch_size=batch_size,
230
- shuffle=True,
231
- collate_fn=collate_fn,
232
- num_workers=num_workers,
233
- pin_memory=True,
234
- drop_last=True
235
- )
236
-
237
- # --------------------------- Оптимизатор ---------------------------
238
- def get_param_groups(module, weight_decay=0.001):
239
- no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "ln_1.weight", "ln_f.weight"]
240
- decay_params, no_decay_params = [], []
241
- for n, p in vae.named_parameters(): # глобально по vae, с фильтром requires_grad
242
- if not p.requires_grad:
243
- continue
244
- if any(nd in n for nd in no_decay):
245
- no_decay_params.append(p)
246
- else:
247
- decay_params.append(p)
248
- return [
249
- {"params": decay_params, "weight_decay": weight_decay},
250
- {"params": no_decay_params, "weight_decay": 0.0},
251
- ]
252
-
253
- def create_optimizer(name, param_groups):
254
- if name == "adam8bit":
255
- return bnb.optim.AdamW8bit(param_groups, lr=base_learning_rate, betas=(0.9, beta2), eps=eps)
256
- raise ValueError(name)
257
-
258
- param_groups = get_param_groups(trainable_module, weight_decay=0.001)
259
- optimizer = create_optimizer(optimizer_type, param_groups)
260
-
261
- # --------------------------- LR schedule ---------------------------
262
- batches_per_epoch = len(dataloader)
263
- steps_per_epoch = int(math.ceil(batches_per_epoch / float(gradient_accumulation_steps)))
264
- total_steps = steps_per_epoch * num_epochs
265
-
266
- def lr_lambda(step):
267
- if not use_decay:
268
- return 1.0
269
- x = float(step) / float(max(1, total_steps))
270
- warmup = float(warmup_percent)
271
- min_ratio = float(min_learning_rate) / float(base_learning_rate)
272
- if x < warmup:
273
- return min_ratio + (1.0 - min_ratio) * (x / warmup)
274
- decay_ratio = (x - warmup) / (1.0 - warmup)
275
- return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * decay_ratio))
276
-
277
- scheduler = LambdaLR(optimizer, lr_lambda)
278
-
279
- # Подготовка
280
- dataloader, vae, optimizer, scheduler = accelerator.prepare(dataloader, vae, optimizer, scheduler)
281
- trainable_params = [p for p in vae.parameters() if p.requires_grad]
282
-
283
- # --------------------------- LPIPS и вспомогательные ---------------------------
284
- _lpips_net = None
285
- def _get_lpips():
286
- global _lpips_net
287
- if _lpips_net is None:
288
- _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
289
- return _lpips_net
290
-
291
- _sobel_kx = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], dtype=torch.float32)
292
- _sobel_ky = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], dtype=torch.float32)
293
- def sobel_edges(x: torch.Tensor) -> torch.Tensor:
294
- C = x.shape[1]
295
- kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1)
296
- ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1)
297
- gx = F.conv2d(x, kx, padding=1, groups=C)
298
- gy = F.conv2d(x, ky, padding=1, groups=C)
299
- return torch.sqrt(gx * gx + gy * gy + 1e-12)
300
-
301
- class MedianLossNormalizer:
302
- def __init__(self, desired_ratios: dict, window_steps: int):
303
- s = sum(desired_ratios.values())
304
- self.ratios = {k: (v / s) if s > 0 else 0.0 for k, v in desired_ratios.items()}
305
- self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
306
- self.window = window_steps
307
-
308
- def update_and_total(self, abs_losses: dict):
309
- for k, v in abs_losses.items():
310
- if k in self.buffers:
311
- self.buffers[k].append(float(v.detach().abs().cpu()))
312
- meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
313
- coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
314
- total = sum(coeffs[k] * abs_losses[k] for k in abs_losses if k in coeffs)
315
- return total, coeffs, meds
316
-
317
- if full_training and not train_decoder_only:
318
- loss_ratios["kl"] = float(kl_ratio)
319
- normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
320
-
321
- # --------------------------- Сэмплы ---------------------------
322
- @torch.no_grad()
323
- def get_fixed_samples(n=3):
324
- idx = random.sample(range(len(dataset)), min(n, len(dataset)))
325
- pil_imgs = [dataset[i] for i in idx]
326
- tensors = []
327
- for img in pil_imgs:
328
- img = random_crop(img, high_resolution)
329
- tensors.append(tfm(img))
330
- return torch.stack(tensors).to(accelerator.device, dtype)
331
-
332
- fixed_samples = get_fixed_samples()
333
-
334
- @torch.no_grad()
335
- def _to_pil_uint8(img_tensor: torch.Tensor) -> Image.Image:
336
- arr = ((img_tensor.float().clamp(-1, 1) + 1.0) * 127.5).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0)
337
- return Image.fromarray(arr)
338
-
339
- @torch.no_grad()
340
- def generate_and_save_samples(step=None):
341
- try:
342
- temp_vae = accelerator.unwrap_model(vae).eval()
343
- lpips_net = _get_lpips()
344
- with torch.no_grad():
345
- orig_high = fixed_samples
346
- orig_low = F.interpolate(orig_high, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
347
- model_dtype = next(temp_vae.parameters()).dtype
348
- orig_low = orig_low.to(dtype=model_dtype)
349
-
350
- # QWEN: добавляем T=1 на encode/decode и снимаем при сравнении
351
- if is_qwen_vae(temp_vae):
352
- x_in = orig_low.unsqueeze(2) # [B,3,1,H,W]
353
- enc = temp_vae.encode(x_in)
354
- latents_mean = enc.latent_dist.mean
355
- dec = temp_vae.decode(latents_mean).sample # [B,3,1,H,W]
356
- rec = dec.squeeze(2) # [B,3,H,W]
357
- else:
358
- enc = temp_vae.encode(orig_low)
359
- latents_mean = enc.latent_dist.mean
360
- rec = temp_vae.decode(latents_mean).sample
361
-
362
- if rec.shape[-2:] != orig_high.shape[-2:]:
363
- rec = F.interpolate(rec, size=orig_high.shape[-2:], mode="bilinear", align_corners=False)
364
-
365
- first_real = _to_pil_uint8(orig_high[0])
366
- first_dec = _to_pil_uint8(rec[0])
367
- first_real.save(f"{generated_folder}/sample_real.jpg", quality=95)
368
- first_dec.save(f"{generated_folder}/sample_decoded.jpg", quality=95)
369
-
370
- for i in range(rec.shape[0]):
371
- _to_pil_uint8(rec[i]).save(f"{generated_folder}/sample_{i}.jpg", quality=95)
372
-
373
- lpips_scores = []
374
- for i in range(rec.shape[0]):
375
- orig_full = orig_high[i:i+1].to(torch.float32)
376
- rec_full = rec[i:i+1].to(torch.float32)
377
- if rec_full.shape[-2:] != orig_full.shape[-2:]:
378
- rec_full = F.interpolate(rec_full, size=orig_full.shape[-2:], mode="bilinear", align_corners=False)
379
- lpips_val = lpips_net(orig_full, rec_full).item()
380
- lpips_scores.append(lpips_val)
381
- avg_lpips = float(np.mean(lpips_scores))
382
-
383
- if use_wandb and accelerator.is_main_process:
384
- wandb.log({"lpips_mean": avg_lpips}, step=step)
385
- finally:
386
- gc.collect()
387
- torch.cuda.empty_cache()
388
-
389
- if accelerator.is_main_process and save_model:
390
- print("Генерация сэмплов до старта обучения...")
391
- generate_and_save_samples(0)
392
-
393
- accelerator.wait_for_everyone()
394
-
395
- # --------------------------- Тренировка ---------------------------
396
- progress = tqdm(total=total_steps, disable=not accelerator.is_local_main_process)
397
- global_step = 0
398
- min_loss = float("inf")
399
- sample_interval = max(1, total_steps // max(1, sample_interval_share * num_epochs))
400
-
401
- for epoch in range(num_epochs):
402
- vae.train()
403
- batch_losses, batch_grads = [], []
404
- track_losses = {k: [] for k in loss_ratios.keys()}
405
-
406
- for imgs in dataloader:
407
- with accelerator.accumulate(vae):
408
- imgs = imgs.to(accelerator.device)
409
-
410
- if high_resolution != model_resolution:
411
- imgs_low = F.interpolate(imgs, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
412
- else:
413
- imgs_low = imgs
414
-
415
- model_dtype = next(vae.parameters()).dtype
416
- imgs_low_model = imgs_low.to(dtype=model_dtype) if imgs_low.dtype != model_dtype else imgs_low
417
-
418
- # QWEN: encode/decode с T=1
419
- if is_qwen_vae(vae):
420
- x_in = imgs_low_model.unsqueeze(2) # [B,3,1,H,W]
421
- enc = vae.encode(x_in)
422
- latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
423
- dec = vae.decode(latents).sample # [B,3,1,H,W]
424
- rec = dec.squeeze(2) # [B,3,H,W]
425
- else:
426
- enc = vae.encode(imgs_low_model)
427
- latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
428
- rec = vae.decode(latents).sample
429
-
430
- if rec.shape[-2:] != imgs.shape[-2:]:
431
- rec = F.interpolate(rec, size=imgs.shape[-2:], mode="bilinear", align_corners=False)
432
-
433
- rec_f32 = rec.to(torch.float32)
434
- imgs_f32 = imgs.to(torch.float32)
435
-
436
- abs_losses = {
437
- "mae": F.l1_loss(rec_f32, imgs_f32),
438
- "mse": F.mse_loss(rec_f32, imgs_f32),
439
- "lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
440
- "edge": F.l1_loss(sobel_edges(rec_f32), sobel_edges(imgs_f32)),
441
- }
442
-
443
- if full_training and not train_decoder_only:
444
- mean = enc.latent_dist.mean
445
- logvar = enc.latent_dist.logvar
446
- kl = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
447
- abs_losses["kl"] = kl
448
- else:
449
- abs_losses["kl"] = torch.tensor(0.0, device=accelerator.device, dtype=torch.float32)
450
-
451
- total_loss, coeffs, meds = normalizer.update_and_total(abs_losses)
452
-
453
- if torch.isnan(total_loss) or torch.isinf(total_loss):
454
- raise RuntimeError("NaN/Inf loss")
455
-
456
- accelerator.backward(total_loss)
457
-
458
- grad_norm = torch.tensor(0.0, device=accelerator.device)
459
- if accelerator.sync_gradients:
460
- grad_norm = accelerator.clip_grad_norm_(trainable_params, clip_grad_norm)
461
- optimizer.step()
462
- scheduler.step()
463
- optimizer.zero_grad(set_to_none=True)
464
- global_step += 1
465
- progress.update(1)
466
-
467
- if accelerator.is_main_process:
468
- try:
469
- current_lr = optimizer.param_groups[0]["lr"]
470
- except Exception:
471
- current_lr = scheduler.get_last_lr()[0]
472
-
473
- batch_losses.append(total_loss.detach().item())
474
- batch_grads.append(float(grad_norm.detach().cpu().item()) if isinstance(grad_norm, torch.Tensor) else float(grad_norm))
475
- for k, v in abs_losses.items():
476
- track_losses[k].append(float(v.detach().item()))
477
-
478
- if use_wandb and accelerator.sync_gradients:
479
- log_dict = {
480
- "total_loss": float(total_loss.detach().item()),
481
- "learning_rate": current_lr,
482
- "epoch": epoch,
483
- "grad_norm": batch_grads[-1],
484
- "mode/train_decoder_only": int(train_decoder_only),
485
- "mode/full_training": int(full_training),
486
- }
487
- for k, v in abs_losses.items():
488
- log_dict[f"loss_{k}"] = float(v.detach().item())
489
- for k in coeffs:
490
- log_dict[f"coeff_{k}"] = float(coeffs[k])
491
- log_dict[f"median_{k}"] = float(meds[k])
492
- wandb.log(log_dict, step=global_step)
493
-
494
- if global_step > 0 and global_step % sample_interval == 0:
495
- if accelerator.is_main_process:
496
- generate_and_save_samples(global_step)
497
- accelerator.wait_for_everyone()
498
-
499
- n_micro = sample_interval * gradient_accumulation_steps
500
- avg_loss = float(np.mean(batch_losses[-n_micro:])) if len(batch_losses) >= n_micro else float(np.mean(batch_losses)) if batch_losses else float("nan")
501
- avg_grad = float(np.mean(batch_grads[-n_micro:])) if len(batch_grads) >= 1 else float(np.mean(batch_grads)) if batch_grads else 0.0
502
-
503
- if accelerator.is_main_process:
504
- print(f"Epoch {epoch} step {global_step} loss: {avg_loss:.6f}, grad_norm: {avg_grad:.6f}, lr: {current_lr:.9f}")
505
- if save_model and avg_loss < min_loss * save_barrier:
506
- min_loss = avg_loss
507
- accelerator.unwrap_model(vae).save_pretrained(save_as)
508
- if use_wandb:
509
- wandb.log({"interm_loss": avg_loss, "interm_grad": avg_grad}, step=global_step)
510
-
511
- if accelerator.is_main_process:
512
- epoch_avg = float(np.mean(batch_losses)) if batch_losses else float("nan")
513
- print(f"Epoch {epoch} done, avg loss {epoch_avg:.6f}")
514
- if use_wandb:
515
- wandb.log({"epoch_loss": epoch_avg, "epoch": epoch + 1}, step=global_step)
516
-
517
- # --------------------------- Финальное сохранение ---------------------------
518
- if accelerator.is_main_process:
519
- print("Training finished – saving final model")
520
- if save_model:
521
- accelerator.unwrap_model(vae).save_pretrained(save_as)
522
-
523
- accelerator.free_memory()
524
- if torch.distributed.is_initialized():
525
- torch.distributed.destroy_process_group()
526
- print("Готово!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_sdxl_vae_simple.py DELETED
@@ -1,547 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- import os
3
- import math
4
- import re
5
- import torch
6
- import numpy as np
7
- import random
8
- import gc
9
- from datetime import datetime
10
- from pathlib import Path
11
-
12
- import torchvision.transforms as transforms
13
- import torch.nn.functional as F
14
- from torch.utils.data import DataLoader, Dataset
15
- from torch.optim.lr_scheduler import LambdaLR
16
- from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
17
- from accelerate import Accelerator
18
- from PIL import Image, UnidentifiedImageError
19
- from tqdm import tqdm
20
- import bitsandbytes as bnb
21
- import wandb
22
- import lpips # pip install lpips
23
- from collections import deque
24
-
25
- # --------------------------- Параметры ---------------------------
26
- ds_path = "/workspace/png"
27
- project = "simple_vae"
28
- batch_size = 3
29
- base_learning_rate = 5e-5
30
- min_learning_rate = 9e-7
31
- num_epochs = 16
32
- sample_interval_share = 10
33
- use_wandb = True
34
- save_model = True
35
- use_decay = True
36
- asymmetric = False
37
- optimizer_type = "adam8bit"
38
- dtype = torch.float32
39
- # model_resolution — то, что подавается в VAE (низкое разрешение)
40
- model_resolution = 512 # бывший `resolution`
41
- # high_resolution — настоящий «высокий» кроп, на котором считаем метрики и сохраняем сэмплы
42
- high_resolution = 512
43
- limit = 0
44
- save_barrier = 1.03
45
- warmup_percent = 0.01
46
- percentile_clipping = 95
47
- beta2 = 0.97
48
- eps = 1e-6
49
- clip_grad_norm = 1.0
50
- mixed_precision = "no" # или "fp16"/"bf16" при поддержке
51
- gradient_accumulation_steps = 5
52
- generated_folder = "samples"
53
- save_as = "simple_vae_nightly"
54
- num_workers = 0
55
- device = None # accelerator задаст устройство
56
-
57
- # --- Пропорции лоссов и окно медианного нормирования (КОЭФ., не значения) ---
58
- # Итоговые доли в total loss (сумма = 1.0)
59
- loss_ratios = {
60
- "lpips": 0.85,
61
- "edge": 0.05,
62
- "mse": 0.05,
63
- "mae": 0.05,
64
- }
65
- median_coeff_steps = 256 # за сколько шагов считать медианные коэффициенты
66
-
67
- # --------------------------- параметры препроцессинга ---------------------------
68
- resize_long_side = 1280 # если None или 0 — ресайза не будет; рекомендовано 1280
69
-
70
- Path(generated_folder).mkdir(parents=True, exist_ok=True)
71
-
72
- accelerator = Accelerator(
73
- mixed_precision=mixed_precision,
74
- gradient_accumulation_steps=gradient_accumulation_steps
75
- )
76
- device = accelerator.device
77
-
78
- # reproducibility
79
- seed = int(datetime.now().strftime("%Y%m%d"))
80
- torch.manual_seed(seed)
81
- np.random.seed(seed)
82
- random.seed(seed)
83
-
84
- torch.backends.cudnn.benchmark = True
85
-
86
- # --------------------------- WandB ---------------------------
87
- if use_wandb and accelerator.is_main_process:
88
- wandb.init(project=project, config={
89
- "batch_size": batch_size,
90
- "base_learning_rate": base_learning_rate,
91
- "num_epochs": num_epochs,
92
- "optimizer_type": optimizer_type,
93
- "model_resolution": model_resolution,
94
- "high_resolution": high_resolution,
95
- "gradient_accumulation_steps": gradient_accumulation_steps,
96
- })
97
-
98
- # --------------------------- VAE ---------------------------
99
- if model_resolution==high_resolution and not asymmetric:
100
- vae = AutoencoderKL.from_pretrained(project).to(dtype)
101
- else:
102
- vae = AsymmetricAutoencoderKL.from_pretrained(project).to(dtype)
103
-
104
- # torch.compile (если доступно) — просто и без лишней логики
105
- if hasattr(torch, "compile"):
106
- try:
107
- vae = torch.compile(vae)
108
- except Exception as e:
109
- print(f"[WARN] torch.compile failed: {e}")
110
-
111
- # >>> Заморозка всех параметров, затем выборочная разморозка
112
- for p in vae.parameters():
113
- p.requires_grad = False
114
-
115
- decoder = getattr(vae, "decoder", None)
116
- if decoder is None:
117
- raise RuntimeError("vae.decoder not found — не могу применить стратегию разморозки. Проверь структуру модели.")
118
-
119
- unfrozen_param_names = []
120
-
121
- if not hasattr(decoder, "up_blocks"):
122
- raise RuntimeError("decoder.up_blocks не найдены — ожидается список блоков декодера.")
123
-
124
- # >>> Размораживаем все up_blocks и mid_block (как было в твоём варианте start_idx=0)
125
- n_up = len(decoder.up_blocks)
126
- start_idx = 0
127
- for idx in range(start_idx, n_up):
128
- block = decoder.up_blocks[idx]
129
- for name, p in block.named_parameters():
130
- p.requires_grad = True
131
- unfrozen_param_names.append(f"decoder.up_blocks.{idx}.{name}")
132
-
133
- if hasattr(decoder, "mid_block"):
134
- for name, p in decoder.mid_block.named_parameters():
135
- p.requires_grad = True
136
- unfrozen_param_names.append(f"decoder.mid_block.{name}")
137
- else:
138
- print("[WARN] decoder.mid_block не найден — mid_block не разморожен.")
139
-
140
- print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:")
141
- for nm in unfrozen_param_names[:200]:
142
- print(" ", nm)
143
-
144
- # сохраняем trainable_module (get_param_groups будет учитывать p.requires_grad)
145
- trainable_module = vae.decoder
146
-
147
- # --------------------------- Custom PNG Dataset (only .png, skip corrupted) -----------
148
- class PngFolderDataset(Dataset):
149
- def __init__(self, root_dir, min_exts=('.png',), resolution=1024, limit=0):
150
- self.root_dir = root_dir
151
- self.resolution = resolution
152
- self.paths = []
153
- # collect png files recursively
154
- for root, _, files in os.walk(root_dir):
155
- for fname in files:
156
- if fname.lower().endswith(tuple(ext.lower() for ext in min_exts)):
157
- self.paths.append(os.path.join(root, fname))
158
- # optional limit
159
- if limit:
160
- self.paths = self.paths[:limit]
161
- # verify images and keep only valid ones
162
- valid = []
163
- for p in self.paths:
164
- try:
165
- with Image.open(p) as im:
166
- im.verify() # fast check for truncated/corrupted images
167
- valid.append(p)
168
- except (OSError, UnidentifiedImageError):
169
- # skip corrupted image
170
- continue
171
- self.paths = valid
172
- if len(self.paths) == 0:
173
- raise RuntimeError(f"No valid PNG images found under {root_dir}")
174
- # final shuffle for randomness
175
- random.shuffle(self.paths)
176
-
177
- def __len__(self):
178
- return len(self.paths)
179
-
180
- def __getitem__(self, idx):
181
- p = self.paths[idx % len(self.paths)]
182
- # open and convert to RGB; ensure file is closed promptly
183
- with Image.open(p) as img:
184
- img = img.convert("RGB")
185
- # пережимаем длинную сторону до resize_long_side (Lanczos)
186
- if not resize_long_side or resize_long_side <= 0:
187
- return img
188
- w, h = img.size
189
- long = max(w, h)
190
- if long <= resize_long_side:
191
- return img
192
- scale = resize_long_side / float(long)
193
- new_w = int(round(w * scale))
194
- new_h = int(round(h * scale))
195
- return img.resize((new_w, new_h), Image.LANCZOS)
196
-
197
- # --------------------------- Датасет и трансформы ---------------------------
198
-
199
- def random_crop(img, sz):
200
- w, h = img.size
201
- if w < sz or h < sz:
202
- img = img.resize((max(sz, w), max(sz, h)), Image.LANCZOS)
203
- x = random.randint(0, max(1, img.width - sz))
204
- y = random.randint(0, max(1, img.height - sz))
205
- return img.crop((x, y, x + sz, y + sz))
206
-
207
- tfm = transforms.Compose([
208
- transforms.ToTensor(),
209
- transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
210
- ])
211
-
212
- # build dataset using high_resolution crops
213
- dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit)
214
- if len(dataset) < batch_size:
215
- raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}")
216
-
217
- # collate_fn кропит до high_resolution
218
-
219
- def collate_fn(batch):
220
- imgs = []
221
- for img in batch: # img is PIL.Image
222
- img = random_crop(img, high_resolution) # кропим high-res
223
- imgs.append(tfm(img))
224
- return torch.stack(imgs)
225
-
226
- dataloader = DataLoader(
227
- dataset,
228
- batch_size=batch_size,
229
- shuffle=True,
230
- collate_fn=collate_fn,
231
- num_workers=num_workers,
232
- pin_memory=True,
233
- drop_last=True
234
- )
235
-
236
- # --------------------------- Оптимизатор ---------------------------
237
-
238
- def get_param_groups(module, weight_decay=0.001):
239
- no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "ln_1.weight", "ln_f.weight"]
240
- decay_params = []
241
- no_decay_params = []
242
- for n, p in module.named_parameters():
243
- if not p.requires_grad:
244
- continue
245
- if any(nd in n for nd in no_decay):
246
- no_decay_params.append(p)
247
- else:
248
- decay_params.append(p)
249
- return [
250
- {"params": decay_params, "weight_decay": weight_decay},
251
- {"params": no_decay_params, "weight_decay": 0.0},
252
- ]
253
-
254
- def create_optimizer(name, param_groups):
255
- if name == "adam8bit":
256
- return bnb.optim.AdamW8bit(
257
- param_groups, lr=base_learning_rate, betas=(0.9, beta2), eps=eps
258
- )
259
- raise ValueError(name)
260
-
261
- param_groups = get_param_groups(trainable_module, weight_decay=0.001)
262
- optimizer = create_optimizer(optimizer_type, param_groups)
263
-
264
- # --------------------------- Подготовка Accelerate (вместе) ---------------------------
265
-
266
- batches_per_epoch = len(dataloader) # число микро-батчей (dataloader steps)
267
- steps_per_epoch = int(math.ceil(batches_per_epoch / float(gradient_accumulation_steps))) # число optimizer.step() за эпоху
268
- total_steps = steps_per_epoch * num_epochs
269
-
270
-
271
- def lr_lambda(step):
272
- if not use_decay:
273
- return 1.0
274
- x = float(step) / float(max(1, total_steps))
275
- warmup = float(warmup_percent)
276
- min_ratio = float(min_learning_rate) / float(base_learning_rate)
277
- if x < warmup:
278
- return min_ratio + (1.0 - min_ratio) * (x / warmup)
279
- decay_ratio = (x - warmup) / (1.0 - warmup)
280
- return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * decay_ratio))
281
-
282
- scheduler = LambdaLR(optimizer, lr_lambda)
283
-
284
- # Подготовка
285
- dataloader, vae, optimizer, scheduler = accelerator.prepare(dataloader, vae, optimizer, scheduler)
286
-
287
- trainable_params = [p for p in vae.decoder.parameters() if p.requires_grad]
288
-
289
- # --------------------------- LPIPS и вспомогательные функции ---------------------------
290
- _lpips_net = None
291
-
292
- def _get_lpips():
293
- global _lpips_net
294
- if _lpips_net is None:
295
- _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
296
- return _lpips_net
297
-
298
- # Собель для edge loss
299
- _sobel_kx = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], dtype=torch.float32)
300
- _sobel_ky = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], dtype=torch.float32)
301
-
302
- def sobel_edges(x: torch.Tensor) -> torch.Tensor:
303
- # x: [B,C,H,W] в [-1,1]
304
- C = x.shape[1]
305
- kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1)
306
- ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1)
307
- gx = F.conv2d(x, kx, padding=1, groups=C)
308
- gy = F.conv2d(x, ky, padding=1, groups=C)
309
- return torch.sqrt(gx * gx + gy * gy + 1e-12)
310
-
311
- # Нормализация лоссов по медианам: считаем КОЭФФИЦИЕНТЫ
312
- class MedianLossNormalizer:
313
- def __init__(self, desired_ratios: dict, window_steps: int):
314
- # нормируем доли на случай, если сумма != 1
315
- s = sum(desired_ratios.values())
316
- self.ratios = {k: (v / s) for k, v in desired_ratios.items()}
317
- self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
318
- self.window = window_steps
319
-
320
- def update_and_total(self, abs_losses: dict):
321
- # Заполняем буферы фактическими АБСОЛЮТНЫМИ значениями лоссов
322
- for k, v in abs_losses.items():
323
- if k in self.buffers:
324
- self.buffers[k].append(float(v.detach().cpu()))
325
- # Медианы (устойчивые к выбросам)
326
- meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
327
- # Вычисляем КОЭФФИЦИЕНТЫ как ratio_k / median_k — т.е. именно коэффициенты, а не значения
328
- coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
329
- # Важно: при таких коэффициентах сумма (coeff_k * median_k) = сумма(ratio_k) = 1, т.е. масштаб стабилен
330
- total = sum(coeffs[k] * abs_losses[k] for k in coeffs)
331
- return total, coeffs, meds
332
-
333
- normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
334
-
335
- # --------------------------- Сэмплы ---------------------------
336
- @torch.no_grad()
337
- def get_fixed_samples(n=3):
338
- idx = random.sample(range(len(dataset)), min(n, len(dataset)))
339
- pil_imgs = [dataset[i] for i in idx] # dataset returns PIL.Image
340
- tensors = []
341
- for img in pil_imgs:
342
- img = random_crop(img, high_resolution) # high-res fixed samples
343
- tensors.append(tfm(img))
344
- return torch.stack(tensors).to(accelerator.device, dtype)
345
-
346
- fixed_samples = get_fixed_samples()
347
-
348
- @torch.no_grad()
349
- def _to_pil_uint8(img_tensor: torch.Tensor) -> Image.Image:
350
- # img_tensor: [C,H,W] in [-1,1]
351
- arr = ((img_tensor.float().clamp(-1, 1) + 1.0) * 127.5).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0)
352
- return Image.fromarray(arr)
353
-
354
- @torch.no_grad()
355
- def generate_and_save_samples(step=None):
356
- try:
357
- temp_vae = accelerator.unwrap_model(vae).eval()
358
- lpips_net = _get_lpips()
359
- with torch.no_grad():
360
- # Готовим low-res вход для кодера ВСЕГДА под model_resolution
361
- orig_high = fixed_samples # [B,C,H,W] в [-1,1]
362
- orig_low = F.interpolate(orig_high, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
363
- # dtype как у модели
364
- model_dtype = next(temp_vae.parameters()).dtype
365
- orig_low = orig_low.to(dtype=model_dtype)
366
- # encode/decode
367
- latents = temp_vae.encode(orig_low).latent_dist.mean
368
- rec = temp_vae.decode(latents).sample
369
-
370
- # Приводим spatial размер рекона к high-res (downsample для асимметричных VAE)
371
- if rec.shape[-2:] != orig_high.shape[-2:]:
372
- rec = F.interpolate(rec, size=orig_high.shape[-2:], mode="bilinear", align_corners=False)
373
-
374
- # Сохраняем ПЕРВЫЙ семпл: real и decoded без номера шага в имени
375
- first_real = _to_pil_uint8(orig_high[0])
376
- first_dec = _to_pil_uint8(rec[0])
377
- first_real.save(f"{generated_folder}/sample_real.jpg", quality=95)
378
- first_dec.save(f"{generated_folder}/sample_decoded.jpg", quality=95)
379
-
380
- # Дополнительно сохраняем текущие реконструкции без номера шага (чтобы не плодить файлы — будут перезаписываться)
381
- for i in range(rec.shape[0]):
382
- _to_pil_uint8(rec[i]).save(f"{generated_folder}/sample_{i}.jpg", quality=95)
383
-
384
- # LPIPS на полном изображении (high-res) — для лога
385
- lpips_scores = []
386
- for i in range(rec.shape[0]):
387
- orig_full = orig_high[i:i+1].to(torch.float32)
388
- rec_full = rec[i:i+1].to(torch.float32)
389
- if rec_full.shape[-2:] != orig_full.shape[-2:]:
390
- rec_full = F.interpolate(rec_full, size=orig_full.shape[-2:], mode="bilinear", align_corners=False)
391
- lpips_val = lpips_net(orig_full, rec_full).item()
392
- lpips_scores.append(lpips_val)
393
- avg_lpips = float(np.mean(lpips_scores))
394
-
395
- if use_wandb and accelerator.is_main_process:
396
- wandb.log({
397
- "lpips_mean": avg_lpips,
398
- }, step=step)
399
- finally:
400
- gc.collect()
401
- torch.cuda.empty_cache()
402
-
403
- if accelerator.is_main_process and save_model:
404
- print("Генерация сэмплов до старта обучения...")
405
- generate_and_save_samples(0)
406
-
407
- accelerator.wait_for_everyone()
408
-
409
- # --------------------------- Тренировка ---------------------------
410
-
411
- progress = tqdm(total=total_steps, disable=not accelerator.is_local_main_process)
412
- global_step = 0
413
- min_loss = float("inf")
414
- sample_interval = max(1, total_steps // max(1, sample_interval_share * num_epochs))
415
-
416
- for epoch in range(num_epochs):
417
- vae.train()
418
- batch_losses = []
419
- batch_grads = []
420
- # Доп. трекинг по отдельным лоссам
421
- track_losses = {k: [] for k in loss_ratios.keys()}
422
- for imgs in dataloader:
423
- with accelerator.accumulate(vae):
424
- # imgs: high-res tensor from dataloader ([-1,1]), move to device
425
- imgs = imgs.to(accelerator.device)
426
-
427
- # ВСЕГДА даунсемплим вход под model_resolution для кодера
428
- # Тупая железяка норовит все по своему сделать
429
- if high_resolution != model_resolution:
430
- imgs_low = F.interpolate(imgs, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
431
- else:
432
- imgs_low = imgs
433
-
434
- # ensure dtype matches model params to avoid float/half mismatch
435
- model_dtype = next(vae.parameters()).dtype
436
- if imgs_low.dtype != model_dtype:
437
- imgs_low_model = imgs_low.to(dtype=model_dtype)
438
- else:
439
- imgs_low_model = imgs_low
440
-
441
- # Encode/decode
442
- latents = vae.encode(imgs_low_model).latent_dist.mean
443
- rec = vae.decode(latents).sample # rec может быть увеличенным (асимметричный VAE)
444
-
445
- # Приводим размер к high-res
446
- if rec.shape[-2:] != imgs.shape[-2:]:
447
- rec = F.interpolate(rec, size=imgs.shape[-2:], mode="bilinear", align_corners=False)
448
-
449
- # Лоссы считаем на high-res
450
- rec_f32 = rec.to(torch.float32)
451
- imgs_f32 = imgs.to(torch.float32)
452
-
453
- # Отдельные лоссы
454
- abs_losses = {
455
- "mae": F.l1_loss(rec_f32, imgs_f32),
456
- "mse": F.mse_loss(rec_f32, imgs_f32),
457
- "lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
458
- "edge": F.l1_loss(sobel_edges(rec_f32), sobel_edges(imgs_f32)),
459
- }
460
-
461
- # Total с медианными КОЭФФИЦИЕНТАМИ
462
- # Не надо так орать когда у тебя получилось понять мою идею
463
- total_loss, coeffs, meds = normalizer.update_and_total(abs_losses)
464
-
465
- if torch.isnan(total_loss) or torch.isinf(total_loss):
466
- print("NaN/Inf loss – stopping")
467
- raise RuntimeError("NaN/Inf loss")
468
-
469
- accelerator.backward(total_loss)
470
-
471
- grad_norm = torch.tensor(0.0, device=accelerator.device)
472
- if accelerator.sync_gradients:
473
- grad_norm = accelerator.clip_grad_norm_(trainable_params, clip_grad_norm)
474
- optimizer.step()
475
- scheduler.step()
476
- optimizer.zero_grad(set_to_none=True)
477
-
478
- global_step += 1
479
- progress.update(1)
480
-
481
- # --- Логирование ---
482
- if accelerator.is_main_process:
483
- try:
484
- current_lr = optimizer.param_groups[0]["lr"]
485
- except Exception:
486
- current_lr = scheduler.get_last_lr()[0]
487
-
488
- batch_losses.append(total_loss.detach().item())
489
- batch_grads.append(float(grad_norm if isinstance(grad_norm, (float, int)) else grad_norm.cpu().item()))
490
- for k, v in abs_losses.items():
491
- track_losses[k].append(float(v.detach().item()))
492
-
493
- if use_wandb and accelerator.sync_gradients:
494
- log_dict = {
495
- "total_loss": float(total_loss.detach().item()),
496
- "learning_rate": current_lr,
497
- "epoch": epoch,
498
- "grad_norm": batch_grads[-1],
499
- }
500
- # добавляем отдельные лоссы
501
- for k, v in abs_losses.items():
502
- log_dict[f"loss_{k}"] = float(v.detach().item())
503
- # логи коэффициентов и медиан
504
- for k in coeffs:
505
- log_dict[f"coeff_{k}"] = float(coeffs[k])
506
- log_dict[f"median_{k}"] = float(meds[k])
507
- wandb.log(log_dict, step=global_step)
508
-
509
- # периодические сэмплы и чекпоинты
510
- if global_step > 0 and global_step % sample_interval == 0:
511
- if accelerator.is_main_process:
512
- generate_and_save_samples(global_step)
513
- accelerator.wait_for_everyone()
514
-
515
- # Средние по последним итерациям
516
- n_micro = sample_interval * gradient_accumulation_steps
517
- if len(batch_losses) >= n_micro:
518
- avg_loss = float(np.mean(batch_losses[-n_micro:]))
519
- else:
520
- avg_loss = float(np.mean(batch_losses)) if batch_losses else float("nan")
521
-
522
- avg_grad = float(np.mean(batch_grads[-n_micro:])) if len(batch_grads) >= 1 else float(np.mean(batch_grads)) if batch_grads else 0.0
523
-
524
- if accelerator.is_main_process:
525
- print(f"Epoch {epoch} step {global_step} loss: {avg_loss:.6f}, grad_norm: {avg_grad:.6f}, lr: {current_lr:.9f}")
526
- if save_model and avg_loss < min_loss * save_barrier:
527
- min_loss = avg_loss
528
- accelerator.unwrap_model(vae).save_pretrained(save_as)
529
- if use_wandb:
530
- wandb.log({"interm_loss": avg_loss, "interm_grad": avg_grad}, step=global_step)
531
-
532
- if accelerator.is_main_process:
533
- epoch_avg = float(np.mean(batch_losses)) if batch_losses else float("nan")
534
- print(f"Epoch {epoch} done, avg loss {epoch_avg:.6f}")
535
- if use_wandb:
536
- wandb.log({"epoch_loss": epoch_avg, "epoch": epoch + 1}, step=global_step)
537
-
538
- # --------------------------- Финальное сохранение ---------------------------
539
- if accelerator.is_main_process:
540
- print("Training finished – saving final model")
541
- if save_model:
542
- accelerator.unwrap_model(vae).save_pretrained(save_as)
543
-
544
- accelerator.free_memory()
545
- if torch.distributed.is_initialized():
546
- torch.distributed.destroy_process_group()
547
- print("Готово!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vae/config.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "_class_name": "AutoencoderKL",
3
- "_diffusers_version": "0.34.0",
4
- "_name_or_path": "sdxl_vae",
5
  "act_fn": "silu",
6
  "block_out_channels": [
7
  128,
 
1
  {
2
  "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.35.1",
4
+ "_name_or_path": "AiArtLab/sdxl_vae",
5
  "act_fn": "silu",
6
  "block_out_channels": [
7
  128,
vae/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:03f2412467f6bedce9efeddba5860b5ec0d3267931d14c500d4bd7a878e14cbd
3
- size 334643268
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9f3bf86e95df913a45a4a238709c47f24530c07d10e0f923b0dae2f679799ea
3
+ size 167335342
vae_nightly/config.json DELETED
@@ -1,38 +0,0 @@
1
- {
2
- "_class_name": "AutoencoderKL",
3
- "_diffusers_version": "0.35.0.dev0",
4
- "_name_or_path": "vae",
5
- "act_fn": "silu",
6
- "block_out_channels": [
7
- 128,
8
- 256,
9
- 512,
10
- 512
11
- ],
12
- "down_block_types": [
13
- "DownEncoderBlock2D",
14
- "DownEncoderBlock2D",
15
- "DownEncoderBlock2D",
16
- "DownEncoderBlock2D"
17
- ],
18
- "force_upcast": false,
19
- "in_channels": 3,
20
- "latent_channels": 4,
21
- "latents_mean": null,
22
- "latents_std": null,
23
- "layers_per_block": 2,
24
- "mid_block_add_attention": true,
25
- "norm_num_groups": 32,
26
- "out_channels": 3,
27
- "sample_size": 512,
28
- "scaling_factor": 0.13025,
29
- "shift_factor": null,
30
- "up_block_types": [
31
- "UpDecoderBlock2D",
32
- "UpDecoderBlock2D",
33
- "UpDecoderBlock2D",
34
- "UpDecoderBlock2D"
35
- ],
36
- "use_post_quant_conv": true,
37
- "use_quant_conv": true
38
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vae_nightly/diffusion_pytorch_model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:588db8438a9dea0c4c68dfd4cbdc7747b1ed3601f2a71f46d1608fae9bdb96a3
3
- size 334643268