Spaces:
Runtime error
Runtime error
from copy import deepcopy | |
import pytest | |
import torch | |
import torch.nn as nn | |
from diffusers import FlowMatchEulerDiscreteScheduler | |
from lbm.models.embedders import ConditionerWrapper | |
from lbm.models.lbm import LBMConfig, LBMModel | |
from lbm.models.unets import DiffusersUNet2DCondWrapper | |
from lbm.models.vae import AutoencoderKLDiffusers, AutoencoderKLDiffusersConfig | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
class TestLBM: | |
def denoiser(self): | |
return DiffusersUNet2DCondWrapper( | |
in_channels=4, # VAE channels | |
out_channels=4, # VAE channels | |
up_block_types=["CrossAttnUpBlock2D"], | |
down_block_types=[ | |
"CrossAttnDownBlock2D", | |
], | |
cross_attention_dim=[320], | |
block_out_channels=[320], | |
transformer_layers_per_block=[1], | |
attention_head_dim=[5], | |
norm_num_groups=32, | |
) | |
def conditioner(self): | |
return ConditionerWrapper([]) | |
def vae(self): | |
return AutoencoderKLDiffusers(AutoencoderKLDiffusersConfig()) | |
def sampling_noise_scheduler(self): | |
return FlowMatchEulerDiscreteScheduler() | |
def training_noise_scheduler(self): | |
return FlowMatchEulerDiscreteScheduler() | |
def model_config(self): | |
return LBMConfig( | |
source_key="source_image", | |
target_key="target_image", | |
) | |
def model_input(self): | |
return { | |
"source_image": torch.randn(2, 3, 256, 256).to(DEVICE), | |
"target_image": torch.randn(2, 3, 256, 256).to(DEVICE), | |
} | |
def model( | |
self, | |
model_config, | |
denoiser, | |
vae, | |
sampling_noise_scheduler, | |
training_noise_scheduler, | |
conditioner, | |
): | |
return LBMModel( | |
config=model_config, | |
denoiser=denoiser, | |
vae=vae, | |
sampling_noise_scheduler=sampling_noise_scheduler, | |
training_noise_scheduler=training_noise_scheduler, | |
conditioner=conditioner, | |
).to(DEVICE) | |
def test_model_forward(self, model, model_input): | |
model_output = model( | |
model_input, | |
) | |
assert model_output["loss"] > 0.0 | |
def test_optimizers(self, model, model_input): | |
optimizer = torch.optim.Adam(model.denoiser.parameters(), lr=1e-4) | |
model.train() | |
model_init = deepcopy(model) | |
optimizer.zero_grad() | |
loss = model(model_input)["loss"] | |
loss.backward() | |
optimizer.step() | |
assert not torch.equal( | |
torch.cat([p.flatten() for p in model.denoiser.parameters()]), | |
torch.cat([p.flatten() for p in model_init.denoiser.parameters()]), | |
) | |