SkillForge45 commited on
Commit
813ea14
·
verified ·
1 Parent(s): bb09e9c

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +24 -0
train.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2
+ model = UNet3D().to(device)
3
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
4
+
5
+ for epoch in range(250):
6
+ for batch in tqdm(dataloader):
7
+ video = batch["video"].to(device)
8
+ text = batch["text"].to(device)
9
+
10
+
11
+ t = torch.randint(0, 1000, (video.shape[0], 1)).to(device)
12
+ noise = torch.randn_like(video)
13
+ alpha_t = (1 - t/1000).view(-1, 1, 1, 1, 1)
14
+ noisy_video = torch.sqrt(alpha_t) * video + torch.sqrt(1 - alpha_t) * noise
15
+
16
+
17
+ pred_noise = model(noisy_video, t/1000, text)
18
+ loss = F.mse_loss(pred_noise, noise)
19
+
20
+ optimizer.zero_grad()
21
+ loss.backward()
22
+ optimizer.step()
23
+
24
+ print(f"Epoch {epoch}, Loss: {loss.item():.4f}")