Arthur-LAGACHERIE commited on
Commit
9dea54a
·
verified ·
1 Parent(s): f550107

Update modeling_recllama.py

Browse files
Files changed (1) hide show
  1. modeling_recllama.py +5 -2
modeling_recllama.py CHANGED
@@ -141,11 +141,12 @@ class RecDynamicCache(DynamicCache):
141
 
142
  class RecLlamaForCausalLM(LlamaForCausalLM):
143
  config_class = RecLlamaConfig
144
- def __init__(self, config: RecLlamaConfig):
145
  super().__init__(config)
146
  self.prelude_layers = config.prelude_layers
147
  self.recurrent_layers = config.recurrent_layers
148
  self.coda_layers = config.coda_layers
 
149
 
150
  for i in range(len(self.model.layers)):
151
  self.model.layers[i].self_attn.k_proj.bias = nn.Parameter(torch.randn(1, self.model.layers[i].self_attn.k_proj.out_features)) #nn.Parameter(torch.full((1, self.model.layers[i].self_attn.k_proj.out_features), k_bias_value))
@@ -374,10 +375,12 @@ class RecLlamaForCausalLM(LlamaForCausalLM):
374
  position_embeddings,
375
  num_steps=None,
376
  ):
377
- if num_steps is None:
378
  num_steps_no_grad, num_steps_with_grad = self.randomized_iteration_sampler() # type: ignore
379
  elif hasattr(num_steps, "__len__") and len(num_steps) > 1:
380
  num_steps_no_grad, num_steps_with_grad = num_steps
 
 
381
  else:
382
  num_steps_no_grad, num_steps_with_grad = num_steps, torch.tensor(0)
383
 
 
141
 
142
  class RecLlamaForCausalLM(LlamaForCausalLM):
143
  config_class = RecLlamaConfig
144
+ def __init__(self, config: RecLlamaConfig, num_steps=None):
145
  super().__init__(config)
146
  self.prelude_layers = config.prelude_layers
147
  self.recurrent_layers = config.recurrent_layers
148
  self.coda_layers = config.coda_layers
149
+ self.num_steps = num_steps
150
 
151
  for i in range(len(self.model.layers)):
152
  self.model.layers[i].self_attn.k_proj.bias = nn.Parameter(torch.randn(1, self.model.layers[i].self_attn.k_proj.out_features)) #nn.Parameter(torch.full((1, self.model.layers[i].self_attn.k_proj.out_features), k_bias_value))
 
375
  position_embeddings,
376
  num_steps=None,
377
  ):
378
+ if num_steps is None and self.num_steps is None:
379
  num_steps_no_grad, num_steps_with_grad = self.randomized_iteration_sampler() # type: ignore
380
  elif hasattr(num_steps, "__len__") and len(num_steps) > 1:
381
  num_steps_no_grad, num_steps_with_grad = num_steps
382
+ elif self.num_steps is not None:
383
+ num_steps_no_grad, num_steps_with_grad = self.num_steps, self.num_steps
384
  else:
385
  num_steps_no_grad, num_steps_with_grad = num_steps, torch.tensor(0)
386