fix: masked fill and imports
Browse files- src/__init__.py +0 -0
- src/modules/__init__.py +0 -0
- src/modules/multihead_attention.py +1 -1
- src/utils/__init__.py +0 -0
src/__init__.py
ADDED
File without changes
|
src/modules/__init__.py
ADDED
File without changes
|
src/modules/multihead_attention.py
CHANGED
@@ -124,7 +124,7 @@ class MultiheadAttention(nn.Module):
|
|
124 |
# don't attend to padding symbols
|
125 |
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
126 |
attn_weights = attn_weights.float().masked_fill(
|
127 |
-
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
128 |
float('-inf'),
|
129 |
).type_as(attn_weights) # FP16 support: cast to float and back
|
130 |
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
|
|
124 |
# don't attend to padding symbols
|
125 |
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
126 |
attn_weights = attn_weights.float().masked_fill(
|
127 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).bool(),
|
128 |
float('-inf'),
|
129 |
).type_as(attn_weights) # FP16 support: cast to float and back
|
130 |
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
src/utils/__init__.py
ADDED
File without changes
|