recoilme commited on
Commit
74bfd8a
·
1 Parent(s): 7434657
eval_alchemist2.py CHANGED
@@ -27,25 +27,27 @@ from scipy.stats import skew, kurtosis
27
  # ========================== Конфиг ==========================
28
  DEVICE = "cuda"
29
  DTYPE = torch.float16
30
- IMAGE_FOLDER = "/home/recoilme/dataset/alchemist"
31
  MIN_SIZE = 1280
32
  CROP_SIZE = 512
33
  BATCH_SIZE = 10
34
- MAX_IMAGES = 500
35
  NUM_WORKERS = 4
36
  SAMPLES_DIR = "vaetest"
37
 
38
  VAE_LIST = [
39
  # ("SD15 VAE", AutoencoderKL, "stable-diffusion-v1-5/stable-diffusion-v1-5", "vae"),
40
  # ("SDXL VAE fp16 fix", AutoencoderKL, "madebyollin/sdxl-vae-fp16-fix", None),
41
- ("Wan2.2-TI2V-5B", AutoencoderKLWan, "Wan-AI/Wan2.2-TI2V-5B-Diffusers", "vae"),
42
- ("Wan2.2-T2V-A14B", AutoencoderKLWan, "Wan-AI/Wan2.2-T2V-A14B-Diffusers", "vae"),
43
  #("SimpleVAE1", AutoencoderKL, "/home/recoilme/simplevae/simplevae", "simple_vae_nightly"),
44
- #("SimpleVAE2", AutoencoderKL, "/home/recoilme/simplevae/simplevae", "simple_vae_nightly2"),
 
45
  #("SimpleVAE nightly", AutoencoderKL, "AiArtLab/simplevae", "simple_vae_nightly"),
46
- #("FLUX.1-schnell VAE", AutoencoderKL, "black-forest-labs/FLUX.1-schnell", "vae"),
 
47
  # ("LTX-Video VAE", AutoencoderKLLTXVideo, "Lightricks/LTX-Video", "vae"),
48
- ("QwenImage", AutoencoderKLQwenImage, "Qwen/Qwen-Image", "vae"),
49
  ]
50
 
51
 
 
27
  # ========================== Конфиг ==========================
28
  DEVICE = "cuda"
29
  DTYPE = torch.float16
30
+ IMAGE_FOLDER = "/workspace/alchemist"
31
  MIN_SIZE = 1280
32
  CROP_SIZE = 512
33
  BATCH_SIZE = 10
34
+ MAX_IMAGES = 0
35
  NUM_WORKERS = 4
36
  SAMPLES_DIR = "vaetest"
37
 
38
  VAE_LIST = [
39
  # ("SD15 VAE", AutoencoderKL, "stable-diffusion-v1-5/stable-diffusion-v1-5", "vae"),
40
  # ("SDXL VAE fp16 fix", AutoencoderKL, "madebyollin/sdxl-vae-fp16-fix", None),
41
+ #("Wan2.2-TI2V-5B", AutoencoderKLWan, "Wan-AI/Wan2.2-TI2V-5B-Diffusers", "vae"),
42
+ #("Wan2.2-T2V-A14B", AutoencoderKLWan, "Wan-AI/Wan2.2-T2V-A14B-Diffusers", "vae"),
43
  #("SimpleVAE1", AutoencoderKL, "/home/recoilme/simplevae/simplevae", "simple_vae_nightly"),
44
+ #("SimpleVAE2", AutoencoderKL, "/home/recoilme/simplevae/simplevae", "simple_vae_nightly2"),
45
+ #("SimpleVAE ", AutoencoderKL, "AiArtLab/simplevae", None),
46
  #("SimpleVAE nightly", AutoencoderKL, "AiArtLab/simplevae", "simple_vae_nightly"),
47
+ ("FLUX.1-schnell VAE", AutoencoderKL, "black-forest-labs/FLUX.1-schnell", "vae"),
48
+ ("SimpleVAE nightly", AutoencoderKL, "AiArtLab/simplevae", "simple_vae_nightly"),
49
  # ("LTX-Video VAE", AutoencoderKLLTXVideo, "Lightricks/LTX-Video", "vae"),
50
+ #("QwenImage", AutoencoderKLQwenImage, "Qwen/Qwen-Image", "vae"),
51
  ]
52
 
53
 
samples/sample_0.jpg CHANGED

Git LFS Details

  • SHA256: cb43df876fea0ab69a3fa63399c378aad4dda308a1534071796834acc26c71a6
  • Pointer size: 130 Bytes
  • Size of remote file: 84.9 kB

Git LFS Details

  • SHA256: 30179cebe92d534a6b7a0adec99dd2e30cf019ca343c520491744ded6dd11a75
  • Pointer size: 130 Bytes
  • Size of remote file: 68.6 kB
samples/sample_1.jpg CHANGED

Git LFS Details

  • SHA256: fc0b8542e55bc97fb988441631c9e80543aef8ce0796c6416280282d73da427f
  • Pointer size: 130 Bytes
  • Size of remote file: 75.7 kB

Git LFS Details

  • SHA256: 9b01273705935081122818ef111ae92c02274dc1941cd22379f00ec8c2b64741
  • Pointer size: 131 Bytes
  • Size of remote file: 120 kB
samples/sample_2.jpg CHANGED

Git LFS Details

  • SHA256: 6d7969e2ba962645308392a623d1bc8b8573472aae631a68ac2996c31f2dd8af
  • Pointer size: 130 Bytes
  • Size of remote file: 71.2 kB

Git LFS Details

  • SHA256: aa41aaa491109ab475473a13dcb1e5222e3f557ed9916d3d511b6f1d260363ee
  • Pointer size: 131 Bytes
  • Size of remote file: 115 kB
samples/sample_decoded.jpg CHANGED

Git LFS Details

  • SHA256: cb43df876fea0ab69a3fa63399c378aad4dda308a1534071796834acc26c71a6
  • Pointer size: 130 Bytes
  • Size of remote file: 84.9 kB

Git LFS Details

  • SHA256: 30179cebe92d534a6b7a0adec99dd2e30cf019ca343c520491744ded6dd11a75
  • Pointer size: 130 Bytes
  • Size of remote file: 68.6 kB
samples/sample_real.jpg CHANGED

Git LFS Details

  • SHA256: b187738cf82a8633e1409e6ed3db35fb5930681957ed8d69ae8cce6da881371f
  • Pointer size: 130 Bytes
  • Size of remote file: 89.9 kB

Git LFS Details

  • SHA256: 14111dc19d653993b54635ffccdbd4b8e3d2cdf7d3419a1b2c9064e2051813b0
  • Pointer size: 130 Bytes
  • Size of remote file: 70.7 kB
simple_vae/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f5f0a20e403669e880b510514ee575a2a9cb74a1b36ab0e31fc68ef66c2173d7
3
  size 335311892
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ba1d500c4bd376a7c8662a35fa1857c7e577da0635414b524180852143ef2f6
3
  size 335311892
simple_vae_nightly/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7b705da7f401289eefa22570514d7c1b9b2f9fd32a71159e2d3d5888f74e41cd
3
  size 335311892
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b39620d0953839362425f03674e6c3e37f03d20be3fbd7f281baea4dfc336a40
3
  size 335311892
train_sdxl_vae_full.py CHANGED
@@ -26,9 +26,9 @@ from collections import deque
26
  ds_path = "/workspace/png"
27
  project = "simple_vae"
28
  batch_size = 3
29
- base_learning_rate = 5e-5
30
- min_learning_rate = 9e-7
31
- num_epochs = 16
32
  sample_interval_share = 10
33
  use_wandb = True
34
  save_model = True
@@ -58,20 +58,20 @@ device = None # accelerator задаст устройство
58
  # CHANGED: добавлен параметр для полного обучения VAE (а не только декодера).
59
  # Если False — поведение прежнее: учим только decoder.* (up_blocks + mid_block).
60
  # Если True — размораживаем ВСЮ модель и добавляем KL-loss для энкодера.
61
- full_training = True
62
 
63
  # CHANGED: добавлен вес (через долю в нормализаторе) для KL, используется только при full_training=True.
64
- kl_ratio = 0.05 # простая доля для KL в общей смеси (KISS). Игнорируется, если full_training=False.
65
 
66
  # --- Пропорции лоссов и окно медианного нормирования (КОЭФ., не значения) ---
67
  # Итоговые доли в total loss (сумма = 1.0 после нормализации).
68
  loss_ratios = {
69
- "lpips": 0.80,
70
- "edge": 0.05,
71
- "mse": 0.05,
72
- "mae": 0.05,
73
  # CHANGED: заранее добавлен ключ "kl" (по умолчанию 0.0). Если включаем full_training — активируем ниже.
74
- "kl": 0.05,
75
  }
76
  median_coeff_steps = 256 # за сколько шагов считать медианные коэффициенты
77
 
@@ -490,7 +490,11 @@ for epoch in range(num_epochs):
490
  mean = enc.latent_dist.mean
491
  logvar = enc.latent_dist.logvar
492
  # стабильное усреднение по батчу и пространству
493
- kl = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
 
 
 
 
494
  abs_losses["kl"] = kl
495
  else:
496
  # ключ присутствует в ratios, но при partial-training его доля = 0 и он не влияет
 
26
  ds_path = "/workspace/png"
27
  project = "simple_vae"
28
  batch_size = 3
29
+ base_learning_rate = 2e-6
30
+ min_learning_rate = 8e-7
31
+ num_epochs = 8
32
  sample_interval_share = 10
33
  use_wandb = True
34
  save_model = True
 
58
  # CHANGED: добавлен параметр для полного обучения VAE (а не только декодера).
59
  # Если False — поведение прежнее: учим только decoder.* (up_blocks + mid_block).
60
  # Если True — размораживаем ВСЮ модель и добавляем KL-loss для энкодера.
61
+ full_training = False
62
 
63
  # CHANGED: добавлен вес (через долю в нормализаторе) для KL, используется только при full_training=True.
64
+ kl_ratio = 0.00 # простая доля для KL в общей смеси (KISS). Игнорируется, если full_training=False.
65
 
66
  # --- Пропорции лоссов и окно медианного нормирования (КОЭФ., не значения) ---
67
  # Итоговые доли в total loss (сумма = 1.0 после нормализации).
68
  loss_ratios = {
69
+ "lpips": 0.60,
70
+ "edge": 0.10,
71
+ "mse": 0.15,
72
+ "mae": 0.15,
73
  # CHANGED: заранее добавлен ключ "kl" (по умолчанию 0.0). Если включаем full_training — активируем ниже.
74
+ "kl": 0.00,
75
  }
76
  median_coeff_steps = 256 # за сколько шагов считать медианные коэффициенты
77
 
 
490
  mean = enc.latent_dist.mean
491
  logvar = enc.latent_dist.logvar
492
  # стабильное усреднение по батчу и пространству
493
+ # СТАРОЕ (неправильное):
494
+ #kl = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
495
+ # НОВОЕ (правильное):
496
+ kl_per_sample = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp(), dim=[1, 2, 3])
497
+ kl = torch.mean(kl_per_sample)
498
  abs_losses["kl"] = kl
499
  else:
500
  # ключ присутствует в ratios, но при partial-training его доля = 0 и он не влияет
train_sdxl_vae_wan.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import math
4
+ import re
5
+ import torch
6
+ import numpy as np
7
+ import random
8
+ import gc
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+
12
+ import torchvision.transforms as transforms
13
+ import torch.nn.functional as F
14
+ from torch.utils.data import DataLoader, Dataset
15
+ from torch.optim.lr_scheduler import LambdaLR
16
+ from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
17
+ # QWEN: импорт класса
18
+ from diffusers import AutoencoderKLQwenImage
19
+ from diffusers import AutoencoderKLWan
20
+
21
+ from accelerate import Accelerator
22
+ from PIL import Image, UnidentifiedImageError
23
+ from tqdm import tqdm
24
+ import bitsandbytes as bnb
25
+ import wandb
26
+ import lpips # pip install lpips
27
+ from collections import deque
28
+
29
+ # --------------------------- Параметры ---------------------------
30
+ ds_path = "/workspace/png"
31
+ project = "wan16x_vae"
32
+ batch_size = 4
33
+ base_learning_rate = 6e-6
34
+ min_learning_rate = 9e-7
35
+ num_epochs = 8
36
+ sample_interval_share = 10
37
+ use_wandb = True
38
+ save_model = True
39
+ use_decay = True
40
+ optimizer_type = "adam8bit"
41
+ dtype = torch.float32
42
+
43
+ model_resolution = 512
44
+ high_resolution = 512
45
+ limit = 0
46
+ save_barrier = 1.03
47
+ warmup_percent = 0.01
48
+ percentile_clipping = 95
49
+ beta2 = 0.97
50
+ eps = 1e-6
51
+ clip_grad_norm = 1.0
52
+ mixed_precision = "no"
53
+ gradient_accumulation_steps = 4
54
+ generated_folder = "samples"
55
+ save_as = "wan16x_vae_nightly"
56
+ num_workers = 0
57
+ device = None
58
+
59
+ # --- Режимы обучения ---
60
+ # QWEN: учим только декодер
61
+ train_decoder_only = True
62
+ full_training = False # если True — учим весь VAE и добавляем KL (ниже)
63
+ kl_ratio = 0.05
64
+
65
+ # Доли лоссов
66
+ loss_ratios = {
67
+ "lpips": 0.75,
68
+ "edge": 0.05,
69
+ "mse": 0.10,
70
+ "mae": 0.10,
71
+ "kl": 0.00, # активируем при full_training=True
72
+ }
73
+ median_coeff_steps = 256
74
+
75
+ resize_long_side = 1280 # ресайз длинной стороны исходных картинок
76
+
77
+ # QWEN: конфиг загрузки модели
78
+ vae_kind = "wan" # "qwen" или "kl" (обычный)
79
+
80
+ Path(generated_folder).mkdir(parents=True, exist_ok=True)
81
+
82
+ accelerator = Accelerator(
83
+ mixed_precision=mixed_precision,
84
+ gradient_accumulation_steps=gradient_accumulation_steps
85
+ )
86
+ device = accelerator.device
87
+
88
+ # reproducibility
89
+ seed = int(datetime.now().strftime("%Y%m%d"))
90
+ torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
91
+ torch.backends.cudnn.benchmark = False
92
+
93
+ # --------------------------- WandB ---------------------------
94
+ if use_wandb and accelerator.is_main_process:
95
+ wandb.init(project=project, config={
96
+ "batch_size": batch_size,
97
+ "base_learning_rate": base_learning_rate,
98
+ "num_epochs": num_epochs,
99
+ "optimizer_type": optimizer_type,
100
+ "model_resolution": model_resolution,
101
+ "high_resolution": high_resolution,
102
+ "gradient_accumulation_steps": gradient_accumulation_steps,
103
+ "train_decoder_only": train_decoder_only,
104
+ "full_training": full_training,
105
+ "kl_ratio": kl_ratio,
106
+ "vae_kind": vae_kind,
107
+ })
108
+
109
+ # --------------------------- VAE ---------------------------
110
+ def get_core_model(model):
111
+ m = model
112
+ # если модель уже обёрнута torch.compile
113
+ if hasattr(m, "_orig_mod"):
114
+ m = m._orig_mod
115
+ return m
116
+
117
+ def is_video_vae(model) -> bool:
118
+ # WAN/Qwen — это видео-VAEs
119
+ if vae_kind in ("wan", "qwen"):
120
+ return True
121
+ # fallback по структуре (если понадобится)
122
+ try:
123
+ core = get_core_model(model)
124
+ enc = getattr(core, "encoder", None)
125
+ conv_in = getattr(enc, "conv_in", None)
126
+ w = getattr(conv_in, "weight", None)
127
+ if isinstance(w, torch.nn.Parameter):
128
+ return w.ndim == 5
129
+ except Exception:
130
+ pass
131
+ return False
132
+
133
+ # загрузка
134
+ if vae_kind == "qwen":
135
+ vae = AutoencoderKLQwenImage.from_pretrained("Qwen/Qwen-Image", subfolder="vae")
136
+ else:
137
+ if vae_kind == "wan":
138
+ vae = AutoencoderKLWan.from_pretrained(project)
139
+ else:
140
+ # старое поведение (пример)
141
+ if model_resolution==high_resolution:
142
+ vae = AutoencoderKL.from_pretrained(project)
143
+ else:
144
+ vae = AsymmetricAutoencoderKL.from_pretrained(project)
145
+
146
+ vae = vae.to(dtype)
147
+
148
+ # torch.compile (опционально)
149
+ if hasattr(torch, "compile"):
150
+ try:
151
+ vae = torch.compile(vae)
152
+ except Exception as e:
153
+ print(f"[WARN] torch.compile failed: {e}")
154
+
155
+ # --------------------------- Freeze/Unfreeze ---------------------------
156
+ core = get_core_model(vae)
157
+
158
+ for p in core.parameters():
159
+ p.requires_grad = False
160
+
161
+ unfrozen_param_names = []
162
+
163
+ if full_training and not train_decoder_only:
164
+ for name, p in core.named_parameters():
165
+ p.requires_grad = True
166
+ unfrozen_param_names.append(name)
167
+ loss_ratios["kl"] = float(kl_ratio)
168
+ trainable_module = core
169
+ else:
170
+ # учим только декодер + post_quant_conv на "ядре" модели
171
+ if hasattr(core, "decoder"):
172
+ for name, p in core.decoder.named_parameters():
173
+ p.requires_grad = True
174
+ unfrozen_param_names.append(f"decoder.{name}")
175
+ if hasattr(core, "post_quant_conv"):
176
+ for name, p in core.post_quant_conv.named_parameters():
177
+ p.requires_grad = True
178
+ unfrozen_param_names.append(f"post_quant_conv.{name}")
179
+ trainable_module = core.decoder if hasattr(core, "decoder") else core
180
+
181
+ print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:")
182
+ for nm in unfrozen_param_names[:200]:
183
+ print(" ", nm)
184
+
185
+ # --------------------------- Датасет ---------------------------
186
+ class PngFolderDataset(Dataset):
187
+ def __init__(self, root_dir, min_exts=('.png',), resolution=1024, limit=0):
188
+ self.root_dir = root_dir
189
+ self.resolution = resolution
190
+ self.paths = []
191
+ for root, _, files in os.walk(root_dir):
192
+ for fname in files:
193
+ if fname.lower().endswith(tuple(ext.lower() for ext in min_exts)):
194
+ self.paths.append(os.path.join(root, fname))
195
+ if limit:
196
+ self.paths = self.paths[:limit]
197
+ valid = []
198
+ for p in self.paths:
199
+ try:
200
+ with Image.open(p) as im:
201
+ im.verify()
202
+ valid.append(p)
203
+ except (OSError, UnidentifiedImageError):
204
+ continue
205
+ self.paths = valid
206
+ if len(self.paths) == 0:
207
+ raise RuntimeError(f"No valid PNG images found under {root_dir}")
208
+ random.shuffle(self.paths)
209
+
210
+ def __len__(self):
211
+ return len(self.paths)
212
+
213
+ def __getitem__(self, idx):
214
+ p = self.paths[idx % len(self.paths)]
215
+ with Image.open(p) as img:
216
+ img = img.convert("RGB")
217
+ if not resize_long_side or resize_long_side <= 0:
218
+ return img
219
+ w, h = img.size
220
+ long = max(w, h)
221
+ if long <= resize_long_side:
222
+ return img
223
+ scale = resize_long_side / float(long)
224
+ new_w = int(round(w * scale))
225
+ new_h = int(round(h * scale))
226
+ return img.resize((new_w, new_h), Image.LANCZOS)
227
+
228
+ def random_crop(img, sz):
229
+ w, h = img.size
230
+ if w < sz or h < sz:
231
+ img = img.resize((max(sz, w), max(sz, h)), Image.LANCZOS)
232
+ x = random.randint(0, max(1, img.width - sz))
233
+ y = random.randint(0, max(1, img.height - sz))
234
+ return img.crop((x, y, x + sz, y + sz))
235
+
236
+ tfm = transforms.Compose([
237
+ transforms.ToTensor(),
238
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
239
+ ])
240
+
241
+ dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit)
242
+ if len(dataset) < batch_size:
243
+ raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}")
244
+
245
+ def collate_fn(batch):
246
+ imgs = []
247
+ for img in batch:
248
+ img = random_crop(img, high_resolution)
249
+ imgs.append(tfm(img))
250
+ return torch.stack(imgs)
251
+
252
+ dataloader = DataLoader(
253
+ dataset,
254
+ batch_size=batch_size,
255
+ shuffle=True,
256
+ collate_fn=collate_fn,
257
+ num_workers=num_workers,
258
+ pin_memory=True,
259
+ drop_last=True
260
+ )
261
+
262
+ # --------------------------- Оптимизатор ---------------------------
263
+ def get_param_groups(module, weight_decay=0.001):
264
+ no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "ln_1.weight", "ln_f.weight"]
265
+ decay_params, no_decay_params = [], []
266
+ for n, p in vae.named_parameters(): # глобально по vae, с фильтром requires_grad
267
+ if not p.requires_grad:
268
+ continue
269
+ if any(nd in n for nd in no_decay):
270
+ no_decay_params.append(p)
271
+ else:
272
+ decay_params.append(p)
273
+ return [
274
+ {"params": decay_params, "weight_decay": weight_decay},
275
+ {"params": no_decay_params, "weight_decay": 0.0},
276
+ ]
277
+
278
+ def get_param_groups(module, weight_decay=0.001):
279
+ no_decay_tokens = ("bias", "norm", "rms", "layernorm")
280
+ decay_params, no_decay_params = [], []
281
+ for n, p in module.named_parameters():
282
+ if not p.requires_grad:
283
+ continue
284
+ n_l = n.lower()
285
+ if any(t in n_l for t in no_decay_tokens):
286
+ no_decay_params.append(p)
287
+ else:
288
+ decay_params.append(p)
289
+ return [
290
+ {"params": decay_params, "weight_decay": weight_decay},
291
+ {"params": no_decay_params, "weight_decay": 0.0},
292
+ ]
293
+
294
+ def create_optimizer(name, param_groups):
295
+ if name == "adam8bit":
296
+ return bnb.optim.AdamW8bit(param_groups, lr=base_learning_rate, betas=(0.9, beta2), eps=eps)
297
+ raise ValueError(name)
298
+
299
+ param_groups = get_param_groups(get_core_model(vae), weight_decay=0.001)
300
+ optimizer = create_optimizer(optimizer_type, param_groups)
301
+
302
+ # --------------------------- LR schedule ---------------------------
303
+ batches_per_epoch = len(dataloader)
304
+ steps_per_epoch = int(math.ceil(batches_per_epoch / float(gradient_accumulation_steps)))
305
+ total_steps = steps_per_epoch * num_epochs
306
+
307
+ def lr_lambda(step):
308
+ if not use_decay:
309
+ return 1.0
310
+ x = float(step) / float(max(1, total_steps))
311
+ warmup = float(warmup_percent)
312
+ min_ratio = float(min_learning_rate) / float(base_learning_rate)
313
+ if x < warmup:
314
+ return min_ratio + (1.0 - min_ratio) * (x / warmup)
315
+ decay_ratio = (x - warmup) / (1.0 - warmup)
316
+ return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * decay_ratio))
317
+
318
+ scheduler = LambdaLR(optimizer, lr_lambda)
319
+
320
+ # Подготовка
321
+ dataloader, vae, optimizer, scheduler = accelerator.prepare(dataloader, vae, optimizer, scheduler)
322
+ trainable_params = [p for p in vae.parameters() if p.requires_grad]
323
+
324
+ # --------------------------- LPIPS и вспомогательные ---------------------------
325
+ _lpips_net = None
326
+ def _get_lpips():
327
+ global _lpips_net
328
+ if _lpips_net is None:
329
+ _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
330
+ return _lpips_net
331
+
332
+ _sobel_kx = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], dtype=torch.float32)
333
+ _sobel_ky = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], dtype=torch.float32)
334
+ def sobel_edges(x: torch.Tensor) -> torch.Tensor:
335
+ C = x.shape[1]
336
+ kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1)
337
+ ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1)
338
+ gx = F.conv2d(x, kx, padding=1, groups=C)
339
+ gy = F.conv2d(x, ky, padding=1, groups=C)
340
+ return torch.sqrt(gx * gx + gy * gy + 1e-12)
341
+
342
+ class MedianLossNormalizer:
343
+ def __init__(self, desired_ratios: dict, window_steps: int):
344
+ s = sum(desired_ratios.values())
345
+ self.ratios = {k: (v / s) if s > 0 else 0.0 for k, v in desired_ratios.items()}
346
+ self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
347
+ self.window = window_steps
348
+
349
+ def update_and_total(self, abs_losses: dict):
350
+ for k, v in abs_losses.items():
351
+ if k in self.buffers:
352
+ self.buffers[k].append(float(v.detach().abs().cpu()))
353
+ meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
354
+ coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
355
+ total = sum(coeffs[k] * abs_losses[k] for k in abs_losses if k in coeffs)
356
+ return total, coeffs, meds
357
+
358
+ if full_training and not train_decoder_only:
359
+ loss_ratios["kl"] = float(kl_ratio)
360
+ normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
361
+
362
+ # --------------------------- Сэмплы ---------------------------
363
+ @torch.no_grad()
364
+ def get_fixed_samples(n=3):
365
+ idx = random.sample(range(len(dataset)), min(n, len(dataset)))
366
+ pil_imgs = [dataset[i] for i in idx]
367
+ tensors = []
368
+ for img in pil_imgs:
369
+ img = random_crop(img, high_resolution)
370
+ tensors.append(tfm(img))
371
+ return torch.stack(tensors).to(accelerator.device, dtype)
372
+
373
+ fixed_samples = get_fixed_samples()
374
+
375
+ @torch.no_grad()
376
+ def _to_pil_uint8(img_tensor: torch.Tensor) -> Image.Image:
377
+ arr = ((img_tensor.float().clamp(-1, 1) + 1.0) * 127.5).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0)
378
+ return Image.fromarray(arr)
379
+
380
+ @torch.no_grad()
381
+ def generate_and_save_samples(step=None):
382
+ try:
383
+ temp_vae = accelerator.unwrap_model(vae).eval()
384
+ lpips_net = _get_lpips()
385
+ with torch.no_grad():
386
+ orig_high = fixed_samples
387
+ orig_low = F.interpolate(orig_high, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
388
+ model_dtype = next(temp_vae.parameters()).dtype
389
+ orig_low = orig_low.to(dtype=model_dtype)
390
+
391
+ # QWEN: добавляем T=1 на encode/decode и снимаем при сравнении
392
+ if is_video_vae(temp_vae):
393
+ x_in = orig_low.unsqueeze(2) # [B,3,1,H,W]
394
+ enc = temp_vae.encode(x_in)
395
+ latents_mean = enc.latent_dist.mean
396
+ dec = temp_vae.decode(latents_mean).sample # [B,3,1,H,W]
397
+ rec = dec.squeeze(2) # [B,3,H,W]
398
+ else:
399
+ enc = temp_vae.encode(orig_low)
400
+ latents_mean = enc.latent_dist.mean
401
+ rec = temp_vae.decode(latents_mean).sample
402
+
403
+ if rec.shape[-2:] != orig_high.shape[-2:]:
404
+ rec = F.interpolate(rec, size=orig_high.shape[-2:], mode="bilinear", align_corners=False)
405
+
406
+ first_real = _to_pil_uint8(orig_high[0])
407
+ first_dec = _to_pil_uint8(rec[0])
408
+ first_real.save(f"{generated_folder}/sample_real.jpg", quality=95)
409
+ first_dec.save(f"{generated_folder}/sample_decoded.jpg", quality=95)
410
+
411
+ for i in range(rec.shape[0]):
412
+ _to_pil_uint8(rec[i]).save(f"{generated_folder}/sample_{i}.jpg", quality=95)
413
+
414
+ lpips_scores = []
415
+ for i in range(rec.shape[0]):
416
+ orig_full = orig_high[i:i+1].to(torch.float32)
417
+ rec_full = rec[i:i+1].to(torch.float32)
418
+ if rec_full.shape[-2:] != orig_full.shape[-2:]:
419
+ rec_full = F.interpolate(rec_full, size=orig_full.shape[-2:], mode="bilinear", align_corners=False)
420
+ lpips_val = lpips_net(orig_full, rec_full).item()
421
+ lpips_scores.append(lpips_val)
422
+ avg_lpips = float(np.mean(lpips_scores))
423
+
424
+ if use_wandb and accelerator.is_main_process:
425
+ wandb.log({"lpips_mean": avg_lpips}, step=step)
426
+ finally:
427
+ gc.collect()
428
+ torch.cuda.empty_cache()
429
+
430
+ if accelerator.is_main_process and save_model:
431
+ print("Генерация сэмплов до старта обучения...")
432
+ generate_and_save_samples(0)
433
+
434
+ accelerator.wait_for_everyone()
435
+
436
+ # --------------------------- Тренировка ---------------------------
437
+ progress = tqdm(total=total_steps, disable=not accelerator.is_local_main_process)
438
+ global_step = 0
439
+ min_loss = float("inf")
440
+ sample_interval = max(1, total_steps // max(1, sample_interval_share * num_epochs))
441
+
442
+ for epoch in range(num_epochs):
443
+ vae.train()
444
+ batch_losses, batch_grads = [], []
445
+ track_losses = {k: [] for k in loss_ratios.keys()}
446
+
447
+ for imgs in dataloader:
448
+ with accelerator.accumulate(vae):
449
+ imgs = imgs.to(accelerator.device)
450
+
451
+ if high_resolution != model_resolution:
452
+ imgs_low = F.interpolate(imgs, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
453
+ else:
454
+ imgs_low = imgs
455
+
456
+ model_dtype = next(vae.parameters()).dtype
457
+ imgs_low_model = imgs_low.to(dtype=model_dtype) if imgs_low.dtype != model_dtype else imgs_low
458
+
459
+ # QWEN: encode/decode с T=1
460
+ if is_video_vae(vae):
461
+ x_in = imgs_low_model.unsqueeze(2) # [B,3,1,H,W]
462
+ enc = vae.encode(x_in)
463
+ latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
464
+ dec = vae.decode(latents).sample # [B,3,1,H,W]
465
+ rec = dec.squeeze(2) # [B,3,H,W]
466
+ else:
467
+ enc = vae.encode(imgs_low_model)
468
+ latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
469
+ rec = vae.decode(latents).sample
470
+
471
+ if rec.shape[-2:] != imgs.shape[-2:]:
472
+ rec = F.interpolate(rec, size=imgs.shape[-2:], mode="bilinear", align_corners=False)
473
+
474
+ rec_f32 = rec.to(torch.float32)
475
+ imgs_f32 = imgs.to(torch.float32)
476
+
477
+ abs_losses = {
478
+ "mae": F.l1_loss(rec_f32, imgs_f32),
479
+ "mse": F.mse_loss(rec_f32, imgs_f32),
480
+ "lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
481
+ "edge": F.l1_loss(sobel_edges(rec_f32), sobel_edges(imgs_f32)),
482
+ }
483
+
484
+ if full_training and not train_decoder_only:
485
+ mean = enc.latent_dist.mean
486
+ logvar = enc.latent_dist.logvar
487
+ kl = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
488
+ abs_losses["kl"] = kl
489
+ else:
490
+ abs_losses["kl"] = torch.tensor(0.0, device=accelerator.device, dtype=torch.float32)
491
+
492
+ total_loss, coeffs, meds = normalizer.update_and_total(abs_losses)
493
+
494
+ if torch.isnan(total_loss) or torch.isinf(total_loss):
495
+ raise RuntimeError("NaN/Inf loss")
496
+
497
+ accelerator.backward(total_loss)
498
+
499
+ grad_norm = torch.tensor(0.0, device=accelerator.device)
500
+ if accelerator.sync_gradients:
501
+ grad_norm = accelerator.clip_grad_norm_(trainable_params, clip_grad_norm)
502
+ optimizer.step()
503
+ scheduler.step()
504
+ optimizer.zero_grad(set_to_none=True)
505
+ global_step += 1
506
+ progress.update(1)
507
+
508
+ if accelerator.is_main_process:
509
+ try:
510
+ current_lr = optimizer.param_groups[0]["lr"]
511
+ except Exception:
512
+ current_lr = scheduler.get_last_lr()[0]
513
+
514
+ batch_losses.append(total_loss.detach().item())
515
+ batch_grads.append(float(grad_norm.detach().cpu().item()) if isinstance(grad_norm, torch.Tensor) else float(grad_norm))
516
+ for k, v in abs_losses.items():
517
+ track_losses[k].append(float(v.detach().item()))
518
+
519
+ if use_wandb and accelerator.sync_gradients:
520
+ log_dict = {
521
+ "total_loss": float(total_loss.detach().item()),
522
+ "learning_rate": current_lr,
523
+ "epoch": epoch,
524
+ "grad_norm": batch_grads[-1],
525
+ }
526
+ for k, v in abs_losses.items():
527
+ log_dict[f"loss_{k}"] = float(v.detach().item())
528
+ for k in coeffs:
529
+ log_dict[f"coeff_{k}"] = float(coeffs[k])
530
+ log_dict[f"median_{k}"] = float(meds[k])
531
+ wandb.log(log_dict, step=global_step)
532
+
533
+ if global_step > 0 and global_step % sample_interval == 0:
534
+ if accelerator.is_main_process:
535
+ generate_and_save_samples(global_step)
536
+ accelerator.wait_for_everyone()
537
+
538
+ n_micro = sample_interval * gradient_accumulation_steps
539
+ avg_loss = float(np.mean(batch_losses[-n_micro:])) if len(batch_losses) >= n_micro else float(np.mean(batch_losses)) if batch_losses else float("nan")
540
+ avg_grad = float(np.mean(batch_grads[-n_micro:])) if len(batch_grads) >= 1 else float(np.mean(batch_grads)) if batch_grads else 0.0
541
+
542
+ if accelerator.is_main_process:
543
+ print(f"Epoch {epoch} step {global_step} loss: {avg_loss:.6f}, grad_norm: {avg_grad:.6f}, lr: {current_lr:.9f}")
544
+ if save_model and avg_loss < min_loss * save_barrier:
545
+ min_loss = avg_loss
546
+ accelerator.unwrap_model(vae).save_pretrained(save_as)
547
+ if use_wandb:
548
+ wandb.log({"interm_loss": avg_loss, "interm_grad": avg_grad}, step=global_step)
549
+
550
+ if accelerator.is_main_process:
551
+ epoch_avg = float(np.mean(batch_losses)) if batch_losses else float("nan")
552
+ print(f"Epoch {epoch} done, avg loss {epoch_avg:.6f}")
553
+ if use_wandb:
554
+ wandb.log({"epoch_loss": epoch_avg, "epoch": epoch + 1}, step=global_step)
555
+
556
+ # --------------------------- Финальное сохранение ---------------------------
557
+ if accelerator.is_main_process:
558
+ print("Training finished – saving final model")
559
+ if save_model:
560
+ accelerator.unwrap_model(vae).save_pretrained(save_as)
561
+
562
+ accelerator.free_memory()
563
+ if torch.distributed.is_initialized():
564
+ torch.distributed.destroy_process_group()
565
+ print("Готово!")
vaetest/001_all.png DELETED

Git LFS Details

  • SHA256: 7b7a8098d61a1525db5ce3eaa5cd50e132a5f846a6053f789edd4801e37b0d18
  • Pointer size: 132 Bytes
  • Size of remote file: 2.32 MB
vaetest/001_decoded_FLUX.1_schnell_vae.png DELETED

Git LFS Details

  • SHA256: 21b88d5045d1b9c0a3785d5b96a6dcd225ea92921143fd4a2fe5daabf060ccae
  • Pointer size: 131 Bytes
  • Size of remote file: 494 kB
vaetest/001_decoded_simple_vae.png DELETED

Git LFS Details

  • SHA256: 816836033774e8ad18e4853763cf2f040db44ab08de261ef3b1be95931d7f28d
  • Pointer size: 131 Bytes
  • Size of remote file: 483 kB
vaetest/001_decoded_simple_vae2.png DELETED

Git LFS Details

  • SHA256: 903009ea3ea4344918cf79b06ffc8ba55402275a1fa41ed7d118086c49ec9dd4
  • Pointer size: 131 Bytes
  • Size of remote file: 483 kB
vaetest/001_decoded_simple_vae_nightly.png DELETED

Git LFS Details

  • SHA256: 1246db8b7d3e6a36199dbd83b8532a28917df0528f016429cf499179d8d2bcf4
  • Pointer size: 131 Bytes
  • Size of remote file: 483 kB
vaetest/001_orig.png DELETED

Git LFS Details

  • SHA256: 12c632e0aecc1925185142be560a65e204e12e7167dbcc1a49e3017b371638fe
  • Pointer size: 131 Bytes
  • Size of remote file: 464 kB
vaetest/002_all.png DELETED

Git LFS Details

  • SHA256: 058a39dde15443d7547a4944df594d285b465c5e8817225c669ea04adf4d6c01
  • Pointer size: 132 Bytes
  • Size of remote file: 1.74 MB
vaetest/002_decoded_FLUX.1_schnell_vae.png DELETED

Git LFS Details

  • SHA256: f5bed96d137dbaa377cebaaceb9919a2d8d51a1793759b643e2d772e4fe33785
  • Pointer size: 131 Bytes
  • Size of remote file: 380 kB
vaetest/002_decoded_simple_vae.png DELETED

Git LFS Details

  • SHA256: 808a6138bc48cd85f752d2bcabd1e7795c89df6794aab686a32e5c9ac2f7214f
  • Pointer size: 131 Bytes
  • Size of remote file: 373 kB
vaetest/002_decoded_simple_vae2.png DELETED

Git LFS Details

  • SHA256: 4f67867e41530bcb498a46b963f65db5c603bd70db042b294450d337a9ddb651
  • Pointer size: 131 Bytes
  • Size of remote file: 373 kB
vaetest/002_decoded_simple_vae_nightly.png DELETED

Git LFS Details

  • SHA256: 7905cccc22fc84bddcf712d9615a069b2dd3999a17b37b747b08d2db9c8719b7
  • Pointer size: 131 Bytes
  • Size of remote file: 373 kB
vaetest/002_orig.png DELETED

Git LFS Details

  • SHA256: 177599cb0d77d66058bb53146156de2d9654ac98255ae2005f7aadbebaee0fcf
  • Pointer size: 131 Bytes
  • Size of remote file: 376 kB