feat: implemented positional interpolation
Browse files- modeling_bert.py +2 -1
modeling_bert.py
CHANGED
|
@@ -787,7 +787,8 @@ class JinaBertEncoder(nn.Module):
|
|
| 787 |
# Device catch-up
|
| 788 |
self.alibi = self.alibi.to(hidden_states.device)
|
| 789 |
|
| 790 |
-
|
|
|
|
| 791 |
if self.gradient_checkpointing and self.training:
|
| 792 |
if use_cache:
|
| 793 |
logger.warning_once(
|
|
|
|
| 787 |
# Device catch-up
|
| 788 |
self.alibi = self.alibi.to(hidden_states.device)
|
| 789 |
|
| 790 |
+
unpadded_seqlens = torch.sum(attention_mask, dim=1).unsqueeze(1).unsqueeze(1).unsqueeze(1)
|
| 791 |
+
alibi_bias = self.alibi[:, :, :seqlen, :seqlen] * 512 / unpadded_seqlens
|
| 792 |
if self.gradient_checkpointing and self.training:
|
| 793 |
if use_cache:
|
| 794 |
logger.warning_once(
|