Fixed early-stopping in get_mamba_response based on space/dot tokens (now decodes the strings instead of using hardcoded token ids).
Browse files
chess-gpt-eval/mamba_module.py
CHANGED
|
@@ -81,6 +81,8 @@ class MambaPlayer:
|
|
| 81 |
self.vocab_size = vocab_size
|
| 82 |
self.encode = encode
|
| 83 |
self.decode = decode
|
|
|
|
|
|
|
| 84 |
self.model = model
|
| 85 |
self.ctx = ctx
|
| 86 |
self.device = device
|
|
@@ -107,8 +109,9 @@ class MambaPlayer:
|
|
| 107 |
|
| 108 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 109 |
next_token_id = torch.multinomial(probs, num_samples=1)
|
| 110 |
-
if
|
| 111 |
-
|
|
|
|
| 112 |
else:
|
| 113 |
have_non_space = True
|
| 114 |
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
|
|
|
|
| 81 |
self.vocab_size = vocab_size
|
| 82 |
self.encode = encode
|
| 83 |
self.decode = decode
|
| 84 |
+
self.space_tok = encode(' ')[0]
|
| 85 |
+
self.dot_tok = encode('.')[0]
|
| 86 |
self.model = model
|
| 87 |
self.ctx = ctx
|
| 88 |
self.device = device
|
|
|
|
| 109 |
|
| 110 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 111 |
next_token_id = torch.multinomial(probs, num_samples=1)
|
| 112 |
+
if next_token_id == self.space_tok or next_token_id==self.dot_tok:
|
| 113 |
+
if have_non_space:
|
| 114 |
+
break
|
| 115 |
else:
|
| 116 |
have_non_space = True
|
| 117 |
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
|