# Copyright (c) 2025 Hansheng Chen import os import argparse import numpy as np import torch torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True import torch.nn as nn import torch.nn.functional as F from PIL import Image from lakonlab.models.diffusions.piflow_policies import GMFlowPolicy from lakonlab.models.architecture import ( QwenImageTransformer2DModel, PretrainedVAEQwenImage, PretrainedQwenImageTextEncoder) from lakonlab.models import GaussianFlow EPS = 1e-4 TOTAL_SUBSTEPS = 128 PRINT_EVERY = 10 SAVE_EVERY = 500 def parse_args(): parser = argparse.ArgumentParser( description='A minimal 1-NFE pi-Flow imitation distillation trainer that overfits the teacher (Qwen-Image) ' 'behavior on a fixed initial noise using a static GMFlow policy.') parser.add_argument( '--prompt', type=str, default='Photo of a coffee shop entrance featuring a chalkboard sign reading "Ļ€-Qwen Coffee 😊 $2 per cup," with a neon ' 'light beside it displaying "Ļ€-通义千问". Next to it hangs a poster showing a beautiful Chinese woman, ' 'and beneath the poster is written "eā‰ˆ2.71828-18284-59045-23536-02874-71352".', help='text prompt') parser.add_argument( '--cfg', type=float, default=4.0, help='teacher classifier-free guidance scale') parser.add_argument( '--seed', type=int, default='42', help='random seed') parser.add_argument( '-k', type=int, default=32, help='number of Gasussian components') parser.add_argument( '--num-iters', type=int, default=5000, help='number of iterations') parser.add_argument( '--lr', type=float, default=5e-3, help='learning rate') parser.add_argument( '--out', type=str, default='viz/piflow_qwen_toymodel/output.png', help='output file path') parser.add_argument( '--h', type=int, default=768, help='image height') parser.add_argument( '--w', type=int, default=1360, help='image width') parser.add_argument( '--num-intermediates', type=int, default=2, help='number of intermediate samples') args = parser.parse_args() return args class StaticGMM(nn.Module): """A toy model that outputs a static GM, ignoring the input x_t_src and t_src. In practice, a real model should take x_t and t as input and output a dynamic GM that varies with x_t_src and t_src. """ def __init__(self, init_u, num_gaussians=8): super().__init__() self.latent_size = init_u.shape[1:] self.num_gaussians = num_gaussians self.means = nn.Parameter( init_u.repeat(1, num_gaussians, 1, 1, 1) + torch.randn(1, num_gaussians, *self.latent_size, device=init_u.device) * 0.5) self.logstds = nn.Parameter(torch.full((1, 1, 1, 1, 1), fill_value=np.log(0.05))) self.logweight_logits = nn.Parameter(torch.zeros(1, num_gaussians, 1, *self.latent_size[1:])) def forward(self, x_t_src, t_src): assert (t_src == 1).all(), 'This toy model only supports 1-NFE sampling, thus t_src == 1.' assert x_t_src.size(0) == 1, 'This toy model only supports batch size 1.' assert x_t_src.shape[1:] == self.latent_size, \ f'Expected input shape (1, {self.latent_size}), got {x_t_src.shape}.' # this toy model assumes the input is fixed, so we ignore x_t_src and t_src and return the static GM return dict( means=self.means, logstds=self.logstds, logweights=self.logweight_logits.log_softmax(dim=1) ) def policy_rollout( x_t_start: torch.Tensor, # (B, C, *, H, W) raw_t_start: torch.Tensor, # (B, ) raw_t_end: torch.Tensor, # (B, ) policy, warp_t_fun): ndim = x_t_start.dim() raw_t_start = raw_t_start.reshape(*(ndim * [1])) raw_t_end = raw_t_end.reshape(*(ndim * [1])) delta_raw_t = raw_t_start - raw_t_end num_substeps = (delta_raw_t * TOTAL_SUBSTEPS).round().to(torch.long).clamp(min=1) substep_size = delta_raw_t / num_substeps raw_t = raw_t_start sigma_t = warp_t_fun(raw_t) x_t = x_t_start for substep_id in range(num_substeps.item()): u = policy.pi(x_t, sigma_t) raw_t_minus = (raw_t - substep_size).clamp(min=0) sigma_t_minus = warp_t_fun(raw_t_minus) x_t_minus = x_t + u * (sigma_t_minus - sigma_t) x_t = x_t_minus sigma_t = sigma_t_minus raw_t = raw_t_minus x_t_end = x_t sigma_t_end = sigma_t return x_t_end, sigma_t_end.flatten() def main(): args = parse_args() prompt = args.prompt num_gaussians = args.k num_iters = args.num_iters lr = args.lr out_path = args.out guidance_scale = args.cfg num_intermediates = args.num_intermediates os.makedirs(os.path.dirname(out_path), exist_ok=True) out_path_noext, out_ext = os.path.splitext(out_path) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float16' text_encoder = PretrainedQwenImageTextEncoder( from_pretrained='Qwen/Qwen-Image', torch_dtype=dtype, max_sequence_length=512, pad_seq_len=512, ).to(device) prompt_embed_kwargs = text_encoder(prompt) if guidance_scale > 1.0: empty_prompt_embed_kwargs = text_encoder('') for k in prompt_embed_kwargs: prompt_embed_kwargs[k] = torch.cat([ empty_prompt_embed_kwargs[k], prompt_embed_kwargs[k]], dim=0) del text_encoder torch.cuda.empty_cache() vae = PretrainedVAEQwenImage( from_pretrained='Qwen/Qwen-Image', subfolder='vae', torch_dtype=dtype).to(device) vae_scale_factor = 8 vae_latent_size = (16, args.h // vae_scale_factor, args.w // vae_scale_factor) teacher = GaussianFlow( denoising=QwenImageTransformer2DModel( patch_size=2, freeze=True, pretrained='huggingface://Qwen/Qwen-Image/transformer/diffusion_pytorch_model.safetensors.index.json', in_channels=64, out_channels=64, num_layers=60, attention_head_dim=128, num_attention_heads=24, joint_attention_dim=3584, axes_dims_rope=(16, 56, 56), torch_dtype=dtype), num_timesteps=1, denoising_mean_mode='U', timestep_sampler=dict( type='ContinuousTimeStepSampler', shift=3.2, logit_normal_enable=False)).eval().to(device) # get initial noise torch.manual_seed(args.seed) x_t_src = torch.randn((1, *vae_latent_size), device=device) t_src = torch.ones(1, device=device) # initialize student using the u of teacher u = teacher.forward( return_u=True, x_t=x_t_src, t=t_src, guidance_scale=guidance_scale, **prompt_embed_kwargs) student = StaticGMM( init_u=u, num_gaussians=num_gaussians).to(device) # start training optimizer = torch.optim.Adam(student.parameters(), lr=lr) loss_list = [] for i in range(1, num_iters + 1): optimizer.zero_grad() denoising_output = student(x_t_src, t_src) policy = GMFlowPolicy(denoising_output, x_t_src, t_src) detached_policy = policy.detach() loss = 0 intermediate_t_samples = torch.rand(num_intermediates, device=device).clamp(min=EPS) for raw_t in intermediate_t_samples: x_t, t = policy_rollout( x_t_start=x_t_src, raw_t_start=t_src, raw_t_end=raw_t, policy=detached_policy, warp_t_fun=teacher.timestep_sampler.warp_t) pred_u = policy.pi(x_t, t) teacher_u = teacher.forward( return_u=True, x_t=x_t, t=t, guidance_scale=guidance_scale, **prompt_embed_kwargs) loss += F.mse_loss(pred_u, teacher_u) / num_intermediates loss.backward() optimizer.step() loss_list.append(loss.item()) if i % PRINT_EVERY == 0 or i == num_iters: print(f'Iter {i:04d}/{num_iters:04d}, loss: {np.mean(loss_list):.6f}') loss_list = [] if i % SAVE_EVERY == 0 or i == num_iters: with torch.no_grad(): x_0, _ = policy_rollout( x_t_start=x_t_src, raw_t_start=t_src, raw_t_end=torch.zeros(1, device=device), policy=policy, warp_t_fun=teacher.timestep_sampler.warp_t) image = ((vae.decode(x_0.to(getattr(torch, dtype))) / 2 + 0.5).clamp(0, 1) * 255).round().to( dtype=torch.uint8, device='cpu').squeeze(0).permute(1, 2, 0).numpy() Image.fromarray(image).save(f'{out_path_noext}.iter{i:04d}{out_ext}') print(f'Image saved to {out_path_noext}.iter{i:04d}{out_ext}') if __name__ == '__main__': main()