Maxlegrec commited on
Commit
4cbf8b0
·
verified ·
1 Parent(s): c245e4d

Update model architecture: d_ff=1024, new weights from merged7.pt

Browse files
Files changed (1) hide show
  1. modeling_chessbot.py +6 -4
modeling_chessbot.py CHANGED
@@ -418,14 +418,16 @@ class AbsolutePositionalEncoder(nn.Module):
418
  def __init__(self, d_model):
419
  super(AbsolutePositionalEncoder, self).__init__()
420
  self.d_model = d_model
421
- self.register_buffer('position', torch.arange(64).unsqueeze(1))
 
422
 
423
  positional_encoding = torch.zeros(1, 64, d_model)
424
  _2i = torch.arange(0, d_model, step=2).float()
425
- positional_encoding[:, :, 0::2] = torch.sin(self.position / (10000 ** (_2i / d_model)))
426
- positional_encoding[:, :, 1::2] = torch.cos(self.position / (10000 ** (_2i / d_model)))
427
 
428
- self.register_buffer('positional_encoding', positional_encoding)
 
429
 
430
  def forward(self, x):
431
  batch_size, _, _ = x.size()
 
418
  def __init__(self, d_model):
419
  super(AbsolutePositionalEncoder, self).__init__()
420
  self.d_model = d_model
421
+ # Don't register as buffers since these are computed values
422
+ position = torch.arange(64).unsqueeze(1).float()
423
 
424
  positional_encoding = torch.zeros(1, 64, d_model)
425
  _2i = torch.arange(0, d_model, step=2).float()
426
+ positional_encoding[:, :, 0::2] = torch.sin(position / (10000 ** (_2i / d_model)))
427
+ positional_encoding[:, :, 1::2] = torch.cos(position / (10000 ** (_2i / d_model)))
428
 
429
+ # Register as non-persistent buffer (won't be saved/loaded)
430
+ self.register_buffer('positional_encoding', positional_encoding, persistent=False)
431
 
432
  def forward(self, x):
433
  batch_size, _, _ = x.size()