maomao517's picture
Upload folder using huggingface_hub
d10d213 verified
from copy import deepcopy
import pytest
import torch
import torch.nn as nn
from diffusers import FlowMatchEulerDiscreteScheduler
from lbm.models.embedders import ConditionerWrapper
from lbm.models.lbm import LBMConfig, LBMModel
from lbm.models.unets import DiffusersUNet2DCondWrapper
from lbm.models.vae import AutoencoderKLDiffusers, AutoencoderKLDiffusersConfig
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
class TestLBM:
@pytest.fixture()
def denoiser(self):
return DiffusersUNet2DCondWrapper(
in_channels=4, # VAE channels
out_channels=4, # VAE channels
up_block_types=["CrossAttnUpBlock2D"],
down_block_types=[
"CrossAttnDownBlock2D",
],
cross_attention_dim=[320],
block_out_channels=[320],
transformer_layers_per_block=[1],
attention_head_dim=[5],
norm_num_groups=32,
)
@pytest.fixture()
def conditioner(self):
return ConditionerWrapper([])
@pytest.fixture()
def vae(self):
return AutoencoderKLDiffusers(AutoencoderKLDiffusersConfig())
@pytest.fixture()
def sampling_noise_scheduler(self):
return FlowMatchEulerDiscreteScheduler()
@pytest.fixture()
def training_noise_scheduler(self):
return FlowMatchEulerDiscreteScheduler()
@pytest.fixture()
def model_config(self):
return LBMConfig(
source_key="source_image",
target_key="target_image",
)
@pytest.fixture()
def model_input(self):
return {
"source_image": torch.randn(2, 3, 256, 256).to(DEVICE),
"target_image": torch.randn(2, 3, 256, 256).to(DEVICE),
}
@pytest.fixture()
def model(
self,
model_config,
denoiser,
vae,
sampling_noise_scheduler,
training_noise_scheduler,
conditioner,
):
return LBMModel(
config=model_config,
denoiser=denoiser,
vae=vae,
sampling_noise_scheduler=sampling_noise_scheduler,
training_noise_scheduler=training_noise_scheduler,
conditioner=conditioner,
).to(DEVICE)
@torch.no_grad()
def test_model_forward(self, model, model_input):
model_output = model(
model_input,
)
assert model_output["loss"] > 0.0
def test_optimizers(self, model, model_input):
optimizer = torch.optim.Adam(model.denoiser.parameters(), lr=1e-4)
model.train()
model_init = deepcopy(model)
optimizer.zero_grad()
loss = model(model_input)["loss"]
loss.backward()
optimizer.step()
assert not torch.equal(
torch.cat([p.flatten() for p in model.denoiser.parameters()]),
torch.cat([p.flatten() for p in model_init.denoiser.parameters()]),
)