Spaces:
Running
on
Zero
Running
on
Zero
XXXXRT666
commited on
Commit
·
190f22a
1
Parent(s):
46a0eb0
Fix Mask
Browse files- AR/models/structs.py +2 -1
AR/models/structs.py
CHANGED
@@ -78,6 +78,7 @@ class T2SSession:
|
|
78 |
pos = int(self.x_lens[bs].item())
|
79 |
mask = torch.zeros(pos + y_len, pos + y_len).bool()
|
80 |
mask[:, :pos].fill_(True)
|
81 |
-
|
|
|
82 |
attn_mask.append(mask)
|
83 |
self.attn_mask_nested = torch.nested.nested_tensor(attn_mask)
|
|
|
78 |
pos = int(self.x_lens[bs].item())
|
79 |
mask = torch.zeros(pos + y_len, pos + y_len).bool()
|
80 |
mask[:, :pos].fill_(True)
|
81 |
+
if y_len > 0:
|
82 |
+
mask[-y_len:, -y_len:] = ~torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1)
|
83 |
attn_mask.append(mask)
|
84 |
self.attn_mask_nested = torch.nested.nested_tensor(attn_mask)
|