XXXXRT666 commited on
Commit
190f22a
·
1 Parent(s): 46a0eb0
Files changed (1) hide show
  1. AR/models/structs.py +2 -1
AR/models/structs.py CHANGED
@@ -78,6 +78,7 @@ class T2SSession:
78
  pos = int(self.x_lens[bs].item())
79
  mask = torch.zeros(pos + y_len, pos + y_len).bool()
80
  mask[:, :pos].fill_(True)
81
- mask[-y_len:, -y_len:] = ~torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1)
 
82
  attn_mask.append(mask)
83
  self.attn_mask_nested = torch.nested.nested_tensor(attn_mask)
 
78
  pos = int(self.x_lens[bs].item())
79
  mask = torch.zeros(pos + y_len, pos + y_len).bool()
80
  mask[:, :pos].fill_(True)
81
+ if y_len > 0:
82
+ mask[-y_len:, -y_len:] = ~torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1)
83
  attn_mask.append(mask)
84
  self.attn_mask_nested = torch.nested.nested_tensor(attn_mask)