CoTyle / piFlow /demo /train_piflow_qwen_toymodel.py
root
update
e5a560a
# Copyright (c) 2025 Hansheng Chen
import os
import argparse
import numpy as np
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from lakonlab.models.diffusions.piflow_policies import GMFlowPolicy
from lakonlab.models.architecture import (
QwenImageTransformer2DModel, PretrainedVAEQwenImage, PretrainedQwenImageTextEncoder)
from lakonlab.models import GaussianFlow
EPS = 1e-4
TOTAL_SUBSTEPS = 128
PRINT_EVERY = 10
SAVE_EVERY = 500
def parse_args():
parser = argparse.ArgumentParser(
description='A minimal 1-NFE pi-Flow imitation distillation trainer that overfits the teacher (Qwen-Image) '
'behavior on a fixed initial noise using a static GMFlow policy.')
parser.add_argument(
'--prompt',
type=str,
default='Photo of a coffee shop entrance featuring a chalkboard sign reading "π-Qwen Coffee 😊 $2 per cup," with a neon '
'light beside it displaying "π-通义千问". Next to it hangs a poster showing a beautiful Chinese woman, '
'and beneath the poster is written "e≈2.71828-18284-59045-23536-02874-71352".',
help='text prompt')
parser.add_argument(
'--cfg', type=float, default=4.0, help='teacher classifier-free guidance scale')
parser.add_argument(
'--seed', type=int, default='42', help='random seed')
parser.add_argument(
'-k', type=int, default=32, help='number of Gasussian components')
parser.add_argument(
'--num-iters', type=int, default=5000, help='number of iterations')
parser.add_argument(
'--lr', type=float, default=5e-3, help='learning rate')
parser.add_argument(
'--out', type=str, default='viz/piflow_qwen_toymodel/output.png', help='output file path')
parser.add_argument(
'--h', type=int, default=768, help='image height')
parser.add_argument(
'--w', type=int, default=1360, help='image width')
parser.add_argument(
'--num-intermediates', type=int, default=2, help='number of intermediate samples')
args = parser.parse_args()
return args
class StaticGMM(nn.Module):
"""A toy model that outputs a static GM, ignoring the input x_t_src and t_src. In practice, a real model should
take x_t and t as input and output a dynamic GM that varies with x_t_src and t_src.
"""
def __init__(self, init_u, num_gaussians=8):
super().__init__()
self.latent_size = init_u.shape[1:]
self.num_gaussians = num_gaussians
self.means = nn.Parameter(
init_u.repeat(1, num_gaussians, 1, 1, 1)
+ torch.randn(1, num_gaussians, *self.latent_size, device=init_u.device) * 0.5)
self.logstds = nn.Parameter(torch.full((1, 1, 1, 1, 1), fill_value=np.log(0.05)))
self.logweight_logits = nn.Parameter(torch.zeros(1, num_gaussians, 1, *self.latent_size[1:]))
def forward(self, x_t_src, t_src):
assert (t_src == 1).all(), 'This toy model only supports 1-NFE sampling, thus t_src == 1.'
assert x_t_src.size(0) == 1, 'This toy model only supports batch size 1.'
assert x_t_src.shape[1:] == self.latent_size, \
f'Expected input shape (1, {self.latent_size}), got {x_t_src.shape}.'
# this toy model assumes the input is fixed, so we ignore x_t_src and t_src and return the static GM
return dict(
means=self.means,
logstds=self.logstds,
logweights=self.logweight_logits.log_softmax(dim=1)
)
def policy_rollout(
x_t_start: torch.Tensor, # (B, C, *, H, W)
raw_t_start: torch.Tensor, # (B, )
raw_t_end: torch.Tensor, # (B, )
policy,
warp_t_fun):
ndim = x_t_start.dim()
raw_t_start = raw_t_start.reshape(*(ndim * [1]))
raw_t_end = raw_t_end.reshape(*(ndim * [1]))
delta_raw_t = raw_t_start - raw_t_end
num_substeps = (delta_raw_t * TOTAL_SUBSTEPS).round().to(torch.long).clamp(min=1)
substep_size = delta_raw_t / num_substeps
raw_t = raw_t_start
sigma_t = warp_t_fun(raw_t)
x_t = x_t_start
for substep_id in range(num_substeps.item()):
u = policy.pi(x_t, sigma_t)
raw_t_minus = (raw_t - substep_size).clamp(min=0)
sigma_t_minus = warp_t_fun(raw_t_minus)
x_t_minus = x_t + u * (sigma_t_minus - sigma_t)
x_t = x_t_minus
sigma_t = sigma_t_minus
raw_t = raw_t_minus
x_t_end = x_t
sigma_t_end = sigma_t
return x_t_end, sigma_t_end.flatten()
def main():
args = parse_args()
prompt = args.prompt
num_gaussians = args.k
num_iters = args.num_iters
lr = args.lr
out_path = args.out
guidance_scale = args.cfg
num_intermediates = args.num_intermediates
os.makedirs(os.path.dirname(out_path), exist_ok=True)
out_path_noext, out_ext = os.path.splitext(out_path)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float16'
text_encoder = PretrainedQwenImageTextEncoder(
from_pretrained='Qwen/Qwen-Image',
torch_dtype=dtype,
max_sequence_length=512,
pad_seq_len=512,
).to(device)
prompt_embed_kwargs = text_encoder(prompt)
if guidance_scale > 1.0:
empty_prompt_embed_kwargs = text_encoder('')
for k in prompt_embed_kwargs:
prompt_embed_kwargs[k] = torch.cat([
empty_prompt_embed_kwargs[k],
prompt_embed_kwargs[k]], dim=0)
del text_encoder
torch.cuda.empty_cache()
vae = PretrainedVAEQwenImage(
from_pretrained='Qwen/Qwen-Image',
subfolder='vae',
torch_dtype=dtype).to(device)
vae_scale_factor = 8
vae_latent_size = (16, args.h // vae_scale_factor, args.w // vae_scale_factor)
teacher = GaussianFlow(
denoising=QwenImageTransformer2DModel(
patch_size=2,
freeze=True,
pretrained='huggingface://Qwen/Qwen-Image/transformer/diffusion_pytorch_model.safetensors.index.json',
in_channels=64,
out_channels=64,
num_layers=60,
attention_head_dim=128,
num_attention_heads=24,
joint_attention_dim=3584,
axes_dims_rope=(16, 56, 56),
torch_dtype=dtype),
num_timesteps=1,
denoising_mean_mode='U',
timestep_sampler=dict(
type='ContinuousTimeStepSampler',
shift=3.2,
logit_normal_enable=False)).eval().to(device)
# get initial noise
torch.manual_seed(args.seed)
x_t_src = torch.randn((1, *vae_latent_size), device=device)
t_src = torch.ones(1, device=device)
# initialize student using the u of teacher
u = teacher.forward(
return_u=True, x_t=x_t_src, t=t_src, guidance_scale=guidance_scale, **prompt_embed_kwargs)
student = StaticGMM(
init_u=u,
num_gaussians=num_gaussians).to(device)
# start training
optimizer = torch.optim.Adam(student.parameters(), lr=lr)
loss_list = []
for i in range(1, num_iters + 1):
optimizer.zero_grad()
denoising_output = student(x_t_src, t_src)
policy = GMFlowPolicy(denoising_output, x_t_src, t_src)
detached_policy = policy.detach()
loss = 0
intermediate_t_samples = torch.rand(num_intermediates, device=device).clamp(min=EPS)
for raw_t in intermediate_t_samples:
x_t, t = policy_rollout(
x_t_start=x_t_src,
raw_t_start=t_src,
raw_t_end=raw_t,
policy=detached_policy,
warp_t_fun=teacher.timestep_sampler.warp_t)
pred_u = policy.pi(x_t, t)
teacher_u = teacher.forward(
return_u=True, x_t=x_t, t=t, guidance_scale=guidance_scale, **prompt_embed_kwargs)
loss += F.mse_loss(pred_u, teacher_u) / num_intermediates
loss.backward()
optimizer.step()
loss_list.append(loss.item())
if i % PRINT_EVERY == 0 or i == num_iters:
print(f'Iter {i:04d}/{num_iters:04d}, loss: {np.mean(loss_list):.6f}')
loss_list = []
if i % SAVE_EVERY == 0 or i == num_iters:
with torch.no_grad():
x_0, _ = policy_rollout(
x_t_start=x_t_src,
raw_t_start=t_src,
raw_t_end=torch.zeros(1, device=device),
policy=policy,
warp_t_fun=teacher.timestep_sampler.warp_t)
image = ((vae.decode(x_0.to(getattr(torch, dtype))) / 2 + 0.5).clamp(0, 1) * 255).round().to(
dtype=torch.uint8, device='cpu').squeeze(0).permute(1, 2, 0).numpy()
Image.fromarray(image).save(f'{out_path_noext}.iter{i:04d}{out_ext}')
print(f'Image saved to {out_path_noext}.iter{i:04d}{out_ext}')
if __name__ == '__main__':
main()