sgoel30 commited on
Commit
d510b4a
·
verified ·
1 Parent(s): 51b5f3f

Update src/lm/memdlm/diffusion_module.py

Browse files
Files changed (1) hide show
  1. src/lm/memdlm/diffusion_module.py +0 -1
src/lm/memdlm/diffusion_module.py CHANGED
@@ -78,7 +78,6 @@ class MembraneDiffusion(pl.LightningModule):
78
  u = torch.rand_like(x0, dtype=torch.float)
79
  t1_mask = (u < (t1 / self.config.lm.num_diffusion_timesteps)[:, None]) & maskable_mask
80
  x_t1 = x0.masked_fill(t1_mask, self.mask_id)
81
- x_t1 = x_t1.masked_fill(t1_mask, self.mask_id)
82
  return x_t1, t1_mask
83
 
84
  def get_weight(self, t, weight_type):
 
78
  u = torch.rand_like(x0, dtype=torch.float)
79
  t1_mask = (u < (t1 / self.config.lm.num_diffusion_timesteps)[:, None]) & maskable_mask
80
  x_t1 = x0.masked_fill(t1_mask, self.mask_id)
 
81
  return x_t1, t1_mask
82
 
83
  def get_weight(self, t, weight_type):