File size: 818 Bytes
813ea14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet3D().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(250):
    for batch in tqdm(dataloader):
        video = batch["video"].to(device)
        text = batch["text"].to(device)
        

        t = torch.randint(0, 1000, (video.shape[0], 1)).to(device)
        noise = torch.randn_like(video)
        alpha_t = (1 - t/1000).view(-1, 1, 1, 1, 1)
        noisy_video = torch.sqrt(alpha_t) * video + torch.sqrt(1 - alpha_t) * noise
        
        
        pred_noise = model(noisy_video, t/1000, text)
        loss = F.mse_loss(pred_noise, noise)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")