Update modeling_deberta.py
Browse files- modeling_deberta.py +18 -14
modeling_deberta.py
CHANGED
|
@@ -1251,20 +1251,24 @@ class DebertaV2ForCausalLM(DebertaV2ForMaskedLM):
|
|
| 1251 |
],
|
| 1252 |
dim=-1
|
| 1253 |
)
|
| 1254 |
-
|
| 1255 |
-
|
| 1256 |
-
|
| 1257 |
-
|
| 1258 |
-
|
| 1259 |
-
|
| 1260 |
-
|
| 1261 |
-
|
| 1262 |
-
|
| 1263 |
-
|
| 1264 |
-
|
| 1265 |
-
|
| 1266 |
-
|
| 1267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1268 |
|
| 1269 |
outputs = super().forward(
|
| 1270 |
input_ids,
|
|
|
|
| 1251 |
],
|
| 1252 |
dim=-1
|
| 1253 |
)
|
| 1254 |
+
|
| 1255 |
+
if attention_mask is not None:
|
| 1256 |
+
attention_mask = torch.cat(
|
| 1257 |
+
[
|
| 1258 |
+
attention_mask,
|
| 1259 |
+
torch.full((batch_size, self.n_masks + 1), attention_mask[0, -1], device=attention_mask.device),
|
| 1260 |
+
],
|
| 1261 |
+
dim=-1
|
| 1262 |
+
)
|
| 1263 |
+
|
| 1264 |
+
if position_ids is not None:
|
| 1265 |
+
position_ids = torch.cat(
|
| 1266 |
+
[
|
| 1267 |
+
position_ids,
|
| 1268 |
+
torch.arange(0, self.n_masks + 1, device=position_ids.device).unsqueeze(0) + position_ids[:, -1:],
|
| 1269 |
+
],
|
| 1270 |
+
dim=-1
|
| 1271 |
+
)
|
| 1272 |
|
| 1273 |
outputs = super().forward(
|
| 1274 |
input_ids,
|