amosyou commited on
Commit
999d005
·
1 Parent(s): 929c70d

fix: masked fill and imports

Browse files
src/__init__.py ADDED
File without changes
src/modules/__init__.py ADDED
File without changes
src/modules/multihead_attention.py CHANGED
@@ -124,7 +124,7 @@ class MultiheadAttention(nn.Module):
124
  # don't attend to padding symbols
125
  attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
126
  attn_weights = attn_weights.float().masked_fill(
127
- key_padding_mask.unsqueeze(1).unsqueeze(2),
128
  float('-inf'),
129
  ).type_as(attn_weights) # FP16 support: cast to float and back
130
  attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
 
124
  # don't attend to padding symbols
125
  attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
126
  attn_weights = attn_weights.float().masked_fill(
127
+ key_padding_mask.unsqueeze(1).unsqueeze(2).bool(),
128
  float('-inf'),
129
  ).type_as(attn_weights) # FP16 support: cast to float and back
130
  attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
src/utils/__init__.py ADDED
File without changes