Update modeling_recllama.py
Browse files- modeling_recllama.py +5 -2
modeling_recllama.py
CHANGED
@@ -141,11 +141,12 @@ class RecDynamicCache(DynamicCache):
|
|
141 |
|
142 |
class RecLlamaForCausalLM(LlamaForCausalLM):
|
143 |
config_class = RecLlamaConfig
|
144 |
-
def __init__(self, config: RecLlamaConfig):
|
145 |
super().__init__(config)
|
146 |
self.prelude_layers = config.prelude_layers
|
147 |
self.recurrent_layers = config.recurrent_layers
|
148 |
self.coda_layers = config.coda_layers
|
|
|
149 |
|
150 |
for i in range(len(self.model.layers)):
|
151 |
self.model.layers[i].self_attn.k_proj.bias = nn.Parameter(torch.randn(1, self.model.layers[i].self_attn.k_proj.out_features)) #nn.Parameter(torch.full((1, self.model.layers[i].self_attn.k_proj.out_features), k_bias_value))
|
@@ -374,10 +375,12 @@ class RecLlamaForCausalLM(LlamaForCausalLM):
|
|
374 |
position_embeddings,
|
375 |
num_steps=None,
|
376 |
):
|
377 |
-
if num_steps is None:
|
378 |
num_steps_no_grad, num_steps_with_grad = self.randomized_iteration_sampler() # type: ignore
|
379 |
elif hasattr(num_steps, "__len__") and len(num_steps) > 1:
|
380 |
num_steps_no_grad, num_steps_with_grad = num_steps
|
|
|
|
|
381 |
else:
|
382 |
num_steps_no_grad, num_steps_with_grad = num_steps, torch.tensor(0)
|
383 |
|
|
|
141 |
|
142 |
class RecLlamaForCausalLM(LlamaForCausalLM):
|
143 |
config_class = RecLlamaConfig
|
144 |
+
def __init__(self, config: RecLlamaConfig, num_steps=None):
|
145 |
super().__init__(config)
|
146 |
self.prelude_layers = config.prelude_layers
|
147 |
self.recurrent_layers = config.recurrent_layers
|
148 |
self.coda_layers = config.coda_layers
|
149 |
+
self.num_steps = num_steps
|
150 |
|
151 |
for i in range(len(self.model.layers)):
|
152 |
self.model.layers[i].self_attn.k_proj.bias = nn.Parameter(torch.randn(1, self.model.layers[i].self_attn.k_proj.out_features)) #nn.Parameter(torch.full((1, self.model.layers[i].self_attn.k_proj.out_features), k_bias_value))
|
|
|
375 |
position_embeddings,
|
376 |
num_steps=None,
|
377 |
):
|
378 |
+
if num_steps is None and self.num_steps is None:
|
379 |
num_steps_no_grad, num_steps_with_grad = self.randomized_iteration_sampler() # type: ignore
|
380 |
elif hasattr(num_steps, "__len__") and len(num_steps) > 1:
|
381 |
num_steps_no_grad, num_steps_with_grad = num_steps
|
382 |
+
elif self.num_steps is not None:
|
383 |
+
num_steps_no_grad, num_steps_with_grad = self.num_steps, self.num_steps
|
384 |
else:
|
385 |
num_steps_no_grad, num_steps_with_grad = num_steps, torch.tensor(0)
|
386 |
|